Browse Source

Define `unsafe_convert` methods for `HYPRE(Matrix|Vector|Solver)`

This patch adds methods for `Base.unsafe_convert` for `HYPREMatrix`,
`HYPREVector`, and `HYPRESolver`. This means that those objects can be
passed directly to `ccall` and be "converted" (i.e. extracting the
pointer that is stored in the structs) to the appropriate type expected
by the HYPRE C-library. The advantage is that `ccall` then guarantees
that the objects are kept alive for the duration of the call.
pull/13/head
Fredrik Ekre 3 years ago
parent
commit
cf9c815e85
  1. 16
      docs/src/libhypre.md
  2. 5
      gen/solver_options.jl
  3. 73
      src/HYPRE.jl
  4. 34
      src/solver_options.jl
  5. 48
      src/solvers.jl

16
docs/src/libhypre.md

@ -15,15 +15,13 @@ directly. @@ -15,15 +15,13 @@ directly.
Functions from the `LibHYPRE` submodule can be used together with the high level interface.
This is useful when you need some functionality from the library which isn't exposed in the
high level interface. Many functions require passing a reference to a matrix/vector or a
solver. These can be obtained as follows:
| C type signature | Argument to pass |
|:---------------------|:-------------------------------------|
| `HYPRE_IJMatrix` | `A.ijmatrix` where `A::HYPREMatrix` |
| `HYPRE_ParCSRMatrix` | `A.parmatrix` where `A::HYPREMatrix` |
| `HYPRE_IJVector` | `b.ijvector` where `b::HYPREVector` |
| `HYPRE_ParVector` | `b.parvector` where `b::HYPREVector` |
| `HYPRE_Solver` | `s.solver` where `s::HYPRESolver` |
solver. HYPRE.jl defines the appropriate conversion methods used by `ccall` such that
- `A::HYPREMatrix` can be passed to `HYPRE_*` functions with `HYPRE_IJMatrix` or
`HYPRE_ParCSRMatrix` in the signature
- `b::HYPREVector` can be passed to `HYPRE_*` functions with `HYPRE_IJVector` or
`HYPRE_ParVector` in the signature
- `s::HYPRESolver` can be passed to `HYPRE_*` functions with `HYPRE_Solver` in the
signature
[^1]: Bindings are generated using
[Clang.jl](https://github.com/JuliaInterop/Clang.jl), see

5
gen/solver_options.jl

@ -2,8 +2,7 @@ using HYPRE.LibHYPRE @@ -2,8 +2,7 @@ using HYPRE.LibHYPRE
function generate_options(io, structname, prefixes...)
println(io, "")
println(io, "function Internals.set_options(s::$(structname), kwargs)")
println(io, " solver = s.solver")
println(io, "function Internals.set_options(solver::$(structname), kwargs)")
println(io, " for (k, v) in kwargs")
ns = Tuple{Symbol,String}[]
@ -29,7 +28,7 @@ function generate_options(io, structname, prefixes...) @@ -29,7 +28,7 @@ function generate_options(io, structname, prefixes...)
println(io)
if k == "Precond"
println(io, " Internals.set_precond_defaults(v)")
println(io, " Internals.set_precond(s, v)")
println(io, " Internals.set_precond(solver, v)")
elseif nargs == 1
println(io, " @check ", n, "(solver)")
elseif nargs == 2

73
src/HYPRE.jl

@ -63,6 +63,12 @@ mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex} @@ -63,6 +63,12 @@ mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex}
parmatrix::HYPRE_ParCSRMatrix
end
# Defining unsafe_convert enables ccall to automatically convert A::HYPREMatrix to
# HYPRE_IJMatrix and HYPRE_ParCSRMatrix while also making sure A won't be GC'd and
# finalized.
Base.unsafe_convert(::Type{HYPRE_IJMatrix}, A::HYPREMatrix) = A.ijmatrix
Base.unsafe_convert(::Type{HYPRE_ParCSRMatrix}, A::HYPREMatrix) = A.parmatrix
function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer,
jlower::Integer=ilower, jupper::Integer=iupper)
# Create the IJ matrix
@ -73,15 +79,15 @@ function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, @@ -73,15 +79,15 @@ function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer,
# Attach a finalizer
finalizer(A) do x
if x.ijmatrix != C_NULL
HYPRE_IJMatrixDestroy(x.ijmatrix)
HYPRE_IJMatrixDestroy(x)
x.ijmatrix = x.parmatrix = C_NULL
end
end
push!(Internals.HYPRE_OBJECTS, A => nothing)
# Set storage type
@check HYPRE_IJMatrixSetObjectType(A.ijmatrix, HYPRE_PARCSR)
@check HYPRE_IJMatrixSetObjectType(A, HYPRE_PARCSR)
# Initialize to make ready for setting values
@check HYPRE_IJMatrixInitialize(A.ijmatrix)
@check HYPRE_IJMatrixInitialize(A)
return A
end
@ -89,10 +95,10 @@ end @@ -89,10 +95,10 @@ end
# This should be called after setting all the values
function Internals.assemble_matrix(A::HYPREMatrix)
# Finalize after setting all values
@check HYPRE_IJMatrixAssemble(A.ijmatrix)
@check HYPRE_IJMatrixAssemble(A)
# Fetch the assembled CSR matrix
parmatrix_ref = Ref{Ptr{Cvoid}}(C_NULL)
@check HYPRE_IJMatrixGetObject(A.ijmatrix, parmatrix_ref)
@check HYPRE_IJMatrixGetObject(A, parmatrix_ref)
A.parmatrix = convert(Ptr{HYPRE_ParCSRMatrix}, parmatrix_ref[])
return A
end
@ -109,6 +115,11 @@ mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex} @@ -109,6 +115,11 @@ mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex}
parvector::HYPRE_ParVector
end
# Defining unsafe_convert enables ccall to automatically convert b::HYPREVector to
# HYPRE_IJVector and HYPRE_ParVector while also making sure b won't be GC'd and finalized.
Base.unsafe_convert(::Type{HYPRE_IJVector}, b::HYPREVector) = b.ijvector
Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector
function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer)
# Create the IJ vector
b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL)
@ -118,24 +129,24 @@ function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) @@ -118,24 +129,24 @@ function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer)
# Attach a finalizer
finalizer(b) do x
if x.ijvector != C_NULL
HYPRE_IJVectorDestroy(x.ijvector)
HYPRE_IJVectorDestroy(x)
x.ijvector = x.parvector = C_NULL
end
end
push!(Internals.HYPRE_OBJECTS, b => nothing)
# Set storage type
@check HYPRE_IJVectorSetObjectType(b.ijvector, HYPRE_PARCSR)
@check HYPRE_IJVectorSetObjectType(b, HYPRE_PARCSR)
# Initialize to make ready for setting values
@check HYPRE_IJVectorInitialize(b.ijvector)
@check HYPRE_IJVectorInitialize(b)
return b
end
function Internals.assemble_vector(b::HYPREVector)
# Finalize after setting all values
@check HYPRE_IJVectorAssemble(b.ijvector)
@check HYPRE_IJVectorAssemble(b)
# Fetch the assembled vector
parvector_ref = Ref{Ptr{Cvoid}}(C_NULL)
@check HYPRE_IJVectorGetObject(b.ijvector, parvector_ref)
@check HYPRE_IJVectorGetObject(b, parvector_ref)
b.parvector = convert(Ptr{HYPRE_ParVector}, parvector_ref[])
return b
end
@ -143,7 +154,7 @@ end @@ -143,7 +154,7 @@ end
function Internals.get_proc_rows(b::HYPREVector)
# ilower_ref = Ref{HYPRE_BigInt}()
# iupper_ref = Ref{HYPRE_BigInt}()
# @check HYPRE_IJVectorGetLocalRange(b.ijvector, ilower_ref, iupper_ref)
# @check HYPRE_IJVectorGetLocalRange(b, ilower_ref, iupper_ref)
# ilower = ilower_ref[]
# iupper = iupper_ref[]
# return ilower, iupper
@ -169,7 +180,7 @@ function Base.zero(b::HYPREVector) @@ -169,7 +180,7 @@ function Base.zero(b::HYPREVector)
nvalues = jupper - jlower + 1
indices = collect(HYPRE_BigInt, jlower:jupper)
values = zeros(HYPRE_Complex, nvalues)
@check HYPRE_IJVectorSetValues(x.ijvector, nvalues, indices, values)
@check HYPRE_IJVectorSetValues(x, nvalues, indices, values)
# Finalize and return
Internals.assemble_vector(x)
return x
@ -258,7 +269,7 @@ end @@ -258,7 +269,7 @@ end
function HYPREMatrix(comm::MPI.Comm, B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper)
A = HYPREMatrix(comm, ilower, iupper)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper)
@check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values)
@check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values)
Internals.assemble_matrix(A)
return A
end
@ -281,7 +292,7 @@ end @@ -281,7 +292,7 @@ end
function HYPREVector(comm::MPI.Comm, x::Vector, ilower, iupper)
b = HYPREVector(comm, ilower, iupper)
nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values)
@check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
Internals.assemble_vector(b)
return b
end
@ -297,7 +308,7 @@ function Base.copy!(dst::Vector{HYPRE_Complex}, src::HYPREVector) @@ -297,7 +308,7 @@ function Base.copy!(dst::Vector{HYPRE_Complex}, src::HYPREVector)
throw(ArgumentError("length of dst and src does not match"))
end
indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, dst)
@check HYPRE_IJVectorGetValues(src, nvalues, indices, dst)
return dst
end
@ -308,12 +319,12 @@ function Base.copy!(dst::HYPREVector, src::Vector{HYPRE_Complex}) @@ -308,12 +319,12 @@ function Base.copy!(dst::HYPREVector, src::Vector{HYPRE_Complex})
throw(ArgumentError("length of dst and src does not match"))
end
# Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector)
@check HYPRE_IJVectorInitialize(dst)
# Set all the values
indices = collect(HYPRE_BigInt, ilower:iupper)
@check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, src)
@check HYPRE_IJVectorSetValues(dst, nvalues, indices, src)
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector)
# @check HYPRE_IJVectorAssemble(dst)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed.
return dst
@ -445,7 +456,7 @@ function HYPREMatrix(B::PSparseMatrix) @@ -445,7 +456,7 @@ function HYPREMatrix(B::PSparseMatrix)
# Set all the values
map_parts(B.values, B.rows.partition, B.cols.partition) do Bv, Br, Bc
nrows, ncols, rows, cols, values = Internals.to_hypre_data(Bv, Br, Bc)
@check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values)
@check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values)
return nothing
end
# Finalize
@ -487,7 +498,7 @@ function HYPREVector(v::PVector) @@ -487,7 +498,7 @@ function HYPREVector(v::PVector)
# end
# nvalues = length(indices)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values)
@check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
return nothing
end
# Finalize
@ -521,7 +532,7 @@ function Base.copy!(dst::PVector{HYPRE_Complex}, src::HYPREVector) @@ -521,7 +532,7 @@ function Base.copy!(dst::PVector{HYPRE_Complex}, src::HYPREVector)
fill!(vv, 0)
# TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, vv)
@check HYPRE_IJVectorGetValues(src, nvalues, indices, vv)
end
return dst
end
@ -529,17 +540,17 @@ end @@ -529,17 +540,17 @@ end
function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex})
Internals.copy_check(dst, src)
# Re-initialize the vector
@check HYPRE_IJVectorInitialize(dst.ijvector)
@check HYPRE_IJVectorInitialize(dst)
map_parts(src.values, src.owned_values, src.rows.partition) do vv, _, vr
ilower_src_part = vr.lid_to_gid[vr.oid_to_lid.start]
iupper_src_part = vr.lid_to_gid[vr.oid_to_lid.stop]
nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1)
indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part)
# TODO: Safe to use vv here? Owned values are always first?
@check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, vv)
@check HYPRE_IJVectorSetValues(dst, nvalues, indices, vv)
end
# TODO: It shouldn't be necessary to assemble here since we only set owned rows (?)
# @check HYPRE_IJVectorAssemble(dst.ijvector)
# @check HYPRE_IJVectorAssemble(dst)
# TODO: Necessary to recreate the ParVector? Running some examples it seems like it is
# not needed.
return dst
@ -585,9 +596,9 @@ start_assemble! @@ -585,9 +596,9 @@ start_assemble!
function start_assemble!(A::HYPREMatrix)
if A.parmatrix != C_NULL
# This matrix have been assembled before, reset to 0
@check HYPRE_IJMatrixSetConstantValues(A.ijmatrix, 0)
@check HYPRE_IJMatrixSetConstantValues(A, 0)
end
@check HYPRE_IJMatrixInitialize(A.ijmatrix)
@check HYPRE_IJMatrixInitialize(A)
return HYPREMatrixAssembler(A, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[])
end
@ -595,14 +606,14 @@ function start_assemble!(b::HYPREVector) @@ -595,14 +606,14 @@ function start_assemble!(b::HYPREVector)
if b.parvector != C_NULL
# This vector have been assembled before, reset to 0
# See https://github.com/hypre-space/hypre/pull/689
# @check HYPRE_IJVectorSetConstantValues(b.ijvector, 0)
# @check HYPRE_IJVectorSetConstantValues(b, 0)
end
@check HYPRE_IJVectorInitialize(b.ijvector)
@check HYPRE_IJVectorInitialize(b)
if b.parvector != C_NULL
nvalues = HYPRE_Int(b.iupper - b.ilower + 1)
indices = collect(HYPRE_BigInt, b.ilower:b.iupper)
values = zeros(HYPRE_Complex, nvalues)
@check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values)
@check HYPRE_IJVectorSetValues(b, nvalues, indices, values)
# TODO: Do I need to assemble here?
end
return HYPREVectorAssembler(b, HYPRE_BigInt[], HYPRE_Complex[])
@ -635,14 +646,14 @@ assemble! @@ -635,14 +646,14 @@ assemble!
function assemble!(A::HYPREMatrixAssembler, i::Vector, j::Vector, a::Matrix)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, i, j)
@check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values)
@check HYPRE_IJMatrixAddToValues(A.A, nrows, ncols, rows, cols, values)
return A
end
@deprecate assemble!(A::HYPREMatrixAssembler, ij::Vector, a::Matrix) assemble!(A, ij, ij, a) false
function assemble!(A::HYPREVectorAssembler, ij::Vector, a::Vector)
nvalues, indices, values = Internals.to_hypre_data(A, a, ij)
@check HYPRE_IJVectorAddToValues(A.b.ijvector, nvalues, indices, values)
@check HYPRE_IJVectorAddToValues(A.b, nvalues, indices, values)
return A
end

