diff --git a/src/solvers.jl b/src/solvers.jl index 7e8f78c..9d1ca66 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -7,6 +7,30 @@ Abstract super type of all the wrapped HYPRE solvers. """ abstract type HYPRESolver end +# Generic fallback allocating a zero vector as initial guess +# TODO: This should allocate x using the owned cols instead of rows of A/b, but currently +# it is assumed these are always equivalent. +""" + solve(solver::HYPRESolver, A::HYPREMatrix, b::HYPREVector) + +Solve the linear system `A x = b` using `solver` and return the solution vector. + +This method allocates the initial guess/output vector `x`, initialized to 0. + +See also [`solve!`](@ref). +""" +solve(solver::HYPRESolver, A::HYPREMatrix, b::HYPREVector) = solve!(solver, zero(b), A, b) + +""" + solve!(solver::HYPRESolver, x::HYPREVector, A::HYPREMatrix, b::HYPREVector) + +Solve the linear system `A x = b` using `solver` with `x` as the initial guess. + +See also [`solve`](@ref). +""" +solve!(pcg::HYPRESolver, x::HYPREVector, A::HYPREMatrix, ::HYPREVector) + + ############# # BoomerAMG # ############# diff --git a/test/runtests.jl b/test/runtests.jl index bf3cee1..99b06aa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using SparseMatricesCSR using Test MPI.Init() +HYPRE_Init() @testset "HYPREMatrix" begin H = HYPREMatrix() @@ -264,6 +265,10 @@ end copy!(x, x_h) # Test result with direct solver @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(amg, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol end @testset "(ParCSR)PCG" begin @@ -282,6 +287,10 @@ end copy!(x, x_h) # Test result with direct solver @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(pcg, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol # Solve with AMG preconditioner precond = HYPRE.BoomerAMG() pcg = HYPRE.PCG(; Tol = tol, Precond = precond) @@ -289,4 +298,8 @@ end copy!(x, x_h) # Test result with direct solver @test x ≈ A \ b atol=tol + # Test without passing initial guess + x_h = HYPRE.solve(pcg, A_h, b_h) + copy!(x, x_h) + @test x ≈ A \ b atol=tol end