Skip to content

Commit 1e432a9

Browse files
authored
Clean up UFL manipulation (#130)
* Clean up UFL manipulation * unfiy replace
1 parent 6867184 commit 1e432a9

9 files changed

+41
-103
lines changed

irksome/deriv.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
from ufl.algorithms.map_integrands import map_integrand_dags, map_expr_dag
66
from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset
77
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
8-
from ufl.tensors import ListTensor
9-
from ufl.indexed import Indexed
10-
from ufl.core.multiindex import FixedIndex
118

129

1310
@ufl_type(num_ops=1,
@@ -21,9 +18,6 @@ class TimeDerivative(Derivative):
2118
__slots__ = ()
2219

2320
def __new__(cls, f):
24-
if isinstance(f, ListTensor):
25-
# Push TimeDerivative inside ListTensor
26-
return ListTensor(*map(TimeDerivative, f.ufl_operands))
2721
return Derivative.__new__(cls)
2822

2923
def __init__(self, f):
@@ -32,14 +26,6 @@ def __init__(self, f):
3226
def __str__(self):
3327
return "d{%s}/dt" % (self.ufl_operands[0],)
3428

35-
def _simplify_indexed(self, multiindex):
36-
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
37-
# Push Indexed inside TimeDerivative
38-
if all(isinstance(i, FixedIndex) for i in multiindex):
39-
f, = self.ufl_operands
40-
return TimeDerivative(Indexed(f, multiindex))
41-
return Derivative._simplify_indexed(self, multiindex)
42-
4329

4430
def Dt(f, order=1):
4531
"""Short-hand function to produce a :class:`TimeDerivative` of a given order."""
@@ -72,9 +58,10 @@ def time_derivative(self, o, f):
7258
else:
7359
return map_expr_dag(self, f)
7460

75-
def _linear_op(self, o, f):
76-
return o._ufl_expr_reconstruct_(f)
61+
def _linear_op(self, o, *operands):
62+
return o._ufl_expr_reconstruct_(*operands)
7763

64+
indexed = _linear_op
7865
derivative = _linear_op
7966
grad = _linear_op
8067
curl = _linear_op

irksome/dirk_stepper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ufl.constantvalue import as_ufl
66

77
from .deriv import TimeDerivative, expand_time_derivatives
8-
from .tools import component_replace, replace, MeshConstant, vecconst
8+
from .tools import replace, MeshConstant, vecconst
99
from .bcs import bc2space
1010

1111

@@ -36,7 +36,7 @@ def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None):
3636
repl = {t: t + c * dt,
3737
u0: g + k * (a * dt),
3838
TimeDerivative(u0): k}
39-
stage_F = component_replace(F, repl)
39+
stage_F = replace(F, repl)
4040

4141
bcnew = []
4242

irksome/discontinuous_galerkin_stepper.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .bcs import stage2spaces4bc
88
from .deriv import expand_time_derivatives
99
from .manipulation import extract_terms, strip_dt_form
10-
from .tools import component_replace, replace, vecconst
10+
from .tools import replace, vecconst
1111
import numpy as np
1212
from firedrake import TestFunction
1313

@@ -97,21 +97,21 @@ def getFormDiscGalerkin(F, L, Q, t, dt, u0, stages, bcs=None):
9797
# Jump terms
9898
repl = {u0: u_np[0] - u0,
9999
v: v_np[0]}
100-
Fnew = component_replace(F_dtless, repl)
100+
Fnew = replace(F_dtless, repl)
101101

102102
# Terms with time derivatives
103103
for q in range(len(qpts)):
104104
repl = {t: t + qpts[q] * dt,
105105
v: vsub[q] * dt,
106106
u0: dtu0sub[q] / dt}
107-
Fnew += component_replace(F_dtless, repl)
107+
Fnew += replace(F_dtless, repl)
108108

109109
# Handle the rest of the terms
110110
for q in range(len(qpts)):
111111
repl = {t: t + qpts[q] * dt,
112112
v: vsub[q] * dt,
113113
u0: usub[q]}
114-
Fnew += component_replace(F_remainder, repl)
114+
Fnew += replace(F_remainder, repl)
115115

116116
# Oh, honey, is it the boundary conditions?
117117
if bcs is None:

irksome/galerkin_stepper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .base_time_stepper import StageCoupledTimeStepper
77
from .bcs import bc2space, stage2spaces4bc
88
from .deriv import TimeDerivative, expand_time_derivatives
9-
from .tools import component_replace, replace, vecconst
9+
from .tools import replace, vecconst
1010
import numpy as np
1111
from firedrake import TestFunction
1212

@@ -91,7 +91,7 @@ def getFormGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, bcs=None):
9191
v: vsub[q] * dt,
9292
u0: usub[q],
9393
dtu0: dtu0sub[q] / dt}
94-
Fnew += component_replace(F, repl)
94+
Fnew += replace(F, repl)
9595

