Skip to content

Commit 118e5d3

Browse files
committed
improve tests and fix other comments
1 parent d837949 commit 118e5d3

File tree

4 files changed

+58
-188
lines changed

4 files changed

+58
-188
lines changed

src/TensorOperationsTBLIS.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using TensorOperations: istrivialpermutation, BlasFloat, linearize
66
using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd,
77
argcheck_tensortrace, dimcheck_tensortrace,
88
argcheck_tensorcontract, dimcheck_tensorcontract
9+
using TensorOperations: stridedtensoradd!, stridedtensortrace!, stridedtensorcontract!
910
using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError
1011
using LinearAlgebra: BlasFloat
1112
using TupleTools

src/strided.jl

+20-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#-------------------------------------------------------------------------------------------
2-
# Force strided implementation on AbstractArray instances with HPTTBLAS backend
2+
# Force strided implementation on AbstractArray instances with TBLIS backend
33
#-------------------------------------------------------------------------------------------
44
const SV = StridedView
55
function TensorOperations.tensoradd!(C::AbstractArray,
@@ -69,52 +69,51 @@ function tblis_tensor(A::StridedView,
6969
end
7070
end
7171

72-
function stridedtensoradd!(C::StridedView{T},
73-
A::StridedView{T}, pA::Index2Tuple,
74-
α::Number, β::Number,
75-
backend::TBLIS,
76-
allocator=DefaultAllocator()) where {T<:BlasFloat}
72+
function TensorOperations.stridedtensoradd!(C::StridedView{T},
73+
A::StridedView{T}, pA::Index2Tuple,
74+
α::Number, β::Number,
75+
backend::TBLIS,
76+
allocator=DefaultAllocator()) where {T<:BlasFloat}
7777
argcheck_tensoradd(C, A, pA)
7878
dimcheck_tensoradd(C, A, pA)
7979
if Base.mightalias(C, A)
8080
throw(ArgumentError("output tensor must not be aliased with input tensor"))
8181
end
8282

83-
# directly use TBLIS types to avoid additional conversion step
8483
C_tblis = tblis_tensor(C, β)
8584
A_tblis = tblis_tensor(A, α)
8685
einA, einC = TensorOperations.add_labels(pA)
8786
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
8887
return C
8988
end
9089

91-
function stridedtensortrace!(C::StridedView{T},
92-
A::StridedView{T}, p::Index2Tuple, q::Index2Tuple,
93-
α::Number, β::Number,
94-
backend::TBLIS,
95-
allocator=DefaultAllocator()) where {T<:BlasFloat}
90+
function TensorOperations.stridedtensortrace!(C::StridedView{T},
91+
A::StridedView{T},
92+
p::Index2Tuple,
93+
q::Index2Tuple,
94+
α::Number, β::Number,
95+
backend::TBLIS,
96+
allocator=DefaultAllocator()) where {T<:BlasFloat}
9697
argcheck_tensortrace(C, A, p, q)
9798
dimcheck_tensortrace(C, A, p, q)
9899

99100
Base.mightalias(C, A) &&
100101
throw(ArgumentError("output tensor must not be aliased with input tensor"))
101102

102-
# directly use TBLIS types to avoid additional conversion step
103-
# isone(β) || rmul!(C, β) # TODO: check if TBLIS handles the scaling correctly
104103
C_tblis = tblis_tensor(C, β)
105104
A_tblis = tblis_tensor(A, α)
106105
einA, einC = TensorOperations.trace_labels(p, q)
107106
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
108107
return C
109108
end
110109

111-
function stridedtensorcontract!(C::StridedView{T},
112-
A::StridedView{T}, pA::Index2Tuple,
113-
B::StridedView{T}, pB::Index2Tuple,
114-
pAB::Index2Tuple,
115-
α::Number, β::Number,
116-
backend::TBLIS,
117-
allocator=DefaultAllocator()) where {T<:BlasFloat}
110+
function TensorOperations.stridedtensorcontract!(C::StridedView{T},
111+
A::StridedView{T}, pA::Index2Tuple,
112+
B::StridedView{T}, pB::Index2Tuple,
113+
pAB::Index2Tuple,
114+
α::Number, β::Number,
115+
backend::TBLIS,
116+
allocator=DefaultAllocator()) where {T<:BlasFloat}
118117
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
119118
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
120119

test/methods.jl

+14-56
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ typelist = (Float32, Float64, ComplexF32, ComplexF64)
1313
@test C1 C2
1414
@test C2 == C3
1515
@test C1 ncon(Any[A], Any[[-2, -4, -1, -3]])
16-
@test_throws IndexError tensorcopy(1:4, A, 1:3)
17-
@test_throws IndexError tensorcopy(1:4, A, [1, 2, 2, 4])
1816
end
1917

2018
@testset "tensoradd" begin
@@ -26,45 +24,29 @@ typelist = (Float32, Float64, ComplexF32, ComplexF64)
2624
C3 = @inferred tensoradd(A, ((1:4...,), ()), false, B, (p, ()), false, 1, 1, b)
2725
@test C1 C2
2826
@test C2 == C3
29-
@test C1 A + ncon(Any[B], Any[[-2, -4, -1, -3]]; backend=b)
30-
@test_throws DimensionMismatch tensoradd(A, 1:4, B, 1:4)
3127
end
3228

3329
@testset "tensortrace" begin
3430
A = randn(Float64, (50, 100, 100))
35-
C1 = zeros(50)
36-
for i in 1:50
37-
for j in 1:100
38-
C1[i] += A[i, j, j]
39-
end
40-
end
31+
C1 = tensortrace(A, [:a, :b, :b])
4132
C2 = tensortrace(A, [:a, :b, :b]; backend=b)
4233
C3 = ncon(Any[A], Any[[-1, 1, 1]]; backend=b)
4334
@test C1 C2
4435
@test C2 == C3
4536
A = randn(Float64, (3, 20, 5, 3, 20, 4, 5))
46-
C1 = zeros(4, 3, 3)
47-
for i1 in 1:4, i2 in 1:3, i3 in 1:3
48-
for j1 in 1:20, j2 in 1:5
49-
C1[i1, i2, i3] += A[i2, j1, j2, i3, j1, i1, j2]
50-
end
51-
end
37+
C1 = tensortrace((:e, :a, :d), A, (:a, :b, :c, :d, :b, :e, :c))
5238
C2 = @inferred tensortrace((:e, :a, :d), A, (:a, :b, :c, :d, :b, :e, :c); backend=b)
5339
C3 = @inferred tensortrace(A, ((6, 1, 4), ()), ((2, 3), (5, 7)), false, 1.0, b)
5440
C4 = ncon(Any[A], Any[[-2, 1, 2, -3, 1, -1, 2]]; backend=b)
5541
@test C1 C2
5642
@test C2 == C3 == C4
57-
@test_throws IndexError tensortrace(randn(2, 2, 2, 2, 2, 2, 2), ((1,), (3, 2)),
58-
((1, 5), (2, 6)), false)
5943
end
6044

6145
@testset "tensorcontract" begin
6246
A = randn(T, (3, 20, 5, 3, 4))
6347
B = randn(T, (5, 6, 20, 3))
64-
C1 = zeros(T, (3, 3, 4, 3, 6))
65-
for a in 1:3, b in 1:20, c in 1:5, d in 1:3, e in 1:4, f in 1:6, g in 1:3
66-
C1[a, g, e, d, f] += A[a, b, c, d, e] * B[c, f, b, g]
67-
end
48+
C1 = tensorcontract((:a, :g, :e, :d, :f), A, (:a, :b, :c, :d, :e), B,
49+
(:c, :f, :b, :g))
6850
C2 = @inferred tensorcontract((:a, :g, :e, :d, :f),
6951
A, (:a, :b, :c, :d, :e), B, (:c, :f, :b, :g);
7052
backend=b)
@@ -78,8 +60,6 @@ typelist = (Float32, Float64, ComplexF32, ComplexF64)
7860

7961
@test C1 C2
8062
@test C2 == C3 == C4 == C5
81-
@test_throws IndexError tensorcontract(A, [:a, :b, :c, :d], B, [:c, :f, :b, :g])
82-
@test_throws IndexError tensorcontract(A, [:a, :b, :c, :a, :e], B, [:c, :f, :b, :g])
8363
end
8464

8565
@testset "tensorproduct" begin
@@ -90,17 +70,10 @@ typelist = (Float32, Float64, ComplexF32, ComplexF64)
9070
A, (1, 2, 3, 4), B, (5, 6, 7, 8); backend=b)),
9171
(5 * 5 * 5 * 5, 5 * 5 * 5 * 5))
9272
@test C1 C2
93-
@test_throws IndexError tensorproduct(A, [:a, :b, :c, :d],
94-
B, [:d, :e, :f, :g])
95-
@test_throws IndexError tensorproduct([:a, :b, :c, :d, :e, :f, :g, :i],
96-
A, [:a, :b, :c, :d], B, [:e, :f, :g, :h])
9773

