Browse Source

WIP: Assembly directly to HYPREMatrix.

Fredrik Ekre 3 years ago
parent
commit
f0366fb228
  1. 72
      src/HYPRE.jl
  2. 44
      test/runtests.jl

72
src/HYPRE.jl

@ -524,6 +524,78 @@ function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex}) @@ -524,6 +524,78 @@ function Base.copy!(dst::HYPREVector, src::PVector{HYPRE_Complex})
return dst
end
####################
## HYPREAssembler ##
####################
struct HYPREAssembler
A::HYPREMatrix
b::Union{HYPREVector, Nothing}
ncols::Vector{HYPRE_Int}
rows::Vector{HYPRE_BigInt}
cols::Vector{HYPRE_BigInt}
values::Vector{HYPRE_Complex}
end
function start_assemble!(A::HYPREMatrix, b::Union{HYPREVector,Nothing}=nothing)
if A.parmatrix != C_NULL
# This matrix have been assembled before, reset to 0
@check HYPRE_IJMatrixSetConstantValues(A.ijmatrix, 0)
end
@check HYPRE_IJMatrixInitialize(A.ijmatrix)
return HYPREAssembler(A, b, HYPRE_Int[], HYPRE_BigInt[], HYPRE_BigInt[], HYPRE_Complex[])
end
function assemble!(A::HYPREAssembler, ij::Vector, a::Matrix, ::Union{Vector,Nothing}=nothing)
nrows, ncols, rows, cols, values = Internals.to_hypre_data(A, a, ij, ij)
@check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values)
return A
end
function assemble!(A::HYPREAssembler, ij::Vector, a::SparseMatrixCSC, ::Union{Vector,Nothing}=nothing)
size(a, 1) == size(a, 2) == length(ij) || error("mismatch in number of rows/cols")
# Reuse single-core functionality and recompute to global rows/cols after
nrows, ncols, rows, cols, values = Internals.to_hypre_data(a, 1, size(a, 1))
for i in eachindex(rows)
rows[i] = ij[rows[i]]
end
for i in eachindex(cols)
cols[i] = ij[cols[i]]
end
@check HYPRE_IJMatrixAddToValues(A.A.ijmatrix, nrows, ncols, rows, cols, values)
return A
end
function finish_assemble!(A::HYPREAssembler)
Internals.assemble_matrix(A.A)
return nothing
end
function Internals.to_hypre_data(A::HYPREAssembler, a::Matrix, I::Vector, J::Vector)
size(a, 1) == length(I) || error("mismatching number of rows")
size(a, 2) == length(J) || error("mismatch number of cols")
nrows = HYPRE_Int(length(I))
# Resize cache vectors
ncols = resize!(A.ncols, nrows)
rows = resize!(A.rows, length(a))
cols = resize!(A.cols, length(a))
values = resize!(A.values, length(a))
# Fill vectors
ncols = fill!(ncols, HYPRE_Int(length(J)))
copyto!(rows, I)
idx = 0
for i in 1:length(I), j in 1:length(J)
idx += 1
cols[idx] = J[j]
values[idx] = a[i, j]
end
@assert idx == length(a)
return nrows, ncols, rows, cols, values
end
# Solver interface
include("solvers.jl")
include("solver_options.jl")

44
test/runtests.jl

@ -3,6 +3,7 @@ @@ -3,6 +3,7 @@
using HYPRE
using HYPRE.Internals
using HYPRE.LibHYPRE
using HYPRE.LibHYPRE: @check
using LinearAlgebra
using MPI
using PartitionedArrays
@ -296,6 +297,49 @@ end @@ -296,6 +297,49 @@ end
@test tomain(pbc) == tomain(pb)
end
function getindex_debug(A::HYPREMatrix, i, j)
nrows = HYPRE_Int(length(i))
ncols = fill(HYPRE_Int(length(j)), length(i))
rows = convert(Vector{HYPRE_BigInt}, i)
cols = convert(Vector{HYPRE_BigInt}, repeat(j, length(i)))
values = Vector{HYPRE_Complex}(undef, length(i) * length(j))
@check HYPRE_IJMatrixGetValues(A.ijmatrix, nrows, ncols, rows, cols, values)
return permutedims(reshape(values, (length(j), length(i))))
end
@testset "HYPREMatrix assembler" begin
# Assembly from ::Matrix
A = HYPREMatrix(MPI.COMM_WORLD, 1, 3)
M = zeros(3, 3)
for i in 1:2
assembler = HYPRE.start_assemble!(A)
fill!(M, 0)
for idx in ([1, 2], [3, 1])
a = rand(2, 2)
HYPRE.assemble!(assembler, idx, a)
M[idx, idx] += a
end
HYPRE.finish_assemble!(assembler)
@test getindex_debug(A, 1:3, 1:3) == M
end
# Assembly from ::SparseMatrixCSC
A = HYPREMatrix(MPI.COMM_WORLD, 1, 10)
M = zeros(10, 10)
a = sprand(5, 5, 0.2) # Keep outside to avoid sparsity pattern change
for i in 1:2
@show i
assembler = HYPRE.start_assemble!(A)
fill!(M, 0)
for idx in ([1, 2, 4, 6, 7], [3, 1, 10, 9, 8])
HYPRE.assemble!(assembler, idx, a)
M[idx, idx] += a
end
HYPRE.finish_assemble!(assembler)
@test getindex_debug(A, 1:10, 1:10) == M
end
end
@testset "BiCGSTAB" begin
# Solver constructor and options
@test_throws(

Loading…
Cancel
Save