Skip to content

Commit 04de41a

Browse files
committed
Start fixing 'check_untyped_defs' for 'ffcx.codegeneration.*'
1 parent c054d85 commit 04de41a

File tree

6 files changed

+26
-15
lines changed

6 files changed

+26
-15
lines changed

ffcx/codegeneration/geometry.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,19 @@ def reference_facet_volume(tablename, cellname):
7676
"""Write a reference facet volume."""
7777
celltype = getattr(basix.CellType, cellname)
7878
volumes = basix.cell.facet_reference_volumes(celltype)
79-
for i in volumes[1:]:
80-
if not np.isclose(i, volumes[0]):
79+
for i in volumes[1:]: # type: ignore
80+
if not np.isclose(i, volumes[0]): # type: ignore
8181
raise ValueError("Reference facet volume not supported for this cell type.")
8282
symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL)
83-
return L.VariableDecl(symbol, volumes[0])
83+
return L.VariableDecl(symbol, volumes[0]) # type: ignore
8484

8585

8686
def reference_cell_edge_vectors(tablename, cellname):
8787
"""Write reference edge vectors."""
8888
celltype = getattr(basix.CellType, cellname)
8989
topology = basix.topology(celltype)
9090
geometry = basix.geometry(celltype)
91-
edge_vectors = [geometry[j] - geometry[i] for i, j in topology[1]]
91+
edge_vectors = [geometry[j] - geometry[i] for i, j in topology[1]] # type: ignore
9292
out = np.array(edge_vectors)
9393
symbol = L.Symbol(f"{cellname}_{tablename}", dtype=L.DataType.REAL)
9494
return L.ArrayDecl(symbol, values=out, const=True)
@@ -108,10 +108,11 @@ def reference_facet_edge_vectors(tablename, cellname):
108108
edge_vectors = []
109109
for facet in topology[-2]:
110110
if len(facet) == 3:
111-
edge_vectors += [geometry[facet[j]] - geometry[facet[i]] for i, j in triangle_edges]
111+
edge_vectors += [geometry[facet[j]] - geometry[facet[i]] for i, j in triangle_edges] # type: ignore
112112
elif len(facet) == 4:
113113
edge_vectors += [
114-
geometry[facet[j]] - geometry[facet[i]] for i, j in quadrilateral_edges
114+
geometry[facet[j]] - geometry[facet[i]]
115+
for i, j in quadrilateral_edges # type: ignore
115116
]
116117
else:
117118
raise ValueError("Only triangular and quadrilateral faces supported.")

ffcx/codegeneration/integral_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def generate_geometry_tables(self):
209209
ufl.geometry.ReferenceNormal: "reference_normals",
210210
ufl.geometry.FacetOrientation: "facet_orientation",
211211
}
212-
cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()} # type: ignore
212+
cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()}
213213

214214
for integrand in self.ir.expression.integrand.values():
215215
for attr in integrand["factorization"].nodes.values():

ffcx/codegeneration/jit.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import importlib
11+
import importlib.util
1112
import io
1213
import logging
1314
import os
@@ -109,7 +110,7 @@ def get_cached_module(module_name, object_names, cache_dir, timeout):
109110
for i in range(timeout):
110111
if os.path.exists(ready_name):
111112
spec = finder.find_spec(module_name)
112-
if spec is None:
113+
if spec is None or spec.loader is None:
113114
raise ModuleNotFoundError("Unable to find JIT module.")
114115
compiled_module = importlib.util.module_from_spec(spec)
115116
spec.loader.exec_module(compiled_module)
@@ -409,11 +410,12 @@ def _load_objects(cache_dir, module_name, object_names):
409410
# (new) modules are found
410411
finder.invalidate_caches()
411412
spec = finder.find_spec(module_name)
412-
if spec is None:
413+
if spec is None or spec.loader is None:
413414
raise ModuleNotFoundError("Unable to find JIT module.")
414415

415416
# Load module
416417
compiled_module = importlib.util.module_from_spec(spec)
418+
417419
spec.loader.exec_module(compiled_module)
418420

419421
compiled_objects = []

ffcx/codegeneration/lnodes.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ def __eq__(self, other):
298298

299299
def __float__(self):
300300
"""Convert to float."""
301+
if isinstance(self.value, complex):
302+
raise RuntimeError("__float__ of a complex value not defined.")
303+
301304
return float(self.value)
302305

303306
def __repr__(self):
@@ -316,6 +319,10 @@ def __init__(self, value):
316319
self.value = value
317320
self.dtype = DataType.INT
318321

322+
def __int__(self):
323+
"""Overwrites int()."""
324+
return self.value
325+
319326
def __eq__(self, other):
320327
"""Check equality."""
321328
return isinstance(other, LiteralInt) and self.value == other.value
@@ -447,6 +454,8 @@ def __eq__(self, other):
447454
class BinOp(LExprOperator):
448455
"""A binary operator."""
449456

457+
op = ""
458+
450459
def __init__(self, lhs, rhs):
451460
"""Initialise."""
452461
self.lhs = as_lexpr(lhs)
@@ -955,7 +964,6 @@ def __eq__(self, other):
955964
"""Check equality."""
956965
return (
957966
isinstance(other, type(self))
958-
and self.typename == other.typename
959967
and self.symbol == other.symbol
960968
and self.value == other.value
961969
)

ffcx/codegeneration/symbols.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,7 @@ def element_table(self, tabledata, entity_type, restriction):
185185
else:
186186
entity = self.entity(entity_type, restriction)
187187

188-
if tabledata.is_piecewise:
189-
iq = 0
190-
else:
191-
iq = self.quadrature_loop_index
188+
iq = 0 if tabledata.is_piecewise else self.quadrature_loop_index
192189

193190
if tabledata.is_permuted:
194191
qp = self.quadrature_permutation[0]

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ disallow_any_unimported = false # most of these come from UFL
8383

8484
[[tool.mypy.overrides]]
8585
module = ["ffcx.ir.*", "ffcx.codegeneration.*"]
86-
check_untyped_defs = false
8786
disallow_untyped_defs = false
8887

88+
[[tool.mypy.overrides]]
89+
module = ["ffcx.ir.*"]
90+
check_untyped_defs = false
91+
8992
[tool.ruff]
9093
line-length = 100
9194
indent-width = 4

0 commit comments

Comments
 (0)