diff --git a/gen/solver_options.jl b/gen/solver_options.jl index 8931872..79f480c 100644 --- a/gen/solver_options.jl +++ b/gen/solver_options.jl @@ -15,7 +15,9 @@ function generate_options(io, structname, prefix) k = String(match(r, string(n))[1]) print(io, " $(first ? "" : "else")if k === :$(k)") println(io) - if nargs == 1 + if k == "Precond" + println(io, " Internals.set_precond(s, v)") + elseif nargs == 1 println(io, " @check ", n, "(solver)") elseif nargs == 2 println(io, " @check ", n, "(solver, v)") @@ -37,4 +39,5 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing") generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet") + generate_options(io, "PCG", "HYPRE_PCGSet") end diff --git a/src/Internals.jl b/src/Internals.jl index 20e17a9..ec43ace 100644 --- a/src/Internals.jl +++ b/src/Internals.jl @@ -13,5 +13,6 @@ function assemble_vector end function set_options end function solve_func end function setup_func end +function set_precond end end # module Internals diff --git a/src/solver_options.jl b/src/solver_options.jl index c1170c2..4c30ea8 100644 --- a/src/solver_options.jl +++ b/src/solver_options.jl @@ -258,3 +258,38 @@ function Internals.set_options(s::BoomerAMG, kwargs) end end end + +function Internals.set_options(s::PCG, kwargs) + solver = s.solver + for (k, v) in kwargs + if k === :AbsoluteTol + @check HYPRE_PCGSetAbsoluteTol(solver, v) + elseif k === :AbsoluteTolFactor + @check HYPRE_PCGSetAbsoluteTolFactor(solver, v) + elseif k === :ConvergenceFactorTol + @check HYPRE_PCGSetConvergenceFactorTol(solver, v) + elseif k === :Logging + @check HYPRE_PCGSetLogging(solver, v) + elseif k === :MaxIter + @check HYPRE_PCGSetMaxIter(solver, v) + elseif k === :Precond + Internals.set_precond(s, v) + elseif k === :PrintLevel + @check HYPRE_PCGSetPrintLevel(solver, v) + elseif k === :RecomputeResidual + @check HYPRE_PCGSetRecomputeResidual(solver, v) + elseif k === :RecomputeResidualP + @check HYPRE_PCGSetRecomputeResidualP(solver, v) + elseif k === :RelChange + @check HYPRE_PCGSetRelChange(solver, v) + elseif k === :ResidualTol + @check HYPRE_PCGSetResidualTol(solver, v) + elseif k === :StopCrit + @check HYPRE_PCGSetStopCrit(solver, v) + elseif k === :Tol + @check HYPRE_PCGSetTol(solver, v) + elseif k === :TwoNorm + @check HYPRE_PCGSetTwoNorm(solver, v) + end + end +end diff --git a/src/solvers.jl b/src/solvers.jl index 188ae26..7e8f78c 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -34,3 +34,39 @@ end Internals.solve_func(::BoomerAMG) = HYPRE_BoomerAMGSolve Internals.setup_func(::BoomerAMG) = HYPRE_BoomerAMGSetup + + +############### +# (ParCSR)PCG # +############### + +mutable struct PCG <: HYPRESolver + solver::HYPRE_Solver + function PCG(comm::MPI.Comm=MPI.COMM_WORLD; kwargs...) + solver = new(C_NULL) + solver_ref = Ref{HYPRE_Solver}(C_NULL) + @check HYPRE_ParCSRPCGCreate(comm, solver_ref) + solver.solver = solver_ref[] + # Attach a finalizer + finalizer(x -> HYPRE_ParCSRPCGDestroy(x.solver), solver) + # Set the options + Internals.set_options(solver, kwargs) + return solver + end +end + +const ParCSRPCG = PCG + +function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + @check HYPRE_ParCSRPCGSetup(pcg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + @check HYPRE_ParCSRPCGSolve(pcg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + return x +end + +Internals.solve_func(::PCG) = HYPRE_ParCSRPCGSolve +Internals.setup_func(::PCG) = HYPRE_ParCSRPCGSetup + +function Internals.set_precond(pcg::PCG, p::HYPRESolver) + @check HYPRE_PCGSetPrecond(pcg.solver, Internals.solve_func(p), Internals.setup_func(p), p.solver) + return nothing +end diff --git a/test/runtests.jl b/test/runtests.jl index 1053b6d..bf3cee1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -265,3 +265,28 @@ end # Test result with direct solver @test x ≈ A \ b atol=tol end + +@testset "(ParCSR)PCG" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + ilower, iupper = 1, size(A, 1) + A_h = HYPREMatrix(A, ilower, iupper) + b_h = HYPREVector(b, ilower, iupper) + x_h = HYPREVector(b, ilower, iupper) + # Solve + tol = 1e-9 + pcg = HYPRE.PCG(; Tol = tol) + HYPRE.solve!(pcg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Solve with AMG preconditioner + precond = HYPRE.BoomerAMG() + pcg = HYPRE.PCG(; Tol = tol, Precond = precond) + HYPRE.solve!(pcg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol +end