Skip to content

Commit 204a161

Browse files
committed
Start adding some typing and documentation
1 parent c78d7ea commit 204a161

File tree

6 files changed

+57
-27
lines changed

6 files changed

+57
-27
lines changed

ffcx/codegeneration/access.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import ufl
1414

1515
import ffcx.codegeneration.lnodes as L
16+
from ffcx.definitions import entity_types
1617
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1718
from ffcx.ir.elementtables import UniqueTableReferenceT
1819
from ffcx.ir.representationutils import QuadratureRule
@@ -23,7 +24,9 @@
2324
class FFCXBackendAccess:
2425
"""FFCx specific formatter class."""
2526

26-
def __init__(self, entity_type: str, integral_type: str, symbols, options):
27+
entity_type: entity_types
28+
29+
def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
2730
"""Initialise."""
2831
# Store ir and options
2932
self.entity_type = entity_type
@@ -399,7 +402,7 @@ def _pass(self, *args, **kwargs):
399402
def table_access(
400403
self,
401404
tabledata: UniqueTableReferenceT,
402-
entity_type: str,
405+
entity_type: entity_types,
403406
restriction: str,
404407
quadrature_index: L.MultiIndex,
405408
dof_index: L.MultiIndex,
@@ -408,7 +411,7 @@ def table_access(
408411
409412
Args:
410413
tabledata: Table data object
411-
entity_type: Entity type ("cell", "facet", "vertex")
414+
entity_type: Entity type
412415
restriction: Restriction ("+", "-")
413416
quadrature_index: Quadrature index
414417
dof_index: Dof index

ffcx/codegeneration/definitions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import ufl
1212

1313
import ffcx.codegeneration.lnodes as L
14+
from ffcx.definitions import entity_types
1415
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1516
from ffcx.ir.elementtables import UniqueTableReferenceT
1617
from ffcx.ir.representationutils import QuadratureRule
@@ -50,7 +51,9 @@ def create_dof_index(tabledata, dof_index_symbol):
5051
class FFCXBackendDefinitions:
5152
"""FFCx specific code definitions."""
5253

53-
def __init__(self, entity_type: str, integral_type: str, access, options):
54+
entity_type: entity_types
55+
56+
def __init__(self, entity_type: entity_types, integral_type: str, access, options):
5457
"""Initialise."""
5558
# Store ir and options
5659
self.integral_type = integral_type

ffcx/codegeneration/symbols.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ufl
1111

1212
import ffcx.codegeneration.lnodes as L
13+
from ffcx.definitions import entity_types
1314

1415
logger = logging.getLogger("ffcx")
1516

@@ -95,7 +96,7 @@ def __init__(self, coefficient_numbering, coefficient_offsets, original_constant
9596
# Table for chunk of custom quadrature points (physical coordinates).
9697
self.custom_points_table = L.Symbol("points_chunk", dtype=L.DataType.REAL)
9798

98-
def entity(self, entity_type, restriction):
99+
def entity(self, entity_type: entity_types, restriction):
99100
"""Entity index for lookup in element tables."""
100101
if entity_type == "cell":
101102
# Always 0 for cells (even with restriction)
@@ -175,7 +176,7 @@ def constant_index_access(self, constant, index):
175176
return c[offset + index]
176177

177178
# TODO: Remove this, use table_access instead
178-
def element_table(self, tabledata, entity_type, restriction):
179+
def element_table(self, tabledata, entity_type: entity_types, restriction):
179180
"""Get an element table."""
180181
entity = self.entity(entity_type, restriction)
181182

ffcx/definitions.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module for storing type definitions used in the FFCx code base."""
2+
3+
from typing import Literal
4+
5+
entity_types = Literal["cell", "facet", "vertex"]

ffcx/ir/elementtables.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
import numpy.typing as npt
1414
import ufl
1515

16+
from ffcx.definitions import entity_types
1617
from ffcx.element_interface import basix_index
18+
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1719
from ffcx.ir.representationutils import (
20+
QuadratureRule,
1821
create_quadrature_points_and_weights,
1922
integral_type_to_entity_dim,
2023
map_integral_points,
@@ -82,7 +85,7 @@ def get_ffcx_table_values(
8285
integral_type,
8386
element,
8487
avg,
85-
entity_type,
88+
entity_type: entity_types,
8689
derivative_counts,
8790
flat_component,
8891
codim,
@@ -175,7 +178,12 @@ def get_ffcx_table_values(
175178

176179

177180
def generate_psi_table_name(
178-
quadrature_rule, element_counter, averaged: str, entity_type, derivative_counts, flat_component
181+
quadrature_rule: QuadratureRule,
182+
element_counter,
183+
averaged: str,
184+
entity_type: entity_types,
185+
derivative_counts,
186+
flat_component,
179187
):
180188
"""Generate a name for the psi table.
181189
@@ -293,26 +301,33 @@ def permute_quadrature_quadrilateral(points, reflections=0, rotations=0):
293301

294302

295303
def build_optimized_tables(
296-
quadrature_rule,
297-
cell,
298-
integral_type,
299-
entity_type,
300-
modified_terminals,
301-
existing_tables,
302-
use_sum_factorization,
303-
is_mixed_dim,
304-
rtol=default_rtol,
305-
atol=default_atol,
306-
):
304+
quadrature_rule: QuadratureRule,
305+
cell: ufl.Cell,
306+
integral_type: str,
307+
entity_type: entity_types,
308+
modified_terminals: ModifiedTerminal,
309+
existing_tables: dict[str, np.ndarray],
310+
use_sum_factorization: bool,
311+
is_mixed_dim: bool,
312+
rtol: float = default_rtol,
313+
atol: float = default_atol,
314+
) -> dict[ModifiedTerminal, UniqueTableReferenceT]:
307315
"""Build the element tables needed for a list of modified terminals.
308316
309-
Input:
310-
entity_type - str
311-
modified_terminals - ordered sequence of unique modified terminals
312-
FIXME: Document
313-
314-
Output:
315-
mt_tables - dict(ModifiedTerminal: table data)
317+
Args:
318+
quadrature_rule: The quadrature rule relating to the tables.
319+
cell: The cell type of the domain the tables will be used with.
320+
entity_type: On what entity (vertex,edge,facet,cell) the tables are evaluated at.
321+
integral_type: The type of integral the tables are used for.
322+
modified_terminals: ordered sequence of unique modified terminals
323+
existing_tables: Register of tables that already exist and reused.
324+
use_sum_factorization: Use sum factorization for tensor product elements.
325+
is_mixed_dim: Mixed dimensionality of the domain.
326+
rtol: Relative tolerance for comparing tables.
327+
atol: Absolute tolerance for comparing tables.
328+
329+
Returns:
330+
mt_tables - Dictionary mapping each modified terminal to the a unique table reference.
316331
"""
317332
# Add to element tables
318333
analysis = {}

ffcx/ir/integral.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ufl.checks import is_cellwise_constant
1717
from ufl.classes import QuadratureWeight
1818

19+
from ffcx.definitions import entity_types
1920
from ffcx.ir.analysis.factorization import compute_argument_factorization
2021
from ffcx.ir.analysis.graph import build_scalar_graph
2122
from ffcx.ir.analysis.modified_terminals import analyse_modified_terminal, is_modified_terminal
@@ -46,7 +47,9 @@ class BlockDataT(typing.NamedTuple):
4647
is_permuted: bool # Do quad points on facets need to be permuted?
4748

4849

49-
def compute_integral_ir(cell, integral_type, entity_type, integrands, argument_shape, p, visualise):
50+
def compute_integral_ir(
51+
cell, integral_type: str, entity_type: entity_types, integrands, argument_shape, p, visualise
52+
):
5053
"""Compute intermediate representation for an integral."""
5154
# The intermediate representation dict we're building and returning
5255
# here

0 commit comments

Comments
 (0)