Skip to content

Commit c7bf2ec

Browse files
Optional garbage collection and CheckpointManager._global_deps (#187)
* Optional gc_collect * Add global_deps * Add _adj_deps_cleaned into TimeStep --------- Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
1 parent 86571c1 commit c7bf2ec

File tree

2 files changed

+81
-13
lines changed

2 files changed

+81
-13
lines changed

pyadjoint/checkpointing.py

+55-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from enum import Enum
22
import sys
3+
import gc
4+
import logging
35
from functools import singledispatchmethod
46
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, \
57
Forward, Reverse, StorageType, SingleMemoryStorageSchedule
@@ -37,6 +39,14 @@ class CheckpointManager:
3739
Args:
3840
schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
3941
tape (Tape): A list of blocks :class:`Block` instances.
42+
gc_timestep_frequency (None or int): The number of timesteps between garbage collections. The default
43+
is `None`, which means no invoking the garbage collector during the executions. If an integer is
44+
provided, the garbage collector is applied every `gc_timestep_frequency` timestep. That is useful when
45+
being affected by the Python fails to track and clean all checkpoint objects in memory properly.
46+
gc_generation (int): The generation for garbage collection. Default is 2 that runs a full collection.
47+
To have more information about the garbage collector generation,
48+
please refer to the `documentation
49+
<https://docs.python.org/3/library/gc.html#gc.collect>`_.
4050
4151
Attributes:
4252
tape (Tape): A list of blocks :class:`Block` instances.
@@ -52,7 +62,7 @@ class CheckpointManager:
5262
_current_action (checkpoint_schedules.CheckpointAction): The current `checkpoint_schedules` action.
5363
5464
"""
55-
def __init__(self, schedule, tape):
65+
def __init__(self, schedule, tape, gc_timestep_frequency=None, gc_generation=2):
5666
if (
5767
schedule.uses_storage_type(StorageType.DISK)
5868
and not tape._package_data
@@ -78,9 +88,15 @@ def __init__(self, schedule, tape):
7888
self.forward_schedule.append(self._current_action)
7989
# Tell the tape to only checkpoint input data until told otherwise.
8090
self.tape.latest_checkpoint = 0
81-
self.end_timestep(-1)
8291
self._keep_init_state_in_work = False
83-
self._adj_deps_cleaned = False
92+
self._gc_timestep_frequency = gc_timestep_frequency
93+
self._gc_generation = gc_generation
94+
# ``self._global_deps`` stores checkpoint dependencies that remain unchanged across
95+
# timesteps (``self.tape.timesteps``). During the forward taping process, the code
96+
# checks whether a dependency is in ``self._global_deps`` to avoid unnecessary clearing
97+
# and recreation of its checkpoint data.
98+
self._global_deps = set()
99+
self.end_timestep(-1)
84100

85101
def end_timestep(self, timestep):
86102
"""Mark the end of one timestep when taping the forward model.
@@ -164,11 +180,30 @@ def _(self, cp_action, timestep):
164180
):
165181
for package in self.tape._package_data.values():
166182
package.continue_checkpointing()
183+
if timestep == 1:
184+
# Store the possible global dependencies.
185+
for deps in self.tape.timesteps[timestep - 1].checkpointable_state:
186+
self._global_deps.add(deps)
187+
else:
188+
# Check if the block variables stored in `self._global_deps` are still
189+
# dependencies in the previous timestep. If not, remove them from the
190+
# global dependencies.
191+
deps_to_clear = self._global_deps.difference(
192+
self.tape.timesteps[timestep - 1].checkpointable_state)
193+
194+
# Remove the block variables that are not global dependencies.
195+
self._global_deps.difference_update(deps_to_clear)
196+
197+
# For no global dependencies, checkpoint storage occurs at a self.tape
198+
# timestep only when required by an action from the schedule. Thus, we
199+
# have to clear the checkpoint of block variables excluded from the self._global_deps.
200+
for deps in deps_to_clear:
201+
deps._checkpoint = None
167202

168203
self.tape.timesteps[timestep - 1].checkpoint(
169-
_store_checkpointable_state, _store_adj_dependencies)
204+
_store_checkpointable_state, _store_adj_dependencies, self._global_deps)
170205
# Remove unnecessary variables in working memory from previous steps.
171-
for var in self.tape.timesteps[timestep - 1].checkpointable_state:
206+
for var in self.tape.timesteps[timestep - 1].checkpointable_state - self._global_deps:
172207
var._checkpoint = None
173208
for block in self.tape.timesteps[timestep - 1]:
174209
for out in block.get_outputs():
@@ -182,6 +217,10 @@ def _(self, cp_action, timestep):
182217
# Activate disk checkpointing only in the checkpointing process.
183218
for package in self.tape._package_data.values():
184219
package.pause_checkpointing()
220+
221+
if isinstance(self._gc_timestep_frequency, int) and timestep % self._gc_timestep_frequency == 0:
222+
logging.info("Running a garbage collection cycle")
223+
gc.collect(self._gc_generation)
185224
return True
186225
else:
187226
return False
@@ -300,7 +339,7 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
300339
for package in self.tape._package_data.values():
301340
package.continue_checkpointing()
302341
current_step.checkpoint(
303-
_store_checkpointable_state, _store_adj_dependencies)
342+
_store_checkpointable_state, _store_adj_dependencies, self._global_deps)
304343