9874
A = rand(1, 2)
9975
B = rand(4, 5)
100-
C1 = zeros(T, (2, 4, 1, 5))
101-
for i in axes(C1, 1), j in axes(C1, 2), k in axes(C1, 3), l in axes(C1, 4)
102-
C1[i, j, k, l] = A[k, i] * B[j, l]
103-
end
76+
C1 = tensorcontract((-1, -2, -3, -4), A, (-3, -1), false, B, (-2, -4), false)
10477
C2 = tensorcontract((-1, -2, -3, -4), A, (-3, -1), false, B, (-2, -4), false;
10578
backend=b)
10679
C3 = tensorproduct(A, ((1, 2), ()), false, B, ((), (1, 2)), false,
@@ -141,12 +114,11 @@ end
141114
p = (3, 1, 4, 2)
142115
Cbig = zeros(ComplexF32, (50, 50, 50, 50))
143116
C = view(Cbig, 13 .+ (0:6), 11 .+ 4 * (0:9), 15 .+ 4 * (0:8), 4 .+ 3 * (0:6))
144-
Acopy = tensorcopy(p, A, 1:4)
145117
Ccopy = tensorcopy(1:4, C, 1:4)
146118
α = randn(ComplexF32)
147119
β = randn(ComplexF32)
148120
tensoradd!(C, A, (p, ()), false, α, β, b)
149-
Ccopy = β * Ccopy + α * Acopy
121+
tensoradd!(Ccopy, A, (p, ()), false, α, β) # default backend
150122
@test C Ccopy
151123
@test_throws IndexError tensoradd!(C, A, ((1, 2, 3), ()), false, 1.2, 0.5, b)
152124
@test_throws DimensionMismatch tensoradd!(C, A, ((1, 2, 3, 4), ()), false, 1.2, 0.5,
@@ -164,10 +136,7 @@ end
164136
α = randn(Float64)
165137
β = randn(Float64)
166138
tensortrace!(B, A, ((2, 3), ()), ((1,), (4,)), true, α, β, b)
167-
Bcopy = β * Bcopy
168-
for i in 1 .+ (0:8)
169-
Bcopy += α * conj(view(A, i, :, :, i))
170-
end
139+
tensortrace!(Bcopy, A, ((2, 3), ()), ((1,), (4,)), true, α, β) # default backend
171140
@test B Bcopy
172141
@test_throws IndexError tensortrace!(B, A, ((1,), ()), ((2,), (3,)), false, α, β, b)
173142
@test_throws DimensionMismatch tensortrace!(B, A, ((1, 4), ()), ((2,), (3,)), false,
@@ -178,6 +147,7 @@ end
178147
α, β, b)
179148
end
180149

150+
bref = TensorOperations.DefaultBackend() # reference backend
181151
@testset "tensorcontract! with allocator = $allocator" for allocator in
182152
(DefaultAllocator(),
183153
ManualAllocator())
@@ -192,36 +162,24 @@ end
192162
Ccopy = tensorcopy(C, 1:3)
193163
α = randn(ComplexF64)
194164
β = randn(ComplexF64)
195-
Ccopy = β * Ccopy
196-
for d in 1 .+ (0:8), a in 1 .+ (0:8), e in 1 .+ (0:7)
197-
for b in 1 .+ (0:14), c in 1 .+ (0:6)
198-
Ccopy[d, a, e] += α * A[a, b, c, d] * conj(B[c, e, b])
199-
end
200-
end
201165
tensorcontract!(C, A, ((4, 1), (2, 3)), false, B, ((3, 1), (2,)), true,
202166
((1, 2, 3), ()), α, β, b, allocator)
167+
tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), false, B, ((3, 1), (2,)), true,
168+
((1, 2, 3), ()), α, β, bref, allocator)
203169
@test C Ccopy
204170

