Browse Source

Implement zero(::HYPREVector) and copy!(::Vector, ::HYPREVector).

fe/wip
Fredrik Ekre 3 years ago
parent
commit
496bf99f98
  1. 43
      src/HYPRE.jl
  2. 8
      test/runtests.jl

43
src/HYPRE.jl

@ -88,6 +88,39 @@ function Internals.assemble_vector(b::HYPREVector)
return b return b
end end
function Internals.get_proc_rows(b::HYPREVector)
jlower_ref = Ref{HYPRE_BigInt}()
jupper_ref = Ref{HYPRE_BigInt}()
@check HYPRE_IJVectorGetLocalRange(b.IJVector, jlower_ref, jupper_ref)
jlower = jlower_ref[]
jupper = jupper_ref[]
return jlower, jupper
end
function Internals.get_comm(b::HYPREVector)
# The MPI communicator is (currently) the first field of the struct:
# https://github.com/hypre-space/hypre/blob/48de53e675af0e23baf61caa73d89fd9f478f453/src/IJ_mv/IJ_vector.h#L23
# Fingers crossed this doesn't change!
@assert b.IJVector != C_NULL
comm = unsafe_load(Ptr{MPI.Comm}(b.IJVector))
return comm
end
function Base.zero(b::HYPREVector)
jlower, jupper = Internals.get_proc_rows(b)
comm = Internals.get_comm(b)
x = Internals.init_vector(comm, jlower, jupper)
# TODO All values 0 by default? Looks like it... Work in progress patch to hypre to
# support IJVectorSetConstantValues analoguous to IJMatrixSetConstantValues.
nvalues = jupper - jlower + 1
indices = collect(HYPRE_BigInt, jlower:jupper)
values = zeros(HYPRE_Complex, nvalues)
@check HYPRE_IJVectorSetValues(x.IJVector, nvalues, indices, values)
# Finalize and return
Internals.assemble_vector(x)
return x
end
###################################### ######################################
# SparseMatrixCS(C|R) -> HYPREMatrix # # SparseMatrixCS(C|R) -> HYPREMatrix #
###################################### ######################################
@ -194,6 +227,16 @@ function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD)
return b return b
end end
function Base.copy!(x::Vector, h::HYPREVector)
ilower, iupper = Internals.get_proc_rows(h)
nvalues = iupper - ilower + 1
if length(x) != nvalues
throw(ArgumentError("different lengths"))
end
indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorGetValues(h.IJVector, nvalues, indices, x)
return x
end
################################################## ##################################################
# PartitionedArrays.PSparseMatrix -> HYPREMatrix # # PartitionedArrays.PSparseMatrix -> HYPREMatrix #

8
test/runtests.jl

@ -173,6 +173,14 @@ end
Internals.assemble_vector(h) Internals.assemble_vector(h)
@test h.IJVector != HYPRE_IJVector(C_NULL) @test h.IJVector != HYPRE_IJVector(C_NULL)
@test h.ParVector != HYPRE_ParVector(C_NULL) @test h.ParVector != HYPRE_ParVector(C_NULL)
# Base.zero(::HYPREVector) and Base.copy!(::Vector, HYPREVector)
b = rand(10)
h = HYPREVector(b, 1, 10)
z = zero(h)
b′ = copy!(b, z)
@test b === b′
@test iszero(b)
end end
@testset "HYPREVector(::Vector)" begin @testset "HYPREVector(::Vector)" begin

Loading…
Cancel
Save