diff --git a/Project.toml b/Project.toml index 4e1c7eb..81c1eee 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "1.1.0" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" HYPRE_jll = "0a602bbd-b08b-5d75-8d32-0de6eef44785" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" PartitionedArrays = "5a9dfac6-5c52-46f7-8278-5e2210713be9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -13,6 +14,7 @@ SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" [compat] CEnum = "0.4" +LinearAlgebra = "1" MPI = "0.19, 0.20" PartitionedArrays = "0.2" SparseMatricesCSR = "0.6" diff --git a/src/HYPRE.jl b/src/HYPRE.jl index 7c973f5..0d0b792 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -5,7 +5,7 @@ module HYPRE using MPI: MPI using PartitionedArrays: IndexRange, MPIData, PSparseMatrix, PVector, PartitionedArrays, SequentialData, map_parts -using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, nonzeros, nzrange, rowvals +using SparseArrays: SparseArrays, AbstractSparseMatrixCSC, SparseMatrixCSC, nnz, nonzeros, nzrange, rowvals using SparseMatricesCSR: SparseMatrixCSR, colvals, getrowptr export HYPREMatrix, HYPREVector @@ -528,4 +528,7 @@ end include("solvers.jl") include("solver_options.jl") +# LinearSolve preconditioner interface +include("precs.jl") + end # module HYPRE diff --git a/src/precs.jl b/src/precs.jl new file mode 100644 index 0000000..f105c88 --- /dev/null +++ b/src/precs.jl @@ -0,0 +1,35 @@ +import LinearAlgebra + +struct BoomerAMGPrecWrapper{MatType} + P::HYPRE.BoomerAMG + A::MatType +end + +function LinearAlgebra.ldiv!(y::AbstractVector, prec::BoomerAMGPrecWrapper, x::AbstractVector) + fill!(y, eltype(y)(0.0)) + HYPRE.solve!(prec.P, y, prec.A, x) +end + +""" + BoomerAMGPrecBuilder(settings_fun; kwargs...) +""" +struct BoomerAMGPrecBuilder{SFun, Tk} + settings_fun!::SFun + kwargs::Tk +end + +# Syntactic sugar wth some defaults +function BoomerAMGPrecBuilder(settings_fun! = (amg, A, p) -> nothing; MaxIter = 1, Tol = 0.0, kwargs...) + return construct_boomeramg_prec_builder(settings_fun!; MaxIter, Tol, kwargs) +end + +# Helper to package kwargs +function construct_boomeramg_prec_builder(settings_fun!; kwargs...) + return BoomerAMGPrecBuilder(settings_fun!, kwargs) +end + +function (b::BoomerAMGPrecBuilder)(A::AbstractSparseMatrixCSC, p) + amg = HYPRE.BoomerAMG(; b.kwargs) + settings_fun!(amg, A, p) + return (BoomerAMGPrecWrapper(amg, A), I) +end diff --git a/test/runtests.jl b/test/runtests.jl index b6baa88..5d7ad5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using PartitionedArrays using SparseArrays using SparseMatricesCSR using Test +using LinearSolve # Init HYPRE and MPI HYPRE.Init() @@ -18,6 +19,28 @@ HYPRE.Init() @test LibHYPRE.VERSION.major == 2 end +@testset "use as LinearSolve.jl preconditioner" begin + # Setup + A = sprand(100, 100, 0.05); A = A'A + 5I + b = rand(100) + x = zeros(100) + # Solve + tol = 1e-9 + function set_debug_printlevel(amg, A, p) + HYPRE.HYPRE_BoomerAMGSetPrintLevel(amg, 3) + end + bamg = HYPRE.BoomerAMGPrecBuilder( + (amg, A, p) -> nothing; + MaxIter = 1, + Tol = tol, + ) + prob = LinearProblem(A, b) + solver = KrylovJL_CG(precs = bamg) + x = solve(prob, solver, atol=1.0e-14) + @test x ≈ A \ b atol=tol +end + + @testset "HYPREMatrix" begin H = HYPREMatrix(MPI.COMM_WORLD, 1, 5) @test H.ijmatrix != HYPRE_IJMatrix(C_NULL) @@ -654,3 +677,4 @@ end x = HYPRE.solve(pcg, A, b) @test x ≈ A \ b atol=tol end +