Skip to content

Commit 86a4c75

Browse files
committed
small updates
1 parent 86df987 commit 86a4c75

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# TensorOperationsTBLIS.jl
2-
[tblis](https://github.com/devinamatthews/tblis) wrapper for [TensorOperations.jl]()
2+
3+
[tblis](https://github.com/devinamatthews/tblis) wrapper for [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl)
34

45
[![CI][ci-img]][ci-url] [![CI (Julia nightly)][ci-julia-nightly-img]][ci-julia-nightly-url] [![][codecov-img]][codecov-url]
56

src/TensorOperationsTBLIS.jl

+8-9
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@ export tblis_set_num_threads, tblis_get_num_threads
1313
# TensorOperations
1414
#------------------
1515

16-
const TblisBackend = TensorOperations.Backend{:tblis}
16+
const tblisBackend = TensorOperations.Backend{:tblis}
1717

1818
function TensorOperations.tensoradd!(C::StridedArray{T}, pC::Index2Tuple,
1919
A::StridedArray{T}, conjA::Symbol,
2020
α::Number, β::Number,
21-
::TblisBackend) where {T<:BlasFloat}
21+
::tblisBackend) where {T<:BlasFloat}
2222
TensorOperations.argcheck_tensoradd(C, pC, A)
23-
# check dimensions
24-
size(C) == getindex.(Ref(size(A)), linearize(pC)) ||
25-
throw(DimensionMismatch("incompatible sizes"))
26-
23+
TensorOperations.dimcheck_tensoradd(C, pC, A)
24+
2725
szC = collect(size(C))
2826
strC = collect(strides(C))
2927
C_tblis = tblis_tensor(C, szC, strC, β)
@@ -42,11 +40,11 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
4240
A::StridedArray{T}, pA::Index2Tuple,
4341
conjA::Symbol, B::StridedArray{T},
4442
pB::Index2Tuple, conjB::Symbol, α::Number,
45-
β::Number, ::TblisBackend) where {T<:BlasFloat}
43+
β::Number, ::tblisBackend) where {T<:BlasFloat}
4644
TensorOperations.argcheck_tensorcontract(C, pC, A, pA, B, pB)
4745
TensorOperations.dimcheck_tensorcontract(C, pC, A, pA, B, pB)
4846

49-
rmul!(C, β)
47+
rmul!(C, β) # TODO: is it possible to use tblis scaling here?
5048
szC = ndims(C) == 0 ? Int[] : collect(size(C))
5149
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
5250
C_tblis = tblis_tensor(C, szC, strC)
@@ -67,10 +65,11 @@ function TensorOperations.tensorcontract!(C::StridedArray{T}, pC::Index2Tuple,
6765
end
6866

6967
# partial traces do not exist in tblis afaik -> use default implementation
68+
# TODO: implement full trace
7069
function TensorOperations.tensortrace!(C::StridedArray{T}, pC::Index2Tuple,
7170
A::StridedArray{T}, pA::Index2Tuple, conjA::Symbol,
7271
α::Number, β::Number,
73-
::TblisBackend) where {T<:BlasFloat}
72+
::tblisBackend) where {T<:BlasFloat}
7473
return tensortrace!(C, pC, A, pA, conjA, α, β)
7574
end
7675

0 commit comments

Comments
 (0)