Browse Source

Julia interface to HYPRE solvers

This commits contains the solver interface:

    A = HYPREMatrix(...)
    b = HYPREVector(...)
    x = HYPREVector(...)

    solver = HYPRESolver(; options...)

    solve!(solver, x, A, b)

where the abstract type HYPRESolver is replaced by a concrete solver
implementation (this commit includes the concrete
implementation/wrapping of BoomerAMG <: HYPRESolver).

Solver settings are passed as keyword arguments to the solver
constructor, cf. SetXXX functions in HYPRE.

For example, to create a BoomerAMG solver, and setting the tolerance:

    solver = BoomerAMG(Tol = 1e-7)

Keyword argument names correspond directly to the solvers SetXXX
function in HYPRE; passing Tol corresponds to
HYPRE_BoomerAMGSetTol(solver, 1e-7).
fe/wip
Fredrik Ekre 3 years ago
parent
commit
5370aa77c7
  1. 40
      gen/solver_options.jl
  2. 4
      src/HYPRE.jl
  3. 3
      src/Internals.jl
  4. 260
      src/solver_options.jl
  5. 36
      src/solvers.jl
  6. 18
      test/runtests.jl

40
gen/solver_options.jl

@ -0,0 +1,40 @@
using HYPRE.LibHYPRE
function generate_options(io, structname, prefix)
println(io, "")
println(io, "function Internals.set_options(s::$(structname), kwargs)")
println(io, " solver = s.solver")
println(io, " for (k, v) in kwargs")
r = Regex("^" * prefix * "([A-Z].*)\$")
ns = sort!(filter!(x -> occursin(r, string(x)), names(LibHYPRE)))
first = true
for n in ns
m = get(methods(getfield(LibHYPRE, n)), 1, nothing)
m === nothing && continue
nargs = m.nargs - 1
k = String(match(r, string(n))[1])
print(io, " $(first ? "" : "else")if k === :$(k)")
println(io)
if nargs == 1
println(io, " @check ", n, "(solver)")
elseif nargs == 2
println(io, " @check ", n, "(solver, v)")
else # nargs >= 3
println(io, " @check ", n, "(solver, v...)")
end
first = false
end
println(io, " end")
println(io, " end")
println(io, "end")
end
open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io
println(io, "# SPDX-License-Identifier: MIT")
println(io, "")
println(io, "# This file is automatically generated by gen/solver_options.jl")
println(io, "")
println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing")
generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet")
end

4
src/HYPRE.jl

@ -437,4 +437,8 @@ function Base.copy!(v::PVector, h::HYPREVector)
return v return v
end end
# Solver interface
include("solvers.jl")
include("solver_options.jl")
end # module HYPRE end # module HYPRE

3
src/Internals.jl

@ -10,5 +10,8 @@ function init_matrix end
function init_vector end function init_vector end
function assemble_matrix end function assemble_matrix end
function assemble_vector end function assemble_vector end
function set_options end
function solve_func end
function setup_func end
end # module Internals end # module Internals

260
src/solver_options.jl

