Browse Source

Misc cleanup and TODO notes.

fe/wip
Fredrik Ekre 3 years ago
parent
commit
d5d4d7273b
  1. 13
      src/HYPRE.jl

13
src/HYPRE.jl

@ -199,6 +199,7 @@ function Internals.to_hypre_data(A::SparseMatrixCSR, ilower, iupper)
return nrows, ncols, rows, cols, values return nrows, ncols, rows, cols, values
end end
# TODO: Default to ilower = 1, iupper = size(B, 1)?
function HYPREMatrix(B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper, comm::MPI.Comm=MPI.COMM_WORLD) function HYPREMatrix(B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper, comm::MPI.Comm=MPI.COMM_WORLD)
A = Internals.init_matrix(comm, ilower, iupper) A = Internals.init_matrix(comm, ilower, iupper)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper)
@ -219,6 +220,7 @@ function Internals.to_hypre_data(x::Vector, ilower, iupper)
end end
# TODO: Internals.to_hypre_data(x::SparseVector, ilower, iupper) (?) # TODO: Internals.to_hypre_data(x::SparseVector, ilower, iupper) (?)
# TODO: Default to ilower = 1, iupper = length(x)?
function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD) function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD)
b = Internals.init_vector(comm, ilower, iupper) b = Internals.init_vector(comm, ilower, iupper)
nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper) nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper)
@ -382,7 +384,7 @@ function HYPREVector(v::PVector)
# Create the IJ vector # Create the IJ vector
b = Internals.init_vector(comm, ilower, iupper) b = Internals.init_vector(comm, ilower, iupper)
# Set all the values # Set all the values
map_parts(v.values, v.owned_values, v.rows.partition) do vv, vo, vr map_parts(v.values, v.owned_values, v.rows.partition) do _, vo, vr
ilower_part = vr.lid_to_gid[vr.oid_to_lid.start] ilower_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_part = vr.lid_to_gid[vr.oid_to_lid.stop] iupper_part = vr.lid_to_gid[vr.oid_to_lid.stop]
@ -414,19 +416,16 @@ end
function Base.copy!(v::PVector, h::HYPREVector) function Base.copy!(v::PVector, h::HYPREVector)
if eltype(v) !== HYPRE_Complex if eltype(v) !== HYPRE_Complex
# TODO: This could be allowed with a temporary buffer.
throw(ArgumentError("mismatching element types")) throw(ArgumentError("mismatching element types"))
end end
ilower_v, iupper_v = Internals.get_proc_rows(v) ilower_v, iupper_v = Internals.get_proc_rows(v)
ilower_h_ref = Ref{HYPRE_BigInt}(0) ilower_h, iupper_h = Internals.get_proc_rows(h)
iupper_h_ref = Ref{HYPRE_BigInt}(0)
@check HYPRE_IJVectorGetLocalRange(h.IJVector, ilower_h_ref, iupper_h_ref)
ilower_h = ilower_h_ref[]
iupper_h = iupper_h_ref[]
if ilower_v != ilower_h && iupper_v != iupper_h if ilower_v != ilower_h && iupper_v != iupper_h
# TODO: Why require this? # TODO: Why require this?
throw(ArgumentError("mismatch in owned rows")) throw(ArgumentError("mismatch in owned rows"))
end end
map_parts(v.values, v.owned_values, v.rows.partition) do vv, vo, vr map_parts(v.values, v.owned_values, v.rows.partition) do vv, _, vr
ilower_v_part = vr.lid_to_gid[vr.oid_to_lid.start] ilower_v_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_v_part = vr.lid_to_gid[vr.oid_to_lid.stop] iupper_v_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_v_part - ilower_v_part + 1) nvalues = HYPRE_Int(iupper_v_part - ilower_v_part + 1)

Loading…
Cancel
Save