9696
# Oh, honey, is it the boundary conditions?
9797
if bcs is None:

irksome/imex.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .ButcherTableaux import RadauIIA
1010
from .deriv import TimeDerivative, expand_time_derivatives
1111
from .stage_value import getFormStage
12-
from .tools import AI, ConstantOrZero, IA, MeshConstant, replace, component_replace, getNullspace, get_stage_space
12+
from .tools import AI, ConstantOrZero, IA, MeshConstant, replace, getNullspace, get_stage_space
1313
from .bcs import bc2space
1414

1515

@@ -66,15 +66,15 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None):
6666
for i in range(num_stages):
6767
# replace test function
6868
repl = {v: v_np[i]}
69-
Ftmp = component_replace(Fexp, repl)
69+
Ftmp = replace(Fexp, repl)
7070

7171
# replace the solution with stage values
7272
for j in range(num_stages):
7373
repl = {t: t + C[j] * dt,
7474
u0: u_np[j]}
7575

7676
# and sum the contribution
77-
replF = component_replace(Ftmp, repl)
77+
replF = replace(Ftmp, repl)
7878
Fit += Ait[i, j] * dt * replF
7979
Fprop += Aprop[i, j] * dt * replF
8080
elif splitting == IA:
@@ -84,23 +84,23 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None):
8484
u0: u_np[i],
8585
v: v_np[i]}
8686

87-
Fit += dt * component_replace(Fexp, repl)
87+
Fit += dt * replace(Fexp, repl)
8888

8989
# dense contribution to propagator
9090
AinvAexp = vecconst(np.linalg.solve(butch.A, Aexp))
9191

9292
for i in range(num_stages):
9393
# replace test function
9494
repl = {v: v_np[i]}
95-
Ftmp = component_replace(Fexp, repl)
95+
Ftmp = replace(Fexp, repl)
9696

9797
# replace the solution with stage values
9898
for j in range(num_stages):
9999
repl = {t: t + C[j] * dt,
100100
u0: u_np[j]}
101101

102102
# and sum the contribution
103-
Fprop += AinvAexp[i, j] * dt * component_replace(Ftmp, repl)
103+
Fprop += AinvAexp[i, j] * dt * replace(Ftmp, repl)
104104
else:
105105
raise NotImplementedError(
106106
"Must specify splitting to either IA or AI")
@@ -329,13 +329,13 @@ def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None):
329329
repl = {t: t + c * dt,
330330
u0: g + dt * a * k,
331331
TimeDerivative(u0): k}
332-
stage_F = component_replace(F, repl)
332+
stage_F = replace(F, repl)
333333

334334
# Explicit replacement, solve at time t + chat * dt, for khat
335335
replhat = {t: t + chat * dt,
336336
u0: ghat}
337337

338-
Fhat = inner(khat, vhat)*dx + component_replace(Fexp, replhat)
338+
Fhat = inner(khat, vhat)*dx + replace(Fexp, replhat)
339339