34
src/solver_options.jl

@ -4,8 +4,7 @@ @@ -4,8 +4,7 @@
Internals.set_options(::HYPRESolver, kwargs) = nothing
function Internals.set_options(s::BiCGSTAB, kwargs)
solver = s.solver
function Internals.set_options(solver::BiCGSTAB, kwargs)
for (k, v) in kwargs
if k === :ConvergenceFactorTol
@check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v)
@ -19,7 +18,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs) @@ -19,7 +18,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs)
@check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v)
elseif k === :Precond
Internals.set_precond_defaults(v)
Internals.set_precond(s, v)
Internals.set_precond(solver, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v)
elseif k === :StopCrit
@ -32,8 +31,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs) @@ -32,8 +31,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs)
end
end
function Internals.set_options(s::BoomerAMG, kwargs)
solver = s.solver
function Internals.set_options(solver::BoomerAMG, kwargs)
for (k, v) in kwargs
if k === :ADropTol
@check HYPRE_BoomerAMGSetADropTol(solver, v)
@ -289,8 +287,7 @@ function Internals.set_options(s::BoomerAMG, kwargs) @@ -289,8 +287,7 @@ function Internals.set_options(s::BoomerAMG, kwargs)
end
end
function Internals.set_options(s::FlexGMRES, kwargs)
solver = s.solver
function Internals.set_options(solver::FlexGMRES, kwargs)
for (k, v) in kwargs
if k === :ConvergenceFactorTol
@check HYPRE_FlexGMRESSetConvergenceFactorTol(solver, v)
@ -308,7 +305,7 @@ function Internals.set_options(s::FlexGMRES, kwargs) @@ -308,7 +305,7 @@ function Internals.set_options(s::FlexGMRES, kwargs)
@check HYPRE_ParCSRFlexGMRESSetModifyPC(solver, v)
elseif k === :Precond
Internals.set_precond_defaults(v)
Internals.set_precond(s, v)
Internals.set_precond(solver, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRFlexGMRESSetPrintLevel(solver, v)
elseif k === :Tol
@ -319,8 +316,7 @@ function Internals.set_options(s::FlexGMRES, kwargs) @@ -319,8 +316,7 @@ function Internals.set_options(s::FlexGMRES, kwargs)
end
end
function Internals.set_options(s::GMRES, kwargs)
solver = s.solver
function Internals.set_options(solver::GMRES, kwargs)
for (k, v) in kwargs
if k === :ConvergenceFactorTol
@check HYPRE_GMRESSetConvergenceFactorTol(solver, v)
@ -340,7 +336,7 @@ function Internals.set_options(s::GMRES, kwargs) @@ -340,7 +336,7 @@ function Internals.set_options(s::GMRES, kwargs)
@check HYPRE_ParCSRGMRESSetMinIter(solver, v)
elseif k === :Precond
Internals.set_precond_defaults(v)
Internals.set_precond(s, v)
Internals.set_precond(solver, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRGMRESSetPrintLevel(solver, v)
elseif k === :StopCrit
@ -353,8 +349,7 @@ function Internals.set_options(s::GMRES, kwargs) @@ -353,8 +349,7 @@ function Internals.set_options(s::GMRES, kwargs)
end
end
function Internals.set_options(s::Hybrid, kwargs)
solver = s.solver
function Internals.set_options(solver::Hybrid, kwargs)
for (k, v) in kwargs
if k === :AbsoluteTol
@check HYPRE_ParCSRHybridSetAbsoluteTol(solver, v)
@ -424,7 +419,7 @@ function Internals.set_options(s::Hybrid, kwargs) @@ -424,7 +419,7 @@ function Internals.set_options(s::Hybrid, kwargs)
@check HYPRE_ParCSRHybridSetPMaxElmts(solver, v)
elseif k === :Precond
Internals.set_precond_defaults(v)
Internals.set_precond(s, v)
Internals.set_precond(solver, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRHybridSetPrintLevel(solver, v)
elseif k === :RecomputeResidual
@ -463,8 +458,7 @@ function Internals.set_options(s::Hybrid, kwargs) @@ -463,8 +458,7 @@ function Internals.set_options(s::Hybrid, kwargs)
end
end
function Internals.set_options(s::ILU, kwargs)
solver = s.solver
function Internals.set_options(solver::ILU, kwargs)
for (k, v) in kwargs
if k === :DropThreshold
@check HYPRE_ILUSetDropThreshold(solver, v)
@ -498,8 +492,7 @@ function Internals.set_options(s::ILU, kwargs) @@ -498,8 +492,7 @@ function Internals.set_options(s::ILU, kwargs)
end
end
function Internals.set_options(s::ParaSails, kwargs)
solver = s.solver
function Internals.set_options(solver::ParaSails, kwargs)
for (k, v) in kwargs
if k === :Filter
@check HYPRE_ParCSRParaSailsSetFilter(solver, v)
@ -519,8 +512,7 @@ function Internals.set_options(s::ParaSails, kwargs) @@ -519,8 +512,7 @@ function Internals.set_options(s::ParaSails, kwargs)
end
end
function Internals.set_options(s::PCG, kwargs)
solver = s.solver
function Internals.set_options(solver::PCG, kwargs)
for (k, v) in kwargs
if k === :AbsoluteTolFactor
@check HYPRE_PCGSetAbsoluteTolFactor(solver, v)
@ -540,7 +532,7 @@ function Internals.set_options(s::PCG, kwargs) @@ -540,7 +532,7 @@ function Internals.set_options(s::PCG, kwargs)
@check HYPRE_ParCSRPCGSetMaxIter(solver, v)
elseif k === :Precond
Internals.set_precond_defaults(v)
Internals.set_precond(s, v)
Internals.set_precond(solver, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRPCGSetPrintLevel(solver, v)
elseif k === :RelChange

48
src/solvers.jl

@ -13,12 +13,16 @@ function Internals.safe_finalizer(Destroy, solver) @@ -13,12 +13,16 @@ function Internals.safe_finalizer(Destroy, solver)
# Add a finalizer that only calls Destroy if pointer not C_NULL
finalizer(solver) do s
if s.solver != C_NULL
Destroy(s.solver)
Destroy(s)
s.solver = C_NULL
end
end
end
# Defining unsafe_convert enables ccall to automatically convert solver::HYPRESolver to
# HYPRE_Solver while also making sure solver won't be GC'd and finalized.
Base.unsafe_convert(::Type{HYPRE_Solver}, solver::HYPRESolver) = solver.solver
# Fallback for the solvers that doesn't have required defaults
Internals.set_precond_defaults(::HYPRESolver) = nothing
@ -122,8 +126,8 @@ end @@ -122,8 +126,8 @@ end
const ParCSRBiCGSTAB = BiCGSTAB
function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRBiCGSTABSetup(bicg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRBiCGSTABSolve(bicg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRBiCGSTABSetup(bicg, A, b, x)
@check HYPRE_ParCSRBiCGSTABSolve(bicg, A, b, x)
return x
end
@ -134,7 +138,7 @@ function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver) @@ -134,7 +138,7 @@ function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver)
bicg.precond = p
solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRBiCGSTABSetPrecond(bicg.solver, solve_f, setup_f, p.solver)
@check HYPRE_ParCSRBiCGSTABSetPrecond(bicg, solve_f, setup_f, p)
return nothing
end
@ -169,8 +173,8 @@ mutable struct BoomerAMG <: HYPRESolver @@ -169,8 +173,8 @@ mutable struct BoomerAMG <: HYPRESolver
end
function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_BoomerAMGSetup(amg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_BoomerAMGSolve(amg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_BoomerAMGSetup(amg, A, b, x)
@check HYPRE_BoomerAMGSolve(amg, A, b, x)
return x
end
@ -215,8 +219,8 @@ mutable struct FlexGMRES <: HYPRESolver @@ -215,8 +219,8 @@ mutable struct FlexGMRES <: HYPRESolver
end
function solve!(flex::FlexGMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRFlexGMRESSetup(flex.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRFlexGMRESSolve(flex.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRFlexGMRESSetup(flex, A, b, x)
@check HYPRE_ParCSRFlexGMRESSolve(flex, A, b, x)
return x
end
@ -227,7 +231,7 @@ function Internals.set_precond(flex::FlexGMRES, p::HYPRESolver) @@ -227,7 +231,7 @@ function Internals.set_precond(flex::FlexGMRES, p::HYPRESolver)
flex.precond = p
solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRFlexGMRESSetPrecond(flex.solver, solve_f, setup_f, p.solver)
@check HYPRE_ParCSRFlexGMRESSetPrecond(flex, solve_f, setup_f, p)
return nothing
end
@ -254,8 +258,8 @@ end @@ -254,8 +258,8 @@ end
#end
#function solve!(fsai::FSAI, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
# @check HYPRE_FSAISetup(fsai.solver, A.parmatrix, b.parvector, x.parvector)
# @check HYPRE_FSAISolve(fsai.solver, A.parmatrix, b.parvector, x.parvector)
# @check HYPRE_FSAISetup(fsai, A, b, x)
# @check HYPRE_FSAISolve(fsai, A, b, x)
# return x
#end
@ -300,8 +304,8 @@ mutable struct GMRES <: HYPRESolver @@ -300,8 +304,8 @@ mutable struct GMRES <: HYPRESolver
end
function solve!(gmres::GMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRGMRESSetup(gmres.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRGMRESSolve(gmres.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRGMRESSetup(gmres, A, b, x)
@check HYPRE_ParCSRGMRESSolve(gmres, A, b, x)
return x
end
@ -312,7 +316,7 @@ function Internals.set_precond(gmres::GMRES, p::HYPRESolver) @@ -312,7 +316,7 @@ function Internals.set_precond(gmres::GMRES, p::HYPRESolver)
gmres.precond = p
solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRGMRESSetPrecond(gmres.solver, solve_f, setup_f, p.solver)
@check HYPRE_ParCSRGMRESSetPrecond(gmres, solve_f, setup_f, p)
return nothing
end
@ -347,8 +351,8 @@ mutable struct Hybrid <: HYPRESolver @@ -347,8 +351,8 @@ mutable struct Hybrid <: HYPRESolver
end
function solve!(hybrid::Hybrid, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRHybridSetup(hybrid.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRHybridSolve(hybrid.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRHybridSetup(hybrid, A, b, x)
@check HYPRE_ParCSRHybridSolve(hybrid, A, b, x)
return x
end
@ -362,7 +366,7 @@ function Internals.set_precond(hybrid::Hybrid, p::HYPRESolver) @@ -362,7 +366,7 @@ function Internals.set_precond(hybrid::Hybrid, p::HYPRESolver)
# Deactivate the finalizer of p since the HYBRIDDestroy function does this,
# see https://github.com/hypre-space/hypre/issues/699
finalizer(x -> (x.solver = C_NULL), p)
@check HYPRE_ParCSRHybridSetPrecond(hybrid.solver, solve_f, setup_f, p.solver)
@check HYPRE_ParCSRHybridSetPrecond(hybrid, solve_f, setup_f, p)
return nothing
end
@ -397,8 +401,8 @@ mutable struct ILU <: HYPRESolver @@ -397,8 +401,8 @@ mutable struct ILU <: HYPRESolver
end
function solve!(ilu::ILU, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ILUSetup(ilu.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ILUSolve(ilu.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ILUSetup(ilu, A, b, x)
@check HYPRE_ILUSolve(ilu, A, b, x)
return x
end
@ -482,8 +486,8 @@ end @@ -482,8 +486,8 @@ end
const ParCSRPCG = PCG
function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRPCGSetup(pcg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRPCGSolve(pcg.solver, A.parmatrix, b.parvector, x.parvector)
@check HYPRE_ParCSRPCGSetup(pcg, A, b, x)
@check HYPRE_ParCSRPCGSolve(pcg, A, b, x)
return x
end
@ -494,6 +498,6 @@ function Internals.set_precond(pcg::PCG, p::HYPRESolver) @@ -494,6 +498,6 @@ function Internals.set_precond(pcg::PCG, p::HYPRESolver)
pcg.precond = p
solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRPCGSetPrecond(pcg.solver, solve_f, setup_f, p.solver)
@check HYPRE_ParCSRPCGSetPrecond(pcg, solve_f, setup_f, p)
return nothing
end

Loading…
Cancel
Save