Skip to content

Commit 49430e8

Browse files
authored
Small fixing for single disk checkpointing (#195)
* pause_checkpointing() every time step * Do not clean the step 0 if not in timestep checkpoint
1 parent dbf923e commit 49430e8

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

pyadjoint/checkpointing.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -205,22 +205,24 @@ def _(self, cp_action, timestep):
205205
# Remove unnecessary variables in working memory from previous steps.
206206
for var in self.tape.timesteps[timestep - 1].checkpointable_state - self._global_deps:
207207
var._checkpoint = None
208+
208209
for block in self.tape.timesteps[timestep - 1]:
209210
for out in block.get_outputs():
210211
out._checkpoint = None
212+
213+
if cp_action.storage == StorageType.DISK:
214+
# Activate disk checkpointing only in the checkpointing process.
215+
for package in self.tape._package_data.values():
216+
package.pause_checkpointing()
217+
218+
if isinstance(self._gc_timestep_frequency, int) and timestep % self._gc_timestep_frequency == 0:
219+
logging.info("Running a garbage collection cycle")
220+
gc.collect(self._gc_generation)
221+
211222
if timestep in cp_action and timestep < self.total_timesteps:
212223
self.tape.get_blocks().append_step()
213224
if cp_action.write_ics:
214225
self.tape.latest_checkpoint = cp_action.n0
215-
216-
if cp_action.storage == StorageType.DISK:
217-
# Activate disk checkpointing only in the checkpointing process.
218-
for package in self.tape._package_data.values():
219-
package.pause_checkpointing()
220-
221-
if isinstance(self._gc_timestep_frequency, int) and timestep % self._gc_timestep_frequency == 0:
222-
logging.info("Running a garbage collection cycle")
223-
gc.collect(self._gc_generation)
224226
return True
225227
else:
226228
return False

0 commit comments

Comments
 (0)