Skip to content

Commit d837949

Browse files
committed
complete overhaul
1 parent 0b64ca6 commit d837949

8 files changed

+850
-358
lines changed

Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperationsTBLIS"
22
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
3-
authors = ["lkdvos <lukas.devos@ugent.be>"]
4-
version = "0.2.0"
3+
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
4+
version = "0.3.0"
55

66
[deps]
77
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1111
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"
1212

1313
[compat]
14+
Random = "1"
1415
TensorOperations = "5"
1516
TupleTools = "1"
17+
Test = "1"
1618
julia = "1.8"
1719
tblis_jll = "1.2"
1820

1921
[extras]
2022
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
23+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2124
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2225

2326
[targets]
24-
test = ["Test", "LinearAlgebra"]
27+
test = ["Test", "LinearAlgebra", "Random"]

src/LibTBLIS.jl

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
module LibTBLIS
2+
3+
using tblis_jll
4+
using LinearAlgebra: BlasFloat
5+
6+
export tblis_scalar, tblis_tensor
7+
export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot
8+
export tblis_set_num_threads, tblis_get_num_threads
9+
10+
const ptrdiff_t = Cptrdiff_t
11+
12+
const scomplex = ComplexF32
13+
const dcomplex = ComplexF64
14+
15+
const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET)
16+
if Sys.isapple() && Sys.ARCH === :aarch64
17+
include("lib/aarch64-apple-darwin20.jl")
18+
elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL
19+
include("lib/aarch64-linux-gnu.jl")
20+
elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL
21+
include("lib/aarch64-linux-musl.jl")
22+
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL
23+
include("lib/armv7l-linux-gnueabihf.jl")
24+
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL
25+
include("lib/armv7l-linux-musleabihf.jl")
26+
elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL
27+
include("lib/i686-linux-gnu.jl")
28+
elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL
29+
include("lib/i686-linux-musl.jl")
30+
elseif Sys.iswindows() && Sys.ARCH === :i686
31+
include("lib/i686-w64-mingw32.jl")
32+
elseif Sys.islinux() && Sys.ARCH === :powerpc64le
33+
include("lib/powerpc64le-linux-gnu.jl")
34+
elseif Sys.isapple() && Sys.ARCH === :x86_64
35+
include("lib/x86_64-apple-darwin14.jl")
36+
elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL
37+
include("lib/x86_64-linux-gnu.jl")
38+
elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL
39+
include("lib/x86_64-linux-musl.jl")
40+
elseif Sys.isbsd() && !Sys.isapple()
41+
include("lib/x86_64-unknown-freebsd.jl")
42+
elseif Sys.iswindows() && Sys.ARCH === :x86_64
43+
include("lib/x86_64-w64-mingw32.jl")
44+
else
45+
error("Unknown platform: $(Base.BUILD_TRIPLET)")
46+
end
47+
48+
# tblis_scalar and tblis_tensor
49+
# -----------------------------
50+
"""
51+
tblis_scalar(s::Number)
52+
53+
Initializes a tblis scalar from a number.
54+
"""
55+
function tblis_scalar end
56+
57+
"""
58+
tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number])
59+
60+
Initializes a tblis tensor from an array that should be strided and admit a pointer to its
61+
data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring
62+
that the reference to the array and the sizes and strides stays alive during the lifetime of
63+
the tensor.
64+
"""
65+
function tblis_tensor end
66+
67+
for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in
68+
((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s),
69+
(:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d),
70+
(:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c),
71+
(:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z))
72+
@eval begin
73+
function tblis_scalar(s::$T)
74+
t = Ref{tblis_scalar}()
75+
$tblis_init_scalar(t, s)
76+
return t[]
77+
end
78+
function tblis_tensor(A::AbstractArray{$T,N},
79+
s::Number=one(T),
80+
szA::Vector{len_type}=collect(len_type, size(A)),
81+
strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N}
82+
t = Ref{tblis_tensor}()
83+
if isone(s)
84+
$tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA))
85+
else
86+
$tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A),
87+
pointer(strA))
88+
end
89+
return t[]
90+
end
91+
end
92+
end
93+
94+
# tensor operations
95+
# ------------------
96+
function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB)
97+
return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB)
98+
end
99+
100+
function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor,
101+
idxC)
102+
return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC)
103+
end
104+
105+
function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar)
106+
return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C))
107+
end
108+
109+
end

