Skip to content

Commit a8ee848

Browse files
Ig-dolcidham
andauthored
Disk checkpointing. (#173)
* Disk checkpointing. --------- Co-authored-by: David A. Ham <david.ham@imperial.ac.uk>
1 parent c7939a4 commit a8ee848

File tree

5 files changed

+106
-32
lines changed

5 files changed

+106
-32
lines changed

pyadjoint/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
annotate_tape, stop_annotating, pause_annotation, continue_annotation)
1414
from .adjfloat import AdjFloat, exp, log
1515
from .reduced_functional import ReducedFunctional
16+
from .checkpointing import disk_checkpointing_callback
1617
from .drivers import compute_gradient, compute_hessian, solve_adjoint
1718
from .verification import taylor_test, taylor_to_dict
1819
from .overloaded_type import OverloadedType, create_overloaded_object

pyadjoint/checkpointing.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import sys
33
from functools import singledispatchmethod
44
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType
5+
# A callback interface allowing the user to provide a
6+
# custom error message when disk checkpointing is not configured.
7+
disk_checkpointing_callback = {}
58

69

710
class CheckpointError(RuntimeError):
@@ -54,8 +57,8 @@ def __init__(self, schedule, tape):
5457
and not tape._package_data
5558
):
5659
raise CheckpointError(
57-
"The schedule employs disk checkpointing but it is not configured."
58-
)
60+
"The schedule employs disk checkpointing but it is not configured.\n"
61+
+ "\n".join(disk_checkpointing_callback.values()))
5962
self.tape = tape
6063
self._schedule = schedule
6164
self.forward_schedule = []
@@ -152,6 +155,13 @@ def _(self, cp_action, timestep):
152155
# Store the checkpoint data. This is the required data for
153156
# computing the adjoint model from the step `n1`.
154157
_store_adj_dependencies = True
158+
if (
159+
(_store_checkpointable_state or _store_adj_dependencies)
160+
and cp_action.storage == StorageType.DISK
161+
):
162+
for package in self.tape._package_data.values():
163+
package.continue_checkpointing()
164+
155165
self.tape.timesteps[timestep - 1].checkpoint(
156166
_store_checkpointable_state, _store_adj_dependencies)
157167
# Remove unnecessary variables in working memory from previous steps.
@@ -164,6 +174,11 @@ def _(self, cp_action, timestep):
164174
self.tape.get_blocks().append_step()
165175
if cp_action.write_ics:
166176
self.tape.latest_checkpoint = cp_action.n0
177+
178+
if cp_action.storage == StorageType.DISK:
179+
# Activate disk checkpointing only in the checkpointing process.
180+
for package in self.tape._package_data.values():
181+
package.pause_checkpointing()
167182
return True
168183
else:
169184
return False
@@ -186,11 +201,15 @@ def recompute(self, functional=None):
186201
if self.mode == Mode.RECORD:
187202
# Finalise the taping process.
188203
self.end_taping()
204+
if self._schedule.uses_storage_type(StorageType.DISK):
205+
# Clear the data of the current state before recomputing.
206+
for package in self.tape._package_data.values():
207+
package.reset()
189208
self.mode = Mode.RECOMPUTE
190209
with self.tape.progress_bar("Evaluating Functional", max=self.total_timesteps) as progress_bar:
191210
# Restore the initial condition to advance the forward model from the step 0.
192211
current_step = self.tape.timesteps[self.forward_schedule[0].n0]
193-
current_step.restore_from_checkpoint()
212+
current_step.restore_from_checkpoint(self.forward_schedule[0].storage)
194213
for cp_action in self.forward_schedule:
195214
self._current_action = cp_action
196215
self.process_operation(cp_action, progress_bar, functional=functional)
@@ -271,6 +290,12 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
271290
_store_checkpointable_state = True
272291
if cp_action.write_adj_deps:
273292
_store_adj_dependencies = True
293+
if (
294+
(_store_checkpointable_state or _store_adj_dependencies)
295+
and cp_action.storage == StorageType.DISK
296+
):
297+
for package in self.tape._package_data.values():
298+
package.continue_checkpointing()
274299
current_step.checkpoint(
275300
_store_checkpointable_state, _store_adj_dependencies)
276301

@@ -294,6 +319,10 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
294319
for var in (current_step.checkpointable_state - to_keep):
295320
var._checkpoint = None
296321
step += 1
322+
if cp_action.storage == StorageType.DISK:
323+
# Activate disk checkpointing only in the checkpointing process.
324+
for package in self.tape._package_data.values():
325+
package.pause_checkpointing()
297326

298327
@process_operation.register(Reverse)
299328
def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
@@ -324,12 +353,12 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
324353
@process_operation.register(Copy)
325354
def _(self, cp_action, progress_bar, **kwargs):
326355
current_step = self.tape.timesteps[cp_action.n]
327-
current_step.restore_from_checkpoint()
356+
current_step.restore_from_checkpoint(cp_action.from_storage)
328357

