diff --git a/gen/solver_options.jl b/gen/solver_options.jl index 05c2736..11c8194 100644 --- a/gen/solver_options.jl +++ b/gen/solver_options.jl @@ -50,6 +50,7 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io println(io, "") println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing") + generate_options(io, "BiCGSTAB", "HYPRE_ParCSRBiCGSTABSet", "HYPRE_BiCGSTABSet") generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet") generate_options(io, "GMRES", "HYPRE_ParCSRGMRESSet", "HYPRE_GMRESSet") generate_options(io, "PCG", "HYPRE_ParCSRPCGSet", "HYPRE_PCGSet") diff --git a/src/solver_options.jl b/src/solver_options.jl index 6c25840..0874ce1 100644 --- a/src/solver_options.jl +++ b/src/solver_options.jl @@ -4,6 +4,31 @@ Internals.set_options(::HYPRESolver, kwargs) = nothing +function Internals.set_options(s::BiCGSTAB, kwargs) + solver = s.solver + for (k, v) in kwargs + if k === :ConvergenceFactorTol + @check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v) + elseif k === :AbsoluteTol + @check HYPRE_ParCSRBiCGSTABSetAbsoluteTol(solver, v) + elseif k === :Logging + @check HYPRE_ParCSRBiCGSTABSetLogging(solver, v) + elseif k === :MaxIter + @check HYPRE_ParCSRBiCGSTABSetMaxIter(solver, v) + elseif k === :MinIter + @check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v) + elseif k === :Precond + Internals.set_precond(s, v) + elseif k === :PrintLevel + @check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v) + elseif k === :StopCrit + @check HYPRE_ParCSRBiCGSTABSetStopCrit(solver, v) + elseif k === :Tol + @check HYPRE_ParCSRBiCGSTABSetTol(solver, v) + end + end +end + function Internals.set_options(s::BoomerAMG, kwargs) solver = s.solver for (k, v) in kwargs diff --git a/src/solvers.jl b/src/solvers.jl index 29549ae..e134981 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -52,6 +52,49 @@ function solve!(solver::HYPRESolver, x::PVector, A::PSparseMatrix, b::PVector) end +##################################### +## Concrete solver implementations ## +##################################### + + +#################### +# (ParCSR)BiCGSTAB # +#################### + +mutable struct BiCGSTAB <: HYPRESolver + solver::HYPRE_Solver + function BiCGSTAB(comm::MPI.Comm=MPI.COMM_WORLD; kwargs...) + solver = new(C_NULL) + solver_ref = Ref{HYPRE_Solver}(C_NULL) + @check HYPRE_ParCSRBiCGSTABCreate(comm, solver_ref) + solver.solver = solver_ref[] + # Attach a finalizer + finalizer(x -> HYPRE_ParCSRBiCGSTABDestroy(x.solver), solver) + # Set the options + Internals.set_options(solver, kwargs) + return solver + end +end + +const ParCSRBiCGSTAB = BiCGSTAB + +function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + @check HYPRE_ParCSRBiCGSTABSetup(bicg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + @check HYPRE_ParCSRBiCGSTABSolve(bicg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + return x +end + +Internals.solve_func(::BiCGSTAB) = HYPRE_ParCSRBiCGSTABSolve +Internals.setup_func(::BiCGSTAB) = HYPRE_ParCSRBiCGSTABSetup + +function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver) + solve_f = Internals.solve_func(p) + setup_f = Internals.setup_func(p) + @check HYPRE_ParCSRBiCGSTABSetPrecond(bicg.solver, solve_f, setup_f, p.solver) + return nothing +end + + ############# # BoomerAMG # ############# diff --git a/test/runtests.jl b/test/runtests.jl index 765e75d..199033c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -275,6 +275,40 @@ end @test tomain(pbc) == tomain(pb) end +@testset "BiCGSTAB" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + A_h = HYPREMatrix(A) + b_h = HYPREVector(b) + x_h = HYPREVector(x) + # Solve + tol = 1e-9 + bicg = HYPRE.BiCGSTAB(; Tol = tol) + HYPRE.solve!(bicg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(bicg, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol + + # Solve with preconditioner + precond = HYPRE.BoomerAMG(; MaxIter = 1, Tol = 0.0) + bicg = HYPRE.BiCGSTAB(; Tol = tol, Precond = precond) + x_h = HYPREVector(zeros(100)) + HYPRE.solve!(bicg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(bicg, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol +end + @testset "BoomerAMG" begin # Setup A = sprand(100, 100, 0.05); A = A'A + 5I