Skip to content

Commit 27dce3b

Browse files
authored
Make ds integrals on prism/pyramids generate kernels for each facet type (FEniCS#739)
* working on quadrature * working on prism facet integrals * progress towards ds on prisms * store tables by integrand * correct test * update expressions * ruff * mypy * mypy * ufcx_vertex * ruff * ruff * more ruff * don't overwrite cell * remove ufcx_cell_type enum * ruff * use basix cell type throughout rather than converting to/from string * mypy * set() * make demo xfail * handle "vertex" * typing * remove testing ipython embed * fix merge
1 parent 008bd94 commit 27dce3b

13 files changed

+586
-357
lines changed

ffcx/codegeneration/C/expressions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def generator(ir: ExpressionIR, options):
2525
"""Generate UFC code for an expression."""
2626
logger.info("Generating code for expression:")
2727
assert len(ir.expression.integrand) == 1, "Expressions only support single quadrature rule"
28-
points = next(iter(ir.expression.integrand)).points
28+
points = next(iter(ir.expression.integrand))[1].points
2929
logger.info(f"--- points: {points}")
3030
factory_name = ir.expression.name
3131
logger.info(f"--- name: {factory_name}")

ffcx/codegeneration/C/form.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,44 @@ def generator(ir: FormIR, options):
8686
integrals = []
8787
integral_ids = []
8888
integral_offsets = [0]
89+
integral_domains = []
8990
# Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h
9091
for itg_type in ("cell", "exterior_facet", "interior_facet"):
9192
unsorted_integrals = []
9293
unsorted_ids = []
93-
for name, id in zip(ir.integral_names[itg_type], ir.subdomain_ids[itg_type]):
94+
unsorted_domains = []
95+
for name, domains, id in zip(
96+
ir.integral_names[itg_type],
97+
ir.integral_domains[itg_type],
98+
ir.subdomain_ids[itg_type],
99+
):
94100
unsorted_integrals += [f"&{name}"]
95101
unsorted_ids += [id]
102+
unsorted_domains += [domains]
96103

97104
id_sort = np.argsort(unsorted_ids)
98105
integrals += [unsorted_integrals[i] for i in id_sort]
99106
integral_ids += [unsorted_ids[i] for i in id_sort]
107+
integral_domains += [unsorted_domains[i] for i in id_sort]
100108

101-
integral_offsets.append(len(integrals))
109+
integral_offsets.append(sum(len(d) for d in integral_domains))
102110

103111
if len(integrals) > 0:
104-
sizes = len(integrals)
105-
values = ", ".join(integrals)
112+
sizes = sum(len(domains) for domains in integral_domains)
113+
values = ", ".join(
114+
[
115+
f"{i}_{domain.name}"
116+
for i, domains in zip(integrals, integral_domains)
117+
for domain in domains
118+
]
119+
)
106120
d["form_integrals_init"] = (
107121
f"static ufcx_integral* form_integrals_{ir.name}[{sizes}] = {{{values}}};"
108122
)
109123
d["form_integrals"] = f"form_integrals_{ir.name}"
110-
sizes = len(integral_ids)
111-
values = ", ".join(str(i) for i in integral_ids)
124+
values = ", ".join(
125+
f"{i}" for i, domains in zip(integral_ids, integral_domains) for _ in domains
126+
)
112127
d["form_integral_ids_init"] = f"int form_integral_ids_{ir.name}[{sizes}] = {{{values}}};"
113128
d["form_integral_ids"] = f"form_integral_ids_{ir.name}"
114129
else:

ffcx/codegeneration/C/integrals.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import logging
99
import sys
1010

11+
import basix
1112
import numpy as np
1213

1314
from ffcx.codegeneration.backend import FFCXBackend
@@ -20,14 +21,13 @@
2021
logger = logging.getLogger("ffcx")
2122

2223

23-
def generator(ir: IntegralIR, options):
24+
def generator(ir: IntegralIR, domain: basix.CellType, options):
2425
"""Generate C code for an integral."""
2526
logger.info("Generating code for integral:")
2627
logger.info(f"--- type: {ir.expression.integral_type}")
2728
logger.info(f"--- name: {ir.expression.name}")
2829

29-
"""Generate code for an integral."""
30-
factory_name = ir.expression.name
30+
factory_name = f"{ir.expression.name}_{domain.name}"
3131

3232
# Format declaration
3333
declaration = ufcx_integrals.declaration.format(factory_name=factory_name)
@@ -39,7 +39,7 @@ def generator(ir: IntegralIR, options):
3939
ig = IntegralGenerator(ir, backend)
4040

4141
# Generate code ast for the tabulate_tensor body
42-
parts = ig.generate()
42+
parts = ig.generate(domain)
4343

4444
# Format code as string
4545
CF = CFormatter(options["scalar_type"])
@@ -52,9 +52,9 @@ def generator(ir: IntegralIR, options):
5252
values = ", ".join("1" if i else "0" for i in ir.enabled_coefficients)
5353
sizes = len(ir.enabled_coefficients)
5454
code["enabled_coefficients_init"] = (
55-
f"bool enabled_coefficients_{ir.expression.name}[{sizes}] = {{{values}}};"
55+
f"bool enabled_coefficients_{ir.expression.name}_{domain.name}[{sizes}] = {{{values}}};"
5656
)
57-
code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}"
57+
code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}_{domain.name}"
5858
else:
5959
code["enabled_coefficients_init"] = ""
6060
code["enabled_coefficients"] = "NULL"
@@ -88,6 +88,7 @@ def generator(ir: IntegralIR, options):
8888
tabulate_tensor_float64=code["tabulate_tensor_float64"],
8989
tabulate_tensor_complex64=code["tabulate_tensor_complex64"],
9090
tabulate_tensor_complex128=code["tabulate_tensor_complex128"],
91+
domain=int(domain),
9192
)
9293

