@@ -214,8 +214,7 @@ def evaluate_adj(self, last_block, markings):
214
214
if self .mode not in (Mode .EVALUATED , Mode .FINISHED_RECORDING ):
215
215
raise CheckpointError ("Evaluate Functional before calling gradient." )
216
216
217
- with self .tape .progress_bar ("Evaluating Adjoint" ,
218
- max = self .timesteps ) as bar :
217
+ with self .tape .progress_bar ("Evaluating Adjoint" , max = self .timesteps ) as bar :
219
218
if self .adjoint_evaluated :
220
219
reverse_iterator = iter (self .reverse_schedule )
221
220
while not isinstance (self ._current_action , EndReverse ):
@@ -257,7 +256,8 @@ def process_operation(self, cp_action, bar, **kwargs):
257
256
def _ (self , cp_action , bar , functional = None , ** kwargs ):
258
257
for step in cp_action :
259
258
if self .mode == Mode .RECOMPUTE :
260
- bar .next ()
259
+ if bar :
260
+ bar .next ()
261
261
# Get the blocks of the current step.
262
262
current_step = self .tape .timesteps [step ]
263
263
for block in current_step :
@@ -290,7 +290,8 @@ def _(self, cp_action, bar, functional=None, **kwargs):
290
290
@process_operation .register (Reverse )
291
291
def _ (self , cp_action , bar , markings , functional = None , ** kwargs ):
292
292
for step in cp_action :
293
- bar .next ()
293
+ if bar :
294
+ bar .next ()
294
295
# Get the blocks of the current step.
295
296
current_step = self .tape .timesteps [step ]
296
297
for block in reversed (current_step ):
0 commit comments