From 8dcedf68305b7f6d6232a4ce6e2e2bf135459760 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Fri, 22 Jul 2022 20:19:34 +0200 Subject: [PATCH] Wrap (ParCSR)PCG solver. Settings are passed as keyword arguments, just like BoomerAMG. The Precond argument (corresponding to PCGSetPrecond) is handled separately, and lets you pass another solver directly, instead of the solver pointer, the setup and solve functions, as in the SetPrecond C function. Example: precond = BoomerAMG(; options...) solver = PCG(; Precond = precond, options...) --- gen/solver_options.jl | 5 ++++- src/Internals.jl | 1 + src/solver_options.jl | 35 +++++++++++++++++++++++++++++++++++ src/solvers.jl | 36 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 25 +++++++++++++++++++++++++ 5 files changed, 101 insertions(+), 1 deletion(-) diff --git a/gen/solver_options.jl b/gen/solver_options.jl index 8931872..79f480c 100644 --- a/gen/solver_options.jl +++ b/gen/solver_options.jl @@ -15,7 +15,9 @@ function generate_options(io, structname, prefix) k = String(match(r, string(n))[1]) print(io, " $(first ? "" : "else")if k === :$(k)") println(io) - if nargs == 1 + if k == "Precond" + println(io, " Internals.set_precond(s, v)") + elseif nargs == 1 println(io, " @check ", n, "(solver)") elseif nargs == 2 println(io, " @check ", n, "(solver, v)") @@ -37,4 +39,5 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing") generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet") + generate_options(io, "PCG", "HYPRE_PCGSet") end diff --git a/src/Internals.jl b/src/Internals.jl index 20e17a9..ec43ace 100644 --- a/src/Internals.jl +++ b/src/Internals.jl @@ -13,5 +13,6 @@ function assemble_vector end function set_options end function solve_func end function setup_func end +function set_precond end end # module Internals diff --git a/src/solver_options.jl b/src/solver_options.jl index c1170c2..4c30ea8 100644 --- a/src/solver_options.jl +++ b/src/solver_options.jl @@ -258,3 +258,38 @@ function Internals.set_options(s::BoomerAMG, kwargs) end end end + +function Internals.set_options(s::PCG, kwargs) + solver = s.solver + for (k, v) in kwargs + if k === :AbsoluteTol + @check HYPRE_PCGSetAbsoluteTol(solver, v) + elseif k === :AbsoluteTolFactor + @check HYPRE_PCGSetAbsoluteTolFactor(solver, v) + elseif k === :ConvergenceFactorTol + @check HYPRE_PCGSetConvergenceFactorTol(solver, v) + elseif k === :Logging + @check HYPRE_PCGSetLogging(solver, v) + elseif k === :MaxIter + @check HYPRE_PCGSetMaxIter(solver, v) + elseif k === :Precond + Internals.set_precond(s, v) + elseif k === :PrintLevel + @check HYPRE_PCGSetPrintLevel(solver, v) + elseif k === :RecomputeResidual + @check HYPRE_PCGSetRecomputeResidual(solver, v) + elseif k === :RecomputeResidualP + @check HYPRE_PCGSetRecomputeResidualP(solver, v) + elseif k === :RelChange + @check HYPRE_PCGSetRelChange(solver, v) + elseif k === :ResidualTol + @check HYPRE_PCGSetResidualTol(solver, v) + elseif k === :StopCrit + @check HYPRE_PCGSetStopCrit(solver, v) + elseif k === :Tol + @check HYPRE_PCGSetTol(solver, v) + elseif k === :TwoNorm + @check HYPRE_PCGSetTwoNorm(solver, v) + end + end +end diff --git a/src/solvers.jl b/src/solvers.jl index 188ae26..7e8f78c 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -34,3 +34,39 @@ end Internals.solve_func(::BoomerAMG) = HYPRE_BoomerAMGSolve Internals.setup_func(::BoomerAMG) = HYPRE_BoomerAMGSetup + + +############### +# (ParCSR)PCG # +############### + +mutable struct PCG <: HYPRESolver + solver::HYPRE_Solver + function PCG(comm::MPI.Comm=MPI.COMM_WORLD; kwargs...) + solver = new(C_NULL) + solver_ref = Ref{HYPRE_Solver}(C_NULL) + @check HYPRE_ParCSRPCGCreate(comm, solver_ref) + solver.solver = solver_ref[] + # Attach a finalizer + finalizer(x -> HYPRE_ParCSRPCGDestroy(x.solver), solver) + # Set the options + Internals.set_options(solver, kwargs) + return solver + end +end + +const ParCSRPCG = PCG + +function solve!(pcg::PCG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + @check HYPRE_ParCSRPCGSetup(pcg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + @check HYPRE_ParCSRPCGSolve(pcg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + return x +end + +Internals.solve_func(::PCG) = HYPRE_ParCSRPCGSolve +Internals.setup_func(::PCG) = HYPRE_ParCSRPCGSetup + +function Internals.set_precond(pcg::PCG, p::HYPRESolver) + @check HYPRE_PCGSetPrecond(pcg.solver, Internals.solve_func(p), Internals.setup_func(p), p.solver) + return nothing +end diff --git a/test/runtests.jl b/test/runtests.jl index 1053b6d..bf3cee1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -265,3 +265,28 @@ end # Test result with direct solver @test x ≈ A \ b atol=tol end + +@testset "(ParCSR)PCG" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + ilower, iupper = 1, size(A, 1) + A_h = HYPREMatrix(A, ilower, iupper) + b_h = HYPREVector(b, ilower, iupper) + x_h = HYPREVector(b, ilower, iupper) + # Solve + tol = 1e-9 + pcg = HYPRE.PCG(; Tol = tol) + HYPRE.solve!(pcg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Solve with AMG preconditioner + precond = HYPRE.BoomerAMG() + pcg = HYPRE.PCG(; Tol = tol, Precond = precond) + HYPRE.solve!(pcg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol +end