From 204a1616dfeea30652ba3b37a8726de2e0680de1 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 14:03:39 +0000 Subject: [PATCH 1/9] Start adding some typing and documentation --- ffcx/codegeneration/access.py | 9 +++-- ffcx/codegeneration/definitions.py | 5 ++- ffcx/codegeneration/symbols.py | 5 +-- ffcx/definitions.py | 5 +++ ffcx/ir/elementtables.py | 55 +++++++++++++++++++----------- ffcx/ir/integral.py | 5 ++- 6 files changed, 57 insertions(+), 27 deletions(-) create mode 100644 ffcx/definitions.py diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index c74560e3a..f9120dc34 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -13,6 +13,7 @@ import ufl import ffcx.codegeneration.lnodes as L +from ffcx.definitions import entity_types from ffcx.ir.analysis.modified_terminals import ModifiedTerminal from ffcx.ir.elementtables import UniqueTableReferenceT from ffcx.ir.representationutils import QuadratureRule @@ -23,7 +24,9 @@ class FFCXBackendAccess: """FFCx specific formatter class.""" - def __init__(self, entity_type: str, integral_type: str, symbols, options): + entity_type: entity_types + + def __init__(self, entity_type: entity_types, integral_type: str, symbols, options): """Initialise.""" # Store ir and options self.entity_type = entity_type @@ -399,7 +402,7 @@ def _pass(self, *args, **kwargs): def table_access( self, tabledata: UniqueTableReferenceT, - entity_type: str, + entity_type: entity_types, restriction: str, quadrature_index: L.MultiIndex, dof_index: L.MultiIndex, @@ -408,7 +411,7 @@ def table_access( Args: tabledata: Table data object - entity_type: Entity type ("cell", "facet", "vertex") + entity_type: Entity type restriction: Restriction ("+", "-") quadrature_index: Quadrature index dof_index: Dof index diff --git a/ffcx/codegeneration/definitions.py b/ffcx/codegeneration/definitions.py index 92ea16bbb..c08618d5b 100644 --- a/ffcx/codegeneration/definitions.py +++ b/ffcx/codegeneration/definitions.py @@ -11,6 +11,7 @@ import ufl import ffcx.codegeneration.lnodes as L +from ffcx.definitions import entity_types from ffcx.ir.analysis.modified_terminals import ModifiedTerminal from ffcx.ir.elementtables import UniqueTableReferenceT from ffcx.ir.representationutils import QuadratureRule @@ -50,7 +51,9 @@ def create_dof_index(tabledata, dof_index_symbol): class FFCXBackendDefinitions: """FFCx specific code definitions.""" - def __init__(self, entity_type: str, integral_type: str, access, options): + entity_type: entity_types + + def __init__(self, entity_type: entity_types, integral_type: str, access, options): """Initialise.""" # Store ir and options self.integral_type = integral_type diff --git a/ffcx/codegeneration/symbols.py b/ffcx/codegeneration/symbols.py index f28c911dc..a309ed662 100644 --- a/ffcx/codegeneration/symbols.py +++ b/ffcx/codegeneration/symbols.py @@ -10,6 +10,7 @@ import ufl import ffcx.codegeneration.lnodes as L +from ffcx.definitions import entity_types logger = logging.getLogger("ffcx") @@ -95,7 +96,7 @@ def __init__(self, coefficient_numbering, coefficient_offsets, original_constant # Table for chunk of custom quadrature points (physical coordinates). self.custom_points_table = L.Symbol("points_chunk", dtype=L.DataType.REAL) - def entity(self, entity_type, restriction): + def entity(self, entity_type: entity_types, restriction): """Entity index for lookup in element tables.""" if entity_type == "cell": # Always 0 for cells (even with restriction) @@ -175,7 +176,7 @@ def constant_index_access(self, constant, index): return c[offset + index] # TODO: Remove this, use table_access instead - def element_table(self, tabledata, entity_type, restriction): + def element_table(self, tabledata, entity_type: entity_types, restriction): """Get an element table.""" entity = self.entity(entity_type, restriction) diff --git a/ffcx/definitions.py b/ffcx/definitions.py new file mode 100644 index 000000000..11571e85c --- /dev/null +++ b/ffcx/definitions.py @@ -0,0 +1,5 @@ +"""Module for storing type definitions used in the FFCx code base.""" + +from typing import Literal + +entity_types = Literal["cell", "facet", "vertex"] diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 60cd13ea5..4170964d6 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -13,8 +13,11 @@ import numpy.typing as npt import ufl +from ffcx.definitions import entity_types from ffcx.element_interface import basix_index +from ffcx.ir.analysis.modified_terminals import ModifiedTerminal from ffcx.ir.representationutils import ( + QuadratureRule, create_quadrature_points_and_weights, integral_type_to_entity_dim, map_integral_points, @@ -82,7 +85,7 @@ def get_ffcx_table_values( integral_type, element, avg, - entity_type, + entity_type: entity_types, derivative_counts, flat_component, codim, @@ -175,7 +178,12 @@ def get_ffcx_table_values( def generate_psi_table_name( - quadrature_rule, element_counter, averaged: str, entity_type, derivative_counts, flat_component + quadrature_rule: QuadratureRule, + element_counter, + averaged: str, + entity_type: entity_types, + derivative_counts, + flat_component, ): """Generate a name for the psi table. @@ -293,26 +301,33 @@ def permute_quadrature_quadrilateral(points, reflections=0, rotations=0): def build_optimized_tables( - quadrature_rule, - cell, - integral_type, - entity_type, - modified_terminals, - existing_tables, - use_sum_factorization, - is_mixed_dim, - rtol=default_rtol, - atol=default_atol, -): + quadrature_rule: QuadratureRule, + cell: ufl.Cell, + integral_type: str, + entity_type: entity_types, + modified_terminals: ModifiedTerminal, + existing_tables: dict[str, np.ndarray], + use_sum_factorization: bool, + is_mixed_dim: bool, + rtol: float = default_rtol, + atol: float = default_atol, +) -> dict[ModifiedTerminal, UniqueTableReferenceT]: """Build the element tables needed for a list of modified terminals. - Input: - entity_type - str - modified_terminals - ordered sequence of unique modified terminals - FIXME: Document - - Output: - mt_tables - dict(ModifiedTerminal: table data) + Args: + quadrature_rule: The quadrature rule relating to the tables. + cell: The cell type of the domain the tables will be used with. + entity_type: On what entity (vertex,edge,facet,cell) the tables are evaluated at. + integral_type: The type of integral the tables are used for. + modified_terminals: ordered sequence of unique modified terminals + existing_tables: Register of tables that already exist and reused. + use_sum_factorization: Use sum factorization for tensor product elements. + is_mixed_dim: Mixed dimensionality of the domain. + rtol: Relative tolerance for comparing tables. + atol: Absolute tolerance for comparing tables. + + Returns: + mt_tables - Dictionary mapping each modified terminal to the a unique table reference. """ # Add to element tables analysis = {} diff --git a/ffcx/ir/integral.py b/ffcx/ir/integral.py index f399526c8..ab344c75a 100644 --- a/ffcx/ir/integral.py +++ b/ffcx/ir/integral.py @@ -16,6 +16,7 @@ from ufl.checks import is_cellwise_constant from ufl.classes import QuadratureWeight +from ffcx.definitions import entity_types from ffcx.ir.analysis.factorization import compute_argument_factorization from ffcx.ir.analysis.graph import build_scalar_graph from ffcx.ir.analysis.modified_terminals import analyse_modified_terminal, is_modified_terminal @@ -46,7 +47,9 @@ class BlockDataT(typing.NamedTuple): is_permuted: bool # Do quad points on facets need to be permuted? -def compute_integral_ir(cell, integral_type, entity_type, integrands, argument_shape, p, visualise): +def compute_integral_ir( + cell, integral_type: str, entity_type: entity_types, integrands, argument_shape, p, visualise +): """Compute intermediate representation for an integral.""" # The intermediate representation dict we're building and returning # here From 52443df081a2191f162416f46b259b5200c644bc Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 16:37:15 +0000 Subject: [PATCH 2/9] Type hints + various improvements and one bugfix --- ffcx/codegeneration/access.py | 3 +++ ffcx/codegeneration/definitions.py | 2 ++ ffcx/ir/elementtables.py | 43 +++++++++++++++++------------- ffcx/ir/integral.py | 43 +++++++++++++++++------------- ffcx/ir/representation.py | 3 ++- 5 files changed, 56 insertions(+), 38 deletions(-) diff --git a/ffcx/codegeneration/access.py b/ffcx/codegeneration/access.py index f9120dc34..7e39c1cf4 100644 --- a/ffcx/codegeneration/access.py +++ b/ffcx/codegeneration/access.py @@ -91,6 +91,8 @@ def coefficient( num_dofs = tabledata.values.shape[3] begin = tabledata.offset + assert begin is not None + assert tabledata.block_size is not None end = begin + tabledata.block_size * (num_dofs - 1) + 1 if ttype == "ones" and (end - begin) == 1: @@ -442,6 +444,7 @@ def table_access( ], symbols else: FE = [] + assert tabledata.tensor_factors is not None for i in range(dof_index.dim): factor = tabledata.tensor_factors[i] iq_i = quadrature_index.local_index(i) diff --git a/ffcx/codegeneration/definitions.py b/ffcx/codegeneration/definitions.py index c08618d5b..8e0bbf0bb 100644 --- a/ffcx/codegeneration/definitions.py +++ b/ffcx/codegeneration/definitions.py @@ -133,6 +133,8 @@ def coefficient( num_dofs = tabledata.values.shape[3] bs = tabledata.block_size begin = tabledata.offset + assert bs is not None + assert begin is not None end = begin + bs * (num_dofs - 1) + 1 if ttype == "zeros": diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 4170964d6..07b9a4af6 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -47,15 +47,15 @@ class UniqueTableReferenceT(typing.NamedTuple): name: str values: npt.NDArray[np.float64] - offset: int - block_size: int - ttype: str + offset: int | None + block_size: int | None + ttype: str | None is_piecewise: bool is_uniform: bool is_permuted: bool has_tensor_factorisation: bool - tensor_factors: list[typing.Any] - tensor_permutation: np.typing.NDArray[np.int32] + tensor_factors: list[typing.Any] | None + tensor_permutation: np.typing.NDArray[np.int32] | None def equal_tables(a, b, rtol=default_rtol, atol=default_atol): @@ -127,7 +127,7 @@ def get_ffcx_table_values( elif avg == "facet": integral_type = "exterior_facet" - if isinstance(element, basix.ufl.QuadratureElement): + if isinstance(element, basix.ufl._QuadratureElement): points = element._points weights = element._weights else: @@ -305,13 +305,13 @@ def build_optimized_tables( cell: ufl.Cell, integral_type: str, entity_type: entity_types, - modified_terminals: ModifiedTerminal, + modified_terminals: typing.Iterable[ModifiedTerminal], existing_tables: dict[str, np.ndarray], use_sum_factorization: bool, is_mixed_dim: bool, rtol: float = default_rtol, atol: float = default_atol, -) -> dict[ModifiedTerminal, UniqueTableReferenceT]: +) -> dict[ModifiedTerminal | str, UniqueTableReferenceT]: """Build the element tables needed for a list of modified terminals. Args: @@ -327,7 +327,12 @@ def build_optimized_tables( atol: Absolute tolerance for comparing tables. Returns: - mt_tables - Dictionary mapping each modified terminal to the a unique table reference. + mt_tables: + Dictionary mapping each modified terminal to the a unique table reference. + If ``use_sum_factorization`` is turned on, the map also contains the map + from the unique table reference for the tensor product factorization + to the name of the modified terminal. + """ # Add to element tables analysis = {} @@ -343,11 +348,11 @@ def build_optimized_tables( set(ufl.algorithms.analysis.extract_sub_elements(all_elements)) ) element_numbers = {element: i for i, element in enumerate(unique_elements)} - mt_tables = {} + mt_tables: dict[ModifiedTerminal | str, UniqueTableReferenceT] = {} _existing_tables = existing_tables.copy() - all_tensor_factors = [] + all_tensor_factors: list[UniqueTableReferenceT] = [] tensor_n = 0 for mt in modified_terminals: @@ -483,15 +488,15 @@ def build_optimized_tables( tbl = tbl[:1, :, :, :] # Check for existing identical table - new_table = True + is_new_table = True for table_name in _existing_tables: if equal_tables(tbl, _existing_tables[table_name]): name = table_name tbl = _existing_tables[name] - new_table = False + is_new_table = False break - if new_table: + if is_new_table: _existing_tables[name] = tbl cell_offset = 0 @@ -499,7 +504,7 @@ def build_optimized_tables( if use_sum_factorization and (not quadrature_rule.has_tensor_factors): raise RuntimeError("Sum factorization not available for this quadrature rule.") - tensor_factors = None + tensor_factors: list[UniqueTableReferenceT] | None = None tensor_perm = None if ( use_sum_factorization @@ -515,9 +520,11 @@ def build_optimized_tables( d = local_derivatives[i] sub_tbl = j.tabulate(d, pts)[d] sub_tbl = sub_tbl.reshape(1, 1, sub_tbl.shape[0], sub_tbl.shape[1]) - for i in all_tensor_factors: - if i.values.shape == sub_tbl.shape and np.allclose(i.values, sub_tbl): - tensor_factors.append(i) + for tensor_factor in all_tensor_factors: + if tensor_factor.values.shape == sub_tbl.shape and np.allclose( + tensor_factor.values, sub_tbl + ): + tensor_factors.append(tensor_factor) break else: ut = UniqueTableReferenceT( diff --git a/ffcx/ir/integral.py b/ffcx/ir/integral.py index ab344c75a..3b8ea0247 100644 --- a/ffcx/ir/integral.py +++ b/ffcx/ir/integral.py @@ -19,7 +19,11 @@ from ffcx.definitions import entity_types from ffcx.ir.analysis.factorization import compute_argument_factorization from ffcx.ir.analysis.graph import build_scalar_graph -from ffcx.ir.analysis.modified_terminals import analyse_modified_terminal, is_modified_terminal +from ffcx.ir.analysis.modified_terminals import ( + ModifiedTerminal, + analyse_modified_terminal, + is_modified_terminal, +) from ffcx.ir.analysis.visualise import visualise_graph from ffcx.ir.elementtables import UniqueTableReferenceT, build_optimized_tables @@ -53,7 +57,7 @@ def compute_integral_ir( """Compute intermediate representation for an integral.""" # The intermediate representation dict we're building and returning # here - ir = {} + ir: dict[str, typing.Any] = {} # Shared unique tables for all quadrature loops ir["unique_tables"] = {} @@ -79,7 +83,7 @@ def compute_integral_ir( # efficiently before argument factorization. We can build # terminal_data again after factorization if that's necessary. - initial_terminals = { + initial_terminals: dict[int, ModifiedTerminal] = { i: analyse_modified_terminal(v["expression"]) for i, v in S.nodes.items() if is_modified_terminal(v["expression"]) @@ -133,7 +137,7 @@ def compute_integral_ir( for comp in S.nodes[target]["component"]: assert expressions[comp] is None expressions[comp] = S.nodes[target]["expression"] - expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) + expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) # type: ignore # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) @@ -148,14 +152,12 @@ def compute_integral_ir( # Get the 'target' nodes that are factors of arguments, and insert in dict FV_targets = [i for i, v in F.nodes.items() if v.get("target", False)] - argument_factorization = {} - + argument_factorization: dict[tuple[int, ...], list[tuple[int, int]]] = {} for fi in FV_targets: # Number of blocks using this factor must agree with number of components # to which this factor contributes. I.e. there are more blocks iff there are more # components assert len(F.nodes[fi]["target"]) == len(F.nodes[fi]["component"]) - k = 0 for w in F.nodes[fi]["target"]: comp = F.nodes[fi]["component"][k] @@ -166,10 +168,10 @@ def compute_integral_ir( k += 1 # Get list of indices in F which are the arguments (should be at start) - argkeys = set() + _argkeys: set[int] = set() for w in argument_factorization: - argkeys = argkeys | set(w) - argkeys = list(argkeys) + _argkeys = _argkeys | set(w) + argkeys = list(_argkeys) # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate @@ -200,27 +202,28 @@ def compute_integral_ir( ttypes = tuple(tr.ttype for tr in trs) assert not any(tt == "zeros" for tt in ttypes) - blockmap = [] + _blockmap: list[tuple[int, ...]] = [] for tr in trs: + assert tr is not None begin = tr.offset + assert begin is not None num_dofs = tr.values.shape[3] + assert tr.block_size is not None dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) - blockmap.append(dofmap) + _blockmap.append(dofmap) - blockmap = tuple(blockmap) + blockmap = tuple(_blockmap) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals - block_restrictions = [] + _block_restrictions: list[str] = [] for i, ai in enumerate(ma_indices): - if trs[i].is_uniform: - r = None - else: + if not trs[i].is_uniform: r = F.nodes[ai]["mt"].restriction + _block_restrictions.append(r) - block_restrictions.append(r) - block_restrictions = tuple(block_restrictions) + block_restrictions: tuple[str, ...] = tuple(_block_restrictions) # Check if each *each* factor corresponding to this argument is piecewise all_factors_piecewise = all(F.nodes[ifi[0]]["status"] == "piecewise" for ifi in fi_ci) @@ -255,6 +258,7 @@ def compute_integral_ir( tr = v.get("tr") if tr is not None and F.nodes[i]["status"] != "inactive": if tr.has_tensor_factorisation: + assert tr.tensor_factors is not None for t in tr.tensor_factors: active_table_names.add(t.name) else: @@ -265,6 +269,7 @@ def compute_integral_ir( for blockdata in contributions: for mad in blockdata.ma_data: if mad.tabledata.has_tensor_factorisation: + assert mad.tabledata.tensor_factors is not None for t in mad.tabledata.tensor_factors: active_table_names.add(t.name) else: diff --git a/ffcx/ir/representation.py b/ffcx/ir/representation.py index 7855c13a1..8731d8db3 100644 --- a/ffcx/ir/representation.py +++ b/ffcx/ir/representation.py @@ -32,6 +32,7 @@ from ffcx import naming from ffcx.analysis import UFLData +from ffcx.definitions import entity_types from ffcx.ir.integral import compute_integral_ir from ffcx.ir.representationutils import QuadratureRule, create_quadrature_points_and_weights @@ -68,7 +69,7 @@ class CommonExpressionIR(typing.NamedTuple): """Common-ground for IntegralIR and ExpressionIR.""" integral_type: str - entity_type: str + entity_type: entity_types tensor_shape: list[int] coefficient_numbering: dict[ufl.Coefficient, int] coefficient_offsets: dict[ufl.Coefficient, int] From a821fa924768c6210cf193961cfb818d04555f65 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 16:52:04 +0000 Subject: [PATCH 3/9] Revert | to typing unions --- ffcx/analysis.py | 12 +++++++----- ffcx/codegeneration/C/expressions.py | 3 ++- ffcx/codegeneration/C/form.py | 3 ++- ffcx/codegeneration/backend.py | 4 +++- ffcx/codegeneration/codegeneration.py | 2 +- ffcx/codegeneration/jit.py | 5 +++-- ffcx/ir/elementtables.py | 12 ++++++------ ffcx/ir/representation.py | 2 +- 8 files changed, 25 insertions(+), 18 deletions(-) diff --git a/ffcx/analysis.py b/ffcx/analysis.py index ca7fa11eb..da4dfc4ab 100644 --- a/ffcx/analysis.py +++ b/ffcx/analysis.py @@ -41,10 +41,12 @@ class UFLData(typing.NamedTuple): def analyze_ufl_objects( ufl_objects: list[ - ufl.form.Form - | ufl.AbstractFiniteElement - | ufl.Mesh - | tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]] + typing.Union[ + ufl.form.Form, + ufl.AbstractFiniteElement, + ufl.Mesh, + tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]], + ] ], scalar_type: npt.DTypeLike, ) -> UFLData: @@ -246,7 +248,7 @@ def _analyze_form( def _has_custom_integrals( - o: ufl.integral.Integral | ufl.classes.Form | list | tuple, + o: typing.Union[ufl.integral.Integral, ufl.classes.Form, list, tuple], ) -> bool: """Check for custom integrals.""" if isinstance(o, ufl.integral.Integral): diff --git a/ffcx/codegeneration/C/expressions.py b/ffcx/codegeneration/C/expressions.py index 853f46240..d1da34646 100644 --- a/ffcx/codegeneration/C/expressions.py +++ b/ffcx/codegeneration/C/expressions.py @@ -8,6 +8,7 @@ from __future__ import annotations import logging +import typing import numpy as np @@ -38,7 +39,7 @@ def generator(ir: ExpressionIR, options): backend = FFCXBackend(ir, options) eg = ExpressionGenerator(ir, backend) - d: dict[str, str | int] = {} + d: dict[str, typing.Union[str, int]] = {} d["name_from_uflfile"] = ir.name_from_uflfile d["factory_name"] = factory_name parts = eg.generate() diff --git a/ffcx/codegeneration/C/form.py b/ffcx/codegeneration/C/form.py index eab91c59d..71748fc6e 100644 --- a/ffcx/codegeneration/C/form.py +++ b/ffcx/codegeneration/C/form.py @@ -13,6 +13,7 @@ from __future__ import annotations import logging +import typing import numpy as np @@ -28,7 +29,7 @@ def generator(ir: FormIR, options): logger.info(f"--- rank: {ir.rank}") logger.info(f"--- name: {ir.name}") - d: dict[str, int | str] = {} + d: dict[str, typing.Union[int, str]] = {} d["factory_name"] = ir.name d["name_from_uflfile"] = ir.name_from_uflfile d["signature"] = f'"{ir.signature}"' diff --git a/ffcx/codegeneration/backend.py b/ffcx/codegeneration/backend.py index ef6963987..2adcc2928 100644 --- a/ffcx/codegeneration/backend.py +++ b/ffcx/codegeneration/backend.py @@ -7,6 +7,8 @@ from __future__ import annotations +import typing + from ffcx.codegeneration.access import FFCXBackendAccess from ffcx.codegeneration.definitions import FFCXBackendDefinitions from ffcx.codegeneration.symbols import FFCXBackendSymbols @@ -16,7 +18,7 @@ class FFCXBackend: """Class collecting all aspects of the FFCx backend.""" - def __init__(self, ir: IntegralIR | ExpressionIR, options): + def __init__(self, ir: typing.Union[IntegralIR, ExpressionIR], options): """Initialise.""" coefficient_numbering = ir.expression.coefficient_numbering coefficient_offsets = ir.expression.coefficient_offsets diff --git a/ffcx/codegeneration/codegeneration.py b/ffcx/codegeneration/codegeneration.py index 9564c3837..80b5fb42d 100644 --- a/ffcx/codegeneration/codegeneration.py +++ b/ffcx/codegeneration/codegeneration.py @@ -39,7 +39,7 @@ class CodeBlocks(typing.NamedTuple): file_post: list[tuple[str, str]] -def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) -> CodeBlocks: +def generate_code(ir: DataIR, options: typing.Union[str, int, float, npt.DTypeLike]) -> CodeBlocks: """Generate code blocks from intermediate representation.""" logger.info(79 * "*") logger.info("Compiler stage 3: Generating code") diff --git a/ffcx/codegeneration/jit.py b/ffcx/codegeneration/jit.py index 6eb5dbb8f..be17a9ce3 100644 --- a/ffcx/codegeneration/jit.py +++ b/ffcx/codegeneration/jit.py @@ -16,6 +16,7 @@ import sysconfig import tempfile import time +import typing from contextlib import redirect_stdout from pathlib import Path @@ -152,7 +153,7 @@ def _compilation_signature(cffi_extra_compile_args, cffi_debug): def compile_forms( forms: list[ufl.Form], options: dict = {}, - cache_dir: Path | None = None, + cache_dir: typing.Optional[Path] = None, timeout: int = 10, cffi_extra_compile_args: list[str] = [], cffi_verbose: bool = False, @@ -231,7 +232,7 @@ def compile_forms( def compile_expressions( expressions: list[tuple[ufl.Expr, npt.NDArray[np.floating]]], options: dict = {}, - cache_dir: Path | None = None, + cache_dir: typing.Optional[Path] = None, timeout: int = 10, cffi_extra_compile_args: list[str] = [], cffi_verbose: bool = False, diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 07b9a4af6..2a8bf66ab 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -47,15 +47,15 @@ class UniqueTableReferenceT(typing.NamedTuple): name: str values: npt.NDArray[np.float64] - offset: int | None - block_size: int | None - ttype: str | None + offset: typing.Optional[int] + block_size: typing.Optional[int] + ttype: typing.Optional[str] is_piecewise: bool is_uniform: bool is_permuted: bool has_tensor_factorisation: bool - tensor_factors: list[typing.Any] | None - tensor_permutation: np.typing.NDArray[np.int32] | None + tensor_factors: typing.Optional[list[typing.Any]] + tensor_permutation: typing.Optional[np.typing.NDArray[np.int32]] def equal_tables(a, b, rtol=default_rtol, atol=default_atol): @@ -348,7 +348,7 @@ def build_optimized_tables( set(ufl.algorithms.analysis.extract_sub_elements(all_elements)) ) element_numbers = {element: i for i, element in enumerate(unique_elements)} - mt_tables: dict[ModifiedTerminal | str, UniqueTableReferenceT] = {} + mt_tables: dict[typing.Union[ModifiedTerminal, str], UniqueTableReferenceT] = {} _existing_tables = existing_tables.copy() diff --git a/ffcx/ir/representation.py b/ffcx/ir/representation.py index 8731d8db3..b29b35884 100644 --- a/ffcx/ir/representation.py +++ b/ffcx/ir/representation.py @@ -113,7 +113,7 @@ def compute_ir( analysis: UFLData, object_names: dict[int, str], prefix: str, - options: dict[str, npt.DTypeLike | int | float], + options: dict[str, typing.Union[npt.DTypeLike, int, float]], visualise: bool, ) -> DataIR: """Compute intermediate representation.""" From 228798b01640557b3e6e566244fe2d8bc793ab06 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 16:55:36 +0000 Subject: [PATCH 4/9] Fix typing --- ffcx/codegeneration/codegeneration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ffcx/codegeneration/codegeneration.py b/ffcx/codegeneration/codegeneration.py index 80b5fb42d..40918e8f1 100644 --- a/ffcx/codegeneration/codegeneration.py +++ b/ffcx/codegeneration/codegeneration.py @@ -39,7 +39,9 @@ class CodeBlocks(typing.NamedTuple): file_post: list[tuple[str, str]] -def generate_code(ir: DataIR, options: typing.Union[str, int, float, npt.DTypeLike]) -> CodeBlocks: +def generate_code( + ir: DataIR, options: dict[str, typing.Union[int, float, npt.DTypeLike]] +) -> CodeBlocks: """Generate code blocks from intermediate representation.""" logger.info(79 * "*") logger.info("Compiler stage 3: Generating code") From 86927195e72758c0c385c548256d846526b6ab86 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 17:08:00 +0000 Subject: [PATCH 5/9] Add flag --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 82061ab0d..2724c5bf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,5 +107,9 @@ ignore = ["RUF005", "RUF012", "RUF015"] [tool.ruff.lint.per-file-ignores] "test/*" = ["D"] +[tool.ruff.lint.pyupgrade] +# Remove once target version hits 3.10 +keep-runtime-typing = true + [tool.ruff.lint.pydocstyle] convention = "google" From 8bc9bfb7a32405fca57f22131c5d9b5e33b227be Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 17:14:24 +0000 Subject: [PATCH 6/9] More legacy typesetting --- ffcx/ir/elementtables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 2a8bf66ab..4d0f5d426 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -311,7 +311,7 @@ def build_optimized_tables( is_mixed_dim: bool, rtol: float = default_rtol, atol: float = default_atol, -) -> dict[ModifiedTerminal | str, UniqueTableReferenceT]: +) -> dict[typing.Union[ModifiedTerminal ,str], UniqueTableReferenceT]: """Build the element tables needed for a list of modified terminals. Args: @@ -504,7 +504,7 @@ def build_optimized_tables( if use_sum_factorization and (not quadrature_rule.has_tensor_factors): raise RuntimeError("Sum factorization not available for this quadrature rule.") - tensor_factors: list[UniqueTableReferenceT] | None = None + tensor_factors: typing.Optional[list[UniqueTableReferenceT]] = None tensor_perm = None if ( use_sum_factorization From a8f672d6f86e91e8df88a2d7bf2798a4f3877941 Mon Sep 17 00:00:00 2001 From: jorgensd Date: Sun, 23 Feb 2025 17:17:13 +0000 Subject: [PATCH 7/9] Ruff format --- ffcx/ir/elementtables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 4d0f5d426..980855165 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -311,7 +311,7 @@ def build_optimized_tables( is_mixed_dim: bool, rtol: float = default_rtol, atol: float = default_atol, -) -> dict[typing.Union[ModifiedTerminal ,str], UniqueTableReferenceT]: +) -> dict[typing.Union[ModifiedTerminal, str], UniqueTableReferenceT]: """Build the element tables needed for a list of modified terminals. Args: From 606134869041f30075ccddceca3e43e87ef8575f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20S=2E=20Dokken?= Date: Tue, 18 Mar 2025 09:25:18 +0000 Subject: [PATCH 8/9] Fixes --- ffcx/ir/analysis/graph.py | 5 +++-- ffcx/ir/integral.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/ffcx/ir/analysis/graph.py b/ffcx/ir/analysis/graph.py index 0f52bf63e..8bf8c6606 100644 --- a/ffcx/ir/analysis/graph.py +++ b/ffcx/ir/analysis/graph.py @@ -6,6 +6,7 @@ """Linearized data structure for the computational graph.""" import logging +import typing import numpy as np import ufl @@ -73,7 +74,7 @@ def build_graph_vertices(expressions, skip_terminal_modifiers=False): return G -def build_scalar_graph(expression): +def build_scalar_graph(expression) -> ExpressionGraph: """Build list representation of expression graph covering the given expressions.""" # Populate with vertices G = build_graph_vertices([expression], skip_terminal_modifiers=False) @@ -86,7 +87,7 @@ def build_scalar_graph(expression): G = build_graph_vertices(scalar_expressions, skip_terminal_modifiers=True) # Compute graph edges - V_deps = [] + V_deps: list[typing.Union[tuple[()], list[int]]] = [] for i, v in G.nodes.items(): expr = v["expression"] if expr._ufl_is_terminal_ or expr._ufl_is_terminal_modifier_: diff --git a/ffcx/ir/integral.py b/ffcx/ir/integral.py index 0056eb480..ddda4179e 100644 --- a/ffcx/ir/integral.py +++ b/ffcx/ir/integral.py @@ -140,7 +140,7 @@ def compute_integral_ir( for comp in S.nodes[target]["component"]: assert expressions[comp] is None expressions[comp] = S.nodes[target]["expression"] - expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) + expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) # type: ignore # Rebuild scalar list-based graph representation S = build_scalar_graph(expression) @@ -173,10 +173,10 @@ def compute_integral_ir( k += 1 # Get list of indices in F which are the arguments (should be at start) - argkeys: set[int] = set() + _argkeys: set[int] = set() for w in argument_factorization: - argkeys = argkeys | set(w) - argkeys = list(argkeys) + _argkeys = _argkeys | set(w) + argkeys = list(_argkeys) # Build set of modified_terminals for each mt factorized vertex in F # and attach tables, if appropriate @@ -216,12 +216,13 @@ def compute_integral_ir( assert tr.block_size is not None dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs)) _blockmap.append(dofmap) + blockmap = tuple(_blockmap) block_is_uniform = all(tr.is_uniform for tr in trs) # Collect relevant restrictions to identify blocks correctly # in interior facet integrals - + # Collect relevant restrictions to identify blocks correctly # in interior facet integrals _block_restrictions: list[str] = [] @@ -267,8 +268,9 @@ def compute_integral_ir( tr = v.get("tr") if tr is not None and F.nodes[i]["status"] != "inactive": if tr.has_tensor_factorisation: + assert tr.tensor_factors is not None for t in tr.tensor_factors: - active_table_names.add(t.name) + active_table_names.add(t.name) else: active_table_names.add(tr.name) @@ -277,6 +279,7 @@ def compute_integral_ir( for blockdata in contributions: for mad in blockdata.ma_data: if mad.tabledata.has_tensor_factorisation: + assert mad.tabledata.tensor_factors is not None for t in mad.tabledata.tensor_factors: active_table_names.add(t.name) else: From acf72f521ce723092d3e68dbd62c358c888ae3bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Wed, 19 Mar 2025 09:27:45 +0100 Subject: [PATCH 9/9] Apply suggestions from code review Co-authored-by: Matthew Scroggs --- ffcx/ir/elementtables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ffcx/ir/elementtables.py b/ffcx/ir/elementtables.py index 980855165..89a18074f 100644 --- a/ffcx/ir/elementtables.py +++ b/ffcx/ir/elementtables.py @@ -317,9 +317,9 @@ def build_optimized_tables( Args: quadrature_rule: The quadrature rule relating to the tables. cell: The cell type of the domain the tables will be used with. - entity_type: On what entity (vertex,edge,facet,cell) the tables are evaluated at. + entity_type: The entity type (vertex,edge,facet,cell) that the tables are evaluated for. integral_type: The type of integral the tables are used for. - modified_terminals: ordered sequence of unique modified terminals + modified_terminals: Ordered sequence of unique modified terminals existing_tables: Register of tables that already exist and reused. use_sum_factorization: Use sum factorization for tensor product elements. is_mixed_dim: Mixed dimensionality of the domain.