205171
Ccopy = tensorcopy(C, 1:3)
206-
Ccopy = β * Ccopy
207-
for d in 1 .+ (0:8), a in 1 .+ (0:8), e in 1 .+ (0:7)
208-
for b in 1 .+ (0:14), c in 1 .+ (0:6)
209-
Ccopy[d, a, e] += α * conj(A[a, b, c, d]) * conj(B[c, e, b])
210-
end
211-
end
212172
tensorcontract!(C, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), true,
213173
((1, 2, 3), ()), α, β, b, allocator)
174+
tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), true,
175+
((1, 2, 3), ()), α, β, bref, allocator)
214176
@test C Ccopy
215177

216178
Ccopy = tensorcopy(C, 1:3)
217-
Ccopy = β * Ccopy
218-
for d in 1 .+ (0:8), a in 1 .+ (0:8), e in 1 .+ (0:7)
219-
for b in 1 .+ (0:14), c in 1 .+ (0:6)
220-
Ccopy[d, a, e] += α * conj(A[a, b, c, d]) * B[c, e, b]
221-
end
222-
end
223179
tensorcontract!(C, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), false,
224180
((1, 2, 3), ()), α, β, b, allocator)
181+
tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), false,
182+
((1, 2, 3), ()), α, β, bref, allocator)
225183
@test C Ccopy
226184

227185
@test_throws IndexError tensorcontract!(C,

0 commit comments

Comments
 (0)