diff --git a/src/HYPRE.jl b/src/HYPRE.jl index 6b94f2d..f46f9ff 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -3,7 +3,7 @@ module HYPRE using MPI: MPI -using PartitionedArrays: own_length,tuple_of_arrays, own_to_global, global_length, +using PartitionedArrays: own_length, tuple_of_arrays, own_to_global, global_length, own_to_local, local_to_global, global_to_own, global_to_local, MPIArray, PSparseMatrix, PVector, PartitionedArrays, AbstractLocalIndices, local_values, own_values, partition @@ -336,6 +336,14 @@ end # PartitionedArrays.PSparseMatrix -> HYPREMatrix # ################################################## +function subarray_unsafe_supported() + # Wrapping of SubArrays as raw pointers may or may not be supported + # depending on the Julia version. If this is not supported, we have to fall + # back to allocation of an intermediate buffer. This logic can be removed if + # HYPRE.jl drops support for Julia < 1.9. + return @static Int(VERSION.minor) > 8 || Int(VERSION.major) > 1 +end + # TODO: This has some duplicated code with to_hypre_data(::SparseMatrixCSC, ilower, iupper) function Internals.to_hypre_data(A::SparseMatrixCSC, r::AbstractLocalIndices, c::AbstractLocalIndices) g_to_l_rows = global_to_local(r) # Not sure about this assert @@ -404,7 +412,6 @@ function Internals.to_hypre_data(A::SparseMatrixCSR, r::AbstractLocalIndices, c: @assert g_to_l_rows.own_to_local isa UnitRange && g_to_l_rows.own_to_local.start == 1 n_local_rows = own_length(r) - n_local_cols = own_length(c) ilower = l_to_g_rows[1] iupper = l_to_g_rows[n_local_rows] @@ -546,7 +553,16 @@ function Base.copy!(dst::PVector, src::HYPREVector) iu_src_part = o_to_g[end] nvalues = HYPRE_Int(iu_src_part - il_src_part + 1) indices = collect(HYPRE_BigInt, il_src_part:iu_src_part) - @check HYPRE_IJVectorGetValues(src, nvalues, indices, ov) + values = collect(HYPRE_Complex, ov) + if subarray_unsafe_supported() + values = ov + else + values = collect(HYPRE_Complex, ov) + end + @check HYPRE_IJVectorGetValues(src, nvalues, indices, values) + if !subarray_unsafe_supported() + @. ov = values + end end return dst end @@ -561,7 +577,12 @@ function Base.copy!(dst::HYPREVector, src::PVector) iupper_src_part = o_to_g[end] nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1) indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part) - @check HYPRE_IJVectorSetValues(dst, nvalues, indices, ov) + if subarray_unsafe_supported() + values = ov + else + values = collect(HYPRE_Complex, ov) + end + @check HYPRE_IJVectorSetValues(dst, nvalues, indices, values) end # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?) # @check HYPRE_IJVectorAssemble(dst)