340340
bcnew = []
341341

irksome/nystrom_stepper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .base_time_stepper import StageCoupledTimeStepper
22
from .bcs import BCStageData, bc2space
33
from .deriv import Dt, TimeDerivative, expand_time_derivatives
4-
from .tools import component_replace, replace, vecconst
4+
from .tools import replace, vecconst
55
from firedrake import TestFunction, as_ufl
66
import numpy
77
from ufl import zero
@@ -106,7 +106,7 @@ def getFormNystrom(F, tableau, t, dt, u0, ut0, stages,
106106
u0: u0 + ut0 * (c[i] * dt) + Abark[i] * dt**2,
107107
dtu: ut0 + Ak[i] * dt,
108108
dt2u: k_np[i]}
109-
Fnew += component_replace(F, repl)
109+
Fnew += replace(F, repl)
110110

111111
if bcs is None:
112112
bcs = []

irksome/stage_derivative.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from firedrake import assemble, dx, inner, norm
66

77
from ufl.constantvalue import as_ufl, zero
8-
from .tools import component_replace, replace, AI, vecconst
8+
from .tools import AI, replace, vecconst
99
from .deriv import Dt, TimeDerivative, expand_time_derivatives
1010
from .bcs import EmbeddedBCData, BCStageData, bc2space
1111
from .manipulation import extract_terms
@@ -84,7 +84,7 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):
8484
v: v_np[i],
8585
u0: u0 + A1w[i] * dt,
8686
dtu: A2invw[i]}
87-
Fnew += component_replace(F, repl)
87+
Fnew += replace(F, repl)
8888

8989
if bcs is None:
9090
bcs = []

irksome/stage_value.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .ButcherTableaux import CollocationButcherTableau
1212
from .deriv import expand_time_derivatives
1313
from .manipulation import extract_terms, strip_dt_form
14-
from .tools import AI, is_ode, replace, component_replace, vecconst
14+
from .tools import AI, is_ode, replace, vecconst
1515
from .base_time_stepper import StageCoupledTimeStepper
1616

1717

@@ -117,15 +117,15 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo
117117
repl = {t: t + c[i] * dt,
118118
v: A2invTv[i],
119119
u0: w_np[i] - u0}
120-
Fnew += component_replace(F_dtless, repl)
120+
Fnew += replace(F_dtless, repl)
121121

122122
# Handle the rest of the terms
123123
for i in range(num_stages):
124124
# replace the solution with stage values
125125
repl = {t: t + c[i] * dt,
126126
v: A1Tv[i] * dt,
127127
u0: w_np[i]}
128-
Fnew += component_replace(F_remainder, repl)
128+
Fnew += replace(F_remainder, repl)
129129

130130
if bcs is None:
131131
bcs = []
@@ -212,7 +212,7 @@ def get_update_solver(self, update_solver_parameters):
212212
for i in range(self.num_stages):
213213
repl = {t: t + C[i] * dt,
214214
u0: u_np[i]}
215-
Fupdate += dt * B[i] * component_replace(split_form.remainder, repl)
215+
Fupdate += dt * B[i] * replace(split_form.remainder, repl)
216216

217217
# And the BC's for the update -- just the original BC at t+dt
218218
update_bcs = []

irksome/tools.py

+14-63
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22
from functools import reduce
33
import numpy
44
from firedrake import Function, FunctionSpace, MixedVectorSpaceBasis, Constant
5-
from ufl.algorithms.analysis import extract_type, has_exact_type
6-
from ufl.algorithms.map_integrands import map_integrand_dags
7-
from ufl.classes import CoefficientDerivative, Zero
8-
from ufl.constantvalue import as_ufl
9-
from ufl.corealg.multifunction import MultiFunction
10-
from ufl.tensors import as_tensor
5+
from ufl.algorithms.analysis import extract_type
6+
from ufl import as_tensor, zero
7+
from ufl import replace as ufl_replace
118

129
from irksome.deriv import TimeDerivative
1310

