diff --git a/docs/src/libhypre.md b/docs/src/libhypre.md index 802547d..161dc60 100644 --- a/docs/src/libhypre.md +++ b/docs/src/libhypre.md @@ -15,15 +15,13 @@ directly. Functions from the `LibHYPRE` submodule can be used together with the high level interface. This is useful when you need some functionality from the library which isn't exposed in the high level interface. Many functions require passing a reference to a matrix/vector or a -solver. These can be obtained as follows: - -| C type signature | Argument to pass | -|:---------------------|:-------------------------------------| -| `HYPRE_IJMatrix` | `A.ijmatrix` where `A::HYPREMatrix` | -| `HYPRE_ParCSRMatrix` | `A.parmatrix` where `A::HYPREMatrix` | -| `HYPRE_IJVector` | `b.ijvector` where `b::HYPREVector` | -| `HYPRE_ParVector` | `b.parvector` where `b::HYPREVector` | -| `HYPRE_Solver` | `s.solver` where `s::HYPRESolver` | +solver. HYPRE.jl defines the appropriate conversion methods used by `ccall` such that + - `A::HYPREMatrix` can be passed to `HYPRE_*` functions with `HYPRE_IJMatrix` or + `HYPRE_ParCSRMatrix` in the signature + - `b::HYPREVector` can be passed to `HYPRE_*` functions with `HYPRE_IJVector` or + `HYPRE_ParVector` in the signature + - `s::HYPRESolver` can be passed to `HYPRE_*` functions with `HYPRE_Solver` in the + signature [^1]: Bindings are generated using [Clang.jl](https://github.com/JuliaInterop/Clang.jl), see diff --git a/gen/solver_options.jl b/gen/solver_options.jl index aefd3ed..cc18b33 100644 --- a/gen/solver_options.jl +++ b/gen/solver_options.jl @@ -2,8 +2,7 @@ using HYPRE.LibHYPRE function generate_options(io, structname, prefixes...) println(io, "") - println(io, "function Internals.set_options(s::$(structname), kwargs)") - println(io, " solver = s.solver") + println(io, "function Internals.set_options(solver::$(structname), kwargs)") println(io, " for (k, v) in kwargs") ns = Tuple{Symbol,String}[] @@ -29,7 +28,7 @@ function generate_options(io, structname, prefixes...) println(io) if k == "Precond" println(io, " Internals.set_precond_defaults(v)") - println(io, " Internals.set_precond(s, v)") + println(io, " Internals.set_precond(solver, v)") elseif nargs == 1 println(io, " @check ", n, "(solver)") elseif nargs == 2 diff --git a/src/HYPRE.jl b/src/HYPRE.jl index 746effe..dcaa12e 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -63,6 +63,12 @@ mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex} parmatrix::HYPRE_ParCSRMatrix end +# Defining unsafe_convert enables ccall to automatically convert A::HYPREMatrix to +# HYPRE_IJMatrix and HYPRE_ParCSRMatrix while also making sure A won't be GC'd and +# finalized. +Base.unsafe_convert(::Type{HYPRE_IJMatrix}, A::HYPREMatrix) = A.ijmatrix +Base.unsafe_convert(::Type{HYPRE_ParCSRMatrix}, A::HYPREMatrix) = A.parmatrix + function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, jlower::Integer=ilower, jupper::Integer=iupper) # Create the IJ matrix @@ -73,15 +79,15 @@ function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, # Attach a finalizer finalizer(A) do x if x.ijmatrix != C_NULL - HYPRE_IJMatrixDestroy(x.ijmatrix) + HYPRE_IJMatrixDestroy(x) x.ijmatrix = x.parmatrix = C_NULL end end push!(Internals.HYPRE_OBJECTS, A => nothing) # Set storage type - @check HYPRE_IJMatrixSetObjectType(A.ijmatrix, HYPRE_PARCSR) + @check HYPRE_IJMatrixSetObjectType(A, HYPRE_PARCSR) # Initialize to make ready for setting values - @check HYPRE_IJMatrixInitialize(A.ijmatrix) + @check HYPRE_IJMatrixInitialize(A) return A end @@ -89,10 +95,10 @@ end # This should be called after setting all the values function Internals.assemble_matrix(A::HYPREMatrix) # Finalize after setting all values - @check HYPRE_IJMatrixAssemble(A.ijmatrix) + @check HYPRE_IJMatrixAssemble(A) # Fetch the assembled CSR matrix parmatrix_ref = Ref{Ptr{Cvoid}}(C_NULL) - @check HYPRE_IJMatrixGetObject(A.ijmatrix, parmatrix_ref) + @check HYPRE_IJMatrixGetObject(A, parmatrix_ref) A.parmatrix = convert(Ptr{HYPRE_ParCSRMatrix}, parmatrix_ref[]) return A end @@ -109,6 +115,11 @@ mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex} parvector::HYPRE_ParVector end +# Defining unsafe_convert enables ccall to automatically convert b::HYPREVector to +# HYPRE_IJVector and HYPRE_ParVector while also making sure b won't be GC'd and finalized. +Base.unsafe_convert(::Type{HYPRE_IJVector}, b::HYPREVector) = b.ijvector +Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector + function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) # Create the IJ vector b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL) @@ -118,24 +129,24 @@ function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) # Attach a finalizer finalizer(b) do x if x.ijvector != C_NULL - HYPRE_IJVectorDestroy(x.ijvector) + HYPRE_IJVectorDestroy(x) x.ijvector = x.parvector = C_NULL end end push!(Internals.HYPRE_OBJECTS, b => nothing) # Set storage type - @check HYPRE_IJVectorSetObjectType(b.ijvector, HYPRE_PARCSR) + @check HYPRE_IJVectorSetObjectType(b, HYPRE_PARCSR) # Initialize to make ready for setting values - @check HYPRE_IJVectorInitialize(b.ijvector) + @check HYPRE_IJVectorInitialize(b) return b end function Internals.assemble_vector(b::HYPREVector) # Finalize after setting all values - @check HYPRE_IJVectorAssemble(b.ijvector) + @check HYPRE_IJVectorAssemble(b) # Fetch the assembled vector parvector_ref = Ref{Ptr{Cvoid}}(C_NULL) - @check HYPRE_IJVectorGetObject(b.ijvector, parvector_ref) + @check HYPRE_IJVectorGetObject(b, parvector_ref) b.parvector = convert(Ptr{HYPRE_ParVector}, parvector_ref[]) return b end @@ -143,7 +154,7 @@ end function Internals.get_proc_rows(b::HYPREVector) # ilower_ref = Ref{HYPRE_BigInt}() # iupper_ref = Ref{HYPRE_BigInt}() - # @check HYPRE_IJVectorGetLocalRange(b.ijvector, ilower_ref, iupper_ref) + # @check HYPRE_IJVectorGetLocalRange(b, ilower_ref, iupper_ref) # ilower = ilower_ref[] # iupper = iupper_ref[] # return ilower, iupper @@ -169,7 +180,7 @@ function Base.zero(b::HYPREVector) nvalues = jupper - jlower + 1 indices = collect(HYPRE_BigInt, jlower:jupper) values = zeros(HYPRE_Complex, nvalues) - @check HYPRE_IJVectorSetValues(x.ijvector, nvalues, indices, values) + @check HYPRE_IJVectorSetValues(x, nvalues, indices, values) # Finalize and return Internals.assemble_vector(x) return x @@ -258,7 +269,7 @@ end function HYPREMatrix(comm::MPI.Comm, B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper) A = HYPREMatrix(comm, ilower, iupper) nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) - @check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values) + @check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values) Internals.assemble_matrix(A) return A end @@ -281,7 +292,7 @@ end function HYPREVector(comm::MPI.Comm, x::Vector, ilower, iupper) b = HYPREVector(comm, ilower, iupper) nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper) - @check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) + @check HYPRE_IJVectorSetValues(b, nvalues, indices, values) Internals.assemble_vector(b) return b end @@ -297,7 +308,7 @@ function Base.copy!(dst::Vector{HYPRE_Complex}, src::HYPREVector) throw(ArgumentError("length of dst and src does not match")) end indices = collect(HYPRE_BigInt, ilower:iupper) - @check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, dst) + @check HYPRE_IJVectorGetValues(src, nvalues, indices, dst) return dst end @@ -308,12 +319,12 @@ function Base.copy!(dst::HYPREVector, src::Vector{HYPRE_Complex}) throw(ArgumentError("length of dst and src does not match")) end # Re-initialize the vector - @check HYPRE_IJVectorInitialize(dst.ijvector) + @check HYPRE_IJVectorInitialize(dst) # Set all the values indices = collect(HYPRE_BigInt, ilower:iupper) - @check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, src) + @check HYPRE_IJVectorSetValues(dst, nvalues, indices, src) # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?) - # @check HYPRE_IJVectorAssemble(dst.ijvector) + # @check HYPRE_IJVectorAssemble(dst) # TODO: Necessary to recreate the ParVector? Running some examples it seems like it is # not needed. return dst @@ -445,7 +456,7 @@ function HYPREMatrix(B::PSparseMatrix) # Set all the values map_parts(B.values, B.rows.partition, B.cols.partition) do Bv, Br, Bc nrows, ncols, rows, cols, values = Internals.to_hypre_data(Bv, Br, Bc) - @check HYPRE_IJMatrixSetValues(A.ijmatrix, nrows, ncols, rows, cols, values) + @check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values) return nothing end # Finalize @@ -487,7 +498,7 @@ function HYPREVector(v::PVector) # end # nvalues = length(indices) - @check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) + @check HYPRE_IJVectorSetValues(b, nvalues, indices, values) return nothing end # Finalize @@ -521,7 +532,7 @@ function Base.copy!(dst::PVector{HYPRE_Complex}, src::HYPREVector) fill!(vv, 0) # TODO: Safe to use vv here? Owned values are always first? - @check HYPRE_IJVectorGetValues(src.ijvector, nvalues, indices, vv) + @check HYPRE_IJVectorGetValues(src, nvalues, indices, vv) end return dst end @@ -529,17 +540,17 @@ end function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex}) Internals.copy_check(dst, src) # Re-initialize the vector - @check HYPRE_IJVectorInitialize(dst.ijvector) + @check HYPRE_IJVectorInitialize(dst) map_parts(src.values, src.owned_values, src.rows.partition) do vv, _, vr ilower_src_part = vr.lid_to_gid[vr.oid_to_lid.start] iupper_src_part = vr.lid_to_gid[vr.oid_to_lid.stop] nvalues = HYPRE_Int(iupper_src_part - ilower_src_part + 1) indices = collect(HYPRE_BigInt, ilower_src_part:iupper_src_part) # TODO: Safe to use vv here? Owned values are always first? - @check HYPRE_IJVectorSetValues(dst.ijvector, nvalues, indices, vv) + @check HYPRE_IJVectorSetValues(dst, nvalues, indices, vv) end # TODO: It shouldn't be necessary to assemble here since we only set owned rows (?) - # @check HYPRE_IJVectorAssemble(dst.ijvector) + # @check HYPRE_IJVectorAssemble(dst) # TODO: Necessary to recreate the ParVector? Running some examples it seems like it is # not needed. return dst @@ -585,9 +596,9 @@ start_assemble! function start_assemble!(A::HYPREMatrix) if A.parmatrix != C_NULL # This matrix have been assembled before, reset to 0 - @check HYPRE_IJMatrixSetConstantValues(A.ijmatrix, 0) + @check HYPRE_IJMatrixSetConstantValues(A, 0) end - @check HYPRE_IJMatrixInitialize(A.ijmatrix) + @check HYPRE_IJMatrixInitialize(A) return HYPREMatrixAssembler(A, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[]) end @@ -595,14 +606,14 @@ function start_assemble!(b::HYPREVector) if b.parvector != C_NULL # This vector have been assembled before, reset to 0 # See https://github.com/hypre-space/hypre/pull/689 - # @check HYPRE_IJVectorSetConstantValues(b.ijvector, 0) + # @check HYPRE_IJVectorSetConstantValues(b, 0) end - @check HYPRE_IJVectorInitialize(b.ijvector) + @check HYPRE_IJVectorInitialize(b) if b.parvector != C_NULL nvalues = HYPRE_Int(b.iupper - b.ilower + 1) indices = collect(HYPRE_BigInt, b.ilower:b.iupper) values = zeros(HYPRE_Complex, nvalues) - @check HYPRE_IJVectorSetValues(b.ijvector, nvalues, indices, values) + @check HYPRE_IJVectorSetValues(b, nvalues, indices, values) # TODO: Do I need to assemble here? end return HYPREVectorAssembler(b, HYPRE_BigInt[], HYPRE_Complex[]) @@ -635,14 +646,14 @@ assemble! function assemble!(A::HYPREMatrixAssembler, i::Vector, j::Vector, a::Matrix) nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, i, j) - @check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values) + @check HYPRE_IJMatrixAddToValues(A.A, nrows, ncols, rows, cols, values) return A end @deprecate assemble!(A::HYPREMatrixAssembler, ij::Vector, a::Matrix) assemble!(A, ij, ij, a) false function assemble!(A::HYPREVectorAssembler, ij::Vector, a::Vector) nvalues, indices, values = Internals.to_hypre_data(A, a, ij) - @check HYPRE_IJVectorAddToValues(A.b.ijvector, nvalues, indices, values) + @check HYPRE_IJVectorAddToValues(A.b, nvalues, indices, values) return A end diff --git a/src/solver_options.jl b/src/solver_options.jl index 2885d59..cf2c896 100644 --- a/src/solver_options.jl +++ b/src/solver_options.jl @@ -4,8 +4,7 @@ Internals.set_options(::HYPRESolver, kwargs) = nothing -function Internals.set_options(s::BiCGSTAB, kwargs) - solver = s.solver +function Internals.set_options(solver::BiCGSTAB, kwargs) for (k, v) in kwargs if k === :ConvergenceFactorTol @check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v) @@ -19,7 +18,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs) @check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v) elseif k === :Precond Internals.set_precond_defaults(v) - Internals.set_precond(s, v) + Internals.set_precond(solver, v) elseif k === :PrintLevel @check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v) elseif k === :StopCrit @@ -32,8 +31,7 @@ function Internals.set_options(s::BiCGSTAB, kwargs) end end -function Internals.set_options(s::BoomerAMG, kwargs) - solver = s.solver +function Internals.set_options(solver::BoomerAMG, kwargs) for (k, v) in kwargs if k === :ADropTol @check HYPRE_BoomerAMGSetADropTol(solver, v) @@ -289,8 +287,7 @@ function Internals.set_options(s::BoomerAMG, kwargs) end end -function Internals.set_options(s::FlexGMRES, kwargs) - solver = s.solver +function Internals.set_options(solver::FlexGMRES, kwargs) for (k, v) in kwargs if k === :ConvergenceFactorTol @check HYPRE_FlexGMRESSetConvergenceFactorTol(solver, v) @@ -308,7 +305,7 @@ function Internals.set_options(s::FlexGMRES, kwargs) @check HYPRE_ParCSRFlexGMRESSetModifyPC(solver, v) elseif k === :Precond Internals.set_precond_defaults(v) - Internals.set_precond(s, v) + Internals.set_precond(solver, v) elseif k === :PrintLevel @check HYPRE_ParCSRFlexGMRESSetPrintLevel(solver, v) elseif k === :Tol @@ -319,8 +316,7 @@ function Internals.set_options(s::FlexGMRES, kwargs) end end -function Internals.set_options(s::GMRES, kwargs) - solver = s.solver +function Internals.set_options(solver::GMRES, kwargs) for (k, v) in kwargs if k === :ConvergenceFactorTol @check HYPRE_GMRESSetConvergenceFactorTol(solver, v) @@ -340,7 +336,7 @@ function Internals.set_options(s::GMRES, kwargs) @check HYPRE_ParCSRGMRESSetMinIter(solver, v) elseif k === :Precond Internals.set_precond_defaults(v) - Internals.set_precond(s, v) + Internals.set_precond(solver, v) elseif k === :PrintLevel @check HYPRE_ParCSRGMRESSetPrintLevel(solver, v) elseif k === :StopCrit @@ -353,8 +349,7 @@ function Internals.set_options(s::GMRES, kwargs) end end -function Internals.set_options(s::Hybrid, kwargs) - solver = s.solver +function Internals.set_options(solver::Hybrid, kwargs) for (k, v) in kwargs if k === :AbsoluteTol @check HYPRE_ParCSRHybridSetAbsoluteTol(solver, v) @@ -424,7 +419,7 @@ function Internals.set_options(s::Hybrid, kwargs) @check HYPRE_ParCSRHybridSetPMaxElmts(solver, v) elseif k === :Precond Internals.set_precond_defaults(v) - Internals.set_precond(s, v) + Internals.set_precond(solver, v) elseif k === :PrintLevel @check HYPRE_ParCSRHybridSetPrintLevel(solver, v) elseif k === :RecomputeResidual @@ -463,8 +458,7 @@ function Internals.set_options(s::Hybrid, kwargs) end end -function Internals.set_options(s::ILU, kwargs) - solver = s.solver +function Internals.set_options(solver::ILU, kwargs) for (k, v) in kwargs if k === :DropThreshold @check HYPRE_ILUSetDropThreshold(solver, v) @@ -498,8 +492,7 @@ function Internals.set_options(s::ILU, kwargs) end end -function Internals.set_options(s::ParaSails, kwargs) - solver = s.solver +function Internals.set_options(solver::ParaSails, kwargs) for (k, v) in kwargs if k === :Filter @check HYPRE_ParCSRParaSailsSetFilter(solver, v) @@ -519,8 +512,7 @@ function Internals.set_options(s::ParaSails, kwargs) end end -function Internals.set_options(s::PCG, kwargs) - solver = s.solver +function Internals.set_options(solver::PCG, kwargs) for (k, v) in kwargs if k === :AbsoluteTolFactor @check HYPRE_PCGSetAbsoluteTolFactor(solver, v) @@ -540,7 +532,7 @@ function Internals.set_options(s::PCG, kwargs) @check HYPRE_ParCSRPCGSetMaxIter(solver, v) elseif k === :Precond Internals.set_precond_defaults(v) - Internals.set_precond(s, v) + Internals.set_precond(solver, v) elseif k === :PrintLevel @check HYPRE_ParCSRPCGSetPrintLevel(solver, v) elseif k === :RelChange diff --git a/src/solvers.jl b/src/solvers.jl index 337c672..563c98d 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -13,12 +13,16 @@ function Internals.safe_finalizer(Destroy, solver) # Add a finalizer that only calls Destroy if pointer not C_NULL finalizer(solver) do s if s.solver != C_NULL - Destroy(s.solver) + Destroy(s) s.solver = C_NULL end end end +# Defining unsafe_convert enables ccall to automatically convert solver::HYPRESolver to +# HYPRE_Solver while also making sure solver won't be GC'd and finalized. +Base.unsafe_convert(::Type{HYPRE_Solver}, solver::HYPRESolver) = solver.solver + # Fallback for the solvers that doesn't have required defaults Internals.set_precond_defaults(::HYPRESolver) = nothing @@ -122,8 +126,8 @@ end const ParCSRBiCGSTAB = BiCGSTAB function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ParCSRBiCGSTABSetup(bicg.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ParCSRBiCGSTABSolve(bicg.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ParCSRBiCGSTABSetup(bicg, A, b, x) + @check HYPRE_ParCSRBiCGSTABSolve(bicg, A, b, x) return x end @@ -134,7 +138,7 @@ function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver) bicg.precond = p solve_f = Internals.solve_func(p) setup_f = Internals.setup_func(p) - @check HYPRE_ParCSRBiCGSTABSetPrecond(bicg.solver, solve_f, setup_f, p.solver) + @check HYPRE_ParCSRBiCGSTABSetPrecond(bicg, solve_f, setup_f, p) return nothing end @@ -169,8 +173,8 @@ mutable struct BoomerAMG <: HYPRESolver end function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_BoomerAMGSetup(amg.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_BoomerAMGSolve(amg.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_BoomerAMGSetup(amg, A, b, x) + @check HYPRE_BoomerAMGSolve(amg, A, b, x) return x end @@ -215,8 +219,8 @@ mutable struct FlexGMRES <: HYPRESolver end function solve!(flex::FlexGMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ParCSRFlexGMRESSetup(flex.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ParCSRFlexGMRESSolve(flex.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ParCSRFlexGMRESSetup(flex, A, b, x) + @check HYPRE_ParCSRFlexGMRESSolve(flex, A, b, x) return x end @@ -227,7 +231,7 @@ function Internals.set_precond(flex::FlexGMRES, p::HYPRESolver) flex.precond = p solve_f = Internals.solve_func(p) setup_f = Internals.setup_func(p) - @check HYPRE_ParCSRFlexGMRESSetPrecond(flex.solver, solve_f, setup_f, p.solver) + @check HYPRE_ParCSRFlexGMRESSetPrecond(flex, solve_f, setup_f, p) return nothing end @@ -254,8 +258,8 @@ end #end #function solve!(fsai::FSAI, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) -# @check HYPRE_FSAISetup(fsai.solver, A.parmatrix, b.parvector, x.parvector) -# @check HYPRE_FSAISolve(fsai.solver, A.parmatrix, b.parvector, x.parvector) +# @check HYPRE_FSAISetup(fsai, A, b, x) +# @check HYPRE_FSAISolve(fsai, A, b, x) # return x #end @@ -300,8 +304,8 @@ mutable struct GMRES <: HYPRESolver end function solve!(gmres::GMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ParCSRGMRESSetup(gmres.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ParCSRGMRESSolve(gmres.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ParCSRGMRESSetup(gmres, A, b, x) + @check HYPRE_ParCSRGMRESSolve(gmres, A, b, x) return x end @@ -312,7 +316,7 @@ function Internals.set_precond(gmres::GMRES, p::HYPRESolver) gmres.precond = p solve_f = Internals.solve_func(p) setup_f = Internals.setup_func(p) - @check HYPRE_ParCSRGMRESSetPrecond(gmres.solver, solve_f, setup_f, p.solver) + @check HYPRE_ParCSRGMRESSetPrecond(gmres, solve_f, setup_f, p) return nothing end @@ -347,8 +351,8 @@ mutable struct Hybrid <: HYPRESolver end function solve!(hybrid::Hybrid, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ParCSRHybridSetup(hybrid.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ParCSRHybridSolve(hybrid.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ParCSRHybridSetup(hybrid, A, b, x) + @check HYPRE_ParCSRHybridSolve(hybrid, A, b, x) return x end @@ -362,7 +366,7 @@ function Internals.set_precond(hybrid::Hybrid, p::HYPRESolver) # Deactivate the finalizer of p since the HYBRIDDestroy function does this, # see https://github.com/hypre-space/hypre/issues/699 finalizer(x -> (x.solver = C_NULL), p) - @check HYPRE_ParCSRHybridSetPrecond(hybrid.solver, solve_f, setup_f, p.solver) + @check HYPRE_ParCSRHybridSetPrecond(hybrid, solve_f, setup_f, p) return nothing end @@ -397,8 +401,8 @@ mutable struct ILU <: HYPRESolver end function solve!(ilu::ILU, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ILUSetup(ilu.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ILUSolve(ilu.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ILUSetup(ilu, A, b, x) + @check HYPRE_ILUSolve(ilu, A, b, x) return x end @@ -482,8 +486,8 @@ end const ParCSRPCG = PCG function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) - @check HYPRE_ParCSRPCGSetup(pcg.solver, A.parmatrix, b.parvector, x.parvector) - @check HYPRE_ParCSRPCGSolve(pcg.solver, A.parmatrix, b.parvector, x.parvector) + @check HYPRE_ParCSRPCGSetup(pcg, A, b, x) + @check HYPRE_ParCSRPCGSolve(pcg, A, b, x) return x end @@ -494,6 +498,6 @@ function Internals.set_precond(pcg::PCG, p::HYPRESolver) pcg.precond = p solve_f = Internals.solve_func(p) setup_f = Internals.setup_func(p) - @check HYPRE_ParCSRPCGSetPrecond(pcg.solver, solve_f, setup_f, p.solver) + @check HYPRE_ParCSRPCGSetPrecond(pcg, solve_f, setup_f, p) return nothing end