@@ -343,26 +343,27 @@ def test_TGV_CPU_vs_GPU(self):
343
343
print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" )
344
344
345
345
346
+ # set parameters
346
347
# set parameters
347
348
pars = {'algorithm' : TGV , \
348
- 'input' : u0 ,\
349
- 'regularisation_parameter' :0.02 , \
350
- 'alpha1' :1.0 ,\
351
- 'alpha0' :2.0 ,\
352
- 'number_of_iterations' :1000 ,\
353
- 'LipshitzConstant' :12 ,\
354
- 'tolerance_constant' :0.0 }
349
+ 'input' : u0 ,\
350
+ 'regularisation_parameter' :0.02 , \
351
+ 'alpha1' :1.0 ,\
352
+ 'alpha0' :2.0 ,\
353
+ 'number_of_iterations' :1000 ,\
354
+ 'LipshitzConstant' :12 ,\
355
+ 'tolerance_constant' :0.0 }
355
356
356
357
print ("#############TGV CPU####################" )
357
358
start_time = timeit .default_timer ()
358
- # infovector = np.zeros((2,), dtype='float32')
359
- tgv_cpu = TGV (pars ['input' ],
359
+ infovector = np .zeros ((2 ,), dtype = 'float32' )
360
+ tgv_cpu = TGV (pars ['input' ],
360
361
pars ['regularisation_parameter' ],
361
362
pars ['alpha1' ],
362
363
pars ['alpha0' ],
363
364
pars ['number_of_iterations' ],
364
365
pars ['LipshitzConstant' ],
365
- pars ['tolerance_constant' ], device = 'cpu' )
366
+ pars ['tolerance_constant' ],device = 'cpu' , infovector = infovector )
366
367
367
368
rms = rmse (Im , tgv_cpu )
368
369
pars ['rmse' ] = rms
@@ -373,13 +374,13 @@ def test_TGV_CPU_vs_GPU(self):
373
374
374
375
print ("##############TGV GPU##################" )
375
376
start_time = timeit .default_timer ()
376
- tgv_gpu = TGV (pars ['input' ],
377
+ tgv_gpu = TGV (pars ['input' ],
377
378
pars ['regularisation_parameter' ],
378
379
pars ['alpha1' ],
379
380
pars ['alpha0' ],
380
381
pars ['number_of_iterations' ],
381
382
pars ['LipshitzConstant' ],
382
- pars ['tolerance_constant' ], device = 'gpu' )
383
+ pars ['tolerance_constant' ], device = 'gpu' , infovector = infovector )
383
384
384
385
rms = rmse (Im , tgv_gpu )
385
386
pars ['rmse' ] = rms
@@ -388,7 +389,7 @@ def test_TGV_CPU_vs_GPU(self):
388
389
txtstr += "%s = %.3fs" % ('elapsed time' ,timeit .default_timer () - start_time )
389
390
print (txtstr )
390
391
print ("--------Compare the results--------" )
391
- tolerance = 1e-05
392
+ tolerance = 1e-02
392
393
diff_im = np .zeros (np .shape (tgv_gpu ))
393
394
diff_im = abs (tgv_cpu - tgv_gpu )
394
395
diff_im [diff_im > tolerance ] = 1
0 commit comments