329358
@process_operation.register(Move)
330359
def _(self, cp_action, progress_bar, **kwargs):
331360
current_step = self.tape.timesteps[cp_action.n]
332-
current_step.restore_from_checkpoint()
361+
current_step.restore_from_checkpoint(cp_action.from_storage)
333362
current_step.delete_checkpoint()
334363

335364
@process_operation.register(EndForward)

pyadjoint/tape.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from itertools import chain
88
from typing import Optional, Iterable
99
from abc import ABC, abstractmethod
10-
from .checkpointing import CheckpointManager, CheckpointError
10+
from .checkpointing import CheckpointManager, CheckpointError, StorageType
1111

1212
_working_tape = None
1313
_annotation_enabled = False
@@ -293,7 +293,6 @@ def enable_checkpointing(self, schedule):
293293
Args:
294294
schedule (checkpoint_schedules.schedule): A schedule provided by the
295295
checkpoint_schedules package.
296-
max_n (int, optional): The number of total steps.
297296
"""
298297
if self._blocks:
299298
raise CheckpointError(
@@ -775,23 +774,31 @@ def checkpoint(self, checkpointable_state, adj_dependencies):
775774
Args:
776775
checkpointable_state (bool): If True, store the checkpointable state
777776
required to restart from the start of a timestep.
778-
adj_dependencies): (bool): If True, store the adjoint dependencies required
777+
adj_dependencies (bool): If True, store the adjoint dependencies required
779778
to compute the adjoint of a timestep.
780779
"""
781780
with stop_annotating():
782781
if checkpointable_state:
783782
for var in self.checkpointable_state:
784-
self._checkpoint[var] = var.checkpoint
783+
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
785784

786785
if adj_dependencies:
787786
for var in self.adjoint_dependencies:
788-
self._checkpoint[var] = var.checkpoint
787+
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
789788

790-
def restore_from_checkpoint(self):
789+
def restore_from_checkpoint(self, from_storage):
791790
"""Restore the block var checkpoints from the timestep checkpoint."""
792-
793-
for var in self._checkpoint:
794-
var.checkpoint = self._checkpoint[var]
791+
from .overloaded_type import OverloadedType
792+
for var, checkpoint in self._checkpoint.items():
793+
if (
794+
from_storage == StorageType.DISK
795+
and isinstance(checkpoint, OverloadedType)
796+
):
797+
# checkpoint._ad_restore_checkpoint should be able to restore
798+
# from disk.
799+
var.checkpoint = checkpoint._ad_restore_at_checkpoint(checkpoint)
800+
else:
801+
var.checkpoint = checkpoint
795802

796803
def delete_checkpoint(self):
797804
"""Delete the stored checkpoint references."""
@@ -881,10 +888,22 @@ def checkpoint(self):
881888

882889
@abstractmethod
883890
def restore_from_checkpoint(self, state):
884-
"""Restore state from a previously stored checkpioint."""
891+
"""Restore state from a previously stored checkpoint."""
885892
pass
886893

887894
@abstractmethod
888895
def copy(self):
889896
"""Produce a new copy of state to be passed to a copy of the tape."""
890897
pass
898+
899+
@abstractmethod
900+
def continue_checkpointing(self):
901+
"""Continue the checkpointing process on disk.
902+
"""
903+
pass
904+
905+
@abstractmethod
906+
def pause_checkpointing(self):
907+
"""Pause the checkpointing process on disk.
908+
"""
909+
pass

tests/firedrake_adjoint/test_burgers_newton.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
import numpy as np
1313
set_log_level(CRITICAL)
1414
continue_annotation()
15-
n = 30
16-
mesh = UnitIntervalMesh(n)
17-
V = FunctionSpace(mesh, "CG", 2)
18-
end = 0.3
19-
timestep = Constant(1.0/n)
20-
steps = int(end/float(timestep)) + 1
2115

16+
def basics():
17+
n = 30
18+
mesh = UnitIntervalMesh(n)
19+
end = 0.3
20+
timestep = Constant(1.0/n)
21+
steps = int(end/float(timestep)) + 1
22+
return mesh, timestep, steps
2223

2324
def Dt(u, u_, timestep):
2425
return (u - u_)/timestep
2526

2627

