Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start adding some typing and documentation #751

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 7 additions & 5 deletions ffcx/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ class UFLData(typing.NamedTuple):

def analyze_ufl_objects(
ufl_objects: list[
ufl.form.Form
| ufl.AbstractFiniteElement
| ufl.Mesh
| tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]]
typing.Union[
ufl.form.Form,
ufl.AbstractFiniteElement,
ufl.Mesh,
tuple[ufl.core.expr.Expr, npt.NDArray[np.floating]],
]
],
scalar_type: npt.DTypeLike,
) -> UFLData:
Expand Down Expand Up @@ -246,7 +248,7 @@ def _analyze_form(


def _has_custom_integrals(
o: ufl.integral.Integral | ufl.classes.Form | list | tuple,
o: typing.Union[ufl.integral.Integral, ufl.classes.Form, list, tuple],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this not causing an error previously? Or is 3.9 not running on CI so we just weren't seeing the error?

Is it worth considering bumping the minimun version to 3.10 six months early so we can use |?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on what functions are typed. Mypy ignores that are typed (if they are within an untyped function signature).

) -> bool:
"""Check for custom integrals."""
if isinstance(o, ufl.integral.Integral):
Expand Down
3 changes: 2 additions & 1 deletion ffcx/codegeneration/C/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import logging
import typing

import numpy as np

Expand Down Expand Up @@ -38,7 +39,7 @@ def generator(ir: ExpressionIR, options):
backend = FFCXBackend(ir, options)
eg = ExpressionGenerator(ir, backend)

d: dict[str, str | int] = {}
d: dict[str, typing.Union[str, int]] = {}
d["name_from_uflfile"] = ir.name_from_uflfile
d["factory_name"] = factory_name
parts = eg.generate()
Expand Down
3 changes: 2 additions & 1 deletion ffcx/codegeneration/C/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import annotations

import logging
import typing

import numpy as np

Expand All @@ -28,7 +29,7 @@ def generator(ir: FormIR, options):
logger.info(f"--- rank: {ir.rank}")
logger.info(f"--- name: {ir.name}")

d: dict[str, int | str] = {}
d: dict[str, typing.Union[int, str]] = {}
d["factory_name"] = ir.name
d["name_from_uflfile"] = ir.name_from_uflfile
d["signature"] = f'"{ir.signature}"'
Expand Down
12 changes: 9 additions & 3 deletions ffcx/codegeneration/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
from ffcx.ir.elementtables import UniqueTableReferenceT
from ffcx.ir.representationutils import QuadratureRule
Expand All @@ -23,7 +24,9 @@
class FFCXBackendAccess:
"""FFCx specific formatter class."""

def __init__(self, entity_type: str, integral_type: str, symbols, options):
entity_type: entity_types

def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
"""Initialise."""
# Store ir and options
self.entity_type = entity_type
Expand Down Expand Up @@ -88,6 +91,8 @@ def coefficient(

num_dofs = tabledata.values.shape[3]
begin = tabledata.offset
assert begin is not None
assert tabledata.block_size is not None
end = begin + tabledata.block_size * (num_dofs - 1) + 1

if ttype == "ones" and (end - begin) == 1:
Expand Down Expand Up @@ -406,7 +411,7 @@ def _pass(self, *args, **kwargs):
def table_access(
self,
tabledata: UniqueTableReferenceT,
entity_type: str,
entity_type: entity_types,
restriction: str,
quadrature_index: L.MultiIndex,
dof_index: L.MultiIndex,
Expand All @@ -415,7 +420,7 @@ def table_access(

Args:
tabledata: Table data object
entity_type: Entity type ("cell", "facet", "vertex")
entity_type: Entity type
restriction: Restriction ("+", "-")
quadrature_index: Quadrature index
dof_index: Dof index
Expand Down Expand Up @@ -446,6 +451,7 @@ def table_access(
], symbols
else:
FE = []
assert tabledata.tensor_factors is not None
for i in range(dof_index.dim):
factor = tabledata.tensor_factors[i]
iq_i = quadrature_index.local_index(i)
Expand Down
4 changes: 3 additions & 1 deletion ffcx/codegeneration/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from __future__ import annotations

import typing

from ffcx.codegeneration.access import FFCXBackendAccess
from ffcx.codegeneration.definitions import FFCXBackendDefinitions
from ffcx.codegeneration.symbols import FFCXBackendSymbols
Expand All @@ -16,7 +18,7 @@
class FFCXBackend:
"""Class collecting all aspects of the FFCx backend."""

def __init__(self, ir: IntegralIR | ExpressionIR, options):
def __init__(self, ir: typing.Union[IntegralIR, ExpressionIR], options):
"""Initialise."""
coefficient_numbering = ir.expression.coefficient_numbering
coefficient_offsets = ir.expression.coefficient_offsets
Expand Down
4 changes: 3 additions & 1 deletion ffcx/codegeneration/codegeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class CodeBlocks(typing.NamedTuple):
file_post: list[tuple[str, str]]


def generate_code(ir: DataIR, options: dict[str, int | float | npt.DTypeLike]) -> CodeBlocks:
def generate_code(
ir: DataIR, options: dict[str, typing.Union[int, float, npt.DTypeLike]]
) -> CodeBlocks:
"""Generate code blocks from intermediate representation."""
logger.info(79 * "*")
logger.info("Compiler stage 3: Generating code")
Expand Down
7 changes: 6 additions & 1 deletion ffcx/codegeneration/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
from ffcx.ir.elementtables import UniqueTableReferenceT
from ffcx.ir.representationutils import QuadratureRule
Expand Down Expand Up @@ -50,7 +51,9 @@ def create_dof_index(tabledata, dof_index_symbol):
class FFCXBackendDefinitions:
"""FFCx specific code definitions."""

def __init__(self, entity_type: str, integral_type: str, access, options):
entity_type: entity_types

def __init__(self, entity_type: entity_types, integral_type: str, access, options):
"""Initialise."""
# Store ir and options
self.integral_type = integral_type
Expand Down Expand Up @@ -130,6 +133,8 @@ def coefficient(
num_dofs = tabledata.values.shape[3]
bs = tabledata.block_size
begin = tabledata.offset
assert bs is not None
assert begin is not None
end = begin + bs * (num_dofs - 1) + 1

if ttype == "zeros":
Expand Down
5 changes: 3 additions & 2 deletions ffcx/codegeneration/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sysconfig
import tempfile
import time
import typing
from contextlib import redirect_stdout
from pathlib import Path

Expand Down Expand Up @@ -152,7 +153,7 @@ def _compilation_signature(cffi_extra_compile_args, cffi_debug):
def compile_forms(
forms: list[ufl.Form],
options: dict = {},
cache_dir: Path | None = None,
cache_dir: typing.Optional[Path] = None,
timeout: int = 10,
cffi_extra_compile_args: list[str] = [],
cffi_verbose: bool = False,
Expand Down Expand Up @@ -231,7 +232,7 @@ def compile_forms(
def compile_expressions(
expressions: list[tuple[ufl.Expr, npt.NDArray[np.floating]]],
options: dict = {},
cache_dir: Path | None = None,
cache_dir: typing.Optional[Path] = None,
timeout: int = 10,
cffi_extra_compile_args: list[str] = [],
cffi_verbose: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions ffcx/codegeneration/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ufl

import ffcx.codegeneration.lnodes as L
from ffcx.definitions import entity_types

logger = logging.getLogger("ffcx")

Expand Down Expand Up @@ -95,7 +96,7 @@ def __init__(self, coefficient_numbering, coefficient_offsets, original_constant
# Table for chunk of custom quadrature points (physical coordinates).
self.custom_points_table = L.Symbol("points_chunk", dtype=L.DataType.REAL)

def entity(self, entity_type, restriction):
def entity(self, entity_type: entity_types, restriction):
"""Entity index for lookup in element tables."""
if entity_type == "cell":
# Always 0 for cells (even with restriction)
Expand Down Expand Up @@ -176,7 +177,7 @@ def constant_index_access(self, constant, index):
return c[offset + index]

# TODO: Remove this, use table_access instead
def element_table(self, tabledata, entity_type, restriction):
def element_table(self, tabledata, entity_type: entity_types, restriction):
"""Get an element table."""
entity = self.entity(entity_type, restriction)

Expand Down
5 changes: 5 additions & 0 deletions ffcx/definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Module for storing type definitions used in the FFCx code base."""

from typing import Literal

entity_types = Literal["cell", "facet", "vertex"]
5 changes: 3 additions & 2 deletions ffcx/ir/analysis/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Linearized data structure for the computational graph."""

import logging
import typing

import numpy as np
import ufl
Expand Down Expand Up @@ -73,7 +74,7 @@ def build_graph_vertices(expressions, skip_terminal_modifiers=False):
return G


def build_scalar_graph(expression):
def build_scalar_graph(expression) -> ExpressionGraph:
"""Build list representation of expression graph covering the given expressions."""
# Populate with vertices
G = build_graph_vertices([expression], skip_terminal_modifiers=False)
Expand All @@ -86,7 +87,7 @@ def build_scalar_graph(expression):
G = build_graph_vertices(scalar_expressions, skip_terminal_modifiers=True)

# Compute graph edges
V_deps = []
V_deps: list[typing.Union[tuple[()], list[int]]] = []
for i, v in G.nodes.items():
expr = v["expression"]
if expr._ufl_is_terminal_ or expr._ufl_is_terminal_modifier_:
Expand Down
Loading