Skip to content

Commit e292142

Browse files
Jutholkdvos
andauthored
Refactor and add StridedView support (#5)
* complete overhaul * improve tests and fix other comments * add TBLIS docstring * Update author email * `get_num_tblis_threads` to unexported `get_num_threads` * Update README * Add codecov ignore for `lib` --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 0b64ca6 commit e292142

10 files changed

+736
-368
lines changed

.github/codecov.yml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ignore:
2+
- "src/lib"

Project.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperationsTBLIS"
22
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
3-
authors = ["lkdvos <lukas.devos@ugent.be>"]
4-
version = "0.2.0"
3+
authors = ["Lukas Devos <ldevos98@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
4+
version = "0.3.0"
55

66
[deps]
77
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
@@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1111
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"
1212

1313
[compat]
14+
Random = "1"
1415
TensorOperations = "5"
1516
TupleTools = "1"
17+
Test = "1"
1618
julia = "1.8"
1719
tblis_jll = "1.2"
1820

1921
[extras]
2022
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
23+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2124
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2225

2326
[targets]
24-
test = ["Test", "LinearAlgebra"]
27+
test = ["Test", "LinearAlgebra", "Random"]

README.md

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TensorOperationsTBLIS.jl
22

3-
Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl).
3+
Julia wrapper for [TBLIS](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl).
44

55
[![CI][ci-img]][ci-url] [![CI (Julia nightly)][ci-julia-nightly-img]][ci-julia-nightly-url] [![][codecov-img]][codecov-url]
66

@@ -13,15 +13,14 @@ Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorO
1313
[codecov-img]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl/graph/badge.svg?token=R86L0S70VT
1414
[codecov-url]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl
1515

16-
Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for
17-
`StridedArray{<:BlasFloat}`. These can be accessed through the backend system of
18-
TensorOperations, i.e.
16+
Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for array types compatible with Strided.jl, i.e. `StridedView{<:BlasFloat}`.
17+
These can be accessed through the backend system of TensorOperations, i.e.
1918

2019
```julia
2120
using TensorOperations
2221
using TensorOperationsTBLIS
2322

24-
tblisbackend = tblisBackend()
23+
tblisbackend = TBLIS()
2524
α = randn()
2625
A = randn(5, 5, 5, 5, 5, 5)
2726
B = randn(5, 5, 5)
@@ -34,17 +33,16 @@ D = zeros(5, 5, 5)
3433
end
3534
```
3635

37-
Additionally, the number of threads used by tblis can be set by:
36+
Additionally, the number of threads used by TBLIS can be set by:
3837

3938
```julia
40-
using TensorOperationsTBLIS
41-
tblis_set_num_threads(4)
42-
@show tblis_get_num_threads()
39+
TensorOperationsTBLIS.set_num_threads(4)
40+
@show TensorOperationsTBLIS.get_num_threads()
4341
```
4442

4543
## Notes
4644

47-
- This implementation of tblis for TensorOperations.jl is only supported from v5 of
45+
- This implementation of TBLIS for TensorOperations.jl is only supported from v5 of
4846
TensorOperations.jl onwards. For v4, an earlier version of this package exists.
4947
For older versions, you could look for
5048
[BliContractor.jl](https://github.com/xrq-phys/BliContractor.jl) or

src/LibTBLIS.jl

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
module LibTBLIS
2+
3+
using tblis_jll
4+
using LinearAlgebra: BlasFloat
5+
6+
export tblis_scalar, tblis_tensor
7+
export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot
8+
export tblis_set_num_threads, tblis_get_num_threads
9+
10+
const ptrdiff_t = Cptrdiff_t
11+
12+
const scomplex = ComplexF32
13+
const dcomplex = ComplexF64
14+
15+
const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET)
16+
if Sys.isapple() && Sys.ARCH === :aarch64
17+
include("lib/aarch64-apple-darwin20.jl")
18+
elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL
19+
include("lib/aarch64-linux-gnu.jl")
20+
elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL
21+
include("lib/aarch64-linux-musl.jl")
22+
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL
23+
include("lib/armv7l-linux-gnueabihf.jl")
24+
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL
25+
include("lib/armv7l-linux-musleabihf.jl")
26+
elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL
27+
include("lib/i686-linux-gnu.jl")
28+
elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL
29+
include("lib/i686-linux-musl.jl")
30+
elseif Sys.iswindows() && Sys.ARCH === :i686
31+
include("lib/i686-w64-mingw32.jl")
32+
elseif Sys.islinux() && Sys.ARCH === :powerpc64le
33+
include("lib/powerpc64le-linux-gnu.jl")
34+
elseif Sys.isapple() && Sys.ARCH === :x86_64
35+
include("lib/x86_64-apple-darwin14.jl")
36+
elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL
37+
include("lib/x86_64-linux-gnu.jl")
38+
elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL
39+
include("lib/x86_64-linux-musl.jl")
40+
elseif Sys.isbsd() && !Sys.isapple()
41+
include("lib/x86_64-unknown-freebsd.jl")
42+
elseif Sys.iswindows() && Sys.ARCH === :x86_64
43+
include("lib/x86_64-w64-mingw32.jl")
44+
else
45+
error("Unknown platform: $(Base.BUILD_TRIPLET)")
46+
end
47+
48+
# tblis_scalar and tblis_tensor
49+
# -----------------------------
50+
"""
51+
tblis_scalar(s::Number)
52+
53+
Initializes a tblis scalar from a number.
54+
"""
55+
function tblis_scalar end
56+
57+
"""
58+
tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number])
59+
60+
Initializes a tblis tensor from an array that should be strided and admit a pointer to its
61+
data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring
62+
that the reference to the array and the sizes and strides stays alive during the lifetime of
63+
the tensor.
64+
"""
65+
function tblis_tensor end
66+
67+
for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in
68+
((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s),
69+
(:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d),
70+
(:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c),
71+
(:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z))
72+
@eval begin
73+
function tblis_scalar(s::$T)
74+
t = Ref{tblis_scalar}()
75+
$tblis_init_scalar(t, s)
76+
return t[]
77+
end
78+
function tblis_tensor(A::AbstractArray{$T,N},
79+
s::Number=one(T),
80+
szA::Vector{len_type}=collect(len_type, size(A)),
81+
strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N}
82+
t = Ref{tblis_tensor}()
83+
if isone(s)
84+
$tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA))
85+
else
86+
$tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A),
87+
pointer(strA))
88+
end
89+
return t[]
90+
end
91+
end
92+
end
93+
94+
# tensor operations
95+
# ------------------
96+
function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB)
97+
return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB)
98+
end
99+
100+
function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor,
101+
idxC)
102+
return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC)
103+
end
104+
105+
function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar)
106+
return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C))
107+
end
108+
109+
end

src/LibTblis.jl

-146
This file was deleted.

0 commit comments

Comments
 (0)