305344
to_keep = set()
306345
if step < (self.total_timesteps - 1):
@@ -310,7 +349,7 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
310349
if functional:
311350
to_keep = to_keep.union([functional.block_variable])
312351

313-
for var in current_step.checkpointable_state - to_keep:
352+
for var in current_step.checkpointable_state - to_keep.union(self._global_deps):
314353
# Handle the case where step is 0
315354
if step == 0 and var not in current_step._checkpoint:
316355
# Ensure initialisation state is kept.
@@ -346,6 +385,10 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
346385
if bv not in current_step.adjoint_dependencies.union(to_keep):
347386
bv._checkpoint = None
348387

388+
if self._gc_timestep_frequency and step % self._gc_timestep_frequency == 0:
389+
logging.info("Running a garbage collection cycle")
390+
gc.collect(self._gc_generation)
391+
349392
step += 1
350393
if cp_action.storage == StorageType.DISK:
351394
# Activate disk checkpointing only in the checkpointing process.
@@ -361,11 +404,11 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
361404
current_step = self.tape.timesteps[step]
362405
for block in reversed(current_step):
363406
block.evaluate_adj(markings=markings)
364-
if not self._adj_deps_cleaned:
407+
if not current_step._adj_deps_cleaned:
365408
for out in block._outputs:
366409
if not out.marked_in_path:
367410
current_step.adjoint_dependencies.discard(out)
368-
self._adj_deps_cleaned = True
411+
current_step._adj_deps_cleaned = True
369412
# Output variables are used for the last time when running
370413
# backwards.
371414
to_keep = current_step.checkpointable_state
@@ -379,6 +422,9 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
379422
out.reset_variables(("adjoint", "hessian"))
380423
if cp_action.clear_adj_deps and out not in to_keep:
381424
out._checkpoint = None
425+
if self._gc_timestep_frequency and step % self._gc_timestep_frequency == 0:
426+
logging.info("Running a garbage collection cycle")
427+
gc.collect(self._gc_generation)
382428

383429
@process_operation.register(Copy)
384430
def _(self, cp_action, progress_bar, **kwargs):

pyadjoint/tape.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def add_to_adjoint_dependencies(self, block_var, last_used):
304304
for step in self.timesteps[last_used + 1:]:
305305
step.adjoint_dependencies.add(block_var)
306306

307-
def enable_checkpointing(self, schedule):
307+
def enable_checkpointing(self, schedule, gc_timestep_frequency=None, gc_generation=2):
308308
"""Enable checkpointing on the adjoint evaluation.
309309
310310
A checkpoint manager able to execute the forward and adjoint computations
@@ -313,12 +313,23 @@ def enable_checkpointing(self, schedule):
313313
Args:
314314
schedule (checkpoint_schedules.schedule): A schedule provided by the
315315
checkpoint_schedules package.
316+
gc_timestep_frequency (None or int): The timestep frequency for garbage collection.
317+
For additional information, please refer to the :class:`CheckpointManager`
318+
documentation.
319+
gc_generation (int): The generation for garbage collection. For additional
320+
information, please refer to the :class:`CheckpointManager` documentation.
316321
"""
317322
if self._blocks:
318323
raise CheckpointError(
319324
"Checkpointing must be enabled before any blocks are added to the tape."
320325
)
321-
self._checkpoint_manager = CheckpointManager(schedule, self)
326+
327+
if gc_timestep_frequency is not None and not isinstance(gc_timestep_frequency, int):
328+
raise CheckpointError("gc_timestep_frequency must be an integer.")
329+
330+
self._checkpoint_manager = CheckpointManager(
331+
schedule, self, gc_timestep_frequency=gc_timestep_frequency,
332+
gc_generation=gc_generation)
322333

323334
def get_blocks(self, tag=None):
324335
"""Returns a list of the blocks on the tape.
@@ -782,25 +793,36 @@ def __init__(self, blocks=()):
782793
# A dictionary mapping the block variables in the checkpointable state
783794
# to their checkpoint values.
784795
self._checkpoint = {}
796+
# A flag to indicate whether the adjoint dependencies have been cleaned
797+
# from the outputs not marked in the path.
798+
self._adj_deps_cleaned = False
785799

786800
def copy(self, blocks=None):
787801
out = TimeStep(blocks or self)
788802
out.checkpointable_state = self.checkpointable_state
789803
return out
790804

791-
def checkpoint(self, checkpointable_state, adj_dependencies):
805+
def checkpoint(self, checkpointable_state, adj_dependencies, global_deps):
792806
"""Store a copy of the checkpoints in the checkpointable state.
793807
794808
Args:
795809
checkpointable_state (bool): If True, store the checkpointable state
796810
required to restart from the start of a timestep.
797811
adj_dependencies (bool): If True, store the adjoint dependencies required
798812
to compute the adjoint of a timestep.
813+
global_deps (set): This set stores the common dependencies for all timesteps.
814+
For additional information, please refer to the :class:`CheckpointManager`
815+
documentation.
799816
"""
800817
with stop_annotating():
801818
if checkpointable_state:
802819
for var in self.checkpointable_state:
803-
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
820+
if var in global_deps:
821+
# Creating a new checkpoint object is not necessary here
822+
# because the global dependencies do not change.
823+
self._checkpoint[var] = var._checkpoint
824+
else:
825+
self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
804826

805827
if adj_dependencies:
806828
for var in self.adjoint_dependencies:

0 commit comments

Comments
 (0)