From d5d4d7273b752932e100742a387af911c431fe67 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Fri, 22 Jul 2022 19:26:30 +0200 Subject: [PATCH] Misc cleanup and TODO notes. --- src/HYPRE.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/HYPRE.jl b/src/HYPRE.jl index de6af4d..bde2a59 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -199,6 +199,7 @@ function Internals.to_hypre_data(A::SparseMatrixCSR, ilower, iupper) return nrows, ncols, rows, cols, values end +# TODO: Default to ilower = 1, iupper = size(B, 1)? function HYPREMatrix(B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper, comm::MPI.Comm=MPI.COMM_WORLD) A = Internals.init_matrix(comm, ilower, iupper) nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) @@ -219,6 +220,7 @@ function Internals.to_hypre_data(x::Vector, ilower, iupper) end # TODO: Internals.to_hypre_data(x::SparseVector, ilower, iupper) (?) +# TODO: Default to ilower = 1, iupper = length(x)? function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD) b = Internals.init_vector(comm, ilower, iupper) nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper) @@ -382,7 +384,7 @@ function HYPREVector(v::PVector) # Create the IJ vector b = Internals.init_vector(comm, ilower, iupper) # Set all the values - map_parts(v.values, v.owned_values, v.rows.partition) do vv, vo, vr + map_parts(v.values, v.owned_values, v.rows.partition) do _, vo, vr ilower_part = vr.lid_to_gid[vr.oid_to_lid.start] iupper_part = vr.lid_to_gid[vr.oid_to_lid.stop] @@ -414,19 +416,16 @@ end function Base.copy!(v::PVector, h::HYPREVector) if eltype(v) !== HYPRE_Complex + # TODO: This could be allowed with a temporary buffer. throw(ArgumentError("mismatching element types")) end ilower_v, iupper_v = Internals.get_proc_rows(v) - ilower_h_ref = Ref{HYPRE_BigInt}(0) - iupper_h_ref = Ref{HYPRE_BigInt}(0) - @check HYPRE_IJVectorGetLocalRange(h.IJVector, ilower_h_ref, iupper_h_ref) - ilower_h = ilower_h_ref[] - iupper_h = iupper_h_ref[] + ilower_h, iupper_h = Internals.get_proc_rows(h) if ilower_v != ilower_h && iupper_v != iupper_h # TODO: Why require this? throw(ArgumentError("mismatch in owned rows")) end - map_parts(v.values, v.owned_values, v.rows.partition) do vv, vo, vr + map_parts(v.values, v.owned_values, v.rows.partition) do vv, _, vr ilower_v_part = vr.lid_to_gid[vr.oid_to_lid.start] iupper_v_part = vr.lid_to_gid[vr.oid_to_lid.stop] nvalues = HYPRE_Int(iupper_v_part - ilower_v_part + 1)