Skip to content

Commit 6061348

Browse files
committed
Fixes
1 parent 8ae39ea commit 6061348

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

ffcx/ir/analysis/graph.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""Linearized data structure for the computational graph."""
77

88
import logging
9+
import typing
910

1011
import numpy as np
1112
import ufl
@@ -73,7 +74,7 @@ def build_graph_vertices(expressions, skip_terminal_modifiers=False):
7374
return G
7475

7576

76-
def build_scalar_graph(expression):
77+
def build_scalar_graph(expression) -> ExpressionGraph:
7778
"""Build list representation of expression graph covering the given expressions."""
7879
# Populate with vertices
7980
G = build_graph_vertices([expression], skip_terminal_modifiers=False)
@@ -86,7 +87,7 @@ def build_scalar_graph(expression):
8687
G = build_graph_vertices(scalar_expressions, skip_terminal_modifiers=True)
8788

8889
# Compute graph edges
89-
V_deps = []
90+
V_deps: list[typing.Union[tuple[()], list[int]]] = []
9091
for i, v in G.nodes.items():
9192
expr = v["expression"]
9293
if expr._ufl_is_terminal_ or expr._ufl_is_terminal_modifier_:

ffcx/ir/integral.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def compute_integral_ir(
140140
for comp in S.nodes[target]["component"]:
141141
assert expressions[comp] is None
142142
expressions[comp] = S.nodes[target]["expression"]
143-
expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape))
143+
expression = ufl.as_tensor(np.reshape(expressions, expression.ufl_shape)) # type: ignore
144144

145145
# Rebuild scalar list-based graph representation
146146
S = build_scalar_graph(expression)
@@ -173,10 +173,10 @@ def compute_integral_ir(
173173
k += 1
174174

175175
# Get list of indices in F which are the arguments (should be at start)
176-
argkeys: set[int] = set()
176+
_argkeys: set[int] = set()
177177
for w in argument_factorization:
178-
argkeys = argkeys | set(w)
179-
argkeys = list(argkeys)
178+
_argkeys = _argkeys | set(w)
179+
argkeys = list(_argkeys)
180180

181181
# Build set of modified_terminals for each mt factorized vertex in F
182182
# and attach tables, if appropriate
@@ -216,12 +216,13 @@ def compute_integral_ir(
216216
assert tr.block_size is not None
217217
dofmap = tuple(begin + i * tr.block_size for i in range(num_dofs))
218218
_blockmap.append(dofmap)
219+
blockmap = tuple(_blockmap)
219220

220221
block_is_uniform = all(tr.is_uniform for tr in trs)
221222

222223
# Collect relevant restrictions to identify blocks correctly
223224
# in interior facet integrals
224-
225+
225226
# Collect relevant restrictions to identify blocks correctly
226227
# in interior facet integrals
227228
_block_restrictions: list[str] = []
@@ -267,8 +268,9 @@ def compute_integral_ir(
267268
tr = v.get("tr")
268269
if tr is not None and F.nodes[i]["status"] != "inactive":
269270
if tr.has_tensor_factorisation:
271+
assert tr.tensor_factors is not None
270272
for t in tr.tensor_factors:
271-
active_table_names.add(t.name)
273+
active_table_names.add(t.name)
272274
else:
273275
active_table_names.add(tr.name)
274276

@@ -277,6 +279,7 @@ def compute_integral_ir(
277279
for blockdata in contributions:
278280
for mad in blockdata.ma_data:
279281
if mad.tabledata.has_tensor_factorisation:
282+
assert mad.tabledata.tensor_factors is not None
280283
for t in mad.tabledata.tensor_factors:
281284
active_table_names.add(t.name)
282285
else:

0 commit comments

Comments
 (0)