src/LibTblis.jl

-146
This file was deleted.

src/TensorOperationsTBLIS.jl

+19-78
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,34 @@
11
module TensorOperationsTBLIS
22

33
using TensorOperations
4+
using TensorOperations: StridedView, DefaultAllocator, IndexError
5+
using TensorOperations: istrivialpermutation, BlasFloat, linearize
6+
using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd,
7+
argcheck_tensortrace, dimcheck_tensortrace,
8+
argcheck_tensorcontract, dimcheck_tensorcontract
49
using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError
5-
using LinearAlgebra: BlasFloat, rmul!
10+
using LinearAlgebra: BlasFloat
611
using TupleTools
712

8-
include("LibTblis.jl")
9-
using .LibTblis
13+
include("LibTBLIS.jl")
14+
using .LibTBLIS
15+
using .LibTBLIS: LibTBLIS, len_type, stride_type
1016

11-
export tblis_set_num_threads, tblis_get_num_threads
12-
export tblisBackend
17+
export TBLIS
18+
export get_num_tblis_threads, set_num_tblis_threads
19+
20+
get_num_tblis_threads() = convert(Int, LibTBLIS.tblis_get_num_threads())
21+
set_num_tblis_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n))
1322

1423
# TensorOperations
1524
#------------------
1625

17-
struct tblisBackend <: TensorOperations.AbstractBackend end
18-
19-
function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T},
20-
pA::Index2Tuple, conjA::Bool,
21-
α::Number, β::Number,
22-
::tblisBackend) where {T<:BlasFloat}
23-
TensorOperations.argcheck_tensoradd(C, A, pA)
24-
TensorOperations.dimcheck_tensoradd(C, A, pA)
25-
26-
szC = collect(size(C))
27-
strC = collect(strides(C))
28-
C_tblis = tblis_tensor(C, szC, strC, β)
29-
30-
szA = collect(size(A))
31-
strA = collect(strides(A))
32-
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
33-
34-
einA, einC = TensorOperations.add_labels(pA)
35-
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
36-
37-
return C
38-
end
39-
40-
function TensorOperations.tensorcontract!(C::StridedArray{T},
41-
A::StridedArray{T}, pA::Index2Tuple,
42-
conjA::Bool, B::StridedArray{T},
43-
pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple,
44-
α::Number, β::Number,
45-
::tblisBackend) where {T<:BlasFloat}
46-
TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB)
47-
TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB)
48-
49-
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
50-
szC = ndims(C) == 0 ? Int[] : collect(size(C))
51-
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
52-
C_tblis = tblis_tensor(C, szC, strC)
53-
54-
szA = collect(size(A))
55-
strA = collect(strides(A))
56-
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
57-
58-
szB = collect(size(B))
59-
strB = collect(strides(B))
60-
B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1)
61-
62-
einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB)
63-
tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis,
64-
string(einC...))
65-
66-
return C
67-
end
68-
69-
function TensorOperations.tensortrace!(C::StridedArray{T},
70-
A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple,
71-
conjA::Bool,
72-
α::Number, β::Number,
73-
::tblisBackend) where {T<:BlasFloat}
74-
TensorOperations.argcheck_tensortrace(C, A, p, q)
75-
TensorOperations.dimcheck_tensortrace(C, A, p, q)
76-
77-
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
78-
szC = ndims(C) == 0 ? Int[] : collect(size(C))
79-
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
80-
C_tblis = tblis_tensor(C, szC, strC)
81-
82-
szA = collect(size(A))
83-
strA = collect(strides(A))
84-
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)
85-
86-
einA, einC = TensorOperations.trace_labels(p, q)
26+
struct TBLIS <: TensorOperations.AbstractBackend end
8727

88-
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
28+
Base.@deprecate(tblisBackend(), TBLIS())
29+
Base.@deprecate(tblis_get_num_threads(), get_num_tblis_threads())
30+
Base.@deprecate(tblis_set_num_threads(n), set_num_tblis_threads(n))
8931

90-
return C
91-
end
32+
include("strided.jl")
9233

9334
end # module TensorOperationsTBLIS

0 commit comments

Comments
 (0)