Browse Source

Implement copy!(v::PVector, h::HYPREVector).

fe/wip
Fredrik Ekre 3 years ago
parent
commit
85270fd297
  1. 28
      src/HYPRE.jl
  2. 6
      test/runtests.jl

28
src/HYPRE.jl

@ -330,7 +330,7 @@ end
############################################ ############################################
# PartitionedArrays.PVector -> HYPREVector # # PartitionedArrays.PVector -> HYPREVector #
############################################ ############################################
#
function HYPREVector(v::PVector) function HYPREVector(v::PVector)
# Use the same communicator as the matrix # Use the same communicator as the matrix
comm = Internals.get_comm(v) comm = Internals.get_comm(v)
@ -369,4 +369,30 @@ function HYPREVector(v::PVector)
return b return b
end end
function Base.copy!(v::PVector, h::HYPREVector)
if eltype(v) !== HYPRE_Complex
throw(ArgumentError("mismatching element types"))
end
ilower_v, iupper_v = Internals.get_proc_rows(v)
ilower_h_ref = Ref{HYPRE_BigInt}(0)
iupper_h_ref = Ref{HYPRE_BigInt}(0)
@check HYPRE_IJVectorGetLocalRange(h.IJVector, ilower_h_ref, iupper_h_ref)
ilower_h = ilower_h_ref[]
iupper_h = iupper_h_ref[]
if ilower_v != ilower_h && iupper_v != iupper_h
# TODO: Why require this?
throw(ArgumentError("mismatch in owned rows"))
end
map_parts(v.values, v.owned_values, v.rows.partition) do vv, vo, vr
ilower_v_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_v_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_v_part - ilower_v_part + 1)
indices = collect(HYPRE_BigInt, ilower_v_part:iupper_v_part)
# TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorGetValues(h.IJVector, nvalues, indices, vv)
end
return v
end
end # module HYPRE end # module HYPRE

6
test/runtests.jl

@ -218,6 +218,9 @@ end
H = HYPREVector(pb) H = HYPREVector(pb)
@test H.IJVector != HYPRE_IJVector(C_NULL) @test H.IJVector != HYPRE_IJVector(C_NULL)
@test H.ParVector != HYPRE_ParVector(C_NULL) @test H.ParVector != HYPRE_ParVector(C_NULL)
pbc = fill!(copy(pb), 0)
copy!(pbc, H)
@test tomain(copy(pbc)) == tomain(copy(pb))
# MPI backend # MPI backend
backend = MPIBackend() backend = MPIBackend()
parts = get_part_ids(backend, 1) parts = get_part_ids(backend, 1)
@ -232,4 +235,7 @@ end
H = HYPREVector(pb) H = HYPREVector(pb)
@test H.IJVector != HYPRE_IJVector(C_NULL) @test H.IJVector != HYPRE_IJVector(C_NULL)
@test H.ParVector != HYPRE_ParVector(C_NULL) @test H.ParVector != HYPRE_ParVector(C_NULL)
pbc = fill!(copy(pb), 0)
copy!(pbc, H)
@test tomain(copy(pbc)) == tomain(copy(pb))
end end

Loading…
Cancel
Save