From 5370aa77c743ae18473121647adfaaa8425c2e85 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Fri, 22 Jul 2022 19:56:44 +0200 Subject: [PATCH] Julia interface to HYPRE solvers This commits contains the solver interface: A = HYPREMatrix(...) b = HYPREVector(...) x = HYPREVector(...) solver = HYPRESolver(; options...) solve!(solver, x, A, b) where the abstract type HYPRESolver is replaced by a concrete solver implementation (this commit includes the concrete implementation/wrapping of BoomerAMG <: HYPRESolver). Solver settings are passed as keyword arguments to the solver constructor, cf. SetXXX functions in HYPRE. For example, to create a BoomerAMG solver, and setting the tolerance: solver = BoomerAMG(Tol = 1e-7) Keyword argument names correspond directly to the solvers SetXXX function in HYPRE; passing Tol corresponds to HYPRE_BoomerAMGSetTol(solver, 1e-7). --- gen/solver_options.jl | 40 +++++++ src/HYPRE.jl | 4 + src/Internals.jl | 3 + src/solver_options.jl | 260 ++++++++++++++++++++++++++++++++++++++++++ src/solvers.jl | 36 ++++++ test/runtests.jl | 18 +++ 6 files changed, 361 insertions(+) create mode 100644 gen/solver_options.jl create mode 100644 src/solver_options.jl create mode 100644 src/solvers.jl diff --git a/gen/solver_options.jl b/gen/solver_options.jl new file mode 100644 index 0000000..8931872 --- /dev/null +++ b/gen/solver_options.jl @@ -0,0 +1,40 @@ +using HYPRE.LibHYPRE + +function generate_options(io, structname, prefix) + println(io, "") + println(io, "function Internals.set_options(s::$(structname), kwargs)") + println(io, " solver = s.solver") + println(io, " for (k, v) in kwargs") + r = Regex("^" * prefix * "([A-Z].*)\$") + ns = sort!(filter!(x -> occursin(r, string(x)), names(LibHYPRE))) + first = true + for n in ns + m = get(methods(getfield(LibHYPRE, n)), 1, nothing) + m === nothing && continue + nargs = m.nargs - 1 + k = String(match(r, string(n))[1]) + print(io, " $(first ? "" : "else")if k === :$(k)") + println(io) + if nargs == 1 + println(io, " @check ", n, "(solver)") + elseif nargs == 2 + println(io, " @check ", n, "(solver, v)") + else # nargs >= 3 + println(io, " @check ", n, "(solver, v...)") + end + first = false + end + println(io, " end") + println(io, " end") + println(io, "end") +end + +open(joinpath(@__DIR__, "..", "src", "solver_options.jl"), "w") do io + println(io, "# SPDX-License-Identifier: MIT") + println(io, "") + println(io, "# This file is automatically generated by gen/solver_options.jl") + println(io, "") + println(io, "Internals.set_options(::HYPRESolver, kwargs) = nothing") + + generate_options(io, "BoomerAMG", "HYPRE_BoomerAMGSet") +end diff --git a/src/HYPRE.jl b/src/HYPRE.jl index bde2a59..cf7d66f 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -437,4 +437,8 @@ function Base.copy!(v::PVector, h::HYPREVector) return v end +# Solver interface +include("solvers.jl") +include("solver_options.jl") + end # module HYPRE diff --git a/src/Internals.jl b/src/Internals.jl index a7f7ec7..20e17a9 100644 --- a/src/Internals.jl +++ b/src/Internals.jl @@ -10,5 +10,8 @@ function init_matrix end function init_vector end function assemble_matrix end function assemble_vector end +function set_options end +function solve_func end +function setup_func end end # module Internals diff --git a/src/solver_options.jl b/src/solver_options.jl new file mode 100644 index 0000000..c1170c2 --- /dev/null +++ b/src/solver_options.jl @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: MIT + +# This file is automatically generated by gen/solver_options.jl + +Internals.set_options(::HYPRESolver, kwargs) = nothing + +function Internals.set_options(s::BoomerAMG, kwargs) + solver = s.solver + for (k, v) in kwargs + if k === :ADropTol + @check HYPRE_BoomerAMGSetADropTol(solver, v) + elseif k === :ADropType + @check HYPRE_BoomerAMGSetADropType(solver, v) + elseif k === :AddLastLvl + @check HYPRE_BoomerAMGSetAddLastLvl(solver, v) + elseif k === :AddRelaxType + @check HYPRE_BoomerAMGSetAddRelaxType(solver, v) + elseif k === :AddRelaxWt + @check HYPRE_BoomerAMGSetAddRelaxWt(solver, v) + elseif k === :Additive + @check HYPRE_BoomerAMGSetAdditive(solver, v) + elseif k === :AggInterpType + @check HYPRE_BoomerAMGSetAggInterpType(solver, v) + elseif k === :AggNumLevels + @check HYPRE_BoomerAMGSetAggNumLevels(solver, v) + elseif k === :AggP12MaxElmts + @check HYPRE_BoomerAMGSetAggP12MaxElmts(solver, v) + elseif k === :AggP12TruncFactor + @check HYPRE_BoomerAMGSetAggP12TruncFactor(solver, v) + elseif k === :AggPMaxElmts + @check HYPRE_BoomerAMGSetAggPMaxElmts(solver, v) + elseif k === :AggTruncFactor + @check HYPRE_BoomerAMGSetAggTruncFactor(solver, v) + elseif k === :CGCIts + @check HYPRE_BoomerAMGSetCGCIts(solver, v) + elseif k === :CPoints + @check HYPRE_BoomerAMGSetCPoints(solver, v...) + elseif k === :CRRate + @check HYPRE_BoomerAMGSetCRRate(solver, v) + elseif k === :CRStrongTh + @check HYPRE_BoomerAMGSetCRStrongTh(solver, v) + elseif k === :CRUseCG + @check HYPRE_BoomerAMGSetCRUseCG(solver, v) + elseif k === :ChebyEigEst + @check HYPRE_BoomerAMGSetChebyEigEst(solver, v) + elseif k === :ChebyFraction + @check HYPRE_BoomerAMGSetChebyFraction(solver, v) + elseif k === :ChebyOrder + @check HYPRE_BoomerAMGSetChebyOrder(solver, v) + elseif k === :ChebyScale + @check HYPRE_BoomerAMGSetChebyScale(solver, v) + elseif k === :ChebyVariant + @check HYPRE_BoomerAMGSetChebyVariant(solver, v) + elseif k === :CoarsenCutFactor + @check HYPRE_BoomerAMGSetCoarsenCutFactor(solver, v) + elseif k === :CoarsenType + @check HYPRE_BoomerAMGSetCoarsenType(solver, v) + elseif k === :ConvergeType + @check HYPRE_BoomerAMGSetConvergeType(solver, v) + elseif k === :CoordDim + @check HYPRE_BoomerAMGSetCoordDim(solver, v) + elseif k === :Coordinates + @check HYPRE_BoomerAMGSetCoordinates(solver, v) + elseif k === :CpointsToKeep + @check HYPRE_BoomerAMGSetCpointsToKeep(solver, v...) + elseif k === :CycleNumSweeps + @check HYPRE_BoomerAMGSetCycleNumSweeps(solver, v...) + elseif k === :CycleRelaxType + @check HYPRE_BoomerAMGSetCycleRelaxType(solver, v...) + elseif k === :CycleType + @check HYPRE_BoomerAMGSetCycleType(solver, v) + elseif k === :DebugFlag + @check HYPRE_BoomerAMGSetDebugFlag(solver, v) + elseif k === :DofFunc + @check HYPRE_BoomerAMGSetDofFunc(solver, v) + elseif k === :DomainType + @check HYPRE_BoomerAMGSetDomainType(solver, v) + elseif k === :DropTol + @check HYPRE_BoomerAMGSetDropTol(solver, v) + elseif k === :EuBJ + @check HYPRE_BoomerAMGSetEuBJ(solver, v) + elseif k === :EuLevel + @check HYPRE_BoomerAMGSetEuLevel(solver, v) + elseif k === :EuSparseA + @check HYPRE_BoomerAMGSetEuSparseA(solver, v) + elseif k === :EuclidFile + @check HYPRE_BoomerAMGSetEuclidFile(solver, v) + elseif k === :FCycle + @check HYPRE_BoomerAMGSetFCycle(solver, v) + elseif k === :FPoints + @check HYPRE_BoomerAMGSetFPoints(solver, v...) + elseif k === :Filter + @check HYPRE_BoomerAMGSetFilter(solver, v) + elseif k === :FilterThresholdR + @check HYPRE_BoomerAMGSetFilterThresholdR(solver, v) + elseif k === :GMRESSwitchR + @check HYPRE_BoomerAMGSetGMRESSwitchR(solver, v) + elseif k === :GSMG + @check HYPRE_BoomerAMGSetGSMG(solver, v) + elseif k === :GridRelaxPoints + @check HYPRE_BoomerAMGSetGridRelaxPoints(solver, v) + elseif k === :GridRelaxType + @check HYPRE_BoomerAMGSetGridRelaxType(solver, v) + elseif k === :ILUDroptol + @check HYPRE_BoomerAMGSetILUDroptol(solver, v) + elseif k === :ILULevel + @check HYPRE_BoomerAMGSetILULevel(solver, v) + elseif k === :ILUMaxIter + @check HYPRE_BoomerAMGSetILUMaxIter(solver, v) + elseif k === :ILUMaxRowNnz + @check HYPRE_BoomerAMGSetILUMaxRowNnz(solver, v) + elseif k === :ILUType + @check HYPRE_BoomerAMGSetILUType(solver, v) + elseif k === :ISType + @check HYPRE_BoomerAMGSetISType(solver, v) + elseif k === :InterpType + @check HYPRE_BoomerAMGSetInterpType(solver, v) + elseif k === :InterpVecAbsQTrunc + @check HYPRE_BoomerAMGSetInterpVecAbsQTrunc(solver, v) + elseif k === :InterpVecQMax + @check HYPRE_BoomerAMGSetInterpVecQMax(solver, v) + elseif k === :InterpVecVariant + @check HYPRE_BoomerAMGSetInterpVecVariant(solver, v) + elseif k === :InterpVectors + @check HYPRE_BoomerAMGSetInterpVectors(solver, v...) + elseif k === :IsTriangular + @check HYPRE_BoomerAMGSetIsTriangular(solver, v) + elseif k === :IsolatedFPoints + @check HYPRE_BoomerAMGSetIsolatedFPoints(solver, v...) + elseif k === :JacobiTruncThreshold + @check HYPRE_BoomerAMGSetJacobiTruncThreshold(solver, v) + elseif k === :KeepSameSign + @check HYPRE_BoomerAMGSetKeepSameSign(solver, v) + elseif k === :KeepTranspose + @check HYPRE_BoomerAMGSetKeepTranspose(solver, v) + elseif k === :Level + @check HYPRE_BoomerAMGSetLevel(solver, v) + elseif k === :LevelNonGalerkinTol + @check HYPRE_BoomerAMGSetLevelNonGalerkinTol(solver, v...) + elseif k === :LevelOuterWt + @check HYPRE_BoomerAMGSetLevelOuterWt(solver, v...) + elseif k === :LevelRelaxWt + @check HYPRE_BoomerAMGSetLevelRelaxWt(solver, v...) + elseif k === :Logging + @check HYPRE_BoomerAMGSetLogging(solver, v) + elseif k === :MaxCoarseSize + @check HYPRE_BoomerAMGSetMaxCoarseSize(solver, v) + elseif k === :MaxIter + @check HYPRE_BoomerAMGSetMaxIter(solver, v) + elseif k === :MaxLevels + @check HYPRE_BoomerAMGSetMaxLevels(solver, v) + elseif k === :MaxNzPerRow + @check HYPRE_BoomerAMGSetMaxNzPerRow(solver, v) + elseif k === :MaxRowSum + @check HYPRE_BoomerAMGSetMaxRowSum(solver, v) + elseif k === :MeasureType + @check HYPRE_BoomerAMGSetMeasureType(solver, v) + elseif k === :MinCoarseSize + @check HYPRE_BoomerAMGSetMinCoarseSize(solver, v) + elseif k === :MinIter + @check HYPRE_BoomerAMGSetMinIter(solver, v) + elseif k === :ModuleRAP2 + @check HYPRE_BoomerAMGSetModuleRAP2(solver, v) + elseif k === :MultAddPMaxElmts + @check HYPRE_BoomerAMGSetMultAddPMaxElmts(solver, v) + elseif k === :MultAddTruncFactor + @check HYPRE_BoomerAMGSetMultAddTruncFactor(solver, v) + elseif k === :MultAdditive + @check HYPRE_BoomerAMGSetMultAdditive(solver, v) + elseif k === :Nodal + @check HYPRE_BoomerAMGSetNodal(solver, v) + elseif k === :NodalDiag + @check HYPRE_BoomerAMGSetNodalDiag(solver, v) + elseif k === :NonGalerkTol + @check HYPRE_BoomerAMGSetNonGalerkTol(solver, v...) + elseif k === :NonGalerkinTol + @check HYPRE_BoomerAMGSetNonGalerkinTol(solver, v) + elseif k === :NumCRRelaxSteps + @check HYPRE_BoomerAMGSetNumCRRelaxSteps(solver, v) + elseif k === :NumFunctions + @check HYPRE_BoomerAMGSetNumFunctions(solver, v) + elseif k === :NumGridSweeps + @check HYPRE_BoomerAMGSetNumGridSweeps(solver, v) + elseif k === :NumPaths + @check HYPRE_BoomerAMGSetNumPaths(solver, v) + elseif k === :NumSamples + @check HYPRE_BoomerAMGSetNumSamples(solver, v) + elseif k === :NumSweeps + @check HYPRE_BoomerAMGSetNumSweeps(solver, v) + elseif k === :OldDefault + @check HYPRE_BoomerAMGSetOldDefault(solver) + elseif k === :Omega + @check HYPRE_BoomerAMGSetOmega(solver, v) + elseif k === :OuterWt + @check HYPRE_BoomerAMGSetOuterWt(solver, v) + elseif k === :Overlap + @check HYPRE_BoomerAMGSetOverlap(solver, v) + elseif k === :PMaxElmts + @check HYPRE_BoomerAMGSetPMaxElmts(solver, v) + elseif k === :PlotFileName + @check HYPRE_BoomerAMGSetPlotFileName(solver, v) + elseif k === :PlotGrids + @check HYPRE_BoomerAMGSetPlotGrids(solver, v) + elseif k === :PostInterpType + @check HYPRE_BoomerAMGSetPostInterpType(solver, v) + elseif k === :PrintFileName + @check HYPRE_BoomerAMGSetPrintFileName(solver, v) + elseif k === :PrintLevel + @check HYPRE_BoomerAMGSetPrintLevel(solver, v) + elseif k === :RAP2 + @check HYPRE_BoomerAMGSetRAP2(solver, v) + elseif k === :Redundant + @check HYPRE_BoomerAMGSetRedundant(solver, v) + elseif k === :RelaxOrder + @check HYPRE_BoomerAMGSetRelaxOrder(solver, v) + elseif k === :RelaxType + @check HYPRE_BoomerAMGSetRelaxType(solver, v) + elseif k === :RelaxWeight + @check HYPRE_BoomerAMGSetRelaxWeight(solver, v) + elseif k === :RelaxWt + @check HYPRE_BoomerAMGSetRelaxWt(solver, v) + elseif k === :Restriction + @check HYPRE_BoomerAMGSetRestriction(solver, v) + elseif k === :SCommPkgSwitch + @check HYPRE_BoomerAMGSetSCommPkgSwitch(solver, v) + elseif k === :Sabs + @check HYPRE_BoomerAMGSetSabs(solver, v) + elseif k === :SchwarzRlxWeight + @check HYPRE_BoomerAMGSetSchwarzRlxWeight(solver, v) + elseif k === :SchwarzUseNonSymm + @check HYPRE_BoomerAMGSetSchwarzUseNonSymm(solver, v) + elseif k === :SepWeight + @check HYPRE_BoomerAMGSetSepWeight(solver, v) + elseif k === :SeqThreshold + @check HYPRE_BoomerAMGSetSeqThreshold(solver, v) + elseif k === :Simple + @check HYPRE_BoomerAMGSetSimple(solver, v) + elseif k === :SmoothNumLevels + @check HYPRE_BoomerAMGSetSmoothNumLevels(solver, v) + elseif k === :SmoothNumSweeps + @check HYPRE_BoomerAMGSetSmoothNumSweeps(solver, v) + elseif k === :SmoothType + @check HYPRE_BoomerAMGSetSmoothType(solver, v) + elseif k === :StrongThreshold + @check HYPRE_BoomerAMGSetStrongThreshold(solver, v) + elseif k === :StrongThresholdR + @check HYPRE_BoomerAMGSetStrongThresholdR(solver, v) + elseif k === :Sym + @check HYPRE_BoomerAMGSetSym(solver, v) + elseif k === :Threshold + @check HYPRE_BoomerAMGSetThreshold(solver, v) + elseif k === :Tol + @check HYPRE_BoomerAMGSetTol(solver, v) + elseif k === :TruncFactor + @check HYPRE_BoomerAMGSetTruncFactor(solver, v) + elseif k === :Variant + @check HYPRE_BoomerAMGSetVariant(solver, v) + end + end +end diff --git a/src/solvers.jl b/src/solvers.jl new file mode 100644 index 0000000..188ae26 --- /dev/null +++ b/src/solvers.jl @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: MIT + +""" + HYPRESolver + +Abstract super type of all the wrapped HYPRE solvers. +""" +abstract type HYPRESolver end + +############# +# BoomerAMG # +############# + +mutable struct BoomerAMG <: HYPRESolver + solver::HYPRE_Solver + function BoomerAMG(; kwargs...) + solver = new(C_NULL) + solver_ref = Ref{HYPRE_Solver}(C_NULL) + @check HYPRE_BoomerAMGCreate(solver_ref) + solver.solver = solver_ref[] + # Attach a finalizer + finalizer(x -> HYPRE_BoomerAMGDestroy(x.solver), solver) + # Set the options + Internals.set_options(solver, kwargs) + return solver + end +end + +function solve!(amg::BoomerAMG, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + @check HYPRE_BoomerAMGSetup(amg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + @check HYPRE_BoomerAMGSolve(amg.solver, A.ParCSRMatrix, b.ParVector, x.ParVector) + return x +end + +Internals.solve_func(::BoomerAMG) = HYPRE_BoomerAMGSolve +Internals.setup_func(::BoomerAMG) = HYPRE_BoomerAMGSetup diff --git a/test/runtests.jl b/test/runtests.jl index b8b0cc8..1053b6d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -247,3 +247,21 @@ end copy!(pbc, H) @test tomain(copy(pbc)) == tomain(copy(pb)) end + +@testset "BoomerAMG" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + ilower, iupper = 1, size(A, 1) + A_h = HYPREMatrix(A, ilower, iupper) + b_h = HYPREVector(b, ilower, iupper) + x_h = HYPREVector(b, ilower, iupper) + # Solve + tol = 1e-9 + amg = HYPRE.BoomerAMG(; Tol = tol) + HYPRE.solve!(amg, x_h, A_h, b_h) + copy!(x, x_h) + # Test result with direct solver + @test x ≈ A \ b atol=tol +end