Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ds integrals on prism/pyramids generate kernels for each facet type #739

Merged
merged 31 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
361358d
working on quadrature
mscroggs Dec 18, 2024
24a51de
working on prism facet integrals
mscroggs Dec 18, 2024
c7107d2
progress towards ds on prisms
mscroggs Dec 19, 2024
033475b
store tables by integrand
mscroggs Dec 20, 2024
54ef0c4
correct test
mscroggs Dec 20, 2024
d76950c
update expressions
mscroggs Dec 20, 2024
3356abb
ruff
mscroggs Dec 20, 2024
c8b70a2
mypy
mscroggs Dec 20, 2024
fba972b
mypy
mscroggs Dec 20, 2024
d33f390
ufcx_vertex
mscroggs Dec 20, 2024
2c2a939
ruff
mscroggs Dec 20, 2024
5f502a2
ruff
mscroggs Dec 20, 2024
aebad24
more ruff
mscroggs Dec 20, 2024
50c959b
don't overwrite cell
mscroggs Dec 20, 2024
61bbc6d
Merge branch 'main' into mscroggs/prism-ds
mscroggs Dec 21, 2024
9289781
remove ufcx_cell_type enum
mscroggs Feb 14, 2025
0e6dc3c
ruff
mscroggs Feb 14, 2025
4bbdd14
use basix cell type throughout rather than converting to/from string
mscroggs Feb 14, 2025
2948735
Merge branch 'main' into mscroggs/prism-ds
mscroggs Feb 14, 2025
4cc88cc
mypy
mscroggs Feb 14, 2025
a9533aa
Merge branch 'mscroggs/prism-ds' of github.com:FEniCS/ffcx into mscro…
mscroggs Feb 14, 2025
6306a60
set()
mscroggs Feb 14, 2025
60f4f33
make demo xfail
mscroggs Feb 14, 2025
c0fd0fa
handle "vertex"
mscroggs Feb 14, 2025
6f7dec1
Merge branch 'main' into mscroggs/prism-ds
mscroggs Feb 14, 2025
d9481a6
merge
mscroggs Feb 14, 2025
f1b0a9c
Merge branch 'mscroggs/prism-ds' of github.com:FEniCS/ffcx into mscro…
mscroggs Feb 14, 2025
54b9fd0
typing
mscroggs Feb 14, 2025
9881fb7
remove testing ipython embed
mscroggs Feb 14, 2025
edbe915
Merge branch 'main' into mscroggs/prism-ds
mscroggs Mar 17, 2025
2c5b2af
fix merge
mscroggs Mar 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ffcx/codegeneration/C/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def generator(ir: ExpressionIR, options):
"""Generate UFC code for an expression."""
logger.info("Generating code for expression:")
assert len(ir.expression.integrand) == 1, "Expressions only support single quadrature rule"
points = next(iter(ir.expression.integrand)).points
points = next(iter(ir.expression.integrand))[1].points
logger.info(f"--- points: {points}")
factory_name = ir.expression.name
logger.info(f"--- name: {factory_name}")
Expand Down
27 changes: 21 additions & 6 deletions ffcx/codegeneration/C/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,44 @@ def generator(ir: FormIR, options):
integrals = []
integral_ids = []
integral_offsets = [0]
integral_domains = []
# Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h
for itg_type in ("cell", "exterior_facet", "interior_facet"):
unsorted_integrals = []
unsorted_ids = []
for name, id in zip(ir.integral_names[itg_type], ir.subdomain_ids[itg_type]):
unsorted_domains = []
for name, domains, id in zip(
ir.integral_names[itg_type],
ir.integral_domains[itg_type],
ir.subdomain_ids[itg_type],
):
unsorted_integrals += [f"&{name}"]
unsorted_ids += [id]
unsorted_domains += [domains]

id_sort = np.argsort(unsorted_ids)
integrals += [unsorted_integrals[i] for i in id_sort]
integral_ids += [unsorted_ids[i] for i in id_sort]
integral_domains += [unsorted_domains[i] for i in id_sort]

integral_offsets.append(len(integrals))
integral_offsets.append(sum(len(d) for d in integral_domains))

if len(integrals) > 0:
sizes = len(integrals)
values = ", ".join(integrals)
sizes = sum(len(domains) for domains in integral_domains)
values = ", ".join(
[
f"{i}_{domain.name}"
for i, domains in zip(integrals, integral_domains)
for domain in domains
]
)
d["form_integrals_init"] = (
f"static ufcx_integral* form_integrals_{ir.name}[{sizes}] = {{{values}}};"
)
d["form_integrals"] = f"form_integrals_{ir.name}"
sizes = len(integral_ids)
values = ", ".join(str(i) for i in integral_ids)
values = ", ".join(
f"{i}" for i, domains in zip(integral_ids, integral_domains) for _ in domains
)
d["form_integral_ids_init"] = f"int form_integral_ids_{ir.name}[{sizes}] = {{{values}}};"
d["form_integral_ids"] = f"form_integral_ids_{ir.name}"
else:
Expand Down
13 changes: 7 additions & 6 deletions ffcx/codegeneration/C/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import sys

import basix
import numpy as np

from ffcx.codegeneration.backend import FFCXBackend
Expand All @@ -20,14 +21,13 @@
logger = logging.getLogger("ffcx")