27-
def J(ic, solve_type, checkpointing):
28+
def J(ic, solve_type, timestep, steps, V):
2829
u_ = Function(V)
2930
u = Function(V)
3031
v = TestFunction(V)
@@ -65,6 +66,7 @@ def J(ic, solve_type, checkpointing):
6566
def test_burgers_newton(solve_type, checkpointing):
6667
"""Adjoint-based gradient tests with and without checkpointing.
6768
"""
69+
mesh, timestep, steps = basics()
6870
tape = get_working_tape()
6971
tape.progress_bar = ProgressBar
7072
if checkpointing:
@@ -73,13 +75,17 @@ def test_burgers_newton(solve_type, checkpointing):
7375
if checkpointing == "SingleMemory":
7476
schedule = SingleMemoryStorageSchedule()
7577
if checkpointing == "Mixed":
76-
schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM)
78+
enable_disk_checkpointing()
79+
schedule = MixedCheckpointSchedule(steps, steps//3, storage=StorageType.DISK)
7780
if checkpointing == "NoneAdjoint":
7881
schedule = NoneCheckpointSchedule()
7982
tape.enable_checkpointing(schedule)
83+
if schedule.uses_storage_type(StorageType.DISK):
84+
mesh = checkpointable_mesh(mesh)
8085
x, = SpatialCoordinate(mesh)
86+
V = FunctionSpace(mesh, "CG", 2)
8187
ic = project(sin(2. * pi * x), V)
82-
val = J(ic, solve_type, checkpointing)
88+
val = J(ic, solve_type, timestep, steps, V)
8389
if checkpointing:
8490
assert len(tape.timesteps) == steps
8591
Jhat = ReducedFunctional(val, Control(ic))
@@ -109,13 +115,15 @@ def test_burgers_newton(solve_type, checkpointing):
109115
def test_checkpointing_validity(solve_type, checkpointing):
110116
"""Compare forward and backward results with and without checkpointing.
111117
"""
118+
mesh, timestep, steps = basics()
119+
V = FunctionSpace(mesh, "CG", 2)
112120
# Without checkpointing
113121
tape = get_working_tape()
114122
tape.progress_bar = ProgressBar
115123
x, = SpatialCoordinate(mesh)
116124
ic = project(sin(2.*pi*x), V)
117125

118-
val0 = J(ic, solve_type, False)
126+
val0 = J(ic, solve_type, timestep, steps, V)
119127
Jhat = ReducedFunctional(val0, Control(ic))
120128
dJ0 = Jhat.derivative()
121129
tape.clear_tape()
@@ -125,8 +133,13 @@ def test_checkpointing_validity(solve_type, checkpointing):
125133
if checkpointing == "Revolve":
126134
tape.enable_checkpointing(Revolve(steps, steps//3))
127135
if checkpointing == "Mixed":
128-
tape.enable_checkpointing(MixedCheckpointSchedule(steps, steps//3, storage=StorageType.RAM))
129-
val1 = J(ic, solve_type, True)
136+
enable_disk_checkpointing()
137+
tape.enable_checkpointing(MixedCheckpointSchedule(steps, steps//3, storage=StorageType.DISK))
138+
mesh = checkpointable_mesh(mesh)
139+
V = FunctionSpace(mesh, "CG", 2)
140+
x, = SpatialCoordinate(mesh)
141+
ic = project(sin(2.*pi*x), V)
142+
val1 = J(ic, solve_type, timestep, steps, V)
130143
Jhat = ReducedFunctional(val1, Control(ic))
131144
assert len(tape.timesteps) == steps
132145
assert np.allclose(val0, val1)

tests/firedrake_adjoint/test_disk_checkpointing.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from firedrake import *
55
from firedrake.__future__ import *
66
from firedrake.adjoint import *
7+
from firedrake.adjoint_utils.checkpointing import disk_checkpointing
78
import numpy as np
89
import os
10+
from checkpoint_schedules import SingleDiskStorageSchedule
911

1012

1113
def adjoint_example(fine, coarse):
@@ -54,19 +56,24 @@ def adjoint_example(fine, coarse):
5456
return Jnew, grad_Jnew
5557

5658

57-
def test_disk_checkpointing():
59+
@pytest.mark.parametrize("checkpoint_schedule", [True, False])
60+
def test_disk_checkpointing(checkpoint_schedule):
5861
# Use a Firedrake Tape subclass that supports disk checkpointing.
5962
set_working_tape(Tape())
6063
tape = get_working_tape()
6164
tape.clear_tape()
6265
enable_disk_checkpointing()
63-
66+
if checkpoint_schedule:
67+
tape.enable_checkpointing(SingleDiskStorageSchedule())
6468
fine = checkpointable_mesh(UnitSquareMesh(10, 10, name="fine"))
6569
coarse = checkpointable_mesh(UnitSquareMesh(4, 4, name="coarse"))
6670
J_disk, grad_J_disk = adjoint_example(fine, coarse)
6771

72+
if checkpoint_schedule:
73+
assert disk_checkpointing() is False
6874
tape.clear_tape()
69-
pause_disk_checkpointing()
75+
if not checkpoint_schedule:
76+
pause_disk_checkpointing()
7077

7178
J_mem, grad_J_mem = adjoint_example(fine, coarse)
7279

@@ -75,5 +82,10 @@ def test_disk_checkpointing():
7582
tape.clear_tape()
7683

7784

78-
if __name__ == "__main__":
79-
test_disk_checkpointing()
85+
def test_disk_checkpointing_error():
86+
tape = get_working_tape()
87+
# check the raise of the exception
88+
with pytest.raises(RuntimeError):
89+
tape.enable_checkpointing(SingleDiskStorageSchedule())
90+
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
91+
"before checkpointing on the disk."

0 commit comments

Comments
 (0)