@ -0,0 +1,260 @@
# SPDX-License-Identifier: MIT
# This file is automatically generated by gen/solver_options.jl
Internals.set_options(::HYPRESolver, kwargs) = nothing
function Internals.set_options(s::BoomerAMG, kwargs)
solver = s.solver
for (k, v) in kwargs
if k === :ADropTol
@check HYPRE_BoomerAMGSetADropTol(solver, v)
elseif k === :ADropType
@check HYPRE_BoomerAMGSetADropType(solver, v)
elseif k === :AddLastLvl
@check HYPRE_BoomerAMGSetAddLastLvl(solver, v)
elseif k === :AddRelaxType
@check HYPRE_BoomerAMGSetAddRelaxType(solver, v)
elseif k === :AddRelaxWt
@check HYPRE_BoomerAMGSetAddRelaxWt(solver, v)
elseif k === :Additive
@check HYPRE_BoomerAMGSetAdditive(solver, v)
elseif k === :AggInterpType
@check HYPRE_BoomerAMGSetAggInterpType(solver, v)
elseif k === :AggNumLevels
@check HYPRE_BoomerAMGSetAggNumLevels(solver, v)
elseif k === :AggP12MaxElmts
@check HYPRE_BoomerAMGSetAggP12MaxElmts(solver, v)
elseif k === :AggP12TruncFactor
@check HYPRE_BoomerAMGSetAggP12TruncFactor(solver, v)
elseif k === :AggPMaxElmts
@check HYPRE_BoomerAMGSetAggPMaxElmts(solver, v)
elseif k === :AggTruncFactor
@check HYPRE_BoomerAMGSetAggTruncFactor(solver, v)
elseif k === :CGCIts
@check HYPRE_BoomerAMGSetCGCIts(solver, v)
elseif k === :CPoints
@check HYPRE_BoomerAMGSetCPoints(solver, v...)
elseif k === :CRRate
@check HYPRE_BoomerAMGSetCRRate(solver, v)
elseif k === :CRStrongTh
@check HYPRE_BoomerAMGSetCRStrongTh(solver, v)
elseif k === :CRUseCG
@check HYPRE_BoomerAMGSetCRUseCG(solver, v)
elseif k === :ChebyEigEst
@check HYPRE_BoomerAMGSetChebyEigEst(solver, v)
elseif k === :ChebyFraction
@check HYPRE_BoomerAMGSetChebyFraction(solver, v)
elseif k === :ChebyOrder
@check HYPRE_BoomerAMGSetChebyOrder(solver, v)
elseif k === :ChebyScale
@check HYPRE_BoomerAMGSetChebyScale(solver, v)
elseif k === :ChebyVariant
@check HYPRE_BoomerAMGSetChebyVariant(solver, v)
elseif k === :CoarsenCutFactor
@check HYPRE_BoomerAMGSetCoarsenCutFactor(solver, v)
elseif k === :CoarsenType
@check HYPRE_BoomerAMGSetCoarsenType(solver, v)
elseif k === :ConvergeType
@check HYPRE_BoomerAMGSetConvergeType(solver, v)
elseif k === :CoordDim
@check HYPRE_BoomerAMGSetCoordDim(solver, v)
elseif k === :Coordinates
@check HYPRE_BoomerAMGSetCoordinates(solver, v)
elseif k === :CpointsToKeep
@check HYPRE_BoomerAMGSetCpointsToKeep(solver, v...)
elseif k === :CycleNumSweeps
@check HYPRE_BoomerAMGSetCycleNumSweeps(solver, v...)
elseif k === :CycleRelaxType
@check HYPRE_BoomerAMGSetCycleRelaxType(solver, v...)
elseif k === :CycleType
@check HYPRE_BoomerAMGSetCycleType(solver, v)
elseif k === :DebugFlag
@check HYPRE_BoomerAMGSetDebugFlag(solver, v)
elseif k === :DofFunc
@check HYPRE_BoomerAMGSetDofFunc(solver, v)
elseif k === :DomainType
@check HYPRE_BoomerAMGSetDomainType(solver, v)
elseif k === :DropTol
@check HYPRE_BoomerAMGSetDropTol(solver, v)
elseif k === :EuBJ
@check HYPRE_BoomerAMGSetEuBJ(solver, v)
elseif k === :EuLevel
@check HYPRE_BoomerAMGSetEuLevel(solver, v)
elseif k === :EuSparseA
@check HYPRE_BoomerAMGSetEuSparseA(solver, v)
elseif k === :EuclidFile
@check HYPRE_BoomerAMGSetEuclidFile(solver, v)
elseif k === :FCycle
@check HYPRE_BoomerAMGSetFCycle(solver, v)
elseif k === :FPoints
@check HYPRE_BoomerAMGSetFPoints(solver, v...)
elseif k === :Filter
@check HYPRE_BoomerAMGSetFilter(solver, v)
elseif k === :FilterThresholdR
@check HYPRE_BoomerAMGSetFilterThresholdR(solver, v)
elseif k === :GMRESSwitchR
@check HYPRE_BoomerAMGSetGMRESSwitchR(solver, v)
elseif k === :GSMG
@check HYPRE_BoomerAMGSetGSMG(solver, v)
elseif k === :GridRelaxPoints
@check HYPRE_BoomerAMGSetGridRelaxPoints(solver, v)
elseif k === :GridRelaxType
@check HYPRE_BoomerAMGSetGridRelaxType(solver, v)
elseif k === :ILUDroptol
@check HYPRE_BoomerAMGSetILUDroptol(solver, v)
elseif k === :ILULevel
@check HYPRE_BoomerAMGSetILULevel(solver, v)
elseif k === :ILUMaxIter
@check HYPRE_BoomerAMGSetILUMaxIter(solver, v)
elseif k === :ILUMaxRowNnz
@check HYPRE_BoomerAMGSetILUMaxRowNnz(solver, v)
elseif k === :ILUType
@check HYPRE_BoomerAMGSetILUType(solver, v)
elseif k === :ISType
@check HYPRE_BoomerAMGSetISType(solver, v)
elseif k === :InterpType
@check HYPRE_BoomerAMGSetInterpType(solver, v)
elseif k === :InterpVecAbsQTrunc
@check HYPRE_BoomerAMGSetInterpVecAbsQTrunc(solver, v)
elseif k === :InterpVecQMax
@check HYPRE_BoomerAMGSetInterpVecQMax(solver, v)
elseif k === :InterpVecVariant
@check HYPRE_BoomerAMGSetInterpVecVariant(solver, v)
elseif k === :InterpVectors
@check HYPRE_BoomerAMGSetInterpVectors(solver, v...)
elseif k === :IsTriangular
@check HYPRE_BoomerAMGSetIsTriangular(solver, v)
elseif k === :IsolatedFPoints
@check HYPRE_BoomerAMGSetIsolatedFPoints(solver, v...)
elseif k === :JacobiTruncThreshold
@check HYPRE_BoomerAMGSetJacobiTruncThreshold(solver, v)
elseif k === :KeepSameSign
@check HYPRE_BoomerAMGSetKeepSameSign(solver, v)
elseif k === :KeepTranspose
@check HYPRE_BoomerAMGSetKeepTranspose(solver, v)
elseif k === :Level
@check HYPRE_BoomerAMGSetLevel(solver, v)
elseif k === :LevelNonGalerkinTol
@check HYPRE_BoomerAMGSetLevelNonGalerkinTol(solver, v...)
elseif k === :LevelOuterWt
@check HYPRE_BoomerAMGSetLevelOuterWt(solver, v...)
elseif k === :LevelRelaxWt
@check HYPRE_BoomerAMGSetLevelRelaxWt(solver, v...)
elseif k === :Logging
@check HYPRE_BoomerAMGSetLogging(solver, v)
elseif k === :MaxCoarseSize
@check HYPRE_BoomerAMGSetMaxCoarseSize(solver, v)
elseif k === :MaxIter
@check HYPRE_BoomerAMGSetMaxIter(solver, v)
elseif k === :MaxLevels
@check HYPRE_BoomerAMGSetMaxLevels(solver, v)
elseif k === :MaxNzPerRow
@check HYPRE_BoomerAMGSetMaxNzPerRow(solver, v)
elseif k === :MaxRowSum
@check HYPRE_BoomerAMGSetMaxRowSum(solver, v)
elseif k === :MeasureType
@check HYPRE_BoomerAMGSetMeasureType(solver, v)
elseif k === :MinCoarseSize
@check HYPRE_BoomerAMGSetMinCoarseSize(solver, v)
elseif k === :MinIter
@check HYPRE_BoomerAMGSetMinIter(solver, v)
elseif k === :ModuleRAP2
@check HYPRE_BoomerAMGSetModuleRAP2(solver, v)
elseif k === :MultAddPMaxElmts
@check HYPRE_BoomerAMGSetMultAddPMaxElmts(solver, v)
elseif k === :MultAddTruncFactor
@check HYPRE_BoomerAMGSetMultAddTruncFactor(solver, v)
elseif k === :MultAdditive
@check HYPRE_BoomerAMGSetMultAdditive(solver, v)
elseif k === :Nodal
@check HYPRE_BoomerAMGSetNodal(solver, v)
elseif k === :NodalDiag
@check HYPRE_BoomerAMGSetNodalDiag(solver, v)
elseif k === :NonGalerkTol
@check HYPRE_BoomerAMGSetNonGalerkTol(solver, v...)
elseif k === :NonGalerkinTol
@check HYPRE_BoomerAMGSetNonGalerkinTol(solver, v)
elseif k === :NumCRRelaxSteps
@check HYPRE_BoomerAMGSetNumCRRelaxSteps(solver, v)
elseif k === :NumFunctions
@check HYPRE_BoomerAMGSetNumFunctions(solver, v)
elseif k === :NumGridSweeps
@check HYPRE_BoomerAMGSetNumGridSweeps(solver, v)
elseif k === :NumPaths
@check HYPRE_BoomerAMGSetNumPaths(solver, v)
elseif k === :NumSamples
@check HYPRE_BoomerAMGSetNumSamples(solver, v)
elseif k === :NumSweeps
@check HYPRE_BoomerAMGSetNumSweeps(solver, v)
elseif k === :OldDefault
@check HYPRE_BoomerAMGSetOldDefault(solver)
elseif k === :Omega
@check HYPRE_BoomerAMGSetOmega(solver, v)
elseif k === :OuterWt
@check HYPRE_BoomerAMGSetOuterWt(solver, v)
elseif k === :Overlap
@check HYPRE_BoomerAMGSetOverlap(solver, v)
elseif k === :PMaxElmts
@check HYPRE_BoomerAMGSetPMaxElmts(solver, v)
elseif k === :PlotFileName
@check HYPRE_BoomerAMGSetPlotFileName(solver, v)
elseif k === :PlotGrids
@check HYPRE_BoomerAMGSetPlotGrids(solver, v)
elseif k === :PostInterpType
@check HYPRE_BoomerAMGSetPostInterpType(solver, v)
elseif k === :PrintFileName
@check HYPRE_BoomerAMGSetPrintFileName(solver, v)
elseif k === :PrintLevel
@check HYPRE_BoomerAMGSetPrintLevel(solver, v)
elseif k === :RAP2
@check HYPRE_BoomerAMGSetRAP2(solver, v)
elseif k === :Redundant
@check HYPRE_BoomerAMGSetRedundant(solver, v)
elseif k === :RelaxOrder
@check HYPRE_BoomerAMGSetRelaxOrder(solver, v)
elseif k === :RelaxType
@check HYPRE_BoomerAMGSetRelaxType(solver, v)
elseif k === :RelaxWeight
@check HYPRE_BoomerAMGSetRelaxWeight(solver, v)
elseif k === :RelaxWt
@check HYPRE_BoomerAMGSetRelaxWt(solver, v)
elseif k === :Restriction
@check HYPRE_BoomerAMGSetRestriction(solver, v)
elseif k === :SCommPkgSwitch
@check HYPRE_BoomerAMGSetSCommPkgSwitch(solver, v)
elseif k === :Sabs
@check HYPRE_BoomerAMGSetSabs(solver, v)
elseif k === :SchwarzRlxWeight
@check HYPRE_BoomerAMGSetSchwarzRlxWeight(solver, v)
elseif k === :SchwarzUseNonSymm
@check HYPRE_BoomerAMGSetSchwarzUseNonSymm(solver, v)
elseif k === :SepWeight
@check HYPRE_BoomerAMGSetSepWeight(solver, v)
elseif k === :SeqThreshold
@check HYPRE_BoomerAMGSetSeqThreshold(solver, v)
elseif k === :Simple
@check HYPRE_BoomerAMGSetSimple(solver, v)
elseif k === :SmoothNumLevels
@check HYPRE_BoomerAMGSetSmoothNumLevels(solver, v)
elseif k === :SmoothNumSweeps
@check HYPRE_BoomerAMGSetSmoothNumSweeps(solver, v)
elseif k === :SmoothType
@check HYPRE_BoomerAMGSetSmoothType(solver, v)
elseif k === :StrongThreshold
@check HYPRE_BoomerAMGSetStrongThreshold(solver, v)
elseif k === :StrongThresholdR
@check HYPRE_BoomerAMGSetStrongThresholdR(solver, v)
elseif k === :Sym
@check HYPRE_BoomerAMGSetSym(solver, v)
elseif k === :Threshold
@check HYPRE_BoomerAMGSetThreshold(solver, v)
elseif k === :Tol
@check HYPRE_BoomerAMGSetTol(solver, v)
elseif k === :TruncFactor
@check HYPRE_BoomerAMGSetTruncFactor(solver, v)
elseif k === :Variant
@check HYPRE_BoomerAMGSetVariant(solver, v)
end
end
end

