2
2
from functools import reduce
3
3
import numpy
4
4
from firedrake import Function , FunctionSpace , MixedVectorSpaceBasis , Constant
5
- from ufl .algorithms .analysis import extract_type , has_exact_type
6
- from ufl .algorithms .map_integrands import map_integrand_dags
7
- from ufl .classes import CoefficientDerivative , Zero
8
- from ufl .constantvalue import as_ufl
9
- from ufl .corealg .multifunction import MultiFunction
10
- from ufl .tensors import as_tensor
5
+ from ufl .algorithms .analysis import extract_type
6
+ from ufl import as_tensor , zero
7
+ from ufl import replace as ufl_replace
11
8
12
9
from irksome .deriv import TimeDerivative
13
10
@@ -54,58 +51,10 @@ def getNullspace(V, Vbig, num_stages, nullspace):
54
51
return nspnew
55
52
56
53
57
- # Update for UFL's replace that performs post-order traversal and hence replaces
58
- # more complicated expressions first.
59
- class MyReplacer (MultiFunction ):
60
- def __init__ (self , mapping ):
61
- super ().__init__ ()
62
- self .replacements = mapping
63
- if not all (k .ufl_shape == v .ufl_shape for k , v in mapping .items ()):
64
- raise ValueError ("Replacement expressions must have the same shape as what they replace." )
65
-
66
- def expr (self , o ):
67
- if o in self .replacements :
68
- return self .replacements [o ]
69
- else :
70
- return self .reuse_if_untouched (o , * map (self , o .ufl_operands ))
71
-
72
-
73
54
def replace (e , mapping ):
74
- """Replace subexpressions in expression.
75
-
76
- @param e:
77
- An Expr or Form.
78
- @param mapping:
79
- A dict with from:to replacements to perform.
80
- """
81
- mapping2 = dict ((k , as_ufl (v )) for (k , v ) in mapping .items ())
82
-
83
- # Workaround for problem with delayed derivative evaluation
84
- # The problem is that J = derivative(f(g, h), g) does not evaluate immediately
85
- # So if we subsequently do replace(J, {g: h}) we end up with an expression:
86
- # derivative(f(h, h), h)
87
- # rather than what were were probably thinking of:
88
- # replace(derivative(f(g, h), g), {g: h})
89
- #
90
- # To fix this would require one to expand derivatives early (which
91
- # is not attractive), or make replace lazy too.
92
- if has_exact_type (e , CoefficientDerivative ):
93
- # Hack to avoid circular dependencies
94
- from ufl .algorithms .ad import expand_derivatives
95
- e = expand_derivatives (e )
96
-
97
- return map_integrand_dags (MyReplacer (mapping2 ), e )
98
-
99
-
100
- def component_replace (e , mapping ):
101
- """Replace, recurring on components"""
102
- cmapping = {}
103
- for key , value in mapping .items ():
104
- cmapping [key ] = as_tensor (value )
105
- if key .ufl_shape :
106
- for j in numpy .ndindex (key .ufl_shape ):
107
- cmapping [key [j ]] = value [j ]
108
- return replace (e , cmapping )
55
+ """A wrapper for ufl.replace that allows numpy arrays."""
56
+ cmapping = {k : as_tensor (v ) for k , v in mapping .items ()}
57
+ return ufl_replace (e , cmapping )
109
58
110
59
111
60
# Utility functions that help us refactor
@@ -120,11 +69,13 @@ def IA(A):
120
69
def is_ode (f , u ):
121
70
"""Given a form defined over a function `u`, checks if
122
71
(each bit of) u appears under a time derivative."""
123
- blah = extract_type (f , TimeDerivative )
124
-
125
- Dtbits = set (b .ufl_operands [0 ] for b in blah )
126
- ubits = set (u [i ] for i in numpy .ndindex (u .ufl_shape ))
127
- return Dtbits == ubits
72
+ derivs = extract_type (f , TimeDerivative )
73
+ Dtbits = []
74
+ for k in derivs :
75
+ op , = k .ufl_operands
76
+ Dtbits .extend (op [i ] for i in numpy .ndindex (op .ufl_shape ))
77
+ ubits = [u [i ] for i in numpy .ndindex (u .ufl_shape )]
78
+ return set (Dtbits ) == set (ubits )
128
79
129
80
130
81
# Utility class for constants on a mesh
@@ -139,7 +90,7 @@ def Constant(self, val=0.0):
139
90
140
91
def ConstantOrZero (x , MC = None ):
141
92
const = MC .Constant if MC else Constant
142
- return Zero () if abs (complex (x )) < 1.e-10 else const (x )
93
+ return zero () if abs (complex (x )) < 1.e-10 else const (x )
143
94
144
95
145
96
vecconst = numpy .vectorize (ConstantOrZero )
0 commit comments