1
1
from enum import Enum
2
2
import sys
3
3
from functools import singledispatchmethod
4
- from checkpoint_schedules import Copy , Move , EndForward , EndReverse , Forward , Reverse , StorageType
4
+ from checkpoint_schedules import Copy , Move , EndForward , EndReverse , \
5
+ Forward , Reverse , StorageType , SingleMemoryStorageSchedule
5
6
# A callback interface allowing the user to provide a
6
7
# custom error message when disk checkpointing is not configured.
7
8
disk_checkpointing_callback = {}
@@ -78,6 +79,8 @@ def __init__(self, schedule, tape):
78
79
# Tell the tape to only checkpoint input data until told otherwise.
79
80
self .tape .latest_checkpoint = 0
80
81
self .end_timestep (- 1 )
82
+ self ._keep_init_state_in_work = False
83
+ self ._adj_deps_cleaned = False
81
84
82
85
def end_timestep (self , timestep ):
83
86
"""Mark the end of one timestep when taping the forward model.
@@ -299,25 +302,50 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
299
302
current_step .checkpoint (
300
303
_store_checkpointable_state , _store_adj_dependencies )
301
304
302
- if (
303
- (cp_action .write_adj_deps and cp_action .storage != StorageType .WORK )
304
- or not cp_action .write_adj_deps
305
- ):
306
- to_keep = set ()
307
- if step < (self .total_timesteps - 1 ):
308
- next_step = self .tape .timesteps [step + 1 ]
309
- # The checkpointable state set of the current step.
310
- to_keep = next_step .checkpointable_state
311
- if functional :
312
- to_keep = to_keep .union ([functional .block_variable ])
313
- for block in current_step :
314
- # Remove unnecessary variables from previous steps.
315
- for bv in block .get_outputs ():
305
+ to_keep = set ()
306
+ if step < (self .total_timesteps - 1 ):
307
+ next_step = self .tape .timesteps [step + 1 ]
308
+ # The checkpointable state set of the current step.
309
+ to_keep = next_step .checkpointable_state
310
+ if functional :
311
+ to_keep = to_keep .union ([functional .block_variable ])
312
+
313
+ for var in current_step .checkpointable_state - to_keep :
314
+ # Handle the case where step is 0
315
+ if step == 0 and var not in current_step ._checkpoint :
316
+ # Ensure initialisation state is kept.
317
+ self ._keep_init_state_in_work = True
318
+ break
319
+
320
+ # Handle the case for SingleMemoryStorageSchedule
321
+ if isinstance (self ._schedule , SingleMemoryStorageSchedule ):
322
+ if step > 1 and var not in self .tape .timesteps [step - 1 ].adjoint_dependencies :
323
+ var ._checkpoint = None
324
+ continue
325
+
326
+ # Handle variables in the initial timestep
327
+ if (
328
+ var in self .tape .timesteps [0 ].checkpointable_state
329
+ and self ._keep_init_state_in_work
330
+ ):
331
+ continue
332
+
333
+ # Clear the checkpoint for other cases
334
+ var ._checkpoint = None
335
+
336
+ for block in current_step :
337
+ # Remove unnecessary variables from previous steps.
338
+ for bv in block .get_outputs ():
339
+ if (
340
+ (cp_action .write_adj_deps and cp_action .storage != StorageType .WORK )
341
+ or not cp_action .write_adj_deps
342
+ ):
316
343
if bv not in to_keep :
317
344
bv ._checkpoint = None
318
- # Remove unnecessary variables from previous steps.
319
- for var in (current_step .checkpointable_state - to_keep ):
320
- var ._checkpoint = None
345
+ else :
346
+ if bv not in current_step .adjoint_dependencies .union (to_keep ):
347
+ bv ._checkpoint = None
348
+
321
349
step += 1
322
350
if cp_action .storage == StorageType .DISK :
323
351
# Activate disk checkpointing only in the checkpointing process.
@@ -333,22 +361,24 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
333
361
current_step = self .tape .timesteps [step ]
334
362
for block in reversed (current_step ):
335
363
block .evaluate_adj (markings = markings )
364
+ if not self ._adj_deps_cleaned :
365
+ for out in block ._outputs :
366
+ if not out .marked_in_path :
367
+ current_step .adjoint_dependencies .discard (out )
368
+ self ._adj_deps_cleaned = True
336
369
# Output variables are used for the last time when running
337
370
# backwards.
371
+ to_keep = current_step .checkpointable_state
372
+ if functional :
373
+ to_keep = to_keep .union ([functional .block_variable ])
338
374
for block in current_step :
339
375
block .reset_adjoint_state ()
340
- for var in block .get_outputs ():
341
- var .checkpoint = None
342
- var .reset_variables (("tlm" ,))
343
- if not var .is_control :
344
- var .reset_variables (("adjoint" , "hessian" ))
345
- if cp_action .clear_adj_deps :
346
- to_keep = current_step .checkpointable_state
347
- if functional :
348
- to_keep = to_keep .union ([functional .block_variable ])
349
- for output in block .get_outputs ():
350
- if output not in to_keep :
351
- output ._checkpoint = None
376
+ for out in block .get_outputs ():
377
+ out .reset_variables (("tlm" ,))
378
+ if not out .is_control :
379
+ out .reset_variables (("adjoint" , "hessian" ))
380
+ if cp_action .clear_adj_deps and out not in to_keep :
381
+ out ._checkpoint = None
352
382
353
383
@process_operation .register (Copy )
354
384
def _ (self , cp_action , progress_bar , ** kwargs ):
0 commit comments