1
1
from enum import Enum
2
2
import sys
3
+ import gc
4
+ import logging
3
5
from functools import singledispatchmethod
4
6
from checkpoint_schedules import Copy , Move , EndForward , EndReverse , \
5
7
Forward , Reverse , StorageType , SingleMemoryStorageSchedule
@@ -37,6 +39,14 @@ class CheckpointManager:
37
39
Args:
38
40
schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
39
41
tape (Tape): A list of blocks :class:`Block` instances.
42
+ gc_timestep_frequency (None or int): The number of timesteps between garbage collections. The default
43
+ is `None`, which means no invoking the garbage collector during the executions. If an integer is
44
+ provided, the garbage collector is applied every `gc_timestep_frequency` timestep. That is useful when
45
+ being affected by the Python fails to track and clean all checkpoint objects in memory properly.
46
+ gc_generation (int): The generation for garbage collection. Default is 2 that runs a full collection.
47
+ To have more information about the garbage collector generation,
48
+ please refer to the `documentation
49
+ <https://docs.python.org/3/library/gc.html#gc.collect>`_.
40
50
41
51
Attributes:
42
52
tape (Tape): A list of blocks :class:`Block` instances.
@@ -52,7 +62,7 @@ class CheckpointManager:
52
62
_current_action (checkpoint_schedules.CheckpointAction): The current `checkpoint_schedules` action.
53
63
54
64
"""
55
- def __init__ (self , schedule , tape ):
65
+ def __init__ (self , schedule , tape , gc_timestep_frequency = None , gc_generation = 2 ):
56
66
if (
57
67
schedule .uses_storage_type (StorageType .DISK )
58
68
and not tape ._package_data
@@ -78,9 +88,15 @@ def __init__(self, schedule, tape):
78
88
self .forward_schedule .append (self ._current_action )
79
89
# Tell the tape to only checkpoint input data until told otherwise.
80
90
self .tape .latest_checkpoint = 0
81
- self .end_timestep (- 1 )
82
91
self ._keep_init_state_in_work = False
83
- self ._adj_deps_cleaned = False
92
+ self ._gc_timestep_frequency = gc_timestep_frequency
93
+ self ._gc_generation = gc_generation
94
+ # ``self._global_deps`` stores checkpoint dependencies that remain unchanged across
95
+ # timesteps (``self.tape.timesteps``). During the forward taping process, the code
96
+ # checks whether a dependency is in ``self._global_deps`` to avoid unnecessary clearing
97
+ # and recreation of its checkpoint data.
98
+ self ._global_deps = set ()
99
+ self .end_timestep (- 1 )
84
100
85
101
def end_timestep (self , timestep ):
86
102
"""Mark the end of one timestep when taping the forward model.
@@ -164,11 +180,30 @@ def _(self, cp_action, timestep):
164
180
):
165
181
for package in self .tape ._package_data .values ():
166
182
package .continue_checkpointing ()
183
+ if timestep == 1 :
184
+ # Store the possible global dependencies.
185
+ for deps in self .tape .timesteps [timestep - 1 ].checkpointable_state :
186
+ self ._global_deps .add (deps )
187
+ else :
188
+ # Check if the block variables stored in `self._global_deps` are still
189
+ # dependencies in the previous timestep. If not, remove them from the
190
+ # global dependencies.
191
+ deps_to_clear = self ._global_deps .difference (
192
+ self .tape .timesteps [timestep - 1 ].checkpointable_state )
193
+
194
+ # Remove the block variables that are not global dependencies.
195
+ self ._global_deps .difference_update (deps_to_clear )
196
+
197
+ # For no global dependencies, checkpoint storage occurs at a self.tape
198
+ # timestep only when required by an action from the schedule. Thus, we
199
+ # have to clear the checkpoint of block variables excluded from the self._global_deps.
200
+ for deps in deps_to_clear :
201
+ deps ._checkpoint = None
167
202
168
203
self .tape .timesteps [timestep - 1 ].checkpoint (
169
- _store_checkpointable_state , _store_adj_dependencies )
204
+ _store_checkpointable_state , _store_adj_dependencies , self . _global_deps )
170
205
# Remove unnecessary variables in working memory from previous steps.
171
- for var in self .tape .timesteps [timestep - 1 ].checkpointable_state :
206
+ for var in self .tape .timesteps [timestep - 1 ].checkpointable_state - self . _global_deps :
172
207
var ._checkpoint = None
173
208
for block in self .tape .timesteps [timestep - 1 ]:
174
209
for out in block .get_outputs ():
@@ -182,6 +217,10 @@ def _(self, cp_action, timestep):
182
217
# Activate disk checkpointing only in the checkpointing process.
183
218
for package in self .tape ._package_data .values ():
184
219
package .pause_checkpointing ()
220
+
221
+ if isinstance (self ._gc_timestep_frequency , int ) and timestep % self ._gc_timestep_frequency == 0 :
222
+ logging .info ("Running a garbage collection cycle" )
223
+ gc .collect (self ._gc_generation )
185
224
return True
186
225
else :
187
226
return False
@@ -300,7 +339,7 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
300
339
for package in self .tape ._package_data .values ():
301
340
package .continue_checkpointing ()
302
341
current_step .checkpoint (
303
- _store_checkpointable_state , _store_adj_dependencies )
342
+ _store_checkpointable_state , _store_adj_dependencies , self . _global_deps )
304
343
305
344
to_keep = set ()
306
345
if step < (self .total_timesteps - 1 ):
@@ -310,7 +349,7 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
310
349
if functional :
311
350
to_keep = to_keep .union ([functional .block_variable ])
312
351
313
- for var in current_step .checkpointable_state - to_keep :
352
+ for var in current_step .checkpointable_state - to_keep . union ( self . _global_deps ) :
314
353
# Handle the case where step is 0
315
354
if step == 0 and var not in current_step ._checkpoint :
316
355
# Ensure initialisation state is kept.
@@ -346,6 +385,10 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
346
385
if bv not in current_step .adjoint_dependencies .union (to_keep ):
347
386
bv ._checkpoint = None
348
387
388
+ if self ._gc_timestep_frequency and step % self ._gc_timestep_frequency == 0 :
389
+ logging .info ("Running a garbage collection cycle" )
390
+ gc .collect (self ._gc_generation )
391
+
349
392
step += 1
350
393
if cp_action .storage == StorageType .DISK :
351
394
# Activate disk checkpointing only in the checkpointing process.
@@ -361,11 +404,11 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
361
404
current_step = self .tape .timesteps [step ]
362
405
for block in reversed (current_step ):
363
406
block .evaluate_adj (markings = markings )
364
- if not self ._adj_deps_cleaned :
407
+ if not current_step ._adj_deps_cleaned :
365
408
for out in block ._outputs :
366
409
if not out .marked_in_path :
367
410
current_step .adjoint_dependencies .discard (out )
368
- self ._adj_deps_cleaned = True
411
+ current_step ._adj_deps_cleaned = True
369
412
# Output variables are used for the last time when running
370
413
# backwards.
371
414
to_keep = current_step .checkpointable_state
@@ -379,6 +422,9 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
379
422
out .reset_variables (("adjoint" , "hessian" ))
380
423
if cp_action .clear_adj_deps and out not in to_keep :
381
424
out ._checkpoint = None
425
+ if self ._gc_timestep_frequency and step % self ._gc_timestep_frequency == 0 :
426
+ logging .info ("Running a garbage collection cycle" )
427
+ gc .collect (self ._gc_generation )
382
428
383
429
@process_operation .register (Copy )
384
430
def _ (self , cp_action , progress_bar , ** kwargs ):
0 commit comments