Skip to content

Commit 9fbb0b1

Browse files
authored
Dolci/tape recompute count (#172)
* Add a tape computed counter.
1 parent 87862e1 commit 9fbb0b1

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

pyadjoint/reduced_functional.py

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __call__(self, values):
204204

205205
self.tape.reset_blocks()
206206
blocks = self.tape.get_blocks()
207+
self.tape._recompute_count += 1
207208
with self.marked_controls():
208209
with stop_annotating():
209210
if self.tape._checkpoint_manager:

pyadjoint/tape.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class Tape(object):
163163
__slots__ = ["_blocks", "_tf_tensors", "_tf_added_blocks", "_nodes",
164164
"_tf_registered_blocks", "_bar", "_package_data",
165165
"_checkpoint_manager", "latest_checkpoint",
166-
"_eagerly_checkpoint_outputs"]
166+
"_eagerly_checkpoint_outputs", "_recompute_count"]
167167

168168
def __init__(self, blocks=None, package_data=None):
169169
# Initialize the list of blocks on the tape.
@@ -182,6 +182,8 @@ def __init__(self, blocks=None, package_data=None):
182182
self._checkpoint_manager = None
183183
# Whether to store the adjoint dependencies.
184184
self._eagerly_checkpoint_outputs = False
185+
# A counter for the number of tape recomputations.
186+
self._recompute_count = 0
185187

186188
def clear_tape(self):
187189
"""Clear the tape."""
@@ -196,6 +198,11 @@ def latest_timestep(self):
196198
"""The current time step to which blocks will be added."""
197199
return max(len(self._blocks.steps) - 1, 0)
198200

201+
@property
202+
def recompute_count(self):
203+
"""The number of times the tape has been recomputed."""
204+
return self._recompute_count
205+
199206
def end_timestep(self):
200207
"""Mark the end of a timestep when taping the forward model."""
201208
if self._checkpoint_manager:

tests/firedrake_adjoint/test_solving.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def J(f):
2828
J0 = J(f)
2929
rf = ReducedFunctional(J0, Control(f))
3030
assert_approx_equal(rf(f), J0)
31-
31+
assert rf.tape.recompute_count == 1
3232
_test_adjoint(J, f)
3333

3434

@@ -298,6 +298,8 @@ def test_two_nonlinear_solves():
298298
J = assemble(dot(u1, u1)*dx)
299299
rf = ReducedFunctional(J, c)
300300
assert taylor_test(rf, ui, Constant(0.1)) > 1.95
301+
# Taylor test recomputes the functional 5 times.
302+
assert rf.tape.recompute_count == 5
301303

302304

303305
def convergence_rates(E_values, eps_values):

0 commit comments

Comments
 (0)