Browse Source

PArrays: Collect views for own_values if needed

Required for Julia < 1.9
pull/16/head
Olav Møyner 3 years ago
parent
commit
824dd52e50
  1. 29
      src/HYPRE.jl

29
src/HYPRE.jl

@ -3,7 +3,7 @@ @@ -3,7 +3,7 @@
module HYPRE
using MPI: MPI
using PartitionedArrays: own_length,tuple_of_arrays, own_to_global, global_length,
using PartitionedArrays: own_length, tuple_of_arrays, own_to_global, global_length,
own_to_local, local_to_global, global_to_own, global_to_local,
MPIArray, PSparseMatrix, PVector, PartitionedArrays, AbstractLocalIndices,
local_values, own_values, partition
@ -336,6 +336,14 @@ end @@ -336,6 +336,14 @@ end
# PartitionedArrays.PSparseMatrix -> HYPREMatrix #
##################################################
function subarray_unsafe_supported()
# Wrapping of SubArrays as raw pointers may or may not be supported
# depending on the Julia version. If this is not supported, we have to fall
# back to allocation of an intermediate buffer. This logic can be removed if
# HYPRE.jl drops support for Julia < 1.9.
return @static Int(VERSION.minor) > 8 || Int(VERSION.major) > 1
end
# TODO: This has some duplicated code with to_hypre_data(::SparseMatrixCSC, ilower, iupper)
function Internals.to_hypre_data(A::SparseMatrixCSC, r::AbstractLocalIndices, c::AbstractLocalIndices)
g_to_l_rows = global_to_local(r) # Not sure about this assert
@ -404,7 +412,6 @@ function Internals.to_hypre_data(A::SparseMatrixCSR, r::AbstractLocalIndices, c: @@ -404,7 +412,6 @@ function Internals.to_hypre_data(A::SparseMatrixCSR, r::AbstractLocalIndices, c:
@assert g_to_l_rows.own_to_local isa UnitRange && g_to_l_rows.own_to_local.start == 1
n_local_rows = own_length(r)
n_local_cols = own_length(c)
ilower = l_to_g_rows[1]
iupper = l_to_g_rows[n_local_rows]
@ -546,7 +553,16 @@ function Base.copy!(dst::PVector, src::HYPREVector) @@ -546,7 +553,16 @@ function Base.copy!(dst::PVector, src::HYPREVector)
iu_src_part = o_to_g[end]
nvalues = HYPRE_Int(iu_src_part - il_src_part + 1)
indices = collect(HYPRE_BigInt, il_src_part:iu_src_part)
@check HYPRE_IJVectorGetValues(src, nvalues, indices, ov)
values = collect(HYPRE_Complex, ov)
if subarray_unsafe_supported()
values = ov
else
values = collect(HYPRE_Complex, ov)
end
@check HYPRE_IJVectorGetValues(src, nvalues, indices, values)
if !subarray_unsafe_supported()
@. ov = values
end
end
return dst
end
@ -561,7 +577,12 @@ function Base.copy!(dst::HYPREVector, src::PVector) @@ -561,7 +577,12 @@ function Base.copy!(dst::HYPREVector, src::PVector)
iupper_src_part = o_to_g[end]
nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1)
indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part)
@check HYPRE_IJVectorSetValues(dst, nvalues, indices, ov)
if subarray_unsafe_supported()
values = ov
else
values = collect(HYPRE_Complex, ov)
end
@check HYPRE_IJVectorSetValues(dst, nvalues, indices, values)
end
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst)

Loading…
Cancel
Save