Browse Source

Define `unsafe_convert` methods for `HYPRE(Matrix|Vector|Solver)`

This patch adds methods for `Base.unsafe_convert` for `HYPREMatrix`,
`HYPREVector`, and `HYPRESolver`. This means that those objects can be
passed directly to `ccall` and be "converted" (i.e. extracting the
pointer that is stored in the structs) to the appropriate type expected
by the HYPRE C-library. The advantage is that `ccall` then guarantees
that the objects are kept alive for the duration of the call.
pull/13/head
Fredrik Ekre 3 years ago
parent
commit
cf9c815e85
  1. 16
      docs/src/libhypre.md
  2. 5
      gen/solver_options.jl
  3. 73
      src/HYPRE.jl
  4. 34
      src/solver_options.jl
  5. 48
      src/solvers.jl

16
docs/src/libhypre.md

@ -15,15 +15,13 @@ directly.
Functions from the `LibHYPRE` submodule can be used together with the high level interface. Functions from the `LibHYPRE` submodule can be used together with the high level interface.
This is useful when you need some functionality from the library which isn't exposed in the This is useful when you need some functionality from the library which isn't exposed in the
high level interface. Many functions require passing a reference to a matrix/vector or a high level interface. Many functions require passing a reference to a matrix/vector or a
solver. These can be obtained as follows: solver. HYPRE.jl defines the appropriate conversion methods used by `ccall` such that
- `A::HYPREMatrix` can be passed to `HYPRE_*` functions with `HYPRE_IJMatrix` or
| C type signature | Argument to pass | `HYPRE_ParCSRMatrix` in the signature
|:---------------------|:-------------------------------------| - `b::HYPREVector` can be passed to `HYPRE_*` functions with `HYPRE_IJVector` or
| `HYPRE_IJMatrix` | `A.ijmatrix` where `A::HYPREMatrix` | `HYPRE_ParVector` in the signature
| `HYPRE_ParCSRMatrix` | `A.parmatrix` where `A::HYPREMatrix` | - `s::HYPRESolver` can be passed to `HYPRE_*` functions with `HYPRE_Solver` in the
| `HYPRE_IJVector` | `b.ijvector` where `b::HYPREVector` | signature
| `HYPRE_ParVector` | `b.parvector` where `b::HYPREVector` |
| `HYPRE_Solver` | `s.solver` where `s::HYPRESolver` |
[^1]: Bindings are generated using [^1]: Bindings are generated using
[Clang.jl](https://github.com/JuliaInterop/Clang.jl), see [Clang.jl](https://github.com/JuliaInterop/Clang.jl), see

5
gen/solver_options.jl

@ -2,8 +2,7 @@ using HYPRE.LibHYPRE
function generate_options(io, structname, prefixes...) function generate_options(io, structname, prefixes...)
println(io, "") println(io, "")
println(io, "function Internals.set_options(s::$(structname), kwargs)") println(io, "function Internals.set_options(solver::$(structname), kwargs)")
println(io, " solver = s.solver")
println(io, " for (k, v) in kwargs") println(io, " for (k, v) in kwargs")
ns = Tuple{Symbol,String}[] ns = Tuple{Symbol,String}[]
@ -29,7 +28,7 @@ function generate_options(io, structname, prefixes...)
println(io) println(io)
if k == "Precond" if k == "Precond"
println(io, " Internals.set_precond_defaults(v)") println(io, " Internals.set_precond_defaults(v)")
println(io, " Internals.set_precond(s, v)") println(io, " Internals.set_precond(solver, v)")
elseif nargs == 1 elseif nargs == 1
println(io, " @check ", n, "(solver)") println(io, " @check ", n, "(solver)")
elseif nargs == 2 elseif nargs == 2

73
src/HYPRE.jl

@ -63,6 +63,12 @@ mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex}
parmatrix::HYPRE_ParCSRMatrix parmatrix::HYPRE_ParCSRMatrix
end end
# Defining unsafe_convert enables ccall to automatically convert A::HYPREMatrix to
# HYPRE_IJMatrix and HYPRE_ParCSRMatrix while also making sure A won't be GC'd and
# finalized.
Base.unsafe_convert(::Type{HYPRE_IJMatrix}, A::HYPREMatrix) = A.ijmatrix
Base.unsafe_convert(::Type{HYPRE_ParCSRMatrix}, A::HYPREMatrix) = A.parmatrix
function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer,
jlower::Integer=ilower, jupper::Integer=iupper) jlower::Integer=ilower, jupper::Integer=iupper)
# Create the IJ matrix # Create the IJ matrix
@ -73,15 +79,15 @@ function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer,
# Attach a finalizer # Attach a finalizer
finalizer(A) do x finalizer(A) do x
if x.ijmatrix != C_NULL if x.ijmatrix != C_NULL
HYPRE_IJMatrixDestroy(x.ijmatrix) HYPRE_IJMatrixDestroy(x)
x.ijmatrix = x.parmatrix = C_NULL x.ijmatrix = x.parmatrix = C_NULL
end end
end end
push!(Internals.HYPRE_OBJECTS, A => nothing) push!(Internals.HYPRE_OBJECTS, A => nothing)
# Set storage type # Set storage type
@check HYPRE_IJMatrixSetObjectType(A.ijmatrix, HYPRE_PARCSR) @check HYPRE_IJMatrixSetObjectType(A, HYPRE_PARCSR)
# Initialize to make ready for setting values # Initialize to make ready for setting values
@check HYPRE_IJMatrixInitialize(A.ijmatrix) @check HYPRE_IJMatrixInitialize(A)
return A return A
end end
@ -89,10 +95,10 @@ end
# This should be called after setting all the values # This should be called after setting all the values
function Internals.assemble_matrix(A::HYPREMatrix) function Internals.assemble_matrix(A::HYPREMatrix)
# Finalize after setting all values # Finalize after setting all values
@check HYPRE_IJMatrixAssemble(A.ijmatrix) @check HYPRE_IJMatrixAssemble(A)
# Fetch the assembled CSR matrix # Fetch the assembled CSR matrix
parmatrix_ref = Ref{Ptr{Cvoid}}(C_NULL) parmatrix_ref = Ref{Ptr{Cvoid}}(C_NULL)
@check HYPRE_IJMatrixGetObject(A.ijmatrix, parmatrix_ref) @check HYPRE_IJMatrixGetObject(A, parmatrix_ref)
A.parmatrix = convert(Ptr{HYPRE_ParCSRMatrix}, parmatrix_ref[]) A.parmatrix = convert(Ptr{HYPRE_ParCSRMatrix}, parmatrix_ref[])
return A return A
end end
@ -109,6 +115,11 @@ mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex}
parvector::HYPRE_ParVector parvector::HYPRE_ParVector
end end
# Defining unsafe_convert enables ccall to automatically convert b::HYPREVector to
# HYPRE_IJVector and HYPRE_ParVector while also making sure b won't be GC'd and finalized.
Base.unsafe_convert(::Type{HYPRE_IJVector}, b::HYPREVector) = b.ijvector
Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector
function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer)
# Create the IJ vector # Create the IJ vector
b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL) b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL)
@ -118,24 +129,24 @@ function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer)
# Attach a finalizer # Attach a finalizer
finalizer(b) do x finalizer(b) do x
if x.ijvector != C_NULL if x.ijvector != C_NULL
HYPRE_IJVectorDestroy(x.ijvector) HYPRE_IJVectorDestroy(x)
x.ijvector = x.parvector = C_NULL x.ijvector = x.parvector = C_NULL
end end
end end
push!(Internals.HYPRE_OBJECTS, b => nothing) push!(Internals.HYPRE_OBJECTS, b => nothing)
# Set storage type # Set storage type
@check HYPRE_IJVectorSetObjectType(b.ijvector, HYPRE_PARCSR) @check HYPRE_IJVectorSetObjectType(b, HYPRE_PARCSR)
# Initialize to make ready for setting values # Initialize to make ready for setting values
@check HYPRE_IJVectorInitialize(b.ijvector) @check HYPRE_IJVectorInitialize(b)
return b return b
end end
function Internals.assemble_vector(b::HYPREVector) function Internals.assemble_vector(b::HYPREVector)
# Finalize after setting all values # Finalize after setting all values
@check HYPRE_IJVectorAssemble(b.ijvector) @check HYPRE_IJVectorAssemble(b)
# Fetch the assembled vector # Fetch the assembled vector
parvector_ref = Ref{Ptr{Cvoid}}(C_NULL) parvector_ref = Ref{Ptr{Cvoid}}(C_NULL)
@check HYPRE_IJVectorGetObject(b.ijvector, parvector_ref) @check HYPRE_IJVectorGetObject(b, parvector_ref)
b.parvector = convert(Ptr{HYPRE_ParVector}, parvector_ref[]) b.parvector = convert(Ptr{HYPRE_ParVector}, parvector_ref[])
return b return b
end end
@ -143,7 +154,7 @@ end
function Internals.get_proc_rows(b::HYPREVector) function Internals.get_proc_rows(b::HYPREVector)
# ilower_ref = Ref{HYPRE_BigInt}() # ilower_ref = Ref{HYPRE_BigInt}()
# iupper_ref = Ref{HYPRE_BigInt}() # iupper_ref = Ref{HYPRE_BigInt}()
# @check HYPRE_IJVectorGetLocalRange(b.ijvector, ilower_ref, iupper_ref) # @check HYPRE_IJVectorGetLocalRange(b, ilower_ref, iupper_ref)
# ilower = ilower_ref[] # ilower = ilower_ref[]
# iupper = iupper_ref[] # iupper = iupper_ref[]
# return ilower, iupper # return ilower, iupper
@ -169,7 +180,7 @@ function Base.zero(b::HYPREVector)
nvalues = jupper - jlower + 1 nvalues = jupper - jlower + 1
indices = collect(HYPRE_BigInt, jlower:jupper) indices = collect(HYPRE_BigInt, jlower:jupper)
values = zeros(HYPRE_Complex, nvalues) values = zeros(HYPRE_Complex, nvalues)
@check HYPRE_IJVectorSetValues(x.ijvector, nvalues, indices, values) @check HYPRE_IJVectorSetValues(x, nvalues, indices, values)
# Finalize and return # Finalize and return
Internals.assemble_vector(x) Internals.assemble_vector(x)
return x return x
@ -258,7 +269,7 @@ end
function HYPREMatrix(comm::MPI.Comm, B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper) function HYPREMatrix(comm::MPI.Comm, B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper)
A = HYPREMatrix(comm, ilower, iupper) A = HYPREMatrix(comm, ilower, iupper)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper)
@check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values) @check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values)
Internals.assemble_matrix(A) Internals.assemble_matrix(A)
return A return A
end end
@ -281,7 +292,7 @@ end
function HYPREVector(comm::MPI.Comm, x::Vector, ilower, iupper) function HYPREVector(comm::MPI.Comm, x::Vector, ilower, iupper)
b = HYPREVector(comm, ilower, iupper) b = HYPREVector(comm, ilower, iupper)
nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper) nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) @check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
Internals.assemble_vector(b) Internals.assemble_vector(b)
return b return b
end end
@ -297,7 +308,7 @@ function Base.copy!(dst::Vector{HYPRE_Complex}, src::HYPREVector)
throw(ArgumentError("length of dst and src does not match")) throw(ArgumentError("length of dst and src does not match"))
end end
indices = collect(HYPRE_BigInt, ilower:iupper) indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, dst) @check HYPRE_IJVectorGetValues(src, nvalues, indices, dst)
return dst return dst
end end
@ -308,12 +319,12 @@ function Base.copy!(dst::HYPREVector, src::Vector{HYPRE_Complex})
throw(ArgumentError("length of dst and src does not match")) throw(ArgumentError("length of dst and src does not match"))
end end
# Re-initialize the vector # Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector) @check HYPRE_IJVectorInitialize(dst)
# Set all the values # Set all the values
indices = collect(HYPRE_BigInt, ilower:iupper) indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, src) @check HYPRE_IJVectorSetValues(dst, nvalues, indices, src)
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?) # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector) # @check HYPRE_IJVectorAssemble(dst)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is # TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed. # not needed.
return dst return dst
@ -445,7 +456,7 @@ function HYPREMatrix(B::PSparseMatrix)
# Set all the values # Set all the values
map_parts(B.values, B.rows.partition, B.cols.partition) do Bv, Br, Bc map_parts(B.values, B.rows.partition, B.cols.partition) do Bv, Br, Bc
nrows, ncols, rows, cols, values = Internals.to_hypre_data(Bv, Br, Bc) nrows, ncols, rows, cols, values = Internals.to_hypre_data(Bv, Br, Bc)
@check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values) @check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values)
return nothing return nothing
end end
# Finalize # Finalize
@ -487,7 +498,7 @@ function HYPREVector(v::PVector)
# end # end
# nvalues = length(indices) # nvalues = length(indices)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) @check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
return nothing return nothing
end end
# Finalize # Finalize
@ -521,7 +532,7 @@ function Base.copy!(dst::PVector{HYPRE_Complex}, src::HYPREVector)
fill!(vv, 0) fill!(vv, 0)
# TODO: Safe to use vv here? Owned values are always first? # TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, vv) @check HYPRE_IJVectorGetValues(src, nvalues, indices, vv)
end end
return dst return dst
end end
@ -529,17 +540,17 @@ end
function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex}) function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex})
Internals.copy_check(dst, src) Internals.copy_check(dst, src)
# Re-initialize the vector # Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector) @check HYPRE_IJVectorInitialize(dst)
map_parts(src.values, src.owned_values, src.rows.partition) do vv, _, vr map_parts(src.values, src.owned_values, src.rows.partition) do vv, _, vr
ilower_src_part = vr.lid_to_gid[vr.oid_to_lid.start] ilower_src_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_src_part = vr.lid_to_gid[vr.oid_to_lid.stop] iupper_src_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1) nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1)
indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part) indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part)
# TODO: Safe to use vv here? Owned values are always first? # TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, vv) @check HYPRE_IJVectorSetValues(dst, nvalues, indices, vv)
end end
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?) # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector) # @check HYPRE_IJVectorAssemble(dst)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is # TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed. # not needed.
return dst return dst
@ -585,9 +596,9 @@ start_assemble!
function start_assemble!(A::HYPREMatrix) function start_assemble!(A::HYPREMatrix)
if A.parmatrix != C_NULL if A.parmatrix != C_NULL
# This matrix have been assembled before, reset to 0 # This matrix have been assembled before, reset to 0
@check HYPRE_IJMatrixSetConstantValues(A.ijmatrix, 0) @check HYPRE_IJMatrixSetConstantValues(A, 0)
end end
@check HYPRE_IJMatrixInitialize(A.ijmatrix) @check HYPRE_IJMatrixInitialize(A)
return HYPREMatrixAssembler(A, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[]) return HYPREMatrixAssembler(A, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[])
end end
@ -595,14 +606,14 @@ function start_assemble!(b::HYPREVector)
if b.parvector != C_NULL if b.parvector != C_NULL
# This vector have been assembled before, reset to 0 # This vector have been assembled before, reset to 0
# See https://github.com/hypre-space/hypre/pull/689 # See https://github.com/hypre-space/hypre/pull/689
# @check HYPRE_IJVectorSetConstantValues(b.ijvector, 0) # @check HYPRE_IJVectorSetConstantValues(b, 0)
end end
@check HYPRE_IJVectorInitialize(b.ijvector) @check HYPRE_IJVectorInitialize(b)
if b.parvector != C_NULL if b.parvector != C_NULL
nvalues = HYPRE_Int(b.iupper - b.ilower + 1) nvalues = HYPRE_Int(b.iupper - b.ilower + 1)
indices = collect(HYPRE_BigInt, b.ilower:b.iupper) indices = collect(HYPRE_BigInt, b.ilower:b.iupper)
values = zeros(HYPRE_Complex, nvalues) values = zeros(HYPRE_Complex, nvalues)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) @check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
# TODO: Do I need to assemble here? # TODO: Do I need to assemble here?
end end
return HYPREVectorAssembler(b, HYPRE_BigInt[], HYPRE_Complex[]) return HYPREVectorAssembler(b, HYPRE_BigInt[], HYPRE_Complex[])
@ -635,14 +646,14 @@ assemble!
function assemble!(A::HYPREMatrixAssembler, i::Vector, j::Vector, a::Matrix) function assemble!(A::HYPREMatrixAssembler, i::Vector, j::Vector, a::Matrix)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, i, j) nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, i, j)
@check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values) @check HYPRE_IJMatrixAddToValues(A.A, nrows, ncols, rows, cols, values)
return A return A
end end
@deprecate assemble!(A::HYPREMatrixAssembler, ij::Vector, a::Matrix) assemble!(A, ij, ij, a) false @deprecate assemble!(A::HYPREMatrixAssembler, ij::Vector, a::Matrix) assemble!(A, ij, ij, a) false
function assemble!(A::HYPREVectorAssembler, ij::Vector, a::Vector) function assemble!(A::HYPREVectorAssembler, ij::Vector, a::Vector)
nvalues, indices, values = Internals.to_hypre_data(A, a, ij) nvalues, indices, values = Internals.to_hypre_data(A, a, ij)
@check HYPRE_IJVectorAddToValues(A.b.ijvector, nvalues, indices, values) @check HYPRE_IJVectorAddToValues(A.b, nvalues, indices, values)
return A return A
end end

