From 0fb1326e45592645e1ac88b747fe3cad728b2bec Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Sun, 24 Jul 2022 14:51:46 +0200 Subject: [PATCH] Wrap (ParCSR)GMRES solver. --- gen/solver_options.jl | 1 + src/solver_options.jl | 31 +++++++++++++++++++++++++++++++ src/solvers.jl | 36 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+) diff --git a/gen/solver_options.jl b/gen/solver_options.jl index 8d1768b..05c2736 100644 --- a/gen/solver_options.jl +++ b/gen/solver_options.jl @@ -51,5 +51,6 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing") generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet") + generate_options(io, "GMRES", "HYPRE_ParCSRGMRESSet", "HYPRE_GMRESSet") generate_options(io, "PCG", "HYPRE_ParCSRPCGSet", "HYPRE_PCGSet") end diff --git a/src/solver_options.jl b/src/solver_options.jl index af97578..6c25840 100644 --- a/src/solver_options.jl +++ b/src/solver_options.jl @@ -259,6 +259,37 @@ function Internals.set_options(s::BoomerAMG, kwargs) end end +function Internals.set_options(s::GMRES, kwargs) + solver = s.solver + for (k, v) in kwargs + if k === :ConvergenceFactorTol + @check HYPRE_GMRESSetConvergenceFactorTol(solver, v) + elseif k === :RelChange + @check HYPRE_GMRESSetRelChange(solver, v) + elseif k === :SkipRealResidualCheck + @check HYPRE_GMRESSetSkipRealResidualCheck(solver, v) + elseif k === :AbsoluteTol + @check HYPRE_ParCSRGMRESSetAbsoluteTol(solver, v) + elseif k === :KDim + @check HYPRE_ParCSRGMRESSetKDim(solver, v) + elseif k === :Logging + @check HYPRE_ParCSRGMRESSetLogging(solver, v) + elseif k === :MaxIter + @check HYPRE_ParCSRGMRESSetMaxIter(solver, v) + elseif k === :MinIter + @check HYPRE_ParCSRGMRESSetMinIter(solver, v) + elseif k === :Precond + Internals.set_precond(s, v) + elseif k === :PrintLevel + @check HYPRE_ParCSRGMRESSetPrintLevel(solver, v) + elseif k === :StopCrit + @check HYPRE_ParCSRGMRESSetStopCrit(solver, v) + elseif k === :Tol + @check HYPRE_ParCSRGMRESSetTol(solver, v) + end + end +end + function Internals.set_options(s::PCG, kwargs) solver = s.solver for (k, v) in kwargs diff --git a/src/solvers.jl b/src/solvers.jl index beac282..29549ae 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -81,6 +81,42 @@ Internals.solve_func(::BoomerAMG) = HYPRE_BoomerAMGSolve Internals.setup_func(::BoomerAMG) = HYPRE_BoomerAMGSetup +######### +# GMRES # +######### + +mutable struct GMRES <: HYPRESolver + solver::HYPRE_Solver + function GMRES(comm::MPI.Comm=MPI.COMM_WORLD; kwargs...) + solver = new(C_NULL) + solver_ref = Ref{HYPRE_Solver}(C_NULL) + @check HYPRE_ParCSRGMRESCreate(comm, solver_ref) + solver.solver = solver_ref[] + # Attach a finalizer + finalizer(x -> HYPRE_ParCSRGMRESDestroy(x.solver), solver) + # Set the options + Internals.set_options(solver, kwargs) + return solver + end +end + +function solve!(gmres::GMRES, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + @check HYPRE_ParCSRGMRESSetup(gmres.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + @check HYPRE_ParCSRGMRESSolve(gmres.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + return x +end + +Internals.solve_func(::GMRES) = HYPRE_ParCSRGMRESSetup +Internals.setup_func(::GMRES) = HYPRE_ParCSRGMRESSolve + +function Internals.set_precond(gmres::GMRES, p::HYPRESolver) + solve_f = Internals.solve_func(p) + setup_f = Internals.setup_func(p) + @check HYPRE_ParCSRGMRESSetPrecond(gmres.solver, solve_f, setup_f, p.solver) + return nothing +end + + ############### # (ParCSR)PCG # ############### diff --git a/test/runtests.jl b/test/runtests.jl index 8ab94a0..765e75d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -297,6 +297,40 @@ end @test x ≈ A \ b atol=tol end +@testset "GMRES" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + A_h = HYPREMatrix(A) + b_h = HYPREVector(b) + x_h = HYPREVector(x) + # Solve + tol = 1e-9 + gmres = HYPRE.GMRES(; Tol = tol) + HYPRE.solve!(gmres, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(gmres, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol + + # Solve with preconditioner + precond = HYPRE.BoomerAMG(; MaxIter = 1, Tol = 0.0) + gmres = HYPRE.GMRES(; Tol = tol, Precond = precond) + x_h = HYPREVector(zeros(100)) + HYPRE.solve!(gmres, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(gmres, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol +end + @testset "(ParCSR)PCG" begin # Setup A = sprand(100, 100, 0.05); A = A'A + 5I