@@ -54,58 +51,10 @@ def getNullspace(V, Vbig, num_stages, nullspace):
5451
return nspnew
5552

5653

57-
# Update for UFL's replace that performs post-order traversal and hence replaces
58-
# more complicated expressions first.
59-
class MyReplacer(MultiFunction):
60-
def __init__(self, mapping):
61-
super().__init__()
62-
self.replacements = mapping
63-
if not all(k.ufl_shape == v.ufl_shape for k, v in mapping.items()):
64-
raise ValueError("Replacement expressions must have the same shape as what they replace.")
65-
66-
def expr(self, o):
67-
if o in self.replacements:
68-
return self.replacements[o]
69-
else:
70-
return self.reuse_if_untouched(o, *map(self, o.ufl_operands))
71-
72-
7354
def replace(e, mapping):
74-
"""Replace subexpressions in expression.
75-
76-
@param e:
77-
An Expr or Form.
78-
@param mapping:
79-
A dict with from:to replacements to perform.
80-
"""
81-
mapping2 = dict((k, as_ufl(v)) for (k, v) in mapping.items())
82-
83-
# Workaround for problem with delayed derivative evaluation
84-
# The problem is that J = derivative(f(g, h), g) does not evaluate immediately
85-
# So if we subsequently do replace(J, {g: h}) we end up with an expression:
86-
# derivative(f(h, h), h)
87-
# rather than what were were probably thinking of:
88-
# replace(derivative(f(g, h), g), {g: h})
89-
#
90-
# To fix this would require one to expand derivatives early (which
91-
# is not attractive), or make replace lazy too.
92-
if has_exact_type(e, CoefficientDerivative):
93-
# Hack to avoid circular dependencies
94-
from ufl.algorithms.ad import expand_derivatives
95-
e = expand_derivatives(e)
96-
97-
return map_integrand_dags(MyReplacer(mapping2), e)
98-
99-
100-
def component_replace(e, mapping):
101-
"""Replace, recurring on components"""
102-
cmapping = {}
103-
for key, value in mapping.items():
104-
cmapping[key] = as_tensor(value)
105-
if key.ufl_shape:
106-
for j in numpy.ndindex(key.ufl_shape):
107-
cmapping[key[j]] = value[j]
108-
return replace(e, cmapping)
55+
"""A wrapper for ufl.replace that allows numpy arrays."""
56+
cmapping = {k: as_tensor(v) for k, v in mapping.items()}
57+
return ufl_replace(e, cmapping)
10958

11059

11160
# Utility functions that help us refactor
@@ -120,11 +69,13 @@ def IA(A):
12069
def is_ode(f, u):
12170
"""Given a form defined over a function `u`, checks if
12271
(each bit of) u appears under a time derivative."""
123-
blah = extract_type(f, TimeDerivative)
124-
125-
Dtbits = set(b.ufl_operands[0] for b in blah)
126-
ubits = set(u[i] for i in numpy.ndindex(u.ufl_shape))
127-
return Dtbits == ubits
72+
derivs = extract_type(f, TimeDerivative)
73+
Dtbits = []
74+
for k in derivs:
75+
op, = k.ufl_operands
76+
Dtbits.extend(op[i] for i in numpy.ndindex(op.ufl_shape))
77+
ubits = [u[i] for i in numpy.ndindex(u.ufl_shape)]
78+
return set(Dtbits) == set(ubits)
12879

12980

13081
# Utility class for constants on a mesh
@@ -139,7 +90,7 @@ def Constant(self, val=0.0):
13990

14091
def ConstantOrZero(x, MC=None):
14192
const = MC.Constant if MC else Constant
142-
return Zero() if abs(complex(x)) < 1.e-10 else const(x)
93+
return zero() if abs(complex(x)) < 1.e-10 else const(x)
14394

14495

14596
vecconst = numpy.vectorize(ConstantOrZero)

0 commit comments

Comments
 (0)