34
src/solver_options.jl

@ -4,8 +4,7 @@
Internals.set_options(::HYPRESolver, kwargs) = nothing Internals.set_options(::HYPRESolver, kwargs) = nothing
function Internals.set_options(s::BiCGSTAB, kwargs) function Internals.set_options(solver::BiCGSTAB, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :ConvergenceFactorTol if k === :ConvergenceFactorTol
@check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v) @check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v)
@ -19,7 +18,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs)
@check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v) @check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v)
elseif k === :Precond elseif k === :Precond
Internals.set_precond_defaults(v) Internals.set_precond_defaults(v)
Internals.set_precond(s, v) Internals.set_precond(solver, v)
elseif k === :PrintLevel elseif k === :PrintLevel
@check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v) @check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v)
elseif k === :StopCrit elseif k === :StopCrit
@ -32,8 +31,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs)
end end
end end
function Internals.set_options(s::BoomerAMG, kwargs) function Internals.set_options(solver::BoomerAMG, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :ADropTol if k === :ADropTol
@check HYPRE_BoomerAMGSetADropTol(solver, v) @check HYPRE_BoomerAMGSetADropTol(solver, v)
@ -289,8 +287,7 @@ function Internals.set_options(s::BoomerAMG, kwargs)
end end
end end
function Internals.set_options(s::FlexGMRES, kwargs) function Internals.set_options(solver::FlexGMRES, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :ConvergenceFactorTol if k === :ConvergenceFactorTol
@check HYPRE_FlexGMRESSetConvergenceFactorTol(solver, v) @check HYPRE_FlexGMRESSetConvergenceFactorTol(solver, v)
@ -308,7 +305,7 @@ function Internals.set_options(s::FlexGMRES, kwargs)
@check HYPRE_ParCSRFlexGMRESSetModifyPC(solver, v) @check HYPRE_ParCSRFlexGMRESSetModifyPC(solver, v)
elseif k === :Precond elseif k === :Precond
Internals.set_precond_defaults(v) Internals.set_precond_defaults(v)
Internals.set_precond(s, v) Internals.set_precond(solver, v)
elseif k === :PrintLevel elseif k === :PrintLevel
@check HYPRE_ParCSRFlexGMRESSetPrintLevel(solver, v) @check HYPRE_ParCSRFlexGMRESSetPrintLevel(solver, v)
elseif k === :Tol elseif k === :Tol
@ -319,8 +316,7 @@ function Internals.set_options(s::FlexGMRES, kwargs)
end end
end end
function Internals.set_options(s::GMRES, kwargs) function Internals.set_options(solver::GMRES, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :ConvergenceFactorTol if k === :ConvergenceFactorTol
@check HYPRE_GMRESSetConvergenceFactorTol(solver, v) @check HYPRE_GMRESSetConvergenceFactorTol(solver, v)
@ -340,7 +336,7 @@ function Internals.set_options(s::GMRES, kwargs)
@check HYPRE_ParCSRGMRESSetMinIter(solver, v) @check HYPRE_ParCSRGMRESSetMinIter(solver, v)
elseif k === :Precond elseif k === :Precond
Internals.set_precond_defaults(v) Internals.set_precond_defaults(v)
Internals.set_precond(s, v) Internals.set_precond(solver, v)
elseif k === :PrintLevel elseif k === :PrintLevel
@check HYPRE_ParCSRGMRESSetPrintLevel(solver, v) @check HYPRE_ParCSRGMRESSetPrintLevel(solver, v)
elseif k === :StopCrit elseif k === :StopCrit
@ -353,8 +349,7 @@ function Internals.set_options(s::GMRES, kwargs)
end end
end end
function Internals.set_options(s::Hybrid, kwargs) function Internals.set_options(solver::Hybrid, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :AbsoluteTol if k === :AbsoluteTol
@check HYPRE_ParCSRHybridSetAbsoluteTol(solver, v) @check HYPRE_ParCSRHybridSetAbsoluteTol(solver, v)
@ -424,7 +419,7 @@ function Internals.set_options(s::Hybrid, kwargs)
@check HYPRE_ParCSRHybridSetPMaxElmts(solver, v) @check HYPRE_ParCSRHybridSetPMaxElmts(solver, v)
elseif k === :Precond elseif k === :Precond
Internals.set_precond_defaults(v) Internals.set_precond_defaults(v)
Internals.set_precond(s, v) Internals.set_precond(solver, v)
elseif k === :PrintLevel elseif k === :PrintLevel
@check HYPRE_ParCSRHybridSetPrintLevel(solver, v) @check HYPRE_ParCSRHybridSetPrintLevel(solver, v)
elseif k === :RecomputeResidual elseif k === :RecomputeResidual
@ -463,8 +458,7 @@ function Internals.set_options(s::Hybrid, kwargs)
end end
end end
function Internals.set_options(s::ILU, kwargs) function Internals.set_options(solver::ILU, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :DropThreshold if k === :DropThreshold
@check HYPRE_ILUSetDropThreshold(solver, v) @check HYPRE_ILUSetDropThreshold(solver, v)
@ -498,8 +492,7 @@ function Internals.set_options(s::ILU, kwargs)
end end
end end
function Internals.set_options(s::ParaSails, kwargs) function Internals.set_options(solver::ParaSails, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :Filter if k === :Filter
@check HYPRE_ParCSRParaSailsSetFilter(solver, v) @check HYPRE_ParCSRParaSailsSetFilter(solver, v)
@ -519,8 +512,7 @@ function Internals.set_options(s::ParaSails, kwargs)
end end
end end
function Internals.set_options(s::PCG, kwargs) function Internals.set_options(solver::PCG, kwargs)
solver = s.solver
for (k, v) in kwargs for (k, v) in kwargs
if k === :AbsoluteTolFactor if k === :AbsoluteTolFactor
@check HYPRE_PCGSetAbsoluteTolFactor(solver, v) @check HYPRE_PCGSetAbsoluteTolFactor(solver, v)
@ -540,7 +532,7 @@ function Internals.set_options(s::PCG, kwargs)
@check HYPRE_ParCSRPCGSetMaxIter(solver, v) @check HYPRE_ParCSRPCGSetMaxIter(solver, v)
elseif k === :Precond elseif k === :Precond
Internals.set_precond_defaults(v) Internals.set_precond_defaults(v)
Internals.set_precond(s, v) Internals.set_precond(solver, v)
elseif k === :PrintLevel elseif k === :PrintLevel
@check HYPRE_ParCSRPCGSetPrintLevel(solver, v) @check HYPRE_ParCSRPCGSetPrintLevel(solver, v)
elseif k === :RelChange elseif k === :RelChange

48
src/solvers.jl

@ -13,12 +13,16 @@ function Internals.safe_finalizer(Destroy, solver)
# Add a finalizer that only calls Destroy if pointer not C_NULL # Add a finalizer that only calls Destroy if pointer not C_NULL
finalizer(solver) do s finalizer(solver) do s
if s.solver != C_NULL if s.solver != C_NULL
Destroy(s.solver) Destroy(s)
s.solver = C_NULL s.solver = C_NULL
end end
end end
end end
# Defining unsafe_convert enables ccall to automatically convert solver::HYPRESolver to
# HYPRE_Solver while also making sure solver won't be GC'd and finalized.
Base.unsafe_convert(::Type{HYPRE_Solver}, solver::HYPRESolver) = solver.solver
# Fallback for the solvers that doesn't have required defaults # Fallback for the solvers that doesn't have required defaults
Internals.set_precond_defaults(::HYPRESolver) = nothing Internals.set_precond_defaults(::HYPRESolver) = nothing
@ -122,8 +126,8 @@ end
const ParCSRBiCGSTAB = BiCGSTAB const ParCSRBiCGSTAB = BiCGSTAB
function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRBiCGSTABSetup(bicg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRBiCGSTABSetup(bicg, A, b, x)
@check HYPRE_ParCSRBiCGSTABSolve(bicg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRBiCGSTABSolve(bicg, A, b, x)
return x return x
end end
@ -134,7 +138,7 @@ function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver)
bicg.precond = p bicg.precond = p
solve_f = Internals.solve_func(p) solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p) setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRBiCGSTABSetPrecond(bicg.solver, solve_f, setup_f, p.solver) @check HYPRE_ParCSRBiCGSTABSetPrecond(bicg, solve_f, setup_f, p)
return nothing return nothing
end end
@ -169,8 +173,8 @@ mutable struct BoomerAMG <: HYPRESolver
end end
function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_BoomerAMGSetup(amg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_BoomerAMGSetup(amg, A, b, x)
@check HYPRE_BoomerAMGSolve(amg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_BoomerAMGSolve(amg, A, b, x)
return x return x
end end
@ -215,8 +219,8 @@ mutable struct FlexGMRES <: HYPRESolver
end end
function solve!(flex::FlexGMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(flex::FlexGMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRFlexGMRESSetup(flex.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRFlexGMRESSetup(flex, A, b, x)
@check HYPRE_ParCSRFlexGMRESSolve(flex.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRFlexGMRESSolve(flex, A, b, x)
return x return x
end end
@ -227,7 +231,7 @@ function Internals.set_precond(flex::FlexGMRES, p::HYPRESolver)
flex.precond = p flex.precond = p
solve_f = Internals.solve_func(p) solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p) setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRFlexGMRESSetPrecond(flex.solver, solve_f, setup_f, p.solver) @check HYPRE_ParCSRFlexGMRESSetPrecond(flex, solve_f, setup_f, p)
return nothing return nothing
end end
@ -254,8 +258,8 @@ end
#end #end
#function solve!(fsai::FSAI, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) #function solve!(fsai::FSAI, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
# @check HYPRE_FSAISetup(fsai.solver, A.parmatrix, b.parvector, x.parvector) # @check HYPRE_FSAISetup(fsai, A, b, x)
# @check HYPRE_FSAISolve(fsai.solver, A.parmatrix, b.parvector, x.parvector) # @check HYPRE_FSAISolve(fsai, A, b, x)
# return x # return x
#end #end
@ -300,8 +304,8 @@ mutable struct GMRES <: HYPRESolver
end end
function solve!(gmres::GMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(gmres::GMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRGMRESSetup(gmres.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRGMRESSetup(gmres, A, b, x)
@check HYPRE_ParCSRGMRESSolve(gmres.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRGMRESSolve(gmres, A, b, x)
return x return x
end end
@ -312,7 +316,7 @@ function Internals.set_precond(gmres::GMRES, p::HYPRESolver)
gmres.precond = p gmres.precond = p
solve_f = Internals.solve_func(p) solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p) setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRGMRESSetPrecond(gmres.solver, solve_f, setup_f, p.solver) @check HYPRE_ParCSRGMRESSetPrecond(gmres, solve_f, setup_f, p)
return nothing return nothing
end end
@ -347,8 +351,8 @@ mutable struct Hybrid <: HYPRESolver
end end
function solve!(hybrid::Hybrid, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(hybrid::Hybrid, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRHybridSetup(hybrid.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRHybridSetup(hybrid, A, b, x)
@check HYPRE_ParCSRHybridSolve(hybrid.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRHybridSolve(hybrid, A, b, x)
return x return x
end end
@ -362,7 +366,7 @@ function Internals.set_precond(hybrid::Hybrid, p::HYPRESolver)
# Deactivate the finalizer of p since the HYBRIDDestroy function does this, # Deactivate the finalizer of p since the HYBRIDDestroy function does this,
# see https://github.com/hypre-space/hypre/issues/699 # see https://github.com/hypre-space/hypre/issues/699
finalizer(x -> (x.solver = C_NULL), p) finalizer(x -> (x.solver = C_NULL), p)
@check HYPRE_ParCSRHybridSetPrecond(hybrid.solver, solve_f, setup_f, p.solver) @check HYPRE_ParCSRHybridSetPrecond(hybrid, solve_f, setup_f, p)
return nothing return nothing
end end
@ -397,8 +401,8 @@ mutable struct ILU <: HYPRESolver
end end
function solve!(ilu::ILU, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(ilu::ILU, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ILUSetup(ilu.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ILUSetup(ilu, A, b, x)
@check HYPRE_ILUSolve(ilu.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ILUSolve(ilu, A, b, x)
return x return x
end end
@ -482,8 +486,8 @@ end
const ParCSRPCG = PCG const ParCSRPCG = PCG
function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRPCGSetup(pcg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRPCGSetup(pcg, A, b, x)
@check HYPRE_ParCSRPCGSolve(pcg.solver, A.parmatrix, b.parvector, x.parvector) @check HYPRE_ParCSRPCGSolve(pcg, A, b, x)
return x return x
end end
@ -494,6 +498,6 @@ function Internals.set_precond(pcg::PCG, p::HYPRESolver)
pcg.precond = p pcg.precond = p
solve_f = Internals.solve_func(p) solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p) setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRPCGSetPrecond(pcg.solver, solve_f, setup_f, p.solver) @check HYPRE_ParCSRPCGSetPrecond(pcg, solve_f, setup_f, p)
return nothing return nothing
end end

Loading…
Cancel
Save