Skip to content

Commit ad54f62

Browse files
committed
TensorOperations v5 compatibility
1 parent 50f0ef2 commit ad54f62

File tree

4 files changed

+48
-41
lines changed

4 files changed

+48
-41
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
fail-fast: false
2222
matrix:
2323
version:
24-
- '1.6'
24+
- '1.8'
2525
- '1' # automatically expands to the latest stable 1.x release of Julia
2626
os:
2727
- ubuntu-latest

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperationsTBLIS"
22
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
33
authors = ["lkdvos <lukas.devos@ugent.be>"]
4-
version = "0.1.1"
4+
version = "0.2.0"
55

66
[deps]
77
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -11,9 +11,9 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1111
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"
1212

1313
[compat]
14-
TensorOperations = "4"
14+
TensorOperations = "5"
1515
TupleTools = "1"
16-
julia = "1.6"
16+
julia = "1.8"
1717
tblis_jll = "1.2"
1818

1919
[extras]

src/TensorOperationsTBLIS.jl

+25-22
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,42 @@ include("LibTblis.jl")
99
using .LibTblis
1010

1111
export tblis_set_num_threads, tblis_get_num_threads
12+
export tblisBackend
1213

1314
# TensorOperations
1415
#------------------
1516

16-
const tblisBackend = TensorOperations.Backend{:tblis}
17+
struct tblisBackend <: TensorOperations.AbstractBackend end
1718

18-
function TensorOperations.tensoradd!(C::StridedArray{T}, pC::Index2Tuple,
19-
A::StridedArray{T}, conjA::Symbol,
19+
function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T},
20+
pA::Index2Tuple, conjA::Bool,
2021
α::Number, β::Number,
2122
::tblisBackend) where {T<:BlasFloat}
22-
TensorOperations.argcheck_tensoradd(C, pC, A)
23-
TensorOperations.dimcheck_tensoradd(C, pC, A)
23+
TensorOperations.argcheck_tensoradd(C, A, pA)
24+
TensorOperations.dimcheck_tensoradd(C, A, pA)
2425

2526
szC = collect(size(C))
2627
strC = collect(strides(C))
2728
C_tblis = tblis_tensor(C, szC, strC, β)
2829

2930
szA = collect(size(A))
3031
strA = collect(strides(A))
31-
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
32+
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
3233

33-
einA, einC = TensorOperations.add_labels(pC)
34+
einA, einC = TensorOperations.add_labels(pA)
3435
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
3536

3637
return C
3738
end
3839

39-
function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
40+
function TensorOperations.tensorcontract!(C::StridedArray{T},
4041
A::StridedArray{T}, pA::Index2Tuple,
41-
conjA::Symbol, B::StridedArray{T},
42-
pB::Index2Tuple, conjB::Symbol, α::Number,
43-
β::Number, ::tblisBackend) where {T<:BlasFloat}
44-
TensorOperations.argcheck_tensorcontract(C, pC, A, pA, B, pB)
45-
TensorOperations.dimcheck_tensorcontract(C, pC, A, pA, B, pB)
42+
conjA::Bool, B::StridedArray{T},
43+
pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple,
44+
α::Number, β::Number,
45+
::tblisBackend) where {T<:BlasFloat}
46+
TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB)
47+
TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
4648

4749
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
4850
szC = ndims(C) == 0 ? Int[] : collect(size(C))
@@ -51,25 +53,26 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
5153

5254
szA = collect(size(A))
5355
strA = collect(strides(A))
54-
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
56+
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
5557

5658
szB = collect(size(B))
5759
strB = collect(strides(B))
58-
B_tblis = tblis_tensor(conjB == :C ? conj(B) : B, szB, strB, 1)
60+
B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1)
5961

60-
einA, einB, einC = TensorOperations.contract_labels(pA, pB, pC)
62+
einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB)
6163
tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis,
6264
string(einC...))
6365

6466
return C
6567
end
6668

67-
function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,
68-
A::StridedArray{T}, pA::Index2Tuple, conjA::Symbol,
69+
function TensorOperations.tensortrace!(C::StridedArray{T},
70+
A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple,
71+
conjA::Bool,
6972
α::Number, β::Number,
7073
::tblisBackend) where {T<:BlasFloat}
71-
TensorOperations.argcheck_tensortrace(C, pC, A, pA)
72-
TensorOperations.dimcheck_tensortrace(C, pC, A, pA)
74+
TensorOperations.argcheck_tensortrace(C, A, p, q)
75+
TensorOperations.dimcheck_tensortrace(C, A, p, q)
7376

7477
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
7578
szC = ndims(C) == 0 ? Int[] : collect(size(C))
@@ -78,9 +81,9 @@ function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,
7881

7982
szA = collect(size(A))
8083
strA = collect(strides(A))
81-
A_tblis = tblis_tensor(conjA == :C ? conj(A) : A, szA, strA, α)
84+
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
8285

83-
einA, einC = TensorOperations.trace_labels(pC, pA...)
86+
einA, einC = TensorOperations.trace_labels(p, q)
8487

8588
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
8689

test/runtests.jl

+19-15
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ using TensorOperationsTBLIS
33
using Test
44
using LinearAlgebra: norm
55

6+
const tblisbackend = tblisBackend()
67
@testset "elementary operations" verbose = true begin
78
@testset "tensorcopy" begin
89
A = randn(Float32, (3, 5, 4, 6))
910
@tensor C1[4, 1, 3, 2] := A[1, 2, 3, 4]
10-
@tensor backend = tblis C2[4, 1, 3, 2] := A[1, 2, 3, 4]
11+
@tensor backend = tblisbackend C2[4, 1, 3, 2] := A[1, 2, 3, 4]
1112
@test C2 C1
1213
end
1314