def generator(ir: IntegralIR, options):
def generator(ir: IntegralIR, domain: basix.CellType, options):
"""Generate C code for an integral."""
logger.info("Generating code for integral:")
logger.info(f"--- type: {ir.expression.integral_type}")
logger.info(f"--- name: {ir.expression.name}")

"""Generate code for an integral."""
factory_name = ir.expression.name
factory_name = f"{ir.expression.name}_{domain.name}"

# Format declaration
declaration = ufcx_integrals.declaration.format(factory_name=factory_name)
Expand All @@ -39,7 +39,7 @@ def generator(ir: IntegralIR, options):
ig = IntegralGenerator(ir, backend)

# Generate code ast for the tabulate_tensor body
parts = ig.generate()
parts = ig.generate(domain)

# Format code as string
CF = CFormatter(options["scalar_type"])
Expand All @@ -52,9 +52,9 @@ def generator(ir: IntegralIR, options):
values = ", ".join("1" if i else "0" for i in ir.enabled_coefficients)
sizes = len(ir.enabled_coefficients)
code["enabled_coefficients_init"] = (
f"bool enabled_coefficients_{ir.expression.name}[{sizes}] = {{{values}}};"
f"bool enabled_coefficients_{ir.expression.name}_{domain.name}[{sizes}] = {{{values}}};"
)
code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}"
code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}_{domain.name}"
else:
code["enabled_coefficients_init"] = ""
code["enabled_coefficients"] = "NULL"
Expand Down Expand Up @@ -88,6 +88,7 @@ def generator(ir: IntegralIR, options):
tabulate_tensor_float64=code["tabulate_tensor_float64"],
tabulate_tensor_complex64=code["tabulate_tensor_complex64"],
tabulate_tensor_complex128=code["tabulate_tensor_complex128"],
domain=int(domain),
)

return declaration, implementation
1 change: 1 addition & 0 deletions ffcx/codegeneration/C/integrals_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
{tabulate_tensor_complex128}
.needs_facet_permutations = {needs_facet_permutations},
.coordinate_element_hash = {coordinate_element_hash},
.domain = {domain},
}};

// End of code for integral {factory_name}
Expand Down
9 changes: 8 additions & 1 deletion ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,14 @@ def reference_normal(self, mt, tabledata, access):
def cell_facet_jacobian(self, mt, tabledata, num_points):
"""Access a cell facet jacobian."""
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
if cellname in (
"triangle",
"tetrahedron",
"quadrilateral",
"hexahedron",
"prism",
"pyramid",
):
table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
facet = self.symbols.entity("facet", mt.restriction)
return table[facet][mt.component[0]][mt.component[1]]
Expand Down
6 changes: 5 additions & 1 deletion ffcx/codegeneration/codegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) -
logger.info("Compiler stage 3: Generating code")
logger.info(79 * "*")

code_integrals = [integral_generator(integral_ir, options) for integral_ir in ir.integrals]
code_integrals = [
integral_generator(integral_ir, domain, options)
for integral_ir in ir.integrals
for domain in set(i[0] for i in integral_ir.expression.integrand.keys())
]
code_forms = [form_generator(form_ir, options) for form_ir in ir.forms]
code_expressions = [
expression_generator(expression_ir, options) for expression_ir in ir.expressions
Expand Down
12 changes: 6 additions & 6 deletions ffcx/codegeneration/expression_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def generate_element_tables(self):
"""Generate tables of FE basis evaluated at specified points."""
parts = []

tables = self.ir.expression.unique_tables
tables = self.ir.expression.unique_tables[self.quadrature_rule[0]]
table_names = sorted(tables)

for name in table_names:
Expand Down Expand Up @@ -125,7 +125,7 @@ def generate_quadrature_loop(self):
# Generate varying partition
body = self.generate_varying_partition()
body = L.commented_code_list(
body, f"Points loop body setup quadrature loop {self.quadrature_rule.id()}"
body, f"Points loop body setup quadrature loop {self.quadrature_rule[1].id()}"
)

# Generate dofblock parts, some of this
Expand All @@ -139,7 +139,7 @@ def generate_quadrature_loop(self):
quadparts = []
else:
iq = self.backend.symbols.quadrature_loop_index
num_points = self.quadrature_rule.points.shape[0]
num_points = self.quadrature_rule[1].points.shape[0]
quadparts = [L.ForRange(iq, 0, num_points, body=body)]
return preparts, quadparts

Expand All @@ -148,11 +148,11 @@ def generate_varying_partition(self):
# Get annotated graph of factorisation
F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]

arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR)
arraysymbol = L.Symbol(f"sv_{self.quadrature_rule[1].id()}", dtype=L.DataType.SCALAR)
parts = self.generate_partition(arraysymbol, F, "varying")
parts = L.commented_code_list(
parts,
f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}",
f"Unstructured varying computations for quadrature rule {self.quadrature_rule[1].id()}",
)
return parts

Expand Down Expand Up @@ -216,7 +216,7 @@ def generate_block_parts(self, blockmap, blockdata):
assert not blockdata.transposed, "Not handled yet"
components = ufl.product(self.ir.expression.shape)

num_points = self.quadrature_rule.points.shape[0]
num_points = self.quadrature_rule[1].points.shape[0]
A_shape = [num_points, components] + self.ir.expression.tensor_shape
A = self.backend.symbols.element_tensor
iq = self.backend.symbols.quadrature_loop_index
Expand Down
Loading