Skip to content

Commit ba50d5d

Browse files
authored
TimeDerivativeRuleDispatcher: handle generic UFL types (#132)
* Fix TimeDerivativeRuleDispatcher * add a test
1 parent 73967a7 commit ba50d5d

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

irksome/deriv.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,7 @@ def time_derivative(self, o):
7979
f, = o.ufl_operands
8080
return map_expr_dag(self.rules, f)
8181

82-
expr = MultiFunction.reuse_if_untouched
83-
terminal = expr
84-
derivative = expr
85-
grad = expr
86-
curl = expr
87-
div = expr
82+
ufl_type = MultiFunction.reuse_if_untouched
8883

8984

9085
def apply_time_derivatives(expression, t=None, timedep_coeffs=None):

tests/test_differentiation.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from ufl.algorithms import expand_derivatives
33
from irksome import MeshConstant, Dt, expand_time_derivatives
4-
from firedrake import Constant, diff, dot, FunctionSpace, Function, sin, UnitIntervalMesh, VectorFunctionSpace
4+
from firedrake import Cofunction, Constant, diff, dot, dx, FunctionSpace, Function, inner, sin, TestFunction, UnitIntervalMesh, VectorFunctionSpace
55

66

77
@pytest.fixture
@@ -79,3 +79,16 @@ def test_expand_second_derivative_product_rule(V):
7979
+ dot(u, Dt(w, 2)))
8080
# UFL equality is failing here due to different index numbers
8181
assert str(expr) == str(expected)
82+
83+
84+
def test_cofunction(V):
85+
u = Function(V)
86+
v = TestFunction(V)
87+
c = Cofunction(V.dual())
88+
expr = inner(Dt(u), v)*dx + c
89+
90+
# TimeDerivatives are already expanded
91+
# This is just to test that it can handle Cofunction properly
92+
expected = expr
93+
expr = expand_time_derivatives(expr)
94+
assert expr == expected

0 commit comments

Comments
 (0)