From 496bf99f9898a5d4e5256f90290303e89cc42f69 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Fri, 22 Jul 2022 19:25:16 +0200 Subject: [PATCH] Implement zero(::HYPREVector) and copy!(::Vector, ::HYPREVector). --- src/HYPRE.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 8 ++++++++ 2 files changed, 51 insertions(+) diff --git a/src/HYPRE.jl b/src/HYPRE.jl index 6219116..de6af4d 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -88,6 +88,39 @@ function Internals.assemble_vector(b::HYPREVector) return b end +function Internals.get_proc_rows(b::HYPREVector) + jlower_ref = Ref{HYPRE_BigInt}() + jupper_ref = Ref{HYPRE_BigInt}() + @check HYPRE_IJVectorGetLocalRange(b.IJVector, jlower_ref, jupper_ref) + jlower = jlower_ref[] + jupper = jupper_ref[] + return jlower, jupper +end + +function Internals.get_comm(b::HYPREVector) + # The MPI communicator is (currently) the first field of the struct: + # https://github.com/hypre-space/hypre/blob/48de53e675af0e23baf61caa73d89fd9f478f453/src/IJ_mv/IJ_vector.h#L23 + # Fingers crossed this doesn't change! + @assert b.IJVector != C_NULL + comm = unsafe_load(Ptr{MPI.Comm}(b.IJVector)) + return comm +end + +function Base.zero(b::HYPREVector) + jlower, jupper = Internals.get_proc_rows(b) + comm = Internals.get_comm(b) + x = Internals.init_vector(comm, jlower, jupper) + # TODO All values 0 by default? Looks like it... Work in progress patch to hypre to + # support IJVectorSetConstantValues analoguous to IJMatrixSetConstantValues. + nvalues = jupper - jlower + 1 + indices = collect(HYPRE_BigInt, jlower:jupper) + values = zeros(HYPRE_Complex, nvalues) + @check HYPRE_IJVectorSetValues(x.IJVector, nvalues, indices, values) + # Finalize and return + Internals.assemble_vector(x) + return x +end + ###################################### # SparseMatrixCS(C|R) -> HYPREMatrix # ###################################### @@ -194,6 +227,16 @@ function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD) return b end +function Base.copy!(x::Vector, h::HYPREVector) + ilower, iupper = Internals.get_proc_rows(h) + nvalues = iupper - ilower + 1 + if length(x) != nvalues + throw(ArgumentError("different lengths")) + end + indices = collect(HYPRE_BigInt, ilower:iupper) + @check HYPRE_IJVectorGetValues(h.IJVector, nvalues, indices, x) + return x +end ################################################## # PartitionedArrays.PSparseMatrix -> HYPREMatrix # diff --git a/test/runtests.jl b/test/runtests.jl index ac1bc73..b8b0cc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -173,6 +173,14 @@ end Internals.assemble_vector(h) @test h.IJVector != HYPRE_IJVector(C_NULL) @test h.ParVector != HYPRE_ParVector(C_NULL) + + # Base.zero(::HYPREVector) and Base.copy!(::Vector, HYPREVector) + b = rand(10) + h = HYPREVector(b, 1, 10) + z = zero(h) + b′ = copy!(b, z) + @test b === b′ + @test iszero(b) end @testset "HYPREVector(::Vector)" begin