@@ -16,49 +17,50 @@ using LinearAlgebra: norm
1617
B = randn(Float32, (5, 6, 3, 4))
1718
α = randn(Float32)
1819
@tensor C1[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
19-
@tensor backend = tblis C2[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
20+
@tensor backend = tblisbackend C2[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d]
2021
@test collect(C2) C1
2122

2223
C = randn(ComplexF32, (5, 6, 3, 4))
2324
D = randn(ComplexF32, (5, 3, 4, 6))
2425
β = randn(ComplexF32)
2526
@tensor E1[a, b, c, d] := C[a, b, c, d] + β * conj(D[a, c, d, b])
26-
@tensor backend = tblis E2[a, b, c, d] := C[a, b, c, d] + β * conj(D[a, c, d, b])
27+
@tensor backend = tblisbackend E2[a, b, c, d] := C[a, b, c, d] +
28+
β * conj(D[a, c, d, b])
2729
@test collect(E2) E1
2830
end
2931

3032
@testset "tensortrace" begin
3133
A = randn(Float32, (5, 10, 10))
3234
@tensor B1[a] := A[a, b′, b′]
33-
@tensor backend = tblis B2[a] := A[a, b′, b′]
35+
@tensor backend = tblisbackend B2[a] := A[a, b′, b′]
3436
@test B2 B1
3537

3638
C = randn(ComplexF32, (3, 20, 5, 3, 20, 4, 5))
3739
@tensor D1[e, a, d] := C[a, b, c, d, b, e, c]
38-
@tensor backend = tblis D2[e, a, d] := C[a, b, c, d, b, e, c]
40+
@tensor backend = tblisbackend D2[e, a, d] := C[a, b, c, d, b, e, c]
3941
@test D2 D1
4042

4143
@tensor D3[a, e, d] := conj(C[a, b, c, d, b, e, c])
42-
@tensor backend = tblis D4[a, e, d] := conj(C[a, b, c, d, b, e, c])
44+
@tensor backend = tblisbackend D4[a, e, d] := conj(C[a, b, c, d, b, e, c])
4345
@test D4 D3
4446

4547
α = randn(ComplexF32)
4648
@tensor D5[d, e, a] := α * C[a, b, c, d, b, e, c]
47-
@tensor backend = tblis D6[d, e, a] := α * C[a, b, c, d, b, e, c]
49+
@tensor backend = tblisbackend D6[d, e, a] := α * C[a, b, c, d, b, e, c]
4850
@test D6 D5
4951
end
5052

5153
@testset "tensorcontract" begin
5254
A = randn(Float32, (3, 20, 5, 3, 4))
5355
B = randn(Float32, (5, 6, 20, 3))
5456
@tensor C1[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
55-
@tensor backend = tblis C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
57+
@tensor backend = tblisbackend C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g]
5658
@test C2 C1
5759

5860
D = randn(ComplexF64, (3, 3, 3))
5961
E = rand(ComplexF64, (3, 3, 3))
6062
@tensor F1[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
61-
@tensor backend = tblis F2[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
63+
@tensor backend = tblisbackend F2[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f])
6264
@test F2 F1 atol = 1e-12
6365
end
6466
end
@@ -72,12 +74,14 @@ end
7274
# α = 1
7375

7476
@tensor D1[d, f, h] := A[c, a, f, a, e, b, b, g] * B[c, h, g, e, d] + α * C[d, h, f]
75-
@tensor backend = tblis D2[d, f, h] := A[c, a, f, a, e, b, b, g] * B[c, h, g, e, d] +
76-
α * C[d, h, f]
77+
@tensor backend = tblisbackend D2[d, f, h] := A[c, a, f, a, e, b, b, g] *
78+
B[c, h, g, e, d] +
79+
α * C[d, h, f]
7780
@test D2 D1 rtol = 1e-8
7881

7982
@test norm(vec(D1)) sqrt(abs(@tensor D1[d, f, h] * conj(D1[d, f, h])))
80-
@test norm(D2) sqrt(abs(@tensor backend = tblis D2[d, f, h] * conj(D2[d, f, h])))
83+
@test norm(D2)
84+
sqrt(abs(@tensor backend = tblisbackend D2[d, f, h] * conj(D2[d, f, h])))
8185

8286
@testset "readme example" begin
8387
α = randn()
@@ -90,7 +94,7 @@ end
9094
D[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
9195
E[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
9296
end
93-
@tensor backend = tblis begin
97+
@tensor backend = tblisbackend begin
9498
D2[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
9599
E2[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b]
96100
end
@@ -113,7 +117,7 @@ end
113117
HrA12[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] *
114118
H[s1, s2, t1, t2]
115119
end
116-
@tensor backend = tblis begin
120+
@tensor backend = tblisbackend begin
117121
HrA12′[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] *
118122
H[s1, s2, t1, t2]
119123
end
@@ -123,7 +127,7 @@ end
123127
E1 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] *
124128
conj(A1[a', t, b']) * conj(A2[b', t', c'])
125129
end
126-
@tensor backend = tblis begin
130+
@tensor backend = tblisbackend begin
127131
E2 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] *
128132
conj(A1[a', t, b']) * conj(A2[b', t', c'])
129133
end

0 commit comments

Comments
 (0)