Browse Source

Wrap (ParCSR)BiCGSTAB solver.

fe/copyto
Fredrik Ekre 3 years ago
parent
commit
3bea1b5e86
  1. 1
      gen/solver_options.jl
  2. 25
      src/solver_options.jl
  3. 43
      src/solvers.jl
  4. 34
      test/runtests.jl

1
gen/solver_options.jl

@ -50,6 +50,7 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io @@ -50,6 +50,7 @@ open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io
println(io, "")
println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing")
generate_options(io, "BiCGSTAB", "HYPRE_ParCSRBiCGSTABSet", "HYPRE_BiCGSTABSet")
generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet")
generate_options(io, "GMRES", "HYPRE_ParCSRGMRESSet", "HYPRE_GMRESSet")
generate_options(io, "PCG", "HYPRE_ParCSRPCGSet", "HYPRE_PCGSet")

25
src/solver_options.jl

@ -4,6 +4,31 @@ @@ -4,6 +4,31 @@
Internals.set_options(::HYPRESolver, kwargs) = nothing
function Internals.set_options(s::BiCGSTAB, kwargs)
solver = s.solver
for (k, v) in kwargs
if k === :ConvergenceFactorTol
@check HYPRE_BiCGSTABSetConvergenceFactorTol(solver, v)
elseif k === :AbsoluteTol
@check HYPRE_ParCSRBiCGSTABSetAbsoluteTol(solver, v)
elseif k === :Logging
@check HYPRE_ParCSRBiCGSTABSetLogging(solver, v)
elseif k === :MaxIter
@check HYPRE_ParCSRBiCGSTABSetMaxIter(solver, v)
elseif k === :MinIter
@check HYPRE_ParCSRBiCGSTABSetMinIter(solver, v)
elseif k === :Precond
Internals.set_precond(s, v)
elseif k === :PrintLevel
@check HYPRE_ParCSRBiCGSTABSetPrintLevel(solver, v)
elseif k === :StopCrit
@check HYPRE_ParCSRBiCGSTABSetStopCrit(solver, v)
elseif k === :Tol
@check HYPRE_ParCSRBiCGSTABSetTol(solver, v)
end
end
end
function Internals.set_options(s::BoomerAMG, kwargs)
solver = s.solver
for (k, v) in kwargs

43
src/solvers.jl

@ -52,6 +52,49 @@ function solve!(solver::HYPRESolver, x::PVector, A::PSparseMatrix, b::PVector) @@ -52,6 +52,49 @@ function solve!(solver::HYPRESolver, x::PVector, A::PSparseMatrix, b::PVector)
end
#####################################
## Concrete solver implementations ##
#####################################
####################
# (ParCSR)BiCGSTAB #
####################
mutable struct BiCGSTAB <: HYPRESolver
solver::HYPRE_Solver
function BiCGSTAB(comm::MPI.Comm=MPI.COMM_WORLD; kwargs...)
solver = new(C_NULL)
solver_ref = Ref{HYPRE_Solver}(C_NULL)
@check HYPRE_ParCSRBiCGSTABCreate(comm, solver_ref)
solver.solver = solver_ref[]
# Attach a finalizer
finalizer(x -> HYPRE_ParCSRBiCGSTABDestroy(x.solver), solver)
# Set the options
Internals.set_options(solver, kwargs)
return solver
end
end
const ParCSRBiCGSTAB = BiCGSTAB
function solve!(bicg::BiCGSTAB, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_ParCSRBiCGSTABSetup(bicg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector)
@check HYPRE_ParCSRBiCGSTABSolve(bicg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector)
return x
end
Internals.solve_func(::BiCGSTAB) = HYPRE_ParCSRBiCGSTABSolve
Internals.setup_func(::BiCGSTAB) = HYPRE_ParCSRBiCGSTABSetup
function Internals.set_precond(bicg::BiCGSTAB, p::HYPRESolver)
solve_f = Internals.solve_func(p)
setup_f = Internals.setup_func(p)
@check HYPRE_ParCSRBiCGSTABSetPrecond(bicg.solver, solve_f, setup_f, p.solver)
return nothing
end
#############
# BoomerAMG #
#############

34
test/runtests.jl

@ -275,6 +275,40 @@ end @@ -275,6 +275,40 @@ end
@test tomain(pbc) == tomain(pb)
end
@testset "BiCGSTAB" begin
# Setup
A = sprand(100, 100, 0.05); A = A'A + 5I
b = rand(100)
x = zeros(100)
A_h = HYPREMatrix(A)
b_h = HYPREVector(b)
x_h = HYPREVector(x)
# Solve
tol = 1e-9
bicg = HYPRE.BiCGSTAB(; Tol = tol)
HYPRE.solve!(bicg, x_h, A_h, b_h)
copy!(x, x_h)
# Test result with direct solver
@test x A \ b atol=tol
# Test without passing initial guess
x_h = HYPRE.solve(bicg, A_h, b_h)
copy!(x, x_h)
@test x A \ b atol=tol
# Solve with preconditioner
precond = HYPRE.BoomerAMG(; MaxIter = 1, Tol = 0.0)
bicg = HYPRE.BiCGSTAB(; Tol = tol, Precond = precond)
x_h = HYPREVector(zeros(100))
HYPRE.solve!(bicg, x_h, A_h, b_h)
copy!(x, x_h)
# Test result with direct solver
@test x A \ b atol=tol
# Test without passing initial guess
x_h = HYPRE.solve(bicg, A_h, b_h)
copy!(x, x_h)
@test x A \ b atol=tol
end
@testset "BoomerAMG" begin
# Setup
A = sprand(100, 100, 0.05); A = A'A + 5I

Loading…
Cancel
Save