Browse Source

Implement copy!(::HYPREVector, ::(P)Vector) for reusing an allocated HYPREVector.

pull/5/head
Fredrik Ekre 3 years ago
parent
commit
a8881f6adc
  1. 86
      src/HYPRE.jl
  2. 11
      src/Internals.jl
  3. 16
      test/runtests.jl

86
src/HYPRE.jl

@ -269,15 +269,33 @@ HYPREVector(x::Vector, ilower=1, iupper=length(x)) =
HYPREVector(MPI.COMM_SELF, x, ilower, iupper) HYPREVector(MPI.COMM_SELF, x, ilower, iupper)
# TODO: Other eltypes could be support by using a intermediate buffer # TODO: Other eltypes could be support by using a intermediate buffer
function Base.copy!(x::Vector{HYPRE_Complex}, h::HYPREVector) function Base.copy!(dst::Vector{HYPRE_Complex}, src::HYPREVector)
ilower, iupper = Internals.get_proc_rows(h) ilower, iupper = Internals.get_proc_rows(src)
nvalues = iupper - ilower + 1 nvalues = iupper - ilower + 1
if length(x) != nvalues if length(dst) != nvalues
throw(ArgumentError("different lengths")) throw(ArgumentError("length of dst and src does not match"))
end end
indices = collect(HYPRE_BigInt, ilower:iupper) indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorGetValues(h.ijvector, nvalues, indices, x) @check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, dst)
return x return dst
end
function Base.copy!(dst::HYPREVector, src::Vector{HYPRE_Complex})
ilower, iupper = Internals.get_proc_rows(dst)
nvalues = iupper - ilower + 1
if length(src) != nvalues
throw(ArgumentError("length of dst and src does not match"))
end
# Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector)
# Set all the values
indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, src)
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed.
return dst
end end
################################################## ##################################################
@ -454,24 +472,54 @@ function HYPREVector(v::PVector)
return b return b
end end
# TODO: Other eltypes could be support by using a intermediate buffer function Internals.copy_check(dst::HYPREVector, src::PVector)
function Base.copy!(v::PVector{HYPRE_Complex}, h::HYPREVector) il_dst, iu_dst = Internals.get_proc_rows(dst)
ilower_v, iupper_v = Internals.get_proc_rows(v) il_src, iu_src = Internals.get_proc_rows(src)
ilower_h, iupper_h = Internals.get_proc_rows(h) if il_dst != il_src && iu_dst != iu_src
if ilower_v != ilower_h && iupper_v != iupper_h
# TODO: Why require this? # TODO: Why require this?
throw(ArgumentError("mismatch in owned rows")) throw(ArgumentError(
"row owner mismatch between dst ($(il_dst:iu_dst)) and src ($(il_dst:iu_dst))"
))
end
end
# TODO: Other eltypes could be support by using a intermediate buffer
function Base.copy!(dst::PVector{HYPRE_Complex}, src::HYPREVector)
Internals.copy_check(src, dst)
map_parts(dst.values, dst.owned_values, dst.rows.partition) do vv, _, vr
il_src_part = vr.lid_to_gid[vr.oid_to_lid.start]
iu_src_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iu_src_part - il_src_part + 1)
indices = collect(HYPRE_BigInt, il_src_part:iu_src_part)
# Assumption: the dst vector is assembled, and should thus have 0s on the ghost
# entries (??). If this is not true, we must call fill!(vv, 0) here. This should be
# fairly cheap anyway, so might as well do it...
fill!(vv, 0)
# TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, vv)
end
return dst
end end
map_parts(v.values, v.owned_values, v.rows.partition) do vv, _, vr
ilower_v_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_v_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_v_part - ilower_v_part + 1)
indices = collect(HYPRE_BigInt, ilower_v_part:iupper_v_part)
function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex})
Internals.copy_check(dst, src)
# Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector)
map_parts(src.values, src.owned_values, src.rows.partition) do vv, _, vr
ilower_src_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_src_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1)
indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part)
# TODO: Safe to use vv here? Owned values are always first? # TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorGetValues(h.ijvector, nvalues, indices, vv) @check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, vv)
end end
return v # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed.
return dst
end end
# Solver interface # Solver interface

11
src/Internals.jl

@ -2,16 +2,17 @@
module Internals module Internals
function assemble_matrix end
function assemble_vector end
function check_n_rows end function check_n_rows end
function to_hypre_data end function copy_check end
function get_comm end function get_comm end
function get_proc_rows end function get_proc_rows end
function assemble_matrix end
function assemble_vector end
function set_options end function set_options end
function solve_func end
function setup_func end
function set_precond end function set_precond end
function set_precond_defaults end function set_precond_defaults end
function setup_func end
function solve_func end
function to_hypre_data end
end # module Internals end # module Internals

16
test/runtests.jl

@ -231,6 +231,15 @@ end
@test indices::Vector{HYPRE_Int} == collect(1:10) @test indices::Vector{HYPRE_Int} == collect(1:10)
@test values::Vector{HYPRE_Complex} == b # == for other eltype @test values::Vector{HYPRE_Complex} == b # == for other eltype
@test_throws ArgumentError Internals.to_hypre_data([1, 2], ilower, iupper) @test_throws ArgumentError Internals.to_hypre_data([1, 2], ilower, iupper)
# Copying Vector -> HYPREVector
b = rand(10)
b2 = zeros(10)
h = HYPREVector(b2)
h′ = copy!(h, b)
@test h === h′
copy!(b2, h)
@test b == b2
end end
@testset "HYPREVector(::PVector)" begin @testset "HYPREVector(::PVector)" begin
@ -256,6 +265,13 @@ end
pbc = fill!(copy(pb), 0) pbc = fill!(copy(pb), 0)
copy!(pbc, H) copy!(pbc, H)
@test tomain(pbc) == tomain(pb) @test tomain(pbc) == tomain(pb)
pb2 = 2 * pb
H′ = copy!(H, pb2)
@test H === H′
copy!(pbc, H)
@test tomain(pbc) == 2 * tomain(pb)
# MPI backend # MPI backend
backend = MPIBackend() backend = MPIBackend()
parts = get_part_ids(backend, 1) parts = get_part_ids(backend, 1)

Loading…
Cancel
Save