Skip to content

Commit 3415139

Browse files
authored
Return AdjFloat of an AdjFloat summation (#181)
* Return AdjFloat of an AdjFloat summation and check the control update in rf
1 parent 16e6434 commit 3415139

File tree

6 files changed

+26
-8
lines changed

6 files changed

+26
-8
lines changed

pyadjoint/adjfloat.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ def __init__(self, *args):
342342
self.add_dependency(dep)
343343

344344
def recompute_component(self, inputs, block_variable, idx, prepared):
345-
return self.operator(*(term.saved_output for term in self.terms))
345+
output = self.operator(*(term.saved_output for term in self.terms))
346+
return self._outputs[0].saved_output._ad_convert_type(output)
346347

347348
def __str__(self):
348349
return f"{self.terms[0]} {self.symbol} {self.terms[1]}"

pyadjoint/reduced_functional.py

+16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .enlisting import Enlist
33
from .tape import get_working_tape, stop_annotating, no_annotations
44
from .overloaded_type import OverloadedType, create_overloaded_object
5+
from .adjfloat import AdjFloat
56

67

78
def _get_extract_derivative_components(derivative_components):
@@ -196,6 +197,21 @@ def __call__(self, values):
196197
if len(values) != len(self.controls):
197198
raise ValueError("values should be a list of same length as controls.")
198199

200+
for i, value in enumerate(values):
201+
control_type = type(self.controls[i].control)
202+
if isinstance(value, (int, float)) and control_type is AdjFloat:
203+
value = self.controls[i].control._ad_convert_type(value)
204+
elif not isinstance(value, control_type):
205+
if len(values) == 1:
206+
raise TypeError(
207+
"Control value must be an `OverloadedType` object with the same "
208+
f"type as the control, which is {control_type}"
209+
)
210+
else:
211+
raise TypeError(
212+
f"The control at index {i} must be an `OverloadedType` object "
213+
f"with the same type as the control, which is {control_type}"
214+
)
199215
# Call callback.
200216
self.eval_cb_pre(self.controls.delist(values))
201217

pyadjoint/tape.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def continue_annotation():
2929

3030
class set_working_tape(ContextDecorator):
3131
"""Set a new tape as the working tape.
32-
32+
3333
This class can be used in three ways:
3434
1) as a free function to replace the working tape,
3535
2) as a context manager within which a new tape is set as the working tape,

tests/firedrake_adjoint/test_burgers_newton.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ def J(ic, solve_type, timestep, steps, V):
9595
solver = NonlinearVariationalSolver(problem)
9696

9797
tape = get_working_tape()
98+
J = 0.0
9899
for _ in tape.timestepper(range(steps)):
99100
if solve_type == "NLVS":
100101
solver.solve()
101102
else:
102103
solve(F == 0, u, bc)
103104
u_.assign(u)
104-
105-
return assemble(u_*u_*dx + ic*ic*dx)
105+
J += assemble(u_*u_*dx + ic*ic*dx)
106+
return J
106107

107108

108109
@pytest.mark.parametrize("solve_type, checkpointing",

tests/firedrake_adjoint/test_external_modification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_external_modification():
1414
v1 = Function(fs)
1515
v2 = Function(fs)
1616

17-
u.assign(1.)
17+
u.interpolate(1.)
1818
v1.project(u)
1919
with stop_annotating(modifies=u):
2020
u.dat.data[:] = 2.
@@ -23,4 +23,4 @@ def test_external_modification():
2323
J = assemble(v1*dx + v2*dx)
2424
Jhat = ReducedFunctional(J, Control(u))
2525

26-
assert np.allclose(J, Jhat(2))
26+
assert np.allclose(J, Jhat(Function(fs).interpolate(2.)))

tests/firedrake_adjoint/test_tlm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_tlm_bc():
6666
c.block_variable.tlm_value = Function(R, val=1)
6767
tape.evaluate_tlm()
6868

69-
assert (taylor_test(Jhat, Constant(c), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9)
69+
assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
7070

7171

7272
def test_tlm_func():
@@ -234,7 +234,7 @@ def test_projection():
234234

235235
k.block_variable.tlm_value = Constant(1)
236236
tape.evaluate_tlm()
237-
assert(taylor_test(Jhat, Constant(k), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9)
237+
assert(taylor_test(Jhat, k, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)
238238

239239

240240
def test_projection_function():

0 commit comments

Comments
 (0)