diff --git a/ffcx/codegeneration/C/expressions.py b/ffcx/codegeneration/C/expressions.py index 853f46240..177e8a0be 100644 --- a/ffcx/codegeneration/C/expressions.py +++ b/ffcx/codegeneration/C/expressions.py @@ -25,7 +25,7 @@ def generator(ir: ExpressionIR, options): """Generate UFC code for an expression.""" logger.info("Generating code for expression:") assert len(ir.expression.integrand) == 1, "Expressions only support single quadrature rule" - points = next(iter(ir.expression.integrand)).points + points = next(iter(ir.expression.integrand))[1].points logger.info(f"--- points: {points}") factory_name = ir.expression.name logger.info(f"--- name: {factory_name}") diff --git a/ffcx/codegeneration/C/form.py b/ffcx/codegeneration/C/form.py index eab91c59d..a0fbaa831 100644 --- a/ffcx/codegeneration/C/form.py +++ b/ffcx/codegeneration/C/form.py @@ -86,29 +86,44 @@ def generator(ir: FormIR, options): integrals = [] integral_ids = [] integral_offsets = [0] + integral_domains = [] # Note: the order of this list is defined by the enum ufcx_integral_type in ufcx.h for itg_type in ("cell", "exterior_facet", "interior_facet"): unsorted_integrals = [] unsorted_ids = [] - for name, id in zip(ir.integral_names[itg_type], ir.subdomain_ids[itg_type]): + unsorted_domains = [] + for name, domains, id in zip( + ir.integral_names[itg_type], + ir.integral_domains[itg_type], + ir.subdomain_ids[itg_type], + ): unsorted_integrals += [f"&{name}"] unsorted_ids += [id] + unsorted_domains += [domains] id_sort = np.argsort(unsorted_ids) integrals += [unsorted_integrals[i] for i in id_sort] integral_ids += [unsorted_ids[i] for i in id_sort] + integral_domains += [unsorted_domains[i] for i in id_sort] - integral_offsets.append(len(integrals)) + integral_offsets.append(sum(len(d) for d in integral_domains)) if len(integrals) > 0: - sizes = len(integrals) - values = ", ".join(integrals) + sizes = sum(len(domains) for domains in integral_domains) + values = ", ".join( + [ + f"{i}_{domain.name}" + for i, domains in zip(integrals, integral_domains) + for domain in domains + ] + ) d["form_integrals_init"] = ( f"static ufcx_integral* form_integrals_{ir.name}[{sizes}] = {{{values}}};" ) d["form_integrals"] = f"form_integrals_{ir.name}" - sizes = len(integral_ids) - values = ", ".join(str(i) for i in integral_ids) + values = ", ".join( + f"{i}" for i, domains in zip(integral_ids, integral_domains) for _ in domains + ) d["form_integral_ids_init"] = f"int form_integral_ids_{ir.name}[{sizes}] = {{{values}}};" d["form_integral_ids"] = f"form_integral_ids_{ir.name}" else: diff --git a/ffcx/codegeneration/C/integrals.py b/ffcx/codegeneration/C/integrals.py index 6c636a520..9fb64ccc7 100644 --- a/ffcx/codegeneration/C/integrals.py +++ b/ffcx/codegeneration/C/integrals.py @@ -8,6 +8,7 @@ import logging import sys +import basix import numpy as np from ffcx.codegeneration.backend import FFCXBackend @@ -20,14 +21,13 @@ logger = logging.getLogger("ffcx") -def generator(ir: IntegralIR, options): +def generator(ir: IntegralIR, domain: basix.CellType, options): """Generate C code for an integral.""" logger.info("Generating code for integral:") logger.info(f"--- type: {ir.expression.integral_type}") logger.info(f"--- name: {ir.expression.name}") - """Generate code for an integral.""" - factory_name = ir.expression.name + factory_name = f"{ir.expression.name}_{domain.name}" # Format declaration declaration = ufcx_integrals.declaration.format(factory_name=factory_name) @@ -39,7 +39,7 @@ def generator(ir: IntegralIR, options): ig = IntegralGenerator(ir, backend) # Generate code ast for the tabulate_tensor body - parts = ig.generate() + parts = ig.generate(domain) # Format code as string CF = CFormatter(options["scalar_type"]) @@ -52,9 +52,9 @@ def generator(ir: IntegralIR, options): values = ", ".join("1" if i else "0" for i in ir.enabled_coefficients) sizes = len(ir.enabled_coefficients) code["enabled_coefficients_init"] = ( - f"bool enabled_coefficients_{ir.expression.name}[{sizes}] = {{{values}}};" + f"bool enabled_coefficients_{ir.expression.name}_{domain.name}[{sizes}] = {{{values}}};" ) - code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}" + code["enabled_coefficients"] = f"enabled_coefficients_{ir.expression.name}_{domain.name}" else: code["enabled_coefficients_init"] = "" code["enabled_coefficients"] = "NULL" @@ -88,6 +88,7 @@ def generator(ir: IntegralIR, options): tabulate_tensor_float64=code["tabulate_tensor_float64"], tabulate_tensor_complex64=code["tabulate_tensor_complex64"], tabulate_tensor_complex128=code["tabulate_tensor_complex128"], + domain=int(domain), ) return declaration, implementation diff --git a/ffcx/codegeneration/C/integrals_template.py b/ffcx/codegeneration/C/integrals_template.py index 2bb1568ec..6645e185c 100644 --- a/ffcx/codegeneration/C/integrals_template.py +++ b/ffcx/codegeneration/C/integrals_template.py @@ -32,6 +32,7 @@ {tabulate_tensor_complex128} .needs_facet_permutations = {needs_facet_permutations}, .coordinate_element_hash = {coordinate_element_hash}, + .domain = {domain}, }}; // End of code for integral {factory_name} diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index c74560e3a..b47d04b4e 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -237,7 +237,14 @@ def reference_normal(self, mt, tabledata, access): def cell_facet_jacobian(self, mt, tabledata, num_points): """Access a cell facet jacobian.""" cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname() - if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"): + if cellname in ( + "triangle", + "tetrahedron", + "quadrilateral", + "hexahedron", + "prism", + "pyramid", + ): table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL) facet = self.symbols.entity("facet", mt.restriction) return table[facet][mt.component[0]][mt.component[1]] diff --git a/ffcx/codegeneration/codegeneration.py b/ffcx/codegeneration/codegeneration.py index 9564c3837..bc1ecfe2a 100644 --- a/ffcx/codegeneration/codegeneration.py +++ b/ffcx/codegeneration/codegeneration.py @@ -45,7 +45,11 @@ def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) - logger.info("Compiler stage 3: Generating code") logger.info(79 * "*") - code_integrals = [integral_generator(integral_ir, options) for integral_ir in ir.integrals] + code_integrals = [ + integral_generator(integral_ir, domain, options) + for integral_ir in ir.integrals + for domain in set(i[0] for i in integral_ir.expression.integrand.keys()) + ] code_forms = [form_generator(form_ir, options) for form_ir in ir.forms] code_expressions = [ expression_generator(expression_ir, options) for expression_ir in ir.expressions diff --git a/ffcx/codegeneration/expression_generator.py b/ffcx/codegeneration/expression_generator.py index fe8717adb..15d5c56a2 100644 --- a/ffcx/codegeneration/expression_generator.py +++ b/ffcx/codegeneration/expression_generator.py @@ -97,7 +97,7 @@ def generate_element_tables(self): """Generate tables of FE basis evaluated at specified points.""" parts = [] - tables = self.ir.expression.unique_tables + tables = self.ir.expression.unique_tables[self.quadrature_rule[0]] table_names = sorted(tables) for name in table_names: @@ -125,7 +125,7 @@ def generate_quadrature_loop(self): # Generate varying partition body = self.generate_varying_partition() body = L.commented_code_list( - body, f"Points loop body setup quadrature loop {self.quadrature_rule.id()}" + body, f"Points loop body setup quadrature loop {self.quadrature_rule[1].id()}" ) # Generate dofblock parts, some of this @@ -139,7 +139,7 @@ def generate_quadrature_loop(self): quadparts = [] else: iq = self.backend.symbols.quadrature_loop_index - num_points = self.quadrature_rule.points.shape[0] + num_points = self.quadrature_rule[1].points.shape[0] quadparts = [L.ForRange(iq, 0, num_points, body=body)] return preparts, quadparts @@ -148,11 +148,11 @@ def generate_varying_partition(self): # Get annotated graph of factorisation F = self.ir.expression.integrand[self.quadrature_rule]["factorization"] - arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR) + arraysymbol = L.Symbol(f"sv_{self.quadrature_rule[1].id()}", dtype=L.DataType.SCALAR) parts = self.generate_partition(arraysymbol, F, "varying") parts = L.commented_code_list( parts, - f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}", + f"Unstructured varying computations for quadrature rule {self.quadrature_rule[1].id()}", ) return parts @@ -216,7 +216,7 @@ def generate_block_parts(self, blockmap, blockdata): assert not blockdata.transposed, "Not handled yet" components = ufl.product(self.ir.expression.shape) - num_points = self.quadrature_rule.points.shape[0] + num_points = self.quadrature_rule[1].points.shape[0] A_shape = [num_points, components] + self.ir.expression.tensor_shape A = self.backend.symbols.element_tensor iq = self.backend.symbols.quadrature_loop_index diff --git a/ffcx/codegeneration/integral_generator.py b/ffcx/codegeneration/integral_generator.py index 92fe3ee59..7c5768957 100644 --- a/ffcx/codegeneration/integral_generator.py +++ b/ffcx/codegeneration/integral_generator.py @@ -12,6 +12,7 @@ from numbers import Integral from typing import Any +import basix import ufl import ffcx.codegeneration.lnodes as L @@ -75,9 +76,9 @@ def init_scopes(self): self.scopes = { quadrature_rule: {} for quadrature_rule in self.ir.expression.integrand.keys() } - self.scopes[None] = {} + self.scopes[(None, None)] = {} - def set_var(self, quadrature_rule, v, vaccess): + def set_var(self, quadrature_rule, domain, v, vaccess): """Set a new variable in variable scope dicts. Scope is determined by quadrature_rule which identifies the @@ -85,12 +86,13 @@ def set_var(self, quadrature_rule, v, vaccess): Args: quadrature_rule: Quadrature rule + domain: The domain of the integral v: the ufl expression vaccess: the LNodes expression to access the value in the code """ - self.scopes[quadrature_rule][v] = vaccess + self.scopes[(domain, quadrature_rule)][v] = vaccess - def get_var(self, quadrature_rule, v): + def get_var(self, quadrature_rule, domain, v): """Lookup ufl expression v in variable scope dicts. Scope is determined by quadrature rule which identifies the @@ -105,11 +107,11 @@ def get_var(self, quadrature_rule, v): return L.ufl_to_lnodes(v) # quadrature loop scope - f = self.scopes[quadrature_rule].get(v) + f = self.scopes[(domain, quadrature_rule)].get(v) # piecewise scope if f is None: - f = self.scopes[None].get(v) + f = self.scopes[(None, None)].get(v) return f def new_temp_symbol(self, basename): @@ -128,7 +130,7 @@ def get_temp_symbol(self, tempname, key): self.temp_symbols[key] = s return s, defined - def generate(self): + def generate(self, domain: basix.CellType): """Generate entire tabulate_tensor body. Assumes that the code returned from here will be wrapped in a @@ -142,11 +144,11 @@ def generate(self): parts = [] # Generate the tables of quadrature points and weights - parts += self.generate_quadrature_tables() + parts += self.generate_quadrature_tables(domain) # Generate the tables of basis function values and # pre-integrated blocks - parts += self.generate_element_tables() + parts += self.generate_element_tables(domain) # Generate the tables of geometry data that are needed parts += self.generate_geometry_tables() @@ -159,13 +161,14 @@ def generate(self): # Pre-definitions are collected across all quadrature loops to # improve re-use and avoid name clashes - for rule in self.ir.expression.integrand.keys(): - # Generate code to compute piecewise constant scalar factors - all_preparts += self.generate_piecewise_partition(rule) + for cell, rule in self.ir.expression.integrand.keys(): + if domain == cell: + # Generate code to compute piecewise constant scalar factors + all_preparts += self.generate_piecewise_partition(rule, cell) - # Generate code to integrate reusable blocks of final - # element tensor - all_quadparts += self.generate_quadrature_loop(rule) + # Generate code to integrate reusable blocks of final + # element tensor + all_quadparts += self.generate_quadrature_loop(rule, cell) # Collect parts before, during, and after quadrature loops parts += all_preparts @@ -173,9 +176,9 @@ def generate(self): return L.StatementList(parts) - def generate_quadrature_tables(self): + def generate_quadrature_tables(self, domain: basix.CellType): """Generate static tables of quadrature points and weights.""" - parts = [] + parts: list[L.LNode] = [] # No quadrature tables for custom (given argument) or point # (evaluation in single vertex) @@ -184,10 +187,11 @@ def generate_quadrature_tables(self): return parts # Loop over quadrature rules - for quadrature_rule, _ in self.ir.expression.integrand.items(): - # Generate quadrature weights array - wsym = self.backend.symbols.weights_table(quadrature_rule) - parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)] + for (cell, quadrature_rule), _ in self.ir.expression.integrand.items(): + if domain == cell: + # Generate quadrature weights array + wsym = self.backend.symbols.weights_table(quadrature_rule) + parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)] # Add leading comment if there are any tables parts = L.commented_code_list(parts, "Quadrature rules") @@ -224,14 +228,14 @@ def generate_geometry_tables(self): return parts - def generate_element_tables(self): + def generate_element_tables(self, domain: basix.CellType): """Generate static tables. With precomputed element basis function values in quadrature points. """ parts = [] - tables = self.ir.expression.unique_tables - table_types = self.ir.expression.unique_table_types + tables = self.ir.expression.unique_tables[domain] + table_types = self.ir.expression.unique_table_types[domain] if self.ir.expression.integral_type in ufl.custom_integral_types: # Define only piecewise tables table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes] @@ -264,13 +268,13 @@ def declare_table(self, name, table): self.backend.symbols.element_tables[name] = table_symbol return [L.ArrayDecl(table_symbol, values=table, const=True)] - def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): + def generate_quadrature_loop(self, quadrature_rule: QuadratureRule, domain: basix.CellType): """Generate quadrature loop with for this quadrature_rule.""" # Generate varying partition - definitions, intermediates_0 = self.generate_varying_partition(quadrature_rule) + definitions, intermediates_0 = self.generate_varying_partition(quadrature_rule, domain) # Generate dofblock parts, some of this will be placed before or after quadloop - tensor_comp, intermediates_fw = self.generate_dofblock_partition(quadrature_rule) + tensor_comp, intermediates_fw = self.generate_dofblock_partition(quadrature_rule, domain) assert all([isinstance(tc, L.Section) for tc in tensor_comp]) # Check if we only have Section objects @@ -297,21 +301,21 @@ def generate_quadrature_loop(self, quadrature_rule: QuadratureRule): return [L.create_nested_for_loops([iq], code)] - def generate_piecewise_partition(self, quadrature_rule): + def generate_piecewise_partition(self, quadrature_rule, domain: basix.CellType): """Generate a piecewise partition.""" # Get annotated graph of factorisation - F = self.ir.expression.integrand[quadrature_rule]["factorization"] + F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"] arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) - return self.generate_partition(arraysymbol, F, "piecewise", None) + return self.generate_partition(arraysymbol, F, "piecewise", None, None) - def generate_varying_partition(self, quadrature_rule): + def generate_varying_partition(self, quadrature_rule, domain: basix.CellType): """Generate a varying partition.""" # Get annotated graph of factorisation - F = self.ir.expression.integrand[quadrature_rule]["factorization"] + F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"] arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR) - return self.generate_partition(arraysymbol, F, "varying", quadrature_rule) + return self.generate_partition(arraysymbol, F, "varying", quadrature_rule, domain) - def generate_partition(self, symbol, F, mode, quadrature_rule): + def generate_partition(self, symbol, F, mode, quadrature_rule, domain): """Generate a partition.""" definitions = [] intermediates = [] @@ -322,7 +326,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): v = attr["expression"] # Generate code only if the expression is not already in cache - if not self.get_var(quadrature_rule, v): + if not self.get_var(quadrature_rule, domain, v): if v._ufl_is_literal_: vaccess = L.ufl_to_lnodes(v) elif mt := attr.get("mt"): @@ -340,7 +344,7 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): definitions += [vdef] else: # Get previously visited operands - vops = [self.get_var(quadrature_rule, op) for op in v.ufl_operands] + vops = [self.get_var(quadrature_rule, domain, op) for op in v.ufl_operands] dtype = extract_dtype(v, vops) # Mapping UFL operator to target language @@ -352,15 +356,17 @@ def generate_partition(self, symbol, F, mode, quadrature_rule): intermediates.append(L.VariableDecl(vaccess, vexpr)) # Store access node for future reference - self.set_var(quadrature_rule, v, vaccess) + self.set_var(quadrature_rule, domain, v, vaccess) # Optimize definitions definitions = optimize(definitions, quadrature_rule) return definitions, intermediates - def generate_dofblock_partition(self, quadrature_rule: QuadratureRule): + def generate_dofblock_partition(self, quadrature_rule: QuadratureRule, domain: basix.CellType): """Generate a dofblock partition.""" - block_contributions = self.ir.expression.integrand[quadrature_rule]["block_contributions"] + block_contributions = self.ir.expression.integrand[(domain, quadrature_rule)][ + "block_contributions" + ] quadparts = [] blocks = [ (blockmap, blockdata) @@ -385,7 +391,7 @@ def generate_dofblock_partition(self, quadrature_rule: QuadratureRule): intermediates = [] for blockmap in block_groups: block_quadparts, intermediate = self.generate_block_parts( - quadrature_rule, blockmap, block_groups[blockmap] + quadrature_rule, domain, blockmap, block_groups[blockmap] ) intermediates += intermediate @@ -394,14 +400,14 @@ def generate_dofblock_partition(self, quadrature_rule: QuadratureRule): return quadparts, intermediates - def get_arg_factors(self, blockdata, block_rank, quadrature_rule, iq, indices): + def get_arg_factors(self, blockdata, block_rank, quadrature_rule, domain, iq, indices): """Get arg factors.""" arg_factors = [] tables = [] for i in range(block_rank): mad = blockdata.ma_data[i] td = mad.tabledata - scope = self.ir.expression.integrand[quadrature_rule]["modified_arguments"] + scope = self.ir.expression.integrand[(domain, quadrature_rule)]["modified_arguments"] mt = scope[mad.ma_index] arg_tables = [] @@ -426,7 +432,11 @@ def get_arg_factors(self, blockdata, block_rank, quadrature_rule, iq, indices): return arg_factors, tables def generate_block_parts( - self, quadrature_rule: QuadratureRule, blockmap: tuple, blocklist: list[BlockDataT] + self, + quadrature_rule: QuadratureRule, + domain: basix.CellType, + blockmap: tuple, + blocklist: list[BlockDataT], ): """Generate and return code parts for a given block. @@ -470,10 +480,10 @@ def generate_block_parts( factor_index = blockdata.factor_indices_comp_indices[0][0] # Get factor expression - F = self.ir.expression.integrand[quadrature_rule]["factorization"] + F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"] v = F.nodes[factor_index]["expression"] - f = self.get_var(quadrature_rule, v) + f = self.get_var(quadrature_rule, domain, v) # Quadrature weight was removed in representation, add it back now if self.ir.expression.integral_type in ufl.custom_integral_types: @@ -509,7 +519,7 @@ def generate_block_parts( # Fetch code to access modified arguments arg_factors, table = self.get_arg_factors( - blockdata, block_rank, quadrature_rule, iq, B_indices + blockdata, block_rank, quadrature_rule, domain, iq, B_indices ) tables += table diff --git a/ffcx/codegeneration/ufcx.h b/ffcx/codegeneration/ufcx.h index 73352ef61..59ae20434 100644 --- a/ffcx/codegeneration/ufcx.h +++ b/ffcx/codegeneration/ufcx.h @@ -139,6 +139,8 @@ extern "C" /// Hash of the coordinate element associated with the geometry of the mesh. uint64_t coordinate_element_hash; + + uint8_t domain; } ufcx_integral; typedef struct ufcx_expression diff --git a/ffcx/ir/integral.py b/ffcx/ir/integral.py index f399526c8..bc2eb38a6 100644 --- a/ffcx/ir/integral.py +++ b/ffcx/ir/integral.py @@ -50,7 +50,7 @@ def compute_integral_ir(cell, integral_type, entity_type, integrands, argument_s """Compute intermediate representation for an integral.""" # The intermediate representation dict we're building and returning # here - ir = {} + ir = {"needs_facet_permutations": False} # Shared unique tables for all quadrature loops ir["unique_tables"] = {} @@ -58,239 +58,245 @@ def compute_integral_ir(cell, integral_type, entity_type, integrands, argument_s ir["integrand"] = {} - for quadrature_rule, integrand in integrands.items(): - expression = integrand - - # Rebalance order of nested terminal modifiers - expression = balance_modifiers(expression) - - # Remove QuadratureWeight terminals from expression and replace with 1.0 - expression = replace_quadratureweight(expression) - - # Build initial scalar list-based graph representation - S = build_scalar_graph(expression) - - # Build terminal_data from V here before factorization. Then we - # can use it to derive table properties for all modified - # terminals, and then use that to rebuild the scalar graph more - # efficiently before argument factorization. We can build - # terminal_data again after factorization if that's necessary. - - initial_terminals = { - i: analyse_modified_terminal(v["expression"]) - for i, v in S.nodes.items() - if is_modified_terminal(v["expression"]) - } - - # Check if we have a mixed-dimensional integral - is_mixed_dim = False - for domain in ufl.domain.extract_domains(integrand): - if domain.topological_dimension() != cell.topological_dimension(): - is_mixed_dim = True - - mt_table_reference = build_optimized_tables( - quadrature_rule, - cell, - integral_type, - entity_type, - initial_terminals.values(), - ir["unique_tables"], - use_sum_factorization=p["sum_factorization"], - is_mixed_dim=is_mixed_dim, - rtol=p["table_rtol"], - atol=p["table_atol"], - ) - - # Fetch unique tables for this quadrature rule - table_types = {v.name: v.ttype for v in mt_table_reference.values()} - tables = {v.name: v.values for v in mt_table_reference.values()} - - S_targets = [i for i, v in S.nodes.items() if v.get("target", False)] - num_components = np.int32(np.prod(expression.ufl_shape)) - - if "zeros" in table_types.values(): - # If there are any 'zero' tables, replace symbolically and rebuild graph - for i, mt in initial_terminals.items(): - # Set modified terminals with zero tables to zero - tr = mt_table_reference.get(mt) - if tr is not None and tr.ttype == "zeros": - S.nodes[i]["expression"] = ufl.as_ufl(0.0) - - # Propagate expression changes using dependency list - for i, v in S.nodes.items(): - deps = [S.nodes[j]["expression"] for j in S.out_edges[i]] - if deps: - v["expression"] = v["expression"]._ufl_expr_reconstruct_(*deps) - - # Recreate expression with correct ufl_shape - expressions = [ - None, - ] * num_components - for target in S_targets: - for comp in S.nodes[target]["component"]: - assert expressions[comp] is None - expressions[comp] = S.nodes[target]["expression"] - expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) - - # Rebuild scalar list-based graph representation + for integral_domain, integrands_on_domain in integrands.items(): + ir["unique_tables"][integral_domain] = {} + ir["unique_table_types"][integral_domain] = {} + for quadrature_rule, integrand in integrands_on_domain.items(): + expression = integrand + + # Rebalance order of nested terminal modifiers + expression = balance_modifiers(expression) + + # Remove QuadratureWeight terminals from expression and replace with 1.0 + expression = replace_quadratureweight(expression) + + # Build initial scalar list-based graph representation S = build_scalar_graph(expression) - # Output diagnostic graph as pdf - if visualise: - visualise_graph(S, "S.pdf") - - # Compute factorization of arguments - rank = len(argument_shape) - F = compute_argument_factorization(S, rank) - - # Get the 'target' nodes that are factors of arguments, and insert in dict - FV_targets = [i for i, v in F.nodes.items() if v.get("target", False)] - argument_factorization = {} - - for fi in FV_targets: - # Number of blocks using this factor must agree with number of components - # to which this factor contributes. I.e. there are more blocks iff there are more - # components - assert len(F.nodes[fi]["target"]) == len(F.nodes[fi]["component"]) - - k = 0 - for w in F.nodes[fi]["target"]: - comp = F.nodes[fi]["component"][k] - argument_factorization[w] = argument_factorization.get(w, []) - - # Store tuple of (factor index, component index) - argument_factorization[w].append((fi, comp)) - k += 1 - - # Get list of indices in F which are the arguments (should be at start) - argkeys = set() - for w in argument_factorization: - argkeys = argkeys | set(w) - argkeys = list(argkeys) - - # Build set of modified_terminals for each mt factorized vertex in F - # and attach tables, if appropriate - for i, v in F.nodes.items(): - expr = v["expression"] - if is_modified_terminal(expr): - mt = analyse_modified_terminal(expr) - F.nodes[i]["mt"] = mt - tr = mt_table_reference.get(mt) - if tr is not None: - F.nodes[i]["tr"] = tr - - # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' - analyse_dependencies(F, mt_table_reference) - - # Output diagnostic graph as pdf - if visualise: - visualise_graph(F, "F.pdf") - - # Loop over factorization terms - block_contributions = collections.defaultdict(list) - for ma_indices, fi_ci in sorted(argument_factorization.items()): - # Get a bunch of information about this term - assert rank == len(ma_indices) - trs = tuple(F.nodes[ai]["tr"] for ai in ma_indices) - - unames = tuple(tr.name for tr in trs) - ttypes = tuple(tr.ttype for tr in trs) - assert not any(tt == "zeros" for tt in ttypes) - - blockmap = [] - for tr in trs: - begin = tr.offset - num_dofs = tr.values.shape[3] - dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) - blockmap.append(dofmap) - - blockmap = tuple(blockmap) - block_is_uniform = all(tr.is_uniform for tr in trs) - - # Collect relevant restrictions to identify blocks correctly - # in interior facet integrals - block_restrictions = [] - for i, ai in enumerate(ma_indices): - if trs[i].is_uniform: - r = None - else: - r = F.nodes[ai]["mt"].restriction - - block_restrictions.append(r) - block_restrictions = tuple(block_restrictions) - - # Check if each *each* factor corresponding to this argument is piecewise - all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == "piecewise" for ifi in fi_ci) - block_is_permuted = False - for name in unames: - if tables[name].shape[0] > 1: - block_is_permuted = True - ma_data = [] - for i, ma in enumerate(ma_indices): - ma_data.append(ModifiedArgumentDataT(ma, trs[i])) - - block_is_transposed = False # FIXME: Handle transposes for these block types - block_unames = unames - blockdata = BlockDataT( - ttypes, - fi_ci, - all_factors_piecewise, - block_unames, - block_restrictions, - block_is_transposed, - block_is_uniform, - tuple(ma_data), - block_is_permuted, + # Build terminal_data from V here before factorization. Then we + # can use it to derive table properties for all modified + # terminals, and then use that to rebuild the scalar graph more + # efficiently before argument factorization. We can build + # terminal_data again after factorization if that's necessary. + + initial_terminals = { + i: analyse_modified_terminal(v["expression"]) + for i, v in S.nodes.items() + if is_modified_terminal(v["expression"]) + } + + # Check if we have a mixed-dimensional integral + is_mixed_dim = False + for domain in ufl.domain.extract_domains(integrand): + if domain.topological_dimension() != cell.topological_dimension(): + is_mixed_dim = True + + mt_table_reference = build_optimized_tables( + quadrature_rule, + cell, + integral_type, + entity_type, + initial_terminals.values(), + ir["unique_tables"][integral_domain], + use_sum_factorization=p["sum_factorization"], + is_mixed_dim=is_mixed_dim, + rtol=p["table_rtol"], + atol=p["table_atol"], ) - # Insert in expr_ir for this quadrature loop - block_contributions[blockmap].append(blockdata) - - # Figure out which table names are referenced - active_table_names = set() - for i, v in F.nodes.items(): - tr = v.get("tr") - if tr is not None and F.nodes[i]["status"] != "inactive": - if tr.has_tensor_factorisation: - for t in tr.tensor_factors: - active_table_names.add(t.name) - else: - active_table_names.add(tr.name) - - # Figure out which table names are referenced in blocks - for blockmap, contributions in itertools.chain(block_contributions.items()): - for blockdata in contributions: - for mad in blockdata.ma_data: - if mad.tabledata.has_tensor_factorisation: - for t in mad.tabledata.tensor_factors: + # Fetch unique tables for this quadrature rule + table_types = {v.name: v.ttype for v in mt_table_reference.values()} + tables = {v.name: v.values for v in mt_table_reference.values()} + + S_targets = [i for i, v in S.nodes.items() if v.get("target", False)] + num_components = np.int32(np.prod(expression.ufl_shape)) + + if "zeros" in table_types.values(): + # If there are any 'zero' tables, replace symbolically and rebuild graph + for i, mt in initial_terminals.items(): + # Set modified terminals with zero tables to zero + tr = mt_table_reference.get(mt) + if tr is not None and tr.ttype == "zeros": + S.nodes[i]["expression"] = ufl.as_ufl(0.0) + + # Propagate expression changes using dependency list + for i, v in S.nodes.items(): + deps = [S.nodes[j]["expression"] for j in S.out_edges[i]] + if deps: + v["expression"] = v["expression"]._ufl_expr_reconstruct_(*deps) + + # Recreate expression with correct ufl_shape + expressions = [ + None, + ] * num_components + for target in S_targets: + for comp in S.nodes[target]["component"]: + assert expressions[comp] is None + expressions[comp] = S.nodes[target]["expression"] + expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) + + # Rebuild scalar list-based graph representation + S = build_scalar_graph(expression) + + # Output diagnostic graph as pdf + if visualise: + visualise_graph(S, "S.pdf") + + # Compute factorization of arguments + rank = len(argument_shape) + F = compute_argument_factorization(S, rank) + + # Get the 'target' nodes that are factors of arguments, and insert in dict + FV_targets = [i for i, v in F.nodes.items() if v.get("target", False)] + argument_factorization = {} + + for fi in FV_targets: + # Number of blocks using this factor must agree with number of components + # to which this factor contributes. I.e. there are more blocks iff there are more + # components + assert len(F.nodes[fi]["target"]) == len(F.nodes[fi]["component"]) + + k = 0 + for w in F.nodes[fi]["target"]: + comp = F.nodes[fi]["component"][k] + argument_factorization[w] = argument_factorization.get(w, []) + + # Store tuple of (factor index, component index) + argument_factorization[w].append((fi, comp)) + k += 1 + + # Get list of indices in F which are the arguments (should be at start) + argkeys = set() + for w in argument_factorization: + argkeys = argkeys | set(w) + argkeys = list(argkeys) + + # Build set of modified_terminals for each mt factorized vertex in F + # and attach tables, if appropriate + for i, v in F.nodes.items(): + expr = v["expression"] + if is_modified_terminal(expr): + mt = analyse_modified_terminal(expr) + F.nodes[i]["mt"] = mt + tr = mt_table_reference.get(mt) + if tr is not None: + F.nodes[i]["tr"] = tr + + # Attach 'status' to each node: 'inactive', 'piecewise' or 'varying' + analyse_dependencies(F, mt_table_reference) + + # Output diagnostic graph as pdf + if visualise: + visualise_graph(F, "F.pdf") + + # Loop over factorization terms + block_contributions = collections.defaultdict(list) + for ma_indices, fi_ci in sorted(argument_factorization.items()): + # Get a bunch of information about this term + assert rank == len(ma_indices) + trs = tuple(F.nodes[ai]["tr"] for ai in ma_indices) + + unames = tuple(tr.name for tr in trs) + ttypes = tuple(tr.ttype for tr in trs) + assert not any(tt == "zeros" for tt in ttypes) + + blockmap = [] + for tr in trs: + begin = tr.offset + num_dofs = tr.values.shape[3] + dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) + blockmap.append(dofmap) + + blockmap = tuple(blockmap) + block_is_uniform = all(tr.is_uniform for tr in trs) + + # Collect relevant restrictions to identify blocks correctly + # in interior facet integrals + block_restrictions = [] + for i, ai in enumerate(ma_indices): + if trs[i].is_uniform: + r = None + else: + r = F.nodes[ai]["mt"].restriction + + block_restrictions.append(r) + block_restrictions = tuple(block_restrictions) + + # Check if each *each* factor corresponding to this argument is piecewise + all_factors_piecewise = all( + F.nodes[ifi[0]]["status"] == "piecewise" for ifi in fi_ci + ) + block_is_permuted = False + for name in unames: + if tables[name].shape[0] > 1: + block_is_permuted = True + ma_data = [] + for i, ma in enumerate(ma_indices): + ma_data.append(ModifiedArgumentDataT(ma, trs[i])) + + block_is_transposed = False # FIXME: Handle transposes for these block types + block_unames = unames + blockdata = BlockDataT( + ttypes, + fi_ci, + all_factors_piecewise, + block_unames, + block_restrictions, + block_is_transposed, + block_is_uniform, + tuple(ma_data), + block_is_permuted, + ) + + # Insert in expr_ir for this quadrature loop + block_contributions[blockmap].append(blockdata) + + # Figure out which table names are referenced + active_table_names = set() + for i, v in F.nodes.items(): + tr = v.get("tr") + if tr is not None and F.nodes[i]["status"] != "inactive": + if tr.has_tensor_factorisation: + for t in tr.tensor_factors: active_table_names.add(t.name) else: - active_table_names.add(mad.tabledata.name) - - active_tables = {} - active_table_types = {} - - for name in active_table_names: - # Drop tables not referenced from modified terminals - if table_types[name] not in ("zeros", "ones"): - active_tables[name] = tables[name] - active_table_types[name] = table_types[name] - - # Add tables and types for this quadrature rule to global tables dict - ir["unique_tables"].update(active_tables) - ir["unique_table_types"].update(active_table_types) - # Build IR dict for the given expressions - # Store final ir for this num_points - ir["integrand"][quadrature_rule] = { - "factorization": F, - "modified_arguments": [F.nodes[i]["mt"] for i in argkeys], - "block_contributions": block_contributions, - } - - restrictions = [i.restriction for i in initial_terminals.values()] - ir["needs_facet_permutations"] = ( - "+" in restrictions and "-" in restrictions - ) or is_mixed_dim + active_table_names.add(tr.name) + + # Figure out which table names are referenced in blocks + for blockmap, contributions in itertools.chain(block_contributions.items()): + for blockdata in contributions: + for mad in blockdata.ma_data: + if mad.tabledata.has_tensor_factorisation: + for t in mad.tabledata.tensor_factors: + active_table_names.add(t.name) + else: + active_table_names.add(mad.tabledata.name) + + active_tables = {} + active_table_types = {} + + for name in active_table_names: + # Drop tables not referenced from modified terminals + if table_types[name] not in ("zeros", "ones"): + active_tables[name] = tables[name] + active_table_types[name] = table_types[name] + + # Add tables and types for this quadrature rule to global tables dict + ir["unique_tables"][integral_domain].update(active_tables) + ir["unique_table_types"][integral_domain].update(active_table_types) + # Build IR dict for the given expressions + # Store final ir for this num_points + ir["integrand"][(integral_domain, quadrature_rule)] = { + "factorization": F, + "modified_arguments": [F.nodes[i]["mt"] for i in argkeys], + "block_contributions": block_contributions, + } + + restrictions = [i.restriction for i in initial_terminals.values()] + if not ir["needs_facet_permutations"]: + ir["needs_facet_permutations"] = ( + "+" in restrictions and "-" in restrictions + ) or is_mixed_dim return ir diff --git a/ffcx/ir/representation.py b/ffcx/ir/representation.py index 7855c13a1..cd35c2f6b 100644 --- a/ffcx/ir/representation.py +++ b/ffcx/ir/representation.py @@ -38,6 +38,13 @@ logger = logging.getLogger("ffcx") +def basix_cell_from_string(string: str) -> basix.CellType: + """Convert a string to a Basix CellType.""" + if string == "vertex": + return basix.CellType.point + return getattr(basix.CellType, string) + + class FormIR(typing.NamedTuple): """Intermediate representation of a form.""" @@ -53,6 +60,7 @@ class FormIR(typing.NamedTuple): constant_names: list[str] finite_element_hashes: list[int] integral_names: dict[str, list[str]] + integral_domains: dict[str, list[basix.CellType]] subdomain_ids: dict[str, list[int]] @@ -73,9 +81,9 @@ class CommonExpressionIR(typing.NamedTuple): coefficient_numbering: dict[ufl.Coefficient, int] coefficient_offsets: dict[ufl.Coefficient, int] original_constant_offsets: dict[ufl.Constant, int] - unique_tables: dict[str, npt.NDArray[np.float64]] - unique_table_types: dict[str, str] - integrand: dict[QuadratureRule, dict] + unique_tables: dict[str, dict[basix.CellType, npt.NDArray[np.float64]]] + unique_table_types: dict[basix.CellType, dict[str, str]] + integrand: dict[tuple[basix.CellType, QuadratureRule], dict] name: str needs_facet_permutations: bool shape: list[int] @@ -146,6 +154,10 @@ def compute_ir( ] ir_integrals = list(itertools.chain(*irs)) + integral_domains = { + i.expression.name: set(j[0] for j in i.expression.integrand.keys()) for a in irs for i in a + } + ir_forms = [ _compute_form_ir( fd, @@ -153,6 +165,7 @@ def compute_ir( prefix, form_names, integral_names, + integral_domains, object_names, ) for (i, fd) in enumerate(analysis.form_data) @@ -203,9 +216,9 @@ def _compute_integral_ir( # Compute representation entity_type = _entity_types[itg_data.integral_type] - cell = itg_data.domain.ufl_cell() - cellname = cell.cellname() - tdim = cell.topological_dimension() + ufl_cell = itg_data.domain.ufl_cell() + cell_type = basix_cell_from_string(ufl_cell.cellname()) + tdim = ufl_cell.topological_dimension() assert all(tdim == itg.ufl_domain().topological_dimension() for itg in itg_data.integrals) expression_ir = { @@ -238,18 +251,21 @@ def _compute_integral_ir( expression_ir["tensor_shape"] = argument_dimensions integral_type = itg_data.integral_type - cell = itg_data.domain.ufl_cell() # Group integrands with the same quadrature rule - grouped_integrands: dict[QuadratureRule, list[ufl.core.expr.Expr]] = {} + grouped_integrands: dict[ + basix.CellType, dict[QuadratureRule, list[ufl.core.expr.Expr]] + ] = {} use_sum_factorization = options["sum_factorization"] and itg_data.integral_type == "cell" for integral in itg_data.integrals: md = integral.metadata() or {} scheme = md["quadrature_rule"] tensor_factors = None + rules = {} if scheme == "custom": points = md["quadrature_points"] weights = md["quadrature_weights"] + rules[cell_type] = (points, weights, None) elif scheme == "vertex": # The vertex scheme, i.e., averaging the function value in the # vertices and multiplying with the simplex volume, is only of @@ -261,50 +277,65 @@ def _compute_integral_ir( degree = md["quadrature_degree"] if integral_type != "cell": - facet_types = cell.facet_types() - assert len(facet_types) == 1 - cellname = facet_types[0].cellname() + facet_types = basix.cell.subentity_types(cell_type)[-2] + assert len(set(facet_types)) == 1 + cell_type = facet_types[0] if degree > 1: warnings.warn( "Explicitly selected vertex quadrature (degree 1), " f"but requested degree is {degree}." ) - points = basix.cell.geometry(getattr(basix.CellType, cellname)) - cell_volume = basix.cell.volume(getattr(basix.CellType, cellname)) + points = basix.cell.geometry(cell_type) + cell_volume = basix.cell.volume(cell_type) weights = np.full( points.shape[0], cell_volume / points.shape[0], dtype=points.dtype ) + rules[cell_type] = (points, weights, None) else: degree = md["quadrature_degree"] points, weights, tensor_factors = create_quadrature_points_and_weights( integral_type, - cell, + ufl_cell, degree, scheme, form_data.argument_elements, use_sum_factorization, ) - - points = np.asarray(points) - weights = np.asarray(weights) - rule = QuadratureRule(points, weights, tensor_factors) - - if rule not in grouped_integrands: - grouped_integrands[rule] = [] - grouped_integrands[rule].append(integral.integrand()) - sorted_integrals: dict[QuadratureRule, Integral] = {} - for rule, integrands in grouped_integrands.items(): - integrands_summed = sorted_expr_sum(integrands) - - integral_new = Integral( - integrands_summed, - itg_data.integral_type, - itg_data.domain, - itg_data.subdomain_id, - {}, - None, - ) - sorted_integrals[rule] = integral_new + rules = { + basix_cell_from_string(i): ( + points[i], + weights[i], + tensor_factors[i] if i in tensor_factors else None, + ) + for i in points + } + + for cell_type, (points, weights, tensor_factors) in rules.items(): + points = np.asarray(points) + weights = np.asarray(weights) + rule = QuadratureRule(points, weights, tensor_factors) + + if cell_type not in grouped_integrands: + grouped_integrands[cell_type] = {} + if rule not in grouped_integrands: + grouped_integrands[cell_type][rule] = [] + grouped_integrands[cell_type][rule].append(integral.integrand()) + sorted_integrals: dict[basix.CellType, dict[QuadratureRule, Integral]] = { + cell_type: {} for cell_type in grouped_integrands + } + for cell_type, integrands_by_cell in grouped_integrands.items(): + for rule, integrands in integrands_by_cell.items(): + integrands_summed = sorted_expr_sum(integrands) + + integral_new = Integral( + integrands_summed, + itg_data.integral_type, + itg_data.domain, + itg_data.subdomain_id, + {}, + None, + ) + sorted_integrals[cell_type][rule] = integral_new # TODO: See if coefficient_numbering can be removed # Build coefficient numbering for UFC interface here, to avoid @@ -337,8 +368,9 @@ def _compute_integral_ir( expression_ir["original_constant_offsets"] = original_constant_offsets # Create map from number of quadrature points -> integrand - integrand_map: dict[QuadratureRule, ufl.core.expr.Expr] = { - rule: integral.integrand() for rule, integral in sorted_integrals.items() + integrand_map: dict[basix.CellType, dict[QuadratureRule, ufl.core.expr.Expr]] = { + cell_type: {rule: integral.integrand() for rule, integral in cell_integrals.items()} + for cell_type, cell_integrals in sorted_integrals.items() } # Build more specific intermediate representation @@ -368,6 +400,7 @@ def _compute_form_ir( prefix, form_names, integral_names, + integral_domains, object_names, ) -> FormIR: """Compute intermediate representation of form.""" @@ -407,11 +440,10 @@ def _compute_form_ir( # Store names of integrals and subdomain_ids for this form, grouped # by integral types since form points to all integrals it contains, # it has to know their names for codegen phase - ir["integral_names"] = {} - ir["subdomain_ids"] = {} ufcx_integral_types = ("cell", "exterior_facet", "interior_facet") ir["subdomain_ids"] = {itg_type: [] for itg_type in ufcx_integral_types} ir["integral_names"] = {itg_type: [] for itg_type in ufcx_integral_types} + ir["integral_domains"] = {itg_type: [] for itg_type in ufcx_integral_types} for itg_index, itg_data in enumerate(form_data.integral_data): # UFL is using "otherwise" for default integrals (over whole mesh) # but FFCx needs integers, so otherwise = -1 @@ -422,7 +454,9 @@ def _compute_form_ir( raise ValueError("Integral subdomain IDs must be non-negative.") ir["subdomain_ids"][integral_type] += subdomain_ids for _ in range(len(subdomain_ids)): - ir["integral_names"][integral_type] += [integral_names[(form_id, itg_index)]] + iname = integral_names[(form_id, itg_index)] + ir["integral_names"][integral_type] += [iname] + ir["integral_domains"][integral_type] += [integral_domains[iname]] return FormIR(**ir) @@ -544,7 +578,7 @@ def _compute_expression_ir( weights = np.array([1.0] * points.shape[0]) rule = QuadratureRule(points, weights) - integrands = {rule: expr} + integrands = {"": {rule: expr}} if cell is None: assert ( diff --git a/ffcx/ir/representationutils.py b/ffcx/ir/representationutils.py index 5fdf80ba8..b0926ad7e 100644 --- a/ffcx/ir/representationutils.py +++ b/ffcx/ir/representationutils.py @@ -53,34 +53,41 @@ def create_quadrature_points_and_weights( integral_type, cell, degree, rule, elements, use_tensor_product=False ): """Create quadrature rule and return points and weights.""" - pts = None - wts = None - tensor_factors = None - + pts = {} + wts = {} + tensor_factors = {} if integral_type == "cell": - if cell.cellname() in ["quadrilateral", "hexahedron"] and use_tensor_product: - if cell.cellname() == "quadrilateral": - tensor_factors = [ + cell_name = cell.cellname() + if cell_name in ["quadrilateral", "hexahedron"] and use_tensor_product: + if cell_name == "quadrilateral": + tensor_factors[cell_name] = [ create_quadrature("interval", degree, rule, elements) for _ in range(2) ] - elif cell.cellname() == "hexahedron": - tensor_factors = [ + elif cell_name == "hexahedron": + tensor_factors[cell_name] = [ create_quadrature("interval", degree, rule, elements) for _ in range(3) ] - pts = np.array( - [tuple(i[0] for i in p) for p in itertools.product(*[f[0] for f in tensor_factors])] + pts[cell_name] = np.array( + [ + tuple(i[0] for i in p) + for p in itertools.product(*[f[0] for f in tensor_factors[cell_name]]) + ] + ) + wts[cell_name] = np.array( + [np.prod(p) for p in itertools.product(*[f[1] for f in tensor_factors[cell_name]])] ) - wts = np.array([np.prod(p) for p in itertools.product(*[f[1] for f in tensor_factors])]) else: - pts, wts = create_quadrature(cell.cellname(), degree, rule, elements) + pts[cell_name], wts[cell_name] = create_quadrature(cell_name, degree, rule, elements) elif integral_type in ufl.measure.facet_integral_types: - facet_types = cell.facet_types() - # Raise exception for cells with more than one facet type e.g. prisms - if len(facet_types) > 1: - raise Exception(f"Cell type {cell} not supported for integral type {integral_type}.") - pts, wts = create_quadrature(facet_types[0].cellname(), degree, rule, elements) + for ft in cell.facet_types(): + pts[ft.cellname()], wts[ft.cellname()] = create_quadrature( + ft.cellname(), + degree, + rule, + elements, + ) elif integral_type in ufl.measure.point_integral_types: - pts, wts = create_quadrature("vertex", degree, rule, elements) + pts["vertex"], wts["vertex"] = create_quadrature("vertex", degree, rule, elements) elif integral_type == "expression": pass else: @@ -115,8 +122,13 @@ def map_integral_points(points, integral_type, cell, entity): assert entity == 0 return np.asarray(points) elif entity_dim == tdim - 1: - assert points.shape[1] == tdim - 1 - return np.asarray(map_facet_points(points, entity, cell.cellname())) + if isinstance(points, dict): + for p in points.values(): + assert p.shape[1] == tdim - 1 + return np.asarray(map_facet_points(points, entity, cell.cellname())) + else: + assert points.shape[1] == tdim - 1 + return np.asarray(map_facet_points(points, entity, cell.cellname())) elif entity_dim == 0: return np.asarray([reference_cell_vertices(cell.cellname())[entity]]) else: diff --git a/test/test_jit_forms.py b/test/test_jit_forms.py index 34fd8859e..3df662a41 100644 --- a/test/test_jit_forms.py +++ b/test/test_jit_forms.py @@ -1248,3 +1248,140 @@ def tabulate_tensor(ele_type, V_cell_type, W_cell_type, coeffs): A_ref[:, [4, 5]] = A_ref[:, [5, 4]] assert np.allclose(A, A_ref) + + +@pytest.mark.parametrize("dtype", ["float64"]) +def test_ds_prism(compile_args, dtype): + element = basix.ufl.element("Lagrange", "prism", 1) + domain = ufl.Mesh(basix.ufl.element("Lagrange", "prism", 1, shape=(3,))) + space = ufl.FunctionSpace(domain, element) + u, v = ufl.TrialFunction(space), ufl.TestFunction(space) + + a = ufl.inner(u, v) * ufl.ds + forms = [a] + compiled_forms, module, code = ffcx.codegeneration.jit.compile_forms( + forms, options={"scalar_type": dtype}, cffi_extra_compile_args=compile_args + ) + + for f, compiled_f in zip(forms, compiled_forms): + assert compiled_f.rank == len(f.arguments()) + + ffi = module.ffi + form0 = compiled_forms[0] + + offsets = form0.form_integral_offsets + cell = module.lib.cell + exterior_facet = module.lib.exterior_facet + interior_facet = module.lib.interior_facet + assert offsets[cell + 1] - offsets[cell] == 0 + assert offsets[exterior_facet + 1] - offsets[exterior_facet] == 2 + assert offsets[interior_facet + 1] - offsets[interior_facet] == 0 + + integral_id0 = form0.form_integral_ids[offsets[exterior_facet]] + integral_id1 = form0.form_integral_ids[offsets[exterior_facet] + 1] + assert integral_id0 == integral_id1 == -1 + + integral0 = form0.form_integrals[offsets[exterior_facet]] + integral1 = form0.form_integrals[offsets[exterior_facet] + 1] + + if basix.CellType(integral0.domain) == basix.CellType.triangle: + assert basix.CellType(integral1.domain) == basix.CellType.quadrilateral + integral_tri = integral0 + integral_quad = integral1 + else: + assert basix.CellType(integral0.domain) == basix.CellType.quadrilateral + assert basix.CellType(integral1.domain) == basix.CellType.triangle + integral_tri = integral1 + integral_quad = integral0 + + w = np.array([], dtype=dtype) + c = np.array([], dtype=dtype) + entity_perm = np.array([0], dtype=np.uint8) + + # Test integral over triangle (facet 0) + A = np.zeros((6, 6), dtype=dtype) + entity_index = np.array([0], dtype=int) + + xdtype = dtype_to_scalar_dtype(dtype) + coords = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ], + dtype=xdtype, + ) + + c_type, c_xtype = dtype_to_c_type(dtype), dtype_to_c_type(xdtype) + + kernel = getattr(integral_tri, f"tabulate_tensor_{dtype}") + + kernel( + ffi.cast(f"{c_type} *", A.ctypes.data), + ffi.cast(f"{c_type} *", w.ctypes.data), + ffi.cast(f"{c_type} *", c.ctypes.data), + ffi.cast(f"{c_xtype} *", coords.ctypes.data), + ffi.cast("int *", entity_index.ctypes.data), + ffi.cast("uint8_t *", entity_perm.ctypes.data), + ) + + assert np.allclose( + A, + np.array( + [ + [1 / 12, 1 / 24, 1 / 24, 0, 0, 0], + [1 / 24, 1 / 12, 1 / 24, 0, 0, 0], + [1 / 24, 1 / 24, 1 / 12, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + ) + + # Test integral over quadrilateral (facet 1) + A = np.zeros((6, 6), dtype=dtype) + entity_index = np.array([1], dtype=np.int64) + + xdtype = dtype_to_scalar_dtype(dtype) + coords = np.array( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 1.0], + ], + dtype=xdtype, + ) + + c_type, c_xtype = dtype_to_c_type(dtype), dtype_to_c_type(xdtype) + + kernel = getattr(integral_quad, f"tabulate_tensor_{dtype}") + + kernel( + ffi.cast(f"{c_type} *", A.ctypes.data), + ffi.cast(f"{c_type} *", w.ctypes.data), + ffi.cast(f"{c_type} *", c.ctypes.data), + ffi.cast(f"{c_xtype} *", coords.ctypes.data), + ffi.cast("int *", entity_index.ctypes.data), + ffi.cast("uint8_t *", entity_perm.ctypes.data), + ) + + assert np.allclose( + A, + np.array( + [ + [1 / 9, 1 / 18, 0, 1 / 18, 1 / 36, 0], + [1 / 18, 1 / 9, 0, 1 / 36, 1 / 18, 0], + [0, 0, 0, 0, 0, 0], + [1 / 18, 1 / 36, 0, 1 / 9, 1 / 18, 0], + [1 / 36, 1 / 18, 0, 1 / 18, 1 / 9, 0], + [0, 0, 0, 0, 0, 0], + ] + ), + )