From f0366fb228a1ed86b36f63f58398477c54440906 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Tue, 11 Oct 2022 16:51:56 +0200 Subject: [PATCH] WIP: Assembly directly to HYPREMatrix. --- src/HYPRE.jl | 72 ++++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 44 +++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/src/HYPRE.jl b/src/HYPRE.jl index 7c973f5..f2846d3 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -524,6 +524,78 @@ function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex}) return dst end + +#################### +## HYPREAssembler ## +#################### + +struct HYPREAssembler + A::HYPREMatrix + b::Union{HYPREVector, Nothing} + ncols::Vector{HYPRE_Int} + rows::Vector{HYPRE_BigInt} + cols::Vector{HYPRE_BigInt} + values::Vector{HYPRE_Complex} +end + +function start_assemble!(A::HYPREMatrix, b::Union{HYPREVector,Nothing}=nothing) + if A.parmatrix != C_NULL + # This matrix have been assembled before, reset to 0 + @check HYPRE_IJMatrixSetConstantValues(A.ijmatrix, 0) + end + @check HYPRE_IJMatrixInitialize(A.ijmatrix) + return HYPREAssembler(A, b, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[]) +end +function assemble!(A::HYPREAssembler, ij::Vector, a::Matrix, ::Union{Vector,Nothing}=nothing) + nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, ij, ij) + @check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values) + return A +end +function assemble!(A::HYPREAssembler, ij::Vector, a::SparseMatrixCSC, ::Union{Vector,Nothing}=nothing) + size(a, 1) == size(a, 2) == length(ij) || error("mismatch in number of rows/cols") + # Reuse single-core functionality and recompute to global rows/cols after + nrows, ncols, rows, cols, values = Internals.to_hypre_data(a, 1, size(a, 1)) + for i in eachindex(rows) + rows[i] = ij[rows[i]] + end + for i in eachindex(cols) + cols[i] = ij[cols[i]] + end + @check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values) + return A +end +function finish_assemble!(A::HYPREAssembler) + Internals.assemble_matrix(A.A) + return nothing +end + +function Internals.to_hypre_data(A::HYPREAssembler, a::Matrix, I::Vector, J::Vector) + size(a, 1) == length(I) || error("mismatching number of rows") + size(a, 2) == length(J) || error("mismatch number of cols") + + nrows = HYPRE_Int(length(I)) + + # Resize cache vectors + ncols = resize!(A.ncols, nrows) + rows = resize!(A.rows, length(a)) + cols = resize!(A.cols, length(a)) + values = resize!(A.values, length(a)) + + # Fill vectors + ncols = fill!(ncols, HYPRE_Int(length(J))) + copyto!(rows, I) + idx = 0 + for i in 1:length(I), j in 1:length(J) + idx += 1 + cols[idx] = J[j] + values[idx] = a[i, j] + end + @assert idx == length(a) + + return nrows, ncols, rows, cols, values +end + + # Solver interface include("solvers.jl") include("solver_options.jl") diff --git a/test/runtests.jl b/test/runtests.jl index b6baa88..1c86f7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using HYPRE using HYPRE.Internals using HYPRE.LibHYPRE +using HYPRE.LibHYPRE: @check using LinearAlgebra using MPI using PartitionedArrays @@ -296,6 +297,49 @@ end @test tomain(pbc) == tomain(pb) end + +function getindex_debug(A::HYPREMatrix, i, j) + nrows = HYPRE_Int(length(i)) + ncols = fill(HYPRE_Int(length(j)), length(i)) + rows = convert(Vector{HYPRE_BigInt}, i) + cols = convert(Vector{HYPRE_BigInt}, repeat(j, length(i))) + values = Vector{HYPRE_Complex}(undef, length(i) * length(j)) + @check HYPRE_IJMatrixGetValues(A.ijmatrix, nrows, ncols, rows, cols, values) + return permutedims(reshape(values, (length(j), length(i)))) +end + +@testset "HYPREMatrix assembler" begin + # Assembly from ::Matrix + A = HYPREMatrix(MPI.COMM_WORLD, 1, 3) + M = zeros(3, 3) + for i in 1:2 + assembler = HYPRE.start_assemble!(A) + fill!(M, 0) + for idx in ([1, 2], [3, 1]) + a = rand(2, 2) + HYPRE.assemble!(assembler, idx, a) + M[idx, idx] += a + end + HYPRE.finish_assemble!(assembler) + @test getindex_debug(A, 1:3, 1:3) == M + end + # Assembly from ::SparseMatrixCSC + A = HYPREMatrix(MPI.COMM_WORLD, 1, 10) + M = zeros(10, 10) + a = sprand(5, 5, 0.2) # Keep outside to avoid sparsity pattern change + for i in 1:2 + @show i + assembler = HYPRE.start_assemble!(A) + fill!(M, 0) + for idx in ([1, 2, 4, 6, 7], [3, 1, 10, 9, 8]) + HYPRE.assemble!(assembler, idx, a) + M[idx, idx] += a + end + HYPRE.finish_assemble!(assembler) + @test getindex_debug(A, 1:10, 1:10) == M + end +end + @testset "BiCGSTAB" begin # Solver constructor and options @test_throws(