Skip to content

Commit 2652ad1

Browse files
committed
Add TBLIS implementation for tensortrace
1 parent 96c4261 commit 2652ad1

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

src/TensorOperationsTBLIS.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,27 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
6464
return C
6565
end
6666

67-
# partial traces do not exist in tblis afaik -> use default implementation
68-
# TODO: implement full trace
6967
function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,
7068
A::StridedArray{T}, pA::Index2Tuple, conjA::Symbol,
7169
α::Number, β::Number,
7270
::tblisBackend) where {T<:BlasFloat}
73-
return tensortrace!(C, pC, A, pA, conjA, α, β)
71+
TensorOperations.argcheck_tensortrace(C, pC, A, pA)
72+
TensorOperations.dimcheck_tensortrace(C, pC, A, pA)
73+
74+
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
75+
szC = ndims(C) == 0 ? Int[] : collect(size(C))
76+
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
77+
C_tblis = tblis_tensor(C, szC, strC)
78+
79+
szA = collect(size(A))
80+
strA = collect(strides(A))
81+
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
82+
83+
einA, einC = TensorOperations.trace_labels(pC, pA...)
84+
85+
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
86+
87+
return C
7488
end
7589

7690
end # module TensorOperationsTBLIS

test/runtests.jl

+20-20
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,26 @@ using LinearAlgebra: norm
2727
@test collect(E2) E1
2828
end
2929

30-
# @testset "tensortrace" begin
31-
# A = randn(Float32, (5, 10, 10))
32-
# @tensor B1[a] := A[a, b′, b′]
33-
# @tensor B2[a] := CuArray(A)[a, b′, b′]
34-
# @test collect(B2) ≈ B1
35-
36-
# C = randn(ComplexF32, (3, 20, 5, 3, 20, 4, 5))
37-
# @tensor D1[e, a, d] := C[a, b, c, d, b, e, c]
38-
# @tensor D2[e, a, d] := CuArray(C)[a, b, c, d, b, e, c]
39-
# @test collect(D2) ≈ D1
40-
41-
# @tensor D3[a, e, d] := conj(C[a, b, c, d, b, e, c])
42-
# @tensor D4[a, e, d] := conj(CuArray(C)[a, b, c, d, b, e, c])
43-
# @test collect(D4) ≈ D3
44-
45-
# α = randn(ComplexF32)
46-
# @tensor D5[d, e, a] := α * C[a, b, c, d, b, e, c]
47-
# @tensor D6[d, e, a] := α * CuArray(C)[a, b, c, d, b, e, c]
48-
# @test collect(D6) ≈ D5
49-
# end
30+
@testset "tensortrace" begin
31+
A = randn(Float32, (5, 10, 10))
32+
@tensor B1[a] := A[a, b′, b′]
33+
@tensor backend = tblis B2[a] := A[a, b′, b′]
34+
@test B2 B1
35+
36+
C = randn(ComplexF32, (3, 20, 5, 3, 20, 4, 5))
37+
@tensor D1[e, a, d] := C[a, b, c, d, b, e, c]
38+
@tensor backend = tblis D2[e, a, d] := C[a, b, c, d, b, e, c]
39+
@test D2 D1
40+
41+
@tensor D3[a, e, d] := conj(C[a, b, c, d, b, e, c])
42+
@tensor backend = tblis D4[a, e, d] := conj(C[a, b, c, d, b, e, c])
43+
@test D4 D3
44+
45+
α = randn(ComplexF32)
46+
@tensor D5[d, e, a] := α * C[a, b, c, d, b, e, c]
47+
@tensor backend = tblis D6[d, e, a] := α * C[a, b, c, d, b, e, c]
48+
@test D6 D5
49+
end
5050

5151
@testset "tensorcontract" begin
5252
A = randn(Float32, (3, 20, 5, 3, 4))

0 commit comments

Comments
 (0)