Skip to content

Commit 92121af

Browse files
authored
Merge pull request #159 from jrmaddison/jrmaddison/numpy_2.0
NumPy 2.0 fix
2 parents 271343e + 4485fe6 commit 92121af

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

numpy_adjoint/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, *args, **kwargs):
1111

1212
@classmethod
1313
def _ad_init_object(cls, obj):
14-
return cls(obj.shape, numpy.float_, buffer=obj)
14+
return cls(obj.shape, obj.dtype, buffer=obj)
1515

1616
def _ad_create_checkpoint(self):
1717
return self.copy()

tests/pyadjoint/test_numpy.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import numpy as np
2+
from pyadjoint import *
3+
from numpy_adjoint import *
4+
5+
6+
def test_ndarray_getitem_single():
7+
a = create_overloaded_object(np.array([-2.0]))
8+
J = ReducedFunctional(a[0], Control(a))
9+
dJ = J.derivative()
10+
assert dJ == 1.0

0 commit comments

Comments
 (0)