9394
return declaration, implementation

ffcx/codegeneration/C/integrals_template.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
{tabulate_tensor_complex128}
3333
.needs_facet_permutations = {needs_facet_permutations},
3434
.coordinate_element_hash = {coordinate_element_hash},
35+
.domain = {domain},
3536
}};
3637
3738
// End of code for integral {factory_name}

ffcx/codegeneration/access.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,14 @@ def reference_normal(self, mt, tabledata, access):
237237
def cell_facet_jacobian(self, mt, tabledata, num_points):
238238
"""Access a cell facet jacobian."""
239239
cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
240-
if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
240+
if cellname in (
241+
"triangle",
242+
"tetrahedron",
243+
"quadrilateral",
244+
"hexahedron",
245+
"prism",
246+
"pyramid",
247+
):
241248
table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
242249
facet = self.symbols.entity("facet", mt.restriction)
243250
return table[facet][mt.component[0]][mt.component[1]]

ffcx/codegeneration/codegeneration.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) -
4545
logger.info("Compiler stage 3: Generating code")
4646
logger.info(79 * "*")
4747

48-
code_integrals = [integral_generator(integral_ir, options) for integral_ir in ir.integrals]
48+
code_integrals = [
49+
integral_generator(integral_ir, domain, options)
50+
for integral_ir in ir.integrals
51+
for domain in set(i[0] for i in integral_ir.expression.integrand.keys())
52+
]
4953
code_forms = [form_generator(form_ir, options) for form_ir in ir.forms]
5054
code_expressions = [
5155
expression_generator(expression_ir, options) for expression_ir in ir.expressions

ffcx/codegeneration/expression_generator.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def generate_element_tables(self):
9797
"""Generate tables of FE basis evaluated at specified points."""
9898
parts = []
9999

100-
tables = self.ir.expression.unique_tables
100+
tables = self.ir.expression.unique_tables[self.quadrature_rule[0]]
101101
table_names = sorted(tables)
102102

103103
for name in table_names:
@@ -125,7 +125,7 @@ def generate_quadrature_loop(self):
125125
# Generate varying partition
126126
body = self.generate_varying_partition()
127127
body = L.commented_code_list(
128-
body, f"Points loop body setup quadrature loop {self.quadrature_rule.id()}"
128+
body, f"Points loop body setup quadrature loop {self.quadrature_rule[1].id()}"
129129
)
130130

131131
# Generate dofblock parts, some of this
@@ -139,7 +139,7 @@ def generate_quadrature_loop(self):
139139
quadparts = []
140140
else:
141141
iq = self.backend.symbols.quadrature_loop_index
142-
num_points = self.quadrature_rule.points.shape[0]
142+
num_points = self.quadrature_rule[1].points.shape[0]
143143
quadparts = [L.ForRange(iq, 0, num_points, body=body)]
144144
return preparts, quadparts
145145

@@ -148,11 +148,11 @@ def generate_varying_partition(self):
148148
# Get annotated graph of factorisation
149149
F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]
150150

151-
arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR)
151+
arraysymbol = L.Symbol(f"sv_{self.quadrature_rule[1].id()}", dtype=L.DataType.SCALAR)
152152
parts = self.generate_partition(arraysymbol, F, "varying")
153153
parts = L.commented_code_list(
154154
parts,
155-
f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}",
155+
f"Unstructured varying computations for quadrature rule {self.quadrature_rule[1].id()}",
156156
)
157157
return parts
158158

@@ -216,7 +216,7 @@ def generate_block_parts(self, blockmap, blockdata):
216216
assert not blockdata.transposed, "Not handled yet"
217217
components = ufl.product(self.ir.expression.shape)
218218

219-
num_points = self.quadrature_rule.points.shape[0]
219+
num_points = self.quadrature_rule[1].points.shape[0]
220220
A_shape = [num_points, components] + self.ir.expression.tensor_shape
221221
A = self.backend.symbols.element_tensor
222222
iq = self.backend.symbols.quadrature_loop_index

0 commit comments

Comments
 (0)