@@ -13,17 +13,15 @@ export tblis_set_num_threads, tblis_get_num_threads
13
13
# TensorOperations
14
14
# ------------------
15
15
16
- const TblisBackend = TensorOperations. Backend{:tblis }
16
+ const tblisBackend = TensorOperations. Backend{:tblis }
17
17
18
18
function TensorOperations. tensoradd! (C:: StridedArray{T} , pC:: Index2Tuple ,
19
19
A:: StridedArray{T} , conjA:: Symbol ,
20
20
α:: Number , β:: Number ,
21
- :: TblisBackend ) where {T<: BlasFloat }
21
+ :: tblisBackend ) where {T<: BlasFloat }
22
22
TensorOperations. argcheck_tensoradd (C, pC, A)
23
- # check dimensions
24
- size (C) == getindex .(Ref (size (A)), linearize (pC)) ||
25
- throw (DimensionMismatch (" incompatible sizes" ))
26
-
23
+ TensorOperations. dimcheck_tensoradd (C, pC, A)
24
+
27
25
szC = collect (size (C))
28
26
strC = collect (strides (C))
29
27
C_tblis = tblis_tensor (C, szC, strC, β)
@@ -42,11 +40,11 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
42
40
A:: StridedArray{T} , pA:: Index2Tuple ,
43
41
conjA:: Symbol , B:: StridedArray{T} ,
44
42
pB:: Index2Tuple , conjB:: Symbol , α:: Number ,
45
- β:: Number , :: TblisBackend ) where {T<: BlasFloat }
43
+ β:: Number , :: tblisBackend ) where {T<: BlasFloat }
46
44
TensorOperations. argcheck_tensorcontract (C, pC, A, pA, B, pB)
47
45
TensorOperations. dimcheck_tensorcontract (C, pC, A, pA, B, pB)
48
46
49
- rmul! (C, β)
47
+ rmul! (C, β) # TODO : is it possible to use tblis scaling here?
50
48
szC = ndims (C) == 0 ? Int[] : collect (size (C))
51
49
strC = ndims (C) == 0 ? Int[] : collect (strides (C))
52
50
C_tblis = tblis_tensor (C, szC, strC)
@@ -67,10 +65,11 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
67
65
end
68
66
69
67
# partial traces do not exist in tblis afaik -> use default implementation
68
+ # TODO : implement full trace
70
69
function TensorOperations. tensortrace! (C:: StridedArray{T} , pC:: Index2Tuple ,
71
70
A:: StridedArray{T} , pA:: Index2Tuple , conjA:: Symbol ,
72
71
α:: Number , β:: Number ,
73
- :: TblisBackend ) where {T<: BlasFloat }
72
+ :: tblisBackend ) where {T<: BlasFloat }
74
73
return tensortrace! (C, pC, A, pA, conjA, α, β)
75
74
end
76
75
0 commit comments