36
src/solvers.jl

@ -0,0 +1,36 @@
# SPDX-License-Identifier: MIT
"""
HYPRESolver
Abstract super type of all the wrapped HYPRE solvers.
"""
abstract type HYPRESolver end
#############
# BoomerAMG #
#############
mutable struct BoomerAMG <: HYPRESolver
solver::HYPRE_Solver
function BoomerAMG(; kwargs...)
solver = new(C_NULL)
solver_ref = Ref{HYPRE_Solver}(C_NULL)
@check HYPRE_BoomerAMGCreate(solver_ref)
solver.solver = solver_ref[]
# Attach a finalizer
finalizer(x -> HYPRE_BoomerAMGDestroy(x.solver), solver)
# Set the options
Internals.set_options(solver, kwargs)
return solver
end
end
function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector)
@check HYPRE_BoomerAMGSetup(amg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector)
@check HYPRE_BoomerAMGSolve(amg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector)
return x
end
Internals.solve_func(::BoomerAMG) = HYPRE_BoomerAMGSolve
Internals.setup_func(::BoomerAMG) = HYPRE_BoomerAMGSetup

18
test/runtests.jl

@ -247,3 +247,21 @@ end
copy!(pbc, H) copy!(pbc, H)
@test tomain(copy(pbc)) == tomain(copy(pb)) @test tomain(copy(pbc)) == tomain(copy(pb))
end end
@testset "BoomerAMG" begin
# Setup
A = sprand(100, 100, 0.05); A = A'A + 5I
b = rand(100)
x = zeros(100)
ilower, iupper = 1, size(A, 1)
A_h = HYPREMatrix(A, ilower, iupper)
b_h = HYPREVector(b, ilower, iupper)
x_h = HYPREVector(b, ilower, iupper)
# Solve
tol = 1e-9
amg = HYPRE.BoomerAMG(; Tol = tol)
HYPRE.solve!(amg, x_h, A_h, b_h)
copy!(x, x_h)
# Test result with direct solver
@test x A \ b atol=tol
end

Loading…
Cancel
Save