Skip to content

Commit 16e6434

Browse files
authored
Fix MixedCheckpointSchedule (#180)
* Test if checkpoints are cleaned correctly for multistep
1 parent 5f46e16 commit 16e6434

File tree

4 files changed

+139
-38
lines changed

4 files changed

+139
-38
lines changed

pyadjoint/checkpointing.py

+60-30
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from enum import Enum
22
import sys
33
from functools import singledispatchmethod
4-
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType
4+
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, \
5+
Forward, Reverse, StorageType, SingleMemoryStorageSchedule
56
# A callback interface allowing the user to provide a
67
# custom error message when disk checkpointing is not configured.
78
disk_checkpointing_callback = {}
@@ -78,6 +79,8 @@ def __init__(self, schedule, tape):
7879
# Tell the tape to only checkpoint input data until told otherwise.
7980
self.tape.latest_checkpoint = 0
8081
self.end_timestep(-1)
82+
self._keep_init_state_in_work = False
83+
self._adj_deps_cleaned = False
8184

8285
def end_timestep(self, timestep):
8386
"""Mark the end of one timestep when taping the forward model.
@@ -299,25 +302,50 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
299302
current_step.checkpoint(
300303
_store_checkpointable_state, _store_adj_dependencies)
301304

302-
if (
303-
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
304-
or not cp_action.write_adj_deps
305-
):
306-
to_keep = set()
307-
if step < (self.total_timesteps - 1):
308-
next_step = self.tape.timesteps[step + 1]
309-
# The checkpointable state set of the current step.
310-
to_keep = next_step.checkpointable_state
311-
if functional:
312-
to_keep = to_keep.union([functional.block_variable])
313-
for block in current_step:
314-
# Remove unnecessary variables from previous steps.
315-
for bv in block.get_outputs():
305+
to_keep = set()
306+
if step < (self.total_timesteps - 1):
307+
next_step = self.tape.timesteps[step + 1]
308+
# The checkpointable state set of the current step.
309+
to_keep = next_step.checkpointable_state
310+
if functional:
311+
to_keep = to_keep.union([functional.block_variable])
312+
313+
for var in current_step.checkpointable_state - to_keep:
314+
# Handle the case where step is 0
315+
if step == 0 and var not in current_step._checkpoint:
316+
# Ensure initialisation state is kept.
317+
self._keep_init_state_in_work = True
318+
break
319+
320+
# Handle the case for SingleMemoryStorageSchedule
321+
if isinstance(self._schedule, SingleMemoryStorageSchedule):
322+
if step > 1 and var not in self.tape.timesteps[step - 1].adjoint_dependencies:
323+
var._checkpoint = None
324+
continue
325+
326+
# Handle variables in the initial timestep
327+
if (
328+
var in self.tape.timesteps[0].checkpointable_state
329+
and self._keep_init_state_in_work
330+
):
331+
continue
332+
333+
# Clear the checkpoint for other cases
334+
var._checkpoint = None
335+
336+
for block in current_step:
337+
# Remove unnecessary variables from previous steps.
338+
for bv in block.get_outputs():
339+
if (
340+
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
341+
or not cp_action.write_adj_deps
342+
):
316343
if bv not in to_keep:
317344
bv._checkpoint = None
318-
# Remove unnecessary variables from previous steps.
319-
for var in (current_step.checkpointable_state - to_keep):
320-
var._checkpoint = None
345+
else:
346+
if bv not in current_step.adjoint_dependencies.union(to_keep):
347+
bv._checkpoint = None
348+
321349
step += 1
322350
if cp_action.storage == StorageType.DISK:
323351
# Activate disk checkpointing only in the checkpointing process.
@@ -333,22 +361,24 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
333361
current_step = self.tape.timesteps[step]
334362
for block in reversed(current_step):
335363
block.evaluate_adj(markings=markings)
364+
if not self._adj_deps_cleaned:
365+
for out in block._outputs:
366+
if not out.marked_in_path:
367+
current_step.adjoint_dependencies.discard(out)
368+
self._adj_deps_cleaned = True
336369
# Output variables are used for the last time when running
337370
# backwards.
371+
to_keep = current_step.checkpointable_state
372+
if functional:
373+
to_keep = to_keep.union([functional.block_variable])
338374
for block in current_step:
339375
block.reset_adjoint_state()
340-
for var in block.get_outputs():
341-
var.checkpoint = None
342-
var.reset_variables(("tlm",))
343-
if not var.is_control:
344-
var.reset_variables(("adjoint", "hessian"))
345-
if cp_action.clear_adj_deps:
346-
to_keep = current_step.checkpointable_state
347-
if functional:
348-
to_keep = to_keep.union([functional.block_variable])
349-
for output in block.get_outputs():
350-
if output not in to_keep:
351-
output._checkpoint = None
376+
for out in block.get_outputs():
377+
out.reset_variables(("tlm",))
378+
if not out.is_control:
379+
out.reset_variables(("adjoint", "hessian"))
380+
if cp_action.clear_adj_deps and out not in to_keep:
381+
out._checkpoint = None
352382

353383
@process_operation.register(Copy)
354384
def _(self, cp_action, progress_bar, **kwargs):

tests/firedrake_adjoint/test_burgers_newton.py

+70-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
set_log_level(CRITICAL)
1414
continue_annotation()
1515

16+
1617
def basics():
1718
n = 30
1819
mesh = UnitIntervalMesh(n)
@@ -21,13 +22,67 @@ def basics():
2122
steps = int(end/float(timestep)) + 1
2223
return mesh, timestep, steps
2324

25+
2426
def Dt(u, u_, timestep):
2527
return (u - u_)/timestep
2628

2729

30+
def _check_forward(tape):
31+
for current_step in tape.timesteps[1:-1]:
32+
for block in current_step:
33+
for deps in block.get_dependencies():
34+
if (
35+
deps not in tape.timesteps[0].checkpointable_state
36+
and deps not in tape.timesteps[-1].checkpointable_state
37+
):
38+
assert deps._checkpoint is None
39+
for out in block.get_outputs():
40+
if out not in tape.timesteps[-1].checkpointable_state:
41+
assert out._checkpoint is None
42+
43+
44+
def _check_recompute(tape):
45+
for current_step in tape.timesteps[1:-1]:
46+
for block in current_step:
47+
for deps in block.get_dependencies():
48+
if deps not in tape.timesteps[0].checkpointable_state:
49+
assert deps._checkpoint is None
50+
for out in block.get_outputs():
51+
assert out._checkpoint is None
52+
53+
for block in tape.timesteps[0]:
54+
for out in block.get_outputs():
55+
assert out._checkpoint is None
56+
for block in tape.timesteps[len(tape.timesteps)-1]:
57+
for deps in block.get_dependencies():
58+
if (
59+
deps not in tape.timesteps[0].checkpointable_state
60+
and deps not in tape.timesteps[len(tape.timesteps)-1].adjoint_dependencies
61+
):
62+
assert deps._checkpoint is None
63+
64+
65+
def _check_reverse(tape):
66+
for step, current_step in enumerate(tape.timesteps):
67+
if step > 0:
68+
for block in current_step:
69+
for deps in block.get_dependencies():
70+
if deps not in tape.timesteps[0].checkpointable_state:
71+
assert deps._checkpoint is None
72+
73+
for out in block.get_outputs():
74+
assert out._checkpoint is None
75+
assert out.adj_value is None
76+
77+
for block in current_step:
78+
for out in block.get_outputs():
79+
assert out._checkpoint is None
80+
81+
2882
def J(ic, solve_type, timestep, steps, V):
29-
u_ = Function(V)
30-
u = Function(V)
83+
84+
u_ = Function(V, name="u_")
85+
u = Function(V, name="u")
3186
v = TestFunction(V)
3287
u_.assign(ic)
3388
nu = Constant(0.0001)
@@ -84,17 +139,28 @@ def test_burgers_newton(solve_type, checkpointing):
84139
mesh = checkpointable_mesh(mesh)
85140
x, = SpatialCoordinate(mesh)
86141
V = FunctionSpace(mesh, "CG", 2)
87-
ic = project(sin(2. * pi * x), V)
142+
ic = project(sin(2. * pi * x), V, name="ic")
88143
val = J(ic, solve_type, timestep, steps, V)
89144
if checkpointing:
90145
assert len(tape.timesteps) == steps
146+
if checkpointing == "Revolve" or checkpointing == "Mixed":
147+
_check_forward(tape)
148+
91149
Jhat = ReducedFunctional(val, Control(ic))
92150
if checkpointing != "NoneAdjoint":
93151
dJ = Jhat.derivative()
152+
if checkpointing is not None:
153+
# Check if the reverse checkpointing is working correctly.
154+
if checkpointing == "Revolve" or checkpointing == "Mixed":
155+
_check_reverse(tape)
94156

95157
# Recomputing the functional with a modified control variable
96158
# before the recompute test.
97159
Jhat(project(sin(pi*x), V))
160+
if checkpointing:
161+
# Check is the checkpointing is working correctly.
162+
if checkpointing == "Revolve" or checkpointing == "Mixed":
163+
_check_recompute(tape)
98164

99165
# Recompute test
100166
assert(np.allclose(Jhat(ic), val))
@@ -143,4 +209,4 @@ def test_checkpointing_validity(solve_type, checkpointing):
143209
Jhat = ReducedFunctional(val1, Control(ic))
144210
assert len(tape.timesteps) == steps
145211
assert np.allclose(val0, val1)
146-
assert np.allclose(dJ0.dat.data_ro[:], Jhat.derivative().dat.data_ro[:])
212+
assert np.allclose(dJ0.dat.data_ro[:], Jhat.derivative().dat.data_ro[:])

tests/firedrake_adjoint/test_checkpointing_multistep.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from firedrake import *
55
from firedrake.adjoint import *
6-
from checkpoint_schedules import Revolve
6+
from tests.firedrake_adjoint.test_burgers_newton import _check_forward, \
7+
_check_recompute, _check_reverse
8+
from checkpoint_schedules import MixedCheckpointSchedule, StorageType
79
import numpy as np
810
from collections import deque
911
continue_annotation()
@@ -42,15 +44,18 @@ def J(displacement_0):
4244
def test_multisteps():
4345
tape = get_working_tape()
4446
tape.progress_bar = ProgressBar
45-
tape.enable_checkpointing(Revolve(total_steps, 2))
47+
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
4648
displacement_0 = Function(V).assign(1.0)
4749
val = J(displacement_0)
50+
_check_forward(tape)
4851
c = Control(displacement_0)
4952
J_hat = ReducedFunctional(val, c)
5053
dJ = J_hat.derivative()
54+
_check_reverse(tape)
5155
# Recomputing the functional with a modified control variable
5256
# before the recompute test.
5357
J_hat(Function(V).assign(0.5))
58+
_check_recompute(tape)
5459
# Recompute test
5560
assert(np.allclose(J_hat(displacement_0), val))
5661
# Test recompute adjoint-based gradient
@@ -70,7 +75,7 @@ def test_validity():
7075
tape.clear_tape()
7176

7277
# With checkpointing.
73-
tape.enable_checkpointing(Revolve(total_steps, 2))
78+
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
7479
val = J(displacement_0)
7580
J_hat = ReducedFunctional(val, Control(displacement_0))
7681
dJ = J_hat.derivative()

tests/firedrake_adjoint/test_disk_checkpointing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,5 @@ def test_disk_checkpointing_error():
8787
# check the raise of the exception
8888
with pytest.raises(RuntimeError):
8989
tape.enable_checkpointing(SingleDiskStorageSchedule())
90-
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
90+
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
9191
"before checkpointing on the disk."

0 commit comments

Comments
 (0)