Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and add StridedView support #5

Merged
merged 7 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ignore:
- "src/lib"
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperationsTBLIS"
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
authors = ["lkdvos <lukas.devos@ugent.be>"]
version = "0.2.0"
authors = ["Lukas Devos <ldevos98@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
version = "0.3.0"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"

[compat]
Random = "1"
TensorOperations = "5"
TupleTools = "1"
Test = "1"
julia = "1.8"
tblis_jll = "1.2"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LinearAlgebra"]
test = ["Test", "LinearAlgebra", "Random"]
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TensorOperationsTBLIS.jl

Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl).
Julia wrapper for [TBLIS](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl).

[![CI][ci-img]][ci-url] [![CI (Julia nightly)][ci-julia-nightly-img]][ci-julia-nightly-url] [![][codecov-img]][codecov-url]

Expand All @@ -13,15 +13,14 @@ Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorO
[codecov-img]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl/graph/badge.svg?token=R86L0S70VT
[codecov-url]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl

Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for
`StridedArray{<:BlasFloat}`. These can be accessed through the backend system of
TensorOperations, i.e.
Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for array types compatible with Strided.jl, i.e. `StridedView{<:BlasFloat}`.
These can be accessed through the backend system of TensorOperations, i.e.

```julia
using TensorOperations
using TensorOperationsTBLIS

tblisbackend = tblisBackend()
tblisbackend = TBLIS()
α = randn()
A = randn(5, 5, 5, 5, 5, 5)
B = randn(5, 5, 5)
Expand All @@ -34,17 +33,16 @@ D = zeros(5, 5, 5)
end
```

Additionally, the number of threads used by tblis can be set by:
Additionally, the number of threads used by TBLIS can be set by:

```julia
using TensorOperationsTBLIS
tblis_set_num_threads(4)
@show tblis_get_num_threads()
TensorOperationsTBLIS.set_num_threads(4)
@show TensorOperationsTBLIS.get_num_threads()
```

## Notes

- This implementation of tblis for TensorOperations.jl is only supported from v5 of
- This implementation of TBLIS for TensorOperations.jl is only supported from v5 of
TensorOperations.jl onwards. For v4, an earlier version of this package exists.
For older versions, you could look for
[BliContractor.jl](https://github.com/xrq-phys/BliContractor.jl) or
Expand Down
109 changes: 109 additions & 0 deletions src/LibTBLIS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
module LibTBLIS

using tblis_jll
using LinearAlgebra: BlasFloat

export tblis_scalar, tblis_tensor
export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot
export tblis_set_num_threads, tblis_get_num_threads

const ptrdiff_t = Cptrdiff_t

const scomplex = ComplexF32
const dcomplex = ComplexF64

const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET)
if Sys.isapple() && Sys.ARCH === :aarch64
include("lib/aarch64-apple-darwin20.jl")
elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL
include("lib/aarch64-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL
include("lib/aarch64-linux-musl.jl")
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL
include("lib/armv7l-linux-gnueabihf.jl")
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL
include("lib/armv7l-linux-musleabihf.jl")
elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL
include("lib/i686-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL
include("lib/i686-linux-musl.jl")
elseif Sys.iswindows() && Sys.ARCH === :i686
include("lib/i686-w64-mingw32.jl")
elseif Sys.islinux() && Sys.ARCH === :powerpc64le
include("lib/powerpc64le-linux-gnu.jl")
elseif Sys.isapple() && Sys.ARCH === :x86_64
include("lib/x86_64-apple-darwin14.jl")
elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL
include("lib/x86_64-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL
include("lib/x86_64-linux-musl.jl")
elseif Sys.isbsd() && !Sys.isapple()
include("lib/x86_64-unknown-freebsd.jl")
elseif Sys.iswindows() && Sys.ARCH === :x86_64
include("lib/x86_64-w64-mingw32.jl")
else
error("Unknown platform: $(Base.BUILD_TRIPLET)")
end

# tblis_scalar and tblis_tensor
# -----------------------------
"""
tblis_scalar(s::Number)

Initializes a tblis scalar from a number.
"""
function tblis_scalar end

"""
tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number])

Initializes a tblis tensor from an array that should be strided and admit a pointer to its
data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring
that the reference to the array and the sizes and strides stays alive during the lifetime of
the tensor.
"""
function tblis_tensor end

for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in
((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s),
(:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d),
(:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c),
(:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z))
@eval begin
function tblis_scalar(s::$T)
t = Ref{tblis_scalar}()
$tblis_init_scalar(t, s)
return t[]
end
function tblis_tensor(A::AbstractArray{$T,N},
s::Number=one(T),
szA::Vector{len_type}=collect(len_type, size(A)),
strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N}
t = Ref{tblis_tensor}()
if isone(s)
$tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA))
else
$tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A),
pointer(strA))
end
return t[]
end
end
end

# tensor operations
# ------------------
function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB)
return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB)
end

function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor,
idxC)
return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC)
end

function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar)
return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C))
end

end
146 changes: 0 additions & 146 deletions src/LibTblis.jl

This file was deleted.

Loading