diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..6201ea2 --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,2 @@ +ignore: + - "src/lib" \ No newline at end of file diff --git a/Project.toml b/Project.toml index 07b8e9d..05955a0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperationsTBLIS" uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865" -authors = ["lkdvos "] -version = "0.2.0" +authors = ["Lukas Devos ", "Jutho Haegeman "] +version = "0.3.0" [deps] Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0" [compat] +Random = "1" TensorOperations = "5" TupleTools = "1" +Test = "1" julia = "1.8" tblis_jll = "1.2" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "LinearAlgebra"] +test = ["Test", "LinearAlgebra", "Random"] diff --git a/README.md b/README.md index 90ce28d..1342a9d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # TensorOperationsTBLIS.jl -Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl). +Julia wrapper for [TBLIS](https://github.com/devinamatthews/tblis) with [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl). [![CI][ci-img]][ci-url] [![CI (Julia nightly)][ci-julia-nightly-img]][ci-julia-nightly-url] [![][codecov-img]][codecov-url] @@ -13,15 +13,14 @@ Julia wrapper for [tblis](https://github.com/devinamatthews/tblis) with [TensorO [codecov-img]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl/graph/badge.svg?token=R86L0S70VT [codecov-url]: https://codecov.io/gh/lkdvos/TensorOperationsTBLIS.jl -Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for -`StridedArray{<:BlasFloat}`. These can be accessed through the backend system of -TensorOperations, i.e. +Currently provides implementations of `tensorcontract!`, `tensoradd!` and `tensortrace!` for array types compatible with Strided.jl, i.e. `StridedView{<:BlasFloat}`. +These can be accessed through the backend system of TensorOperations, i.e. ```julia using TensorOperations using TensorOperationsTBLIS -tblisbackend = tblisBackend() +tblisbackend = TBLIS() α = randn() A = randn(5, 5, 5, 5, 5, 5) B = randn(5, 5, 5) @@ -34,17 +33,16 @@ D = zeros(5, 5, 5) end ``` -Additionally, the number of threads used by tblis can be set by: +Additionally, the number of threads used by TBLIS can be set by: ```julia -using TensorOperationsTBLIS -tblis_set_num_threads(4) -@show tblis_get_num_threads() +TensorOperationsTBLIS.set_num_threads(4) +@show TensorOperationsTBLIS.get_num_threads() ``` ## Notes -- This implementation of tblis for TensorOperations.jl is only supported from v5 of +- This implementation of TBLIS for TensorOperations.jl is only supported from v5 of TensorOperations.jl onwards. For v4, an earlier version of this package exists. For older versions, you could look for [BliContractor.jl](https://github.com/xrq-phys/BliContractor.jl) or diff --git a/src/LibTBLIS.jl b/src/LibTBLIS.jl new file mode 100644 index 0000000..3954b80 --- /dev/null +++ b/src/LibTBLIS.jl @@ -0,0 +1,109 @@ +module LibTBLIS + +using tblis_jll +using LinearAlgebra: BlasFloat + +export tblis_scalar, tblis_tensor +export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot +export tblis_set_num_threads, tblis_get_num_threads + +const ptrdiff_t = Cptrdiff_t + +const scomplex = ComplexF32 +const dcomplex = ComplexF64 + +const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET) +if Sys.isapple() && Sys.ARCH === :aarch64 + include("lib/aarch64-apple-darwin20.jl") +elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL + include("lib/aarch64-linux-gnu.jl") +elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL + include("lib/aarch64-linux-musl.jl") +elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL + include("lib/armv7l-linux-gnueabihf.jl") +elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL + include("lib/armv7l-linux-musleabihf.jl") +elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL + include("lib/i686-linux-gnu.jl") +elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL + include("lib/i686-linux-musl.jl") +elseif Sys.iswindows() && Sys.ARCH === :i686 + include("lib/i686-w64-mingw32.jl") +elseif Sys.islinux() && Sys.ARCH === :powerpc64le + include("lib/powerpc64le-linux-gnu.jl") +elseif Sys.isapple() && Sys.ARCH === :x86_64 + include("lib/x86_64-apple-darwin14.jl") +elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL + include("lib/x86_64-linux-gnu.jl") +elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL + include("lib/x86_64-linux-musl.jl") +elseif Sys.isbsd() && !Sys.isapple() + include("lib/x86_64-unknown-freebsd.jl") +elseif Sys.iswindows() && Sys.ARCH === :x86_64 + include("lib/x86_64-w64-mingw32.jl") +else + error("Unknown platform: $(Base.BUILD_TRIPLET)") +end + +# tblis_scalar and tblis_tensor +# ----------------------------- +""" + tblis_scalar(s::Number) + +Initializes a tblis scalar from a number. +""" +function tblis_scalar end + +""" + tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number]) + +Initializes a tblis tensor from an array that should be strided and admit a pointer to its +data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring +that the reference to the array and the sizes and strides stays alive during the lifetime of +the tensor. +""" +function tblis_tensor end + +for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in + ((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s), + (:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d), + (:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c), + (:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z)) + @eval begin + function tblis_scalar(s::$T) + t = Ref{tblis_scalar}() + $tblis_init_scalar(t, s) + return t[] + end + function tblis_tensor(A::AbstractArray{$T,N}, + s::Number=one(T), + szA::Vector{len_type}=collect(len_type, size(A)), + strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N} + t = Ref{tblis_tensor}() + if isone(s) + $tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA)) + else + $tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A), + pointer(strA)) + end + return t[] + end + end +end + +# tensor operations +# ------------------ +function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB) + return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB) +end + +function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor, + idxC) + return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC) +end + +function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar) + return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C)) +end + +end diff --git a/src/LibTblis.jl b/src/LibTblis.jl deleted file mode 100644 index 013ec79..0000000 --- a/src/LibTblis.jl +++ /dev/null @@ -1,146 +0,0 @@ -module LibTblis - -using tblis_jll -using LinearAlgebra: BlasFloat - -export tblis_scalar, tblis_matrix, tblis_tensor -export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot -export tblis_set_num_threads, tblis_get_num_threads - -const ptrdiff_t = Cptrdiff_t - -const scomplex = ComplexF32 -const dcomplex = ComplexF64 - -const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET) -if Sys.isapple() && Sys.ARCH === :aarch64 - include("lib/aarch64-apple-darwin20.jl") -elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL - include("lib/aarch64-linux-gnu.jl") -elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL - include("lib/aarch64-linux-musl.jl") -elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL - include("lib/armv7l-linux-gnueabihf.jl") -elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL - include("lib/armv7l-linux-musleabihf.jl") -elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL - include("lib/i686-linux-gnu.jl") -elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL - include("lib/i686-linux-musl.jl") -elseif Sys.iswindows() && Sys.ARCH === :i686 - include("lib/i686-w64-mingw32.jl") -elseif Sys.islinux() && Sys.ARCH === :powerpc64le - include("lib/powerpc64le-linux-gnu.jl") -elseif Sys.isapple() && Sys.ARCH === :x86_64 - include("lib/x86_64-apple-darwin14.jl") -elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL - include("lib/x86_64-linux-gnu.jl") -elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL - include("lib/x86_64-linux-musl.jl") -elseif Sys.isbsd() && !Sys.isapple() - include("lib/x86_64-unknown-freebsd.jl") -elseif Sys.iswindows() && Sys.ARCH === :x86_64 - include("lib/x86_64-w64-mingw32.jl") -else - error("Unknown platform: $(Base.BUILD_TRIPLET)") -end - -# tblis_scalar -# ------------ -""" - tblis_scalar(s::Number) - -Initializes a tblis scalar from a number. -""" -function tblis_scalar end -function tblis_scalar(s::Float32) - t = Ref{tblis_scalar}() - tblis_init_scalar_s(t, s) - return t[] -end -function tblis_scalar(s::Float64) - t = Ref{tblis_scalar}() - tblis_init_scalar_d(t, s) - return t[] -end -function tblis_scalar(s::ComplexF32) - t = Ref{tblis_scalar}() - tblis_init_scalar_c(t, s) - return t[] -end -function tblis_scalar(s::ComplexF64) - t = Ref{tblis_scalar}() - tblis_init_scalar_z(t, s) - return t[] -end - -# tblis_tensor -# ------------ -""" - tblis_tensor(A::StridedArray{T<:BlasFloat}, szA::Vector{Int}, strA::Vector{Int}, s=one(T)) - -Initializes a tblis tensor from a strided array. This operation is deemed unsafe, in the -sense that the user is responsible for ensuring that the reference to the array and the -sizes and strides stays alive during the lifetime of the tensor. -""" -function tblis_tensor end - -function tblis_tensor(A::StridedArray{Float32,N}, szA::Vector{Int}, strA::Vector{Int}, - s::Number=one(Float32)) where {N} - t = Ref{tblis_tensor}() - if isone(s) - tblis_init_tensor_s(t, N, pointer(szA), pointer(A), pointer(strA)) - else - tblis_init_tensor_scaled_s(t, Float32(s), N, pointer(szA), pointer(A), - pointer(strA)) - end - return t[] -end -function tblis_tensor(A::StridedArray{Float64,N}, szA::Vector{Int}, strA::Vector{Int}, - s::Number=one(Float64)) where {N} - t = Ref{tblis_tensor}() - if isone(s) - tblis_init_tensor_d(t, N, pointer(szA), pointer(A), pointer(strA)) - else - tblis_init_tensor_scaled_d(t, Float64(s), N, pointer(szA), pointer(A), - pointer(strA)) - end - return t[] -end -function tblis_tensor(A::StridedArray{ComplexF32,N}, szA::Vector{Int}, strA::Vector{Int}, - s::Number=one(ComplexF32)) where {N} - t = Ref{tblis_tensor}() - if isone(s) - tblis_init_tensor_c(t, N, pointer(szA), pointer(A), pointer(strA)) - else - tblis_init_tensor_scaled_c(t, ComplexF32(s), N, pointer(szA), pointer(A), - pointer(strA)) - end - return t[] -end -function tblis_tensor(A::StridedArray{ComplexF64,N}, szA::Vector{Int}, strA::Vector{Int}, - s::Number=one(ComplexF64)) where {N} - t = Ref{tblis_tensor}() - if isone(s) - tblis_init_tensor_z(t, N, pointer(szA), pointer(A), pointer(strA)) - else - tblis_init_tensor_scaled_z(t, ComplexF64(s), N, pointer(szA), pointer(A), - pointer(strA)) - end - return t[] -end - -function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB) - return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB) -end - -function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor, - idxC) - return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC) -end - -function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar) - return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C)) -end - -end diff --git a/src/TensorOperationsTBLIS.jl b/src/TensorOperationsTBLIS.jl index 83192ed..7a6b66d 100644 --- a/src/TensorOperationsTBLIS.jl +++ b/src/TensorOperationsTBLIS.jl @@ -1,93 +1,41 @@ module TensorOperationsTBLIS using TensorOperations +using TensorOperations: StridedView, DefaultAllocator, IndexError +using TensorOperations: istrivialpermutation, BlasFloat, linearize +using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd, + argcheck_tensortrace, dimcheck_tensortrace, + argcheck_tensorcontract, dimcheck_tensorcontract +using TensorOperations: stridedtensoradd!, stridedtensortrace!, stridedtensorcontract! using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError -using LinearAlgebra: BlasFloat, rmul! +using LinearAlgebra: BlasFloat using TupleTools -include("LibTblis.jl") -using .LibTblis +include("LibTBLIS.jl") +using .LibTBLIS +using .LibTBLIS: LibTBLIS, len_type, stride_type -export tblis_set_num_threads, tblis_get_num_threads -export tblisBackend +export TBLIS + +get_num_threads() = convert(Int, LibTBLIS.tblis_get_num_threads()) +set_num_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n)) # TensorOperations #------------------ +""" + TBLIS() -struct tblisBackend <: TensorOperations.AbstractBackend end - -function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T}, - pA::Index2Tuple, conjA::Bool, - α::Number, β::Number, - ::tblisBackend) where {T<:BlasFloat} - TensorOperations.argcheck_tensoradd(C, A, pA) - TensorOperations.dimcheck_tensoradd(C, A, pA) - - szC = collect(size(C)) - strC = collect(strides(C)) - C_tblis = tblis_tensor(C, szC, strC, β) - - szA = collect(size(A)) - strA = collect(strides(A)) - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) - - einA, einC = TensorOperations.add_labels(pA) - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) - - return C -end - -function TensorOperations.tensorcontract!(C::StridedArray{T}, - A::StridedArray{T}, pA::Index2Tuple, - conjA::Bool, B::StridedArray{T}, - pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, - α::Number, β::Number, - ::tblisBackend) where {T<:BlasFloat} - TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB) - TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB) - - rmul!(C, β) # TODO: is it possible to use tblis scaling here? - szC = ndims(C) == 0 ? Int[] : collect(size(C)) - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) - C_tblis = tblis_tensor(C, szC, strC) - - szA = collect(size(A)) - strA = collect(strides(A)) - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) - - szB = collect(size(B)) - strB = collect(strides(B)) - B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1) - - einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB) - tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis, - string(einC...)) - - return C -end - -function TensorOperations.tensortrace!(C::StridedArray{T}, - A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple, - conjA::Bool, - α::Number, β::Number, - ::tblisBackend) where {T<:BlasFloat} - TensorOperations.argcheck_tensortrace(C, A, p, q) - TensorOperations.dimcheck_tensortrace(C, A, p, q) - - rmul!(C, β) # TODO: is it possible to use tblis scaling here? - szC = ndims(C) == 0 ? Int[] : collect(size(C)) - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) - C_tblis = tblis_tensor(C, szC, strC) - - szA = collect(size(A)) - strA = collect(strides(A)) - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) - - einA, einC = TensorOperations.trace_labels(p, q) +Backend for tensor operations on strided arrays (arrays that can be cast into a `StridedView`) +that uses [TBLIS](https://github.com/devinamatthews/tblis) library. This library can perform +tensor contractions without additional intermediate tensors. It does not use LinearAlgebra +or BLAS. +""" +struct TBLIS <: TensorOperations.AbstractBackend end - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) +Base.@deprecate(tblisBackend(), TBLIS()) +Base.@deprecate(tblis_get_num_threads(), get_num_threads()) +Base.@deprecate(tblis_set_num_threads(n), set_num_threads(n)) - return C -end +include("strided.jl") end # module TensorOperationsTBLIS diff --git a/src/strided.jl b/src/strided.jl new file mode 100644 index 0000000..66f79b8 --- /dev/null +++ b/src/strided.jl @@ -0,0 +1,171 @@ +#------------------------------------------------------------------------------------------- +# Force strided implementation on AbstractArray instances with TBLIS backend +#------------------------------------------------------------------------------------------- +const SV = StridedView +function TensorOperations.tensoradd!(C::AbstractArray, + A::AbstractArray, pA::Index2Tuple, conjA::Bool, + α::Number, β::Number, + backend::TBLIS, allocator=DefaultAllocator()) + # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on + if conjA + stridedtensoradd!(SV(C), conj(SV(A)), pA, α, β, backend, allocator) + else + stridedtensoradd!(SV(C), SV(A), pA, α, β, backend, allocator) + end + return C +end + +function TensorOperations.tensortrace!(C::AbstractArray, + A::AbstractArray, p::Index2Tuple, q::Index2Tuple, + conjA::Bool, + α::Number, β::Number, + backend::TBLIS, allocator=DefaultAllocator()) + # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on + if conjA + stridedtensortrace!(SV(C), conj(SV(A)), p, q, α, β, backend, allocator) + else + stridedtensortrace!(SV(C), SV(A), p, q, α, β, backend, allocator) + end + return C +end + +function TensorOperations.tensorcontract!(C::AbstractArray, + A::AbstractArray, pA::Index2Tuple, conjA::Bool, + B::AbstractArray, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, + α::Number, β::Number, + backend::TBLIS, allocator=DefaultAllocator()) + # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on + if conjA && conjB + stridedtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B)), pB, pAB, α, β, + backend, allocator) + elseif conjA + stridedtensorcontract!(SV(C), conj(SV(A)), pA, SV(B), pB, pAB, α, β, + backend, allocator) + elseif conjB + stridedtensorcontract!(SV(C), SV(A), pA, conj(SV(B)), pB, pAB, α, β, + backend, allocator) + else + stridedtensorcontract!(SV(C), SV(A), pA, SV(B), pB, pAB, α, β, + backend, allocator) + end + return C +end + +#------------------------------------------------------------------------------------------- +# StridedView implementation +#------------------------------------------------------------------------------------------- +function tblis_tensor(A::StridedView, + s::Number=one(eltype(A)), + szA::Vector{len_type}=collect(len_type, size(A)), + strA::Vector{stride_type}=collect(stride_type, strides(A))) + t₁ = LibTBLIS.tblis_tensor(A, s, szA, strA) + if A.op == conj + t₂ = LibTBLIS.tblis_tensor(t₁.type, Cint(1), t₁.scalar, t₁.data, t₁.ndim, t₁.len, + t₁.stride) + return t₂ + else + return t₁ + end +end + +function TensorOperations.stridedtensoradd!(C::StridedView{T}, + A::StridedView{T}, pA::Index2Tuple, + α::Number, β::Number, + backend::TBLIS, + allocator=DefaultAllocator()) where {T<:BlasFloat} + argcheck_tensoradd(C, A, pA) + dimcheck_tensoradd(C, A, pA) + if Base.mightalias(C, A) + throw(ArgumentError("output tensor must not be aliased with input tensor")) + end + + C_tblis = tblis_tensor(C, β) + A_tblis = tblis_tensor(A, α) + einA, einC = TensorOperations.add_labels(pA) + tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) + return C +end + +function TensorOperations.stridedtensortrace!(C::StridedView{T}, + A::StridedView{T}, + p::Index2Tuple, + q::Index2Tuple, + α::Number, β::Number, + backend::TBLIS, + allocator=DefaultAllocator()) where {T<:BlasFloat} + argcheck_tensortrace(C, A, p, q) + dimcheck_tensortrace(C, A, p, q) + + Base.mightalias(C, A) && + throw(ArgumentError("output tensor must not be aliased with input tensor")) + + C_tblis = tblis_tensor(C, β) + A_tblis = tblis_tensor(A, α) + einA, einC = TensorOperations.trace_labels(p, q) + tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) + return C +end + +function TensorOperations.stridedtensorcontract!(C::StridedView{T}, + A::StridedView{T}, pA::Index2Tuple, + B::StridedView{T}, pB::Index2Tuple, + pAB::Index2Tuple, + α::Number, β::Number, + backend::TBLIS, + allocator=DefaultAllocator()) where {T<:BlasFloat} + argcheck_tensorcontract(C, A, pA, B, pB, pAB) + dimcheck_tensorcontract(C, A, pA, B, pB, pAB) + + (Base.mightalias(C, A) || Base.mightalias(C, B)) && + throw(ArgumentError("output tensor must not be aliased with input tensor")) + + einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB) + + # tblis_tensor_mult ignores conjugation flags in A and B (and C) + if A.op == conj && B.op == conj + iszero(β) || conj!(C) + + C_tblis = tblis_tensor(C, conj(β)) + A_tblis = tblis_tensor(A, conj(α)) + B_tblis = tblis_tensor(B) + + tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), + C_tblis, string(einC...)) + conj!(C) + elseif A.op == conj + pA2 = TensorOperations.trivialpermutation(pA) + A2 = StridedView(TensorOperations.tensoralloc_add(eltype(A), A, pA2, false, + Val(true), + allocator)) + A2 = tensorcopy!(A2, A, pA2, false, α, backend, allocator) + + C_tblis = tblis_tensor(C, β) + A_tblis = tblis_tensor(A2) + B_tblis = tblis_tensor(B) + + tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), + C_tblis, string(einC...)) + elseif B.op == conj + pB2 = TensorOperations.trivialpermutation(pB) + B2 = StridedView(TensorOperations.tensoralloc_add(eltype(B), B, pB2, false, + Val(true), + allocator)) + B2 = tensorcopy!(B2, B, pB2, false, α, backend, allocator) + + C_tblis = tblis_tensor(C, β) + A_tblis = tblis_tensor(A) + B_tblis = tblis_tensor(B2) + + tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), + C_tblis, string(einC...)) + else + C_tblis = tblis_tensor(C, β) + A_tblis = tblis_tensor(A, α) + B_tblis = tblis_tensor(B) + + tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), + C_tblis, string(einC...)) + end + return C +end diff --git a/test/methods.jl b/test/methods.jl new file mode 100644 index 0000000..8dc23dc --- /dev/null +++ b/test/methods.jl @@ -0,0 +1,202 @@ +b = TBLIS() + +# test simple methods +#--------------------- +typelist = (Float32, Float64, ComplexF32, ComplexF64) +@testset "simple methods with eltype = $T" for T in typelist + @testset "tensorcopy" begin + A = randn(T, (3, 5, 4, 6)) + p = (3, 1, 4, 2) + C1 = permutedims(A, p) + C2 = @inferred tensorcopy((p...,), A, (1:4...,); backend=b) + C3 = @inferred tensorcopy(A, (p, ()), false, 1, b) + @test C1 ≈ C2 + @test C2 == C3 + @test C1 ≈ ncon(Any[A], Any[[-2, -4, -1, -3]]) + end + + @testset "tensoradd" begin + A = randn(T, (3, 5, 4, 6)) + B = randn(T, (5, 6, 3, 4)) + p = (3, 1, 4, 2) + C1 = A + permutedims(B, p) + C2 = @inferred tensoradd(A, p, B, (1:4...,); backend=b) + C3 = @inferred tensoradd(A, ((1:4...,), ()), false, B, (p, ()), false, 1, 1, b) + @test C1 ≈ C2 + @test C2 == C3 + end + + @testset "tensortrace" begin + A = randn(Float64, (50, 100, 100)) + C1 = tensortrace(A, [:a, :b, :b]) + C2 = tensortrace(A, [:a, :b, :b]; backend=b) + C3 = ncon(Any[A], Any[[-1, 1, 1]]; backend=b) + @test C1 ≈ C2 + @test C2 == C3 + A = randn(Float64, (3, 20, 5, 3, 20, 4, 5)) + C1 = tensortrace((:e, :a, :d), A, (:a, :b, :c, :d, :b, :e, :c)) + C2 = @inferred tensortrace((:e, :a, :d), A, (:a, :b, :c, :d, :b, :e, :c); backend=b) + C3 = @inferred tensortrace(A, ((6, 1, 4), ()), ((2, 3), (5, 7)), false, 1.0, b) + C4 = ncon(Any[A], Any[[-2, 1, 2, -3, 1, -1, 2]]; backend=b) + @test C1 ≈ C2 + @test C2 == C3 == C4 + end + + @testset "tensorcontract" begin + A = randn(T, (3, 20, 5, 3, 4)) + B = randn(T, (5, 6, 20, 3)) + C1 = tensorcontract((:a, :g, :e, :d, :f), A, (:a, :b, :c, :d, :e), B, + (:c, :f, :b, :g)) + C2 = @inferred tensorcontract((:a, :g, :e, :d, :f), + A, (:a, :b, :c, :d, :e), B, (:c, :f, :b, :g); + backend=b) + C3 = @inferred tensorcontract(A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), + false, ((1, 5, 3, 2, 4), ()), 1, b) + C4 = @inferred tensorcontract(A, ((1, 4, 5), (2, 3)), false, B, ((3, 1), (2, 4)), + false, ((1, 5, 3, 2, 4), ()), 1, b, + ManualAllocator()) + C5 = ncon(Any[A, B], Any[[-1, 1, 2, -4, -3], [2, -5, 1, -2]]; backend=b, + allocator=ManualAllocator()) + + @test C1 ≈ C2 + @test C2 == C3 == C4 == C5 + end + + @testset "tensorproduct" begin + A = randn(T, (5, 5, 5, 5)) + B = rand(T, (5, 5, 5, 5)) + C1 = kron(reshape(B, (25, 25)), reshape(A, (25, 25))) + C2 = reshape((@inferred tensorproduct((1, 2, 5, 6, 3, 4, 7, 8), + A, (1, 2, 3, 4), B, (5, 6, 7, 8); backend=b)), + (5 * 5 * 5 * 5, 5 * 5 * 5 * 5)) + @test C1 ≈ C2 + + A = rand(1, 2) + B = rand(4, 5) + C1 = tensorcontract((-1, -2, -3, -4), A, (-3, -1), false, B, (-2, -4), false) + C2 = tensorcontract((-1, -2, -3, -4), A, (-3, -1), false, B, (-2, -4), false; + backend=b) + C3 = tensorproduct(A, ((1, 2), ()), false, B, ((), (1, 2)), false, + ((2, 3, 1, 4), ()), 1, b) + C4 = tensorproduct(A, ((1, 2), ()), false, B, ((), (1, 2)), false, + ((2, 3, 1, 4), ()), 1, b, ManualAllocator()) + @test C1 ≈ C2 + @test C2 == C3 == C4 + end +end + +# test in-place methods +#----------------------- +# test different versions of in-place methods, +# with changing element type and with nontrivial strides +@testset "in-place methods" begin + @testset "tensorcopy!" begin + Abig = randn(Float64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 * (0:9), 2 .+ 2 * (0:6), 5 .+ 3 * (0:6), 4 .+ 3 * (0:8)) + p = (3, 1, 4, 2) + Cbig = zeros(Float64, (50, 50, 50, 50)) + C = view(Cbig, 13 .+ (0:6), 11 .+ 2 * (0:9), 7 .+ 5 * (0:8), 4 .+ 5 * (0:6)) + Acopy = tensorcopy(A, 1:4) + Ccopy = tensorcopy(C, 1:4) + pA = (p, ()) + α = randn(Float64) + tensorcopy!(C, A, pA, false, α, b) + tensorcopy!(Ccopy, Acopy, pA, false, 1.0, b) + @test C ≈ α * Ccopy + @test_throws IndexError tensorcopy!(C, A, ((1, 2, 3), ()), false, 1.0, b) + @test_throws DimensionMismatch tensorcopy!(C, A, ((1, 2, 3, 4), ()), false, 1.0, b) + @test_throws IndexError tensorcopy!(C, A, ((1, 2, 2, 3), ()), false, 1.0, b) + end + + @testset "tensoradd!" begin + Abig = randn(ComplexF32, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 * (0:9), 2 .+ 2 * (0:6), 5 .+ 4 * (0:6), 4 .+ 3 * (0:8)) + p = (3, 1, 4, 2) + Cbig = zeros(ComplexF32, (50, 50, 50, 50)) + C = view(Cbig, 13 .+ (0:6), 11 .+ 4 * (0:9), 15 .+ 4 * (0:8), 4 .+ 3 * (0:6)) + Ccopy = tensorcopy(1:4, C, 1:4) + α = randn(ComplexF32) + β = randn(ComplexF32) + tensoradd!(C, A, (p, ()), false, α, β, b) + tensoradd!(Ccopy, A, (p, ()), false, α, β) # default backend + @test C ≈ Ccopy + @test_throws IndexError tensoradd!(C, A, ((1, 2, 3), ()), false, 1.2, 0.5, b) + @test_throws DimensionMismatch tensoradd!(C, A, ((1, 2, 3, 4), ()), false, 1.2, 0.5, + b) + @test_throws IndexError tensoradd!(C, A, ((1, 1, 2, 3), ()), false, 1.2, 0.5, b) + end + + @testset "tensortrace!" begin + Abig = rand(ComplexF64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 * (0:8), 2 .+ 2 * (0:14), 5 .+ 4 * (0:6), 7 .+ 2 * (0:8)) + Bbig = rand(ComplexF64, (50, 50)) + B = view(Bbig, 13 .+ (0:14), 3 .+ 5 * (0:6)) + Acopy = tensorcopy(A, 1:4) + Bcopy = tensorcopy(B, 1:2) + α = randn(Float64) + β = randn(Float64) + tensortrace!(B, A, ((2, 3), ()), ((1,), (4,)), true, α, β, b) + tensortrace!(Bcopy, A, ((2, 3), ()), ((1,), (4,)), true, α, β) # default backend + @test B ≈ Bcopy + @test_throws IndexError tensortrace!(B, A, ((1,), ()), ((2,), (3,)), false, α, β, b) + @test_throws DimensionMismatch tensortrace!(B, A, ((1, 4), ()), ((2,), (3,)), false, + α, β, b) + @test_throws IndexError tensortrace!(B, A, ((1, 4), ()), ((1, 1), (4,)), false, α, + β, b) + @test_throws IndexError tensortrace!(B, A, ((1, 4), ()), ((1,), (3,)), false, + α, β, b) + end + + bref = TensorOperations.DefaultBackend() # reference backend + @testset "tensorcontract! with allocator = $allocator" for allocator in + (DefaultAllocator(), + ManualAllocator()) + Abig = rand(ComplexF64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 * (0:8), 2 .+ 2 * (0:14), 5 .+ 4 * (0:6), 7 .+ 2 * (0:8)) + Bbig = rand(ComplexF64, (50, 50, 50)) + B = view(Bbig, 3 .+ 5 * (0:6), 7 .+ 2 * (0:7), 13 .+ (0:14)) + Cbig = rand(ComplexF64, (40, 40, 40)) + C = view(Cbig, 3 .+ 2 * (0:8), 13 .+ (0:8), 7 .+ 3 * (0:7)) + Acopy = tensorcopy(A, 1:4) + Bcopy = tensorcopy(B, 1:3) + Ccopy = tensorcopy(C, 1:3) + α = randn(ComplexF64) + β = randn(ComplexF64) + tensorcontract!(C, A, ((4, 1), (2, 3)), false, B, ((3, 1), (2,)), true, + ((1, 2, 3), ()), α, β, b, allocator) + tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), false, B, ((3, 1), (2,)), true, + ((1, 2, 3), ()), α, β, bref, allocator) + @test C ≈ Ccopy + + Ccopy = tensorcopy(C, 1:3) + tensorcontract!(C, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), true, + ((1, 2, 3), ()), α, β, b, allocator) + tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), true, + ((1, 2, 3), ()), α, β, bref, allocator) + @test C ≈ Ccopy + + Ccopy = tensorcopy(C, 1:3) + tensorcontract!(C, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), false, + ((1, 2, 3), ()), α, β, b, allocator) + tensorcontract!(Ccopy, A, ((4, 1), (2, 3)), true, B, ((3, 1), (2,)), false, + ((1, 2, 3), ()), α, β, bref, allocator) + @test C ≈ Ccopy + + @test_throws IndexError tensorcontract!(C, + A, ((4, 1), (2, 4)), false, + B, ((1, 3), (2,)), false, + ((1, 2, 3), ()), α, β, b) + @test_throws IndexError tensorcontract!(C, + A, ((4, 1), (2, 3)), false, + B, ((1, 3), ()), false, + ((1, 2, 3), ()), α, β, b) + @test_throws IndexError tensorcontract!(C, + A, ((4, 1), (2, 3)), false, + B, ((1, 3), (2,)), false, + ((1, 2), ()), α, β, b) + @test_throws DimensionMismatch tensorcontract!(C, + A, ((4, 1), (2, 3)), false, + B, ((1, 3), (2,)), false, + ((1, 3, 2), ()), α, β, b) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4efbc95..74de148 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,136 +1,20 @@ -using TensorOperations -using TensorOperationsTBLIS -using Test -using LinearAlgebra: norm - -const tblisbackend = tblisBackend() -@testset "elementary operations" verbose = true begin - @testset "tensorcopy" begin - A = randn(Float32, (3, 5, 4, 6)) - @tensor C1[4, 1, 3, 2] := A[1, 2, 3, 4] - @tensor backend = tblisbackend C2[4, 1, 3, 2] := A[1, 2, 3, 4] - @test C2 ≈ C1 - end - - @testset "tensoradd" begin - A = randn(Float32, (5, 6, 3, 4)) - B = randn(Float32, (5, 6, 3, 4)) - α = randn(Float32) - @tensor C1[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d] - @tensor backend = tblisbackend C2[a, b, c, d] := A[a, b, c, d] + α * B[a, b, c, d] - @test collect(C2) ≈ C1 - - C = randn(ComplexF32, (5, 6, 3, 4)) - D = randn(ComplexF32, (5, 3, 4, 6)) - β = randn(ComplexF32) - @tensor E1[a, b, c, d] := C[a, b, c, d] + β * conj(D[a, c, d, b]) - @tensor backend = tblisbackend E2[a, b, c, d] := C[a, b, c, d] + - β * conj(D[a, c, d, b]) - @test collect(E2) ≈ E1 - end - - @testset "tensortrace" begin - A = randn(Float32, (5, 10, 10)) - @tensor B1[a] := A[a, b′, b′] - @tensor backend = tblisbackend B2[a] := A[a, b′, b′] - @test B2 ≈ B1 - - C = randn(ComplexF32, (3, 20, 5, 3, 20, 4, 5)) - @tensor D1[e, a, d] := C[a, b, c, d, b, e, c] - @tensor backend = tblisbackend D2[e, a, d] := C[a, b, c, d, b, e, c] - @test D2 ≈ D1 - - @tensor D3[a, e, d] := conj(C[a, b, c, d, b, e, c]) - @tensor backend = tblisbackend D4[a, e, d] := conj(C[a, b, c, d, b, e, c]) - @test D4 ≈ D3 - - α = randn(ComplexF32) - @tensor D5[d, e, a] := α * C[a, b, c, d, b, e, c] - @tensor backend = tblisbackend D6[d, e, a] := α * C[a, b, c, d, b, e, c] - @test D6 ≈ D5 - end - - @testset "tensorcontract" begin - A = randn(Float32, (3, 20, 5, 3, 4)) - B = randn(Float32, (5, 6, 20, 3)) - @tensor C1[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g] - @tensor backend = tblisbackend C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g] - @test C2 ≈ C1 - - D = randn(ComplexF64, (3, 3, 3)) - E = rand(ComplexF64, (3, 3, 3)) - @tensor F1[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f]) - @tensor backend = tblisbackend F2[a, b, c, d, e, f] := D[a, b, c] * conj(E[d, e, f]) - @test F2 ≈ F1 atol = 1e-12 +using TensorOperations, TensorOperationsTBLIS +using TensorOperations: IndexError +using TensorOperations: DefaultAllocator, ManualAllocator +using Test, Random, LinearAlgebra +Random.seed!(1234567) + +@testset "TensorOperationsTBLIS.jl" begin + TensorOperationsTBLIS.set_num_threads(1) + @testset "method syntax" verbose = true begin + include("methods.jl") end -end - -@testset "more complicated expressions" verbose = true begin - Da, Db, Dc, Dd, De, Df, Dg, Dh = 10, 15, 4, 8, 6, 7, 3, 2 - A = rand(ComplexF64, (Dc, Da, Df, Da, De, Db, Db, Dg)) - B = rand(ComplexF64, (Dc, Dh, Dg, De, Dd)) - C = rand(ComplexF64, (Dd, Dh, Df)) - α = rand(ComplexF64) - # α = 1 - - @tensor D1[d, f, h] := A[c, a, f, a, e, b, b, g] * B[c, h, g, e, d] + α * C[d, h, f] - @tensor backend = tblisbackend D2[d, f, h] := A[c, a, f, a, e, b, b, g] * - B[c, h, g, e, d] + - α * C[d, h, f] - @test D2 ≈ D1 rtol = 1e-8 - @test norm(vec(D1)) ≈ sqrt(abs(@tensor D1[d, f, h] * conj(D1[d, f, h]))) - @test norm(D2) ≈ - sqrt(abs(@tensor backend = tblisbackend D2[d, f, h] * conj(D2[d, f, h]))) - - @testset "readme example" begin - α = randn() - A = randn(5, 5, 5, 5, 5, 5) - B = randn(5, 5, 5) - C = randn(5, 5, 5) - D = zeros(5, 5, 5) - D2 = copy(D) - @tensor begin - D[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] - E[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] - end - @tensor backend = tblisbackend begin - D2[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] - E2[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] - end - @test D2 ≈ D - @test E2 ≈ E + @test TensorOperationsTBLIS.get_num_threads() == 1 + TensorOperationsTBLIS.set_num_threads(2) + @testset "macro with index notation" verbose = true begin + include("tensor.jl") end - @testset "tensor network examples ($T)" for T in - (Float32, Float64, ComplexF32, ComplexF64) - D1, D2, D3 = 30, 40, 20 - d1, d2 = 2, 3 - - A1 = randn(T, D1, d1, D2) - A2 = randn(T, D2, d2, D3) - ρₗ = randn(T, D1, D1) - ρᵣ = randn(T, D3, D3) - H = randn(T, d1, d2, d1, d2) - - @tensor begin - HrA12[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] * - H[s1, s2, t1, t2] - end - @tensor backend = tblisbackend begin - HrA12′[a, s1, s2, c] := ρₗ[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * ρᵣ[c', c] * - H[s1, s2, t1, t2] - end - @test HrA12′ ≈ HrA12 - - @tensor begin - E1 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] * - conj(A1[a', t, b']) * conj(A2[b', t', c']) - end - @tensor backend = tblisbackend begin - E2 = ρₗ[a', a] * A1[a, s, b] * A2[b, s', c] * ρᵣ[c, c'] * H[t, t', s, s'] * - conj(A1[a', t, b']) * conj(A2[b', t', c']) - end - @test E2 ≈ E1 - end + @test TensorOperationsTBLIS.get_num_threads() == 2 end diff --git a/test/tensor.jl b/test/tensor.jl new file mode 100644 index 0000000..74b0b80 --- /dev/null +++ b/test/tensor.jl @@ -0,0 +1,197 @@ +# test index notation using @tensor macro +#----------------------------------------- +b = TBLIS() + +@testset "tensorcontract 1" begin + A = randn(Float64, (3, 5, 4, 6)) + p = (4, 1, 3, 2) + C1 = permutedims(A, p) + @tensor backend = b C2[4, 1, 3, 2] := A[1, 2, 3, 4] + @test C1 ≈ C2 + + B = randn(Float64, (5, 6, 3, 4)) + p = [3, 1, 4, 2] + @tensor backend = b C1[3, 1, 4, 2] := A[3, 1, 4, 2] + B[1, 2, 3, 4] + C2 = A + permutedims(B, p) + @test C1 ≈ C2 + @test_throws DimensionMismatch begin + @tensor backend = b C[1, 2, 3, 4] := A[1, 2, 3, 4] + B[1, 2, 3, 4] + end + + A = randn(Float64, (50, 100, 100)) + @tensor backend = b C1[a] := A[a, b', b'] + @tensor C2[a] := A[a, b', b'] + @test C1 ≈ C2 + + A = randn(Float64, (3, 20, 5, 3, 20, 4, 5)) + @tensor backend = b C1[e, a, d] := A[a, b, c, d, b, e, c] + @tensor C2[e, a, d] := A[a, b, c, d, b, e, c] + @test C1 ≈ C2 + + A = randn(Float64, (3, 20, 5, 3, 4)) + B = randn(Float64, (5, 6, 20, 3)) + @tensor backend = b C1[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g] + @tensor C2[a, g, e, d, f] := A[a, b, c, d, e] * B[c, f, b, g] + @test C1 ≈ C2 +end + +@testset "tensorcontract 2" begin + A = randn(ComplexF32, (5, 5, 5, 5)) + B = rand(ComplexF32, (5, 5, 5, 5)) + @tensor backend = b C1[1, 2, 5, 6, 3, 4, 7, 8] := A[1, 2, 3, 4] * B[5, 6, 7, 8] + @tensor C2[1, 2, 5, 6, 3, 4, 7, 8] := A[1, 2, 3, 4] * B[5, 6, 7, 8] + @test C1 ≈ C2 + @test_throws IndexError begin + @tensor backend = b C[a, b, c, d, e, f, g, i] := A[a, b, c, d] * B[e, f, g, h] + end +end + +@testset "tensorcontract 3" begin + Da, Db, Dc, Dd, De, Df, Dg, Dh = 10, 15, 4, 8, 6, 7, 3, 2 + A = rand(ComplexF64, (Da, Dc, Df, Da, De, Db, Db, Dg)) + B = rand(ComplexF64, (Dc, Dh, Dg, De, Dd)) + C = rand(ComplexF64, (Dd, Dh, Df)) + @tensor backend = b D1[d, f, h] := A[a, c, f, a, e, b, b, g] * B[c, h, g, e, d] + + 0.5 * C[d, h, f] + @tensor D2[d, f, h] := A[a, c, f, a, e, b, b, g] * B[c, h, g, e, d] + 0.5 * C[d, h, f] + @test D1 ≈ D2 + @test norm(vec(D1)) ≈ + sqrt(abs((@tensor backend = b tensorscalar(D1[d, f, h] * conj(D1[d, f, h]))))) +end + +@testset "views" begin + p = [3, 1, 4, 2] + Abig = randn(Float32, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 .* (0:9), 2 .+ 2 .* (0:6), 5 .+ 4 .* (0:6), 4 .+ 3 .* (0:8)) + Cbig = zeros(Float32, (50, 50, 50, 50)) + C = view(Cbig, 13 .+ (0:6), 11 .+ 4 .* (0:9), 15 .+ 4 .* (0:8), 4 .+ 3 .* (0:6)) + Ccopy = copy(C) + @tensor backend = b C[3, 1, 4, 2] = A[1, 2, 3, 4] + @tensor Ccopy[3, 1, 4, 2] = A[1, 2, 3, 4] + @test C ≈ Ccopy +end + +@testset "views 2" begin + p = [3, 1, 4, 2] + Abig = randn(ComplexF64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 .* (0:9), 2 .+ 2 .* (0:6), 5 .+ 4 .* (0:6), 4 .+ 3 .* (0:8)) + Cbig = zeros(ComplexF64, (50, 50, 50, 50)) + C = view(Cbig, 13 .+ (0:6), 11 .+ 4 .* (0:9), 15 .+ 4 .* (0:8), 4 .+ 3 .* (0:6)) + α = randn(Float64) + β = randn(Float64) + @tensor backend = b D[3, 1, 4, 2] := β * C[3, 1, 4, 2] + α * A[1, 2, 3, 4] + @tensor Dcopy[3, 1, 4, 2] := β * C[3, 1, 4, 2] + α * A[1, 2, 3, 4] + @test D ≈ Dcopy +end + +@testset "views 3" begin + Abig = rand(Float64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 .* (0:8), 2 .+ 2 .* (0:14), 5 .+ 4 .* (0:6), 7 .+ 2 .* (0:8)) + Bbig = rand(Float64, (50, 50)) + B = view(Bbig, 13 .+ (0:14), 3 .+ 5 .* (0:6)) + Bcopy = copy(B) + α = randn(Float64) + @tensor backend = b B[b, c] += α * A[a, b, c, a] + @tensor Bcopy[b, c] += α * A[a, b, c, a] + @test B ≈ Bcopy +end + +@testset "views 4" begin + Abig = rand(ComplexF64, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 .* (0:8), 2 .+ 2 .* (0:14), 5 .+ 4 .* (0:6), 7 .+ 2 .* (0:8)) + Bbig = rand(ComplexF64, (50, 50, 50)) + B = view(Bbig, 3 .+ 5 .* (0:6), 7 .+ 2 .* (0:7), 13 .+ (0:14)) + Cbig = rand(ComplexF64, (40, 40, 40)) + C = view(Cbig, 3 .+ 2 .* (0:8), 13 .+ (0:8), 7 .+ 3 .* (0:7)) + Ccopy = copy(C) + α = randn(Float64) + @tensor backend = b C[d, a, e] -= α * A[a, b, c, d] * conj(B[c, e, b]) + @tensor Ccopy[d, a, e] -= α * A[a, b, c, d] * conj(B[c, e, b]) + @test C ≈ Ccopy +end + +@testset "Float32 views" begin + α = randn(Float64) + Abig = rand(ComplexF32, (30, 30, 30, 30)) + A = view(Abig, 1 .+ 3 .* (0:8), 2 .+ 2 .* (0:14), 5 .+ 4 .* (0:6), 7 .+ 2 .* (0:8)) + Bbig = rand(ComplexF32, (50, 50, 50)) + B = view(Bbig, 3 .+ 5 .* (0:6), 7 .+ 2 .* (0:7), 13 .+ (0:14)) + Cbig = rand(ComplexF32, (40, 40, 40)) + C = view(Cbig, 3 .+ 2 .* (0:8), 13 .+ (0:8), 7 .+ 3 .* (0:7)) + Ccopy = copy(C) + @tensor backend = b C[d, a, e] += α * A[a, b, c, d] * conj(B[c, e, b]) + @tensor Ccopy[d, a, e] += α * A[a, b, c, d] * conj(B[c, e, b]) + @test C ≈ Ccopy +end + +# Simple function example +# @tensor function f(A, b) +# w[x] := (1 // 2) * A[x, y] * b[y] +# return w +# end +# for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat) +# A = rand(T, 10, 10) +# b = rand(T, 10) +# @test f(A, b) ≈ (1 // 2) * A * b +# end + +# Example from README.md +@testset "README example" begin + α = randn() + A = randn(5, 5, 5, 5, 5, 5) + B = randn(5, 5, 5) + C = randn(5, 5, 5) + D = zeros(5, 5, 5) + @tensor backend = b begin + D[a, b, c] = A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] + E[a, b, c] := A[a, e, f, c, f, g] * B[g, b, e] + α * C[c, a, b] + end + @test D == E +end + +# Some tensor network examples +scalartypelist = (Float32, Float64, ComplexF32, ComplexF64) +@testset "tensor network examples ($T)" for T in scalartypelist + D1, D2, D3 = 30, 40, 20 + d1, d2 = 2, 3 + A1 = rand(T, D1, d1, D2) .- 1 // 2 + A2 = rand(T, D2, d2, D3) .- 1 // 2 + rhoL = rand(T, D1, D1) .- 1 // 2 + rhoR = rand(T, D3, D3) .- 1 // 2 + H = rand(T, d1, d2, d1, d2) .- 1 // 2 + A12 = reshape(reshape(A1, D1 * d1, D2) * reshape(A2, D2, d2 * D3), (D1, d1, d2, D3)) + @tensor HrA12[a, s1, s2, c] := rhoL[a, a'] * A1[a', t1, b] * A2[b, t2, c'] * + rhoR[c', c] * H[s1, s2, t1, t2] + E = dot(A12, HrA12) + @tensor backend = b HrA12′[a, s1, s2, c] := rhoL[a, a'] * A1[a', t1, b] * + A2[b, t2, c'] * + rhoR[c', c] * H[s1, s2, t1, t2] + @tensor backend = b HrA12′′[:] := rhoL[-1, 1] * H[-2, -3, 4, 5] * A2[2, 5, 3] * + rhoR[3, -4] * + A1[1, 4, 2] # should be contracted in exactly same order + @tensor backend = b order = (a', b, c', t1, t2) begin + HrA12′′′[a, s1, s2, c] := rhoL[a, a'] * H[s1, s2, t1, t2] * A2[b, t2, c'] * + rhoR[c', c] * A1[a', t1, b] # should be contracted in exactly same order + end + @tensor backend = b opt = true HrA12′′′′[:] := rhoL[-1, 1] * H[-2, -3, 4, 5] * + A2[2, 5, 3] * + rhoR[3, -4] * A1[1, 4, 2] + + @test HrA12′ == HrA12′′ == HrA12′′′ # should be exactly equal + @test HrA12 ≈ HrA12′ + @test HrA12 ≈ HrA12′′′′ + @test HrA12′′ ≈ ncon([rhoL, H, A2, rhoR, A1], + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; backend=b) + @test HrA12′′ == @ncon([rhoL, H, A2, rhoR, A1], + [[-1, 1], [-2, -3, 4, 5], [2, 5, 3], [3, -4], [1, 4, 2]]; + order=[1, 2, 3, 4, 5], output=[-1, -2, -3, -4], backend=b) + @test E ≈ + @tensor backend = b tensorscalar(rhoL[a', a] * A1[a, s, b] * A2[b, s', c] * + rhoR[c, c'] * + H[t, t', s, s'] * conj(A1[a', t, b']) * + conj(A2[b', t', c'])) + @test E ≈ @ncon([rhoL, A1, A2, rhoR, H, conj(A1), conj(A2)], + [[5, 1], [1, 2, 3], [3, 4, 9], [9, 10], [6, 8, 2, 4], [5, 6, 7], + [7, 8, 10]]; backend=b) + # this implicitly tests that `ncon` returns a scalar if no open indices +end