|
1 | 1 | module TensorOperationsTBLIS
|
2 | 2 |
|
3 | 3 | using TensorOperations
|
| 4 | +using TensorOperations: StridedView, DefaultAllocator, IndexError |
| 5 | +using TensorOperations: istrivialpermutation, BlasFloat, linearize |
| 6 | +using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd, |
| 7 | + argcheck_tensortrace, dimcheck_tensortrace, |
| 8 | + argcheck_tensorcontract, dimcheck_tensorcontract |
4 | 9 | using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError
|
5 |
| -using LinearAlgebra: BlasFloat, rmul! |
| 10 | +using LinearAlgebra: BlasFloat |
6 | 11 | using TupleTools
|
7 | 12 |
|
8 |
| -include("LibTblis.jl") |
9 |
| -using .LibTblis |
| 13 | +include("LibTBLIS.jl") |
| 14 | +using .LibTBLIS |
| 15 | +using .LibTBLIS: LibTBLIS, len_type, stride_type |
10 | 16 |
|
11 |
| -export tblis_set_num_threads, tblis_get_num_threads |
12 |
| -export tblisBackend |
| 17 | +export TBLIS |
| 18 | +export get_num_tblis_threads, set_num_tblis_threads |
| 19 | + |
| 20 | +get_num_tblis_threads() = convert(Int, LibTBLIS.tblis_get_num_threads()) |
| 21 | +set_num_tblis_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n)) |
13 | 22 |
|
14 | 23 | # TensorOperations
|
15 | 24 | #------------------
|
16 | 25 |
|
17 |
| -struct tblisBackend <: TensorOperations.AbstractBackend end |
18 |
| - |
19 |
| -function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T}, |
20 |
| - pA::Index2Tuple, conjA::Bool, |
21 |
| - α::Number, β::Number, |
22 |
| - ::tblisBackend) where {T<:BlasFloat} |
23 |
| - TensorOperations.argcheck_tensoradd(C, A, pA) |
24 |
| - TensorOperations.dimcheck_tensoradd(C, A, pA) |
25 |
| - |
26 |
| - szC = collect(size(C)) |
27 |
| - strC = collect(strides(C)) |
28 |
| - C_tblis = tblis_tensor(C, szC, strC, β) |
29 |
| - |
30 |
| - szA = collect(size(A)) |
31 |
| - strA = collect(strides(A)) |
32 |
| - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
33 |
| - |
34 |
| - einA, einC = TensorOperations.add_labels(pA) |
35 |
| - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) |
36 |
| - |
37 |
| - return C |
38 |
| -end |
39 |
| - |
40 |
| -function TensorOperations.tensorcontract!(C::StridedArray{T}, |
41 |
| - A::StridedArray{T}, pA::Index2Tuple, |
42 |
| - conjA::Bool, B::StridedArray{T}, |
43 |
| - pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, |
44 |
| - α::Number, β::Number, |
45 |
| - ::tblisBackend) where {T<:BlasFloat} |
46 |
| - TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB) |
47 |
| - TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB) |
48 |
| - |
49 |
| - rmul!(C, β) # TODO: is it possible to use tblis scaling here? |
50 |
| - szC = ndims(C) == 0 ? Int[] : collect(size(C)) |
51 |
| - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) |
52 |
| - C_tblis = tblis_tensor(C, szC, strC) |
53 |
| - |
54 |
| - szA = collect(size(A)) |
55 |
| - strA = collect(strides(A)) |
56 |
| - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
57 |
| - |
58 |
| - szB = collect(size(B)) |
59 |
| - strB = collect(strides(B)) |
60 |
| - B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1) |
61 |
| - |
62 |
| - einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB) |
63 |
| - tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis, |
64 |
| - string(einC...)) |
65 |
| - |
66 |
| - return C |
67 |
| -end |
68 |
| - |
69 |
| -function TensorOperations.tensortrace!(C::StridedArray{T}, |
70 |
| - A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple, |
71 |
| - conjA::Bool, |
72 |
| - α::Number, β::Number, |
73 |
| - ::tblisBackend) where {T<:BlasFloat} |
74 |
| - TensorOperations.argcheck_tensortrace(C, A, p, q) |
75 |
| - TensorOperations.dimcheck_tensortrace(C, A, p, q) |
76 |
| - |
77 |
| - rmul!(C, β) # TODO: is it possible to use tblis scaling here? |
78 |
| - szC = ndims(C) == 0 ? Int[] : collect(size(C)) |
79 |
| - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) |
80 |
| - C_tblis = tblis_tensor(C, szC, strC) |
81 |
| - |
82 |
| - szA = collect(size(A)) |
83 |
| - strA = collect(strides(A)) |
84 |
| - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
85 |
| - |
86 |
| - einA, einC = TensorOperations.trace_labels(p, q) |
| 26 | +struct TBLIS <: TensorOperations.AbstractBackend end |
87 | 27 |
|
88 |
| - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) |
| 28 | +Base.@deprecate(tblisBackend(), TBLIS()) |
| 29 | +Base.@deprecate(tblis_get_num_threads(), get_num_tblis_threads()) |
| 30 | +Base.@deprecate(tblis_set_num_threads(n), set_num_tblis_threads(n)) |
89 | 31 |
|
90 |
| - return C |
91 |
| -end |
| 32 | +include("strided.jl") |
92 | 33 |
|
93 | 34 | end # module TensorOperationsTBLIS
|
0 commit comments