Browse Source

Fix to MPI PArray copy

pull/16/head
Olav Møyner 3 years ago
parent
commit
85f060123c
  1. 12
      src/HYPRE.jl

12
src/HYPRE.jl

@ -459,15 +459,16 @@ function Internals.get_proc_rows(A::Union{PSparseMatrix, PVector})
end end
ilower::HYPRE_BigInt = typemax(HYPRE_BigInt) ilower::HYPRE_BigInt = typemax(HYPRE_BigInt)
iupper::HYPRE_BigInt = typemin(HYPRE_BigInt) iupper::HYPRE_BigInt = typemin(HYPRE_BigInt)
low_high = map(r) do a map(r) do a
# This is a map over the local process' owned indices. For MPI it will
# be a single value but for DebugArray / Array it will have multiple
# values.
o_to_g = own_to_global(a) o_to_g = own_to_global(a)
ilower_part = o_to_g[1] ilower_part = o_to_g[1]
iupper_part = o_to_g[end] iupper_part = o_to_g[end]
return ilower_part, iupper_part ilower = min(ilower, convert(HYPRE_BigInt, ilower_part))
iupper = max(iupper, convert(HYPRE_BigInt, iupper_part))
end end
low, high = tuple_of_arrays(low_high)
ilower = convert(HYPRE_BigInt, reduce(min, low))
iupper = convert(HYPRE_BigInt, reduce(max, high))
return ilower, iupper return ilower, iupper
end end
@ -553,7 +554,6 @@ function Base.copy!(dst::PVector, src::HYPREVector)
iu_src_part = o_to_g[end] iu_src_part = o_to_g[end]
nvalues = HYPRE_Int(iu_src_part - il_src_part + 1) nvalues = HYPRE_Int(iu_src_part - il_src_part + 1)
indices = collect(HYPRE_BigInt, il_src_part:iu_src_part) indices = collect(HYPRE_BigInt, il_src_part:iu_src_part)
values = collect(HYPRE_Complex, ov)
if subarray_unsafe_supported() if subarray_unsafe_supported()
values = ov values = ov
else else

Loading…
Cancel
Save