mirror of https://github.com/fredrikekre/HYPRE.jl
6 changed files with 107 additions and 88 deletions
@ -0,0 +1,86 @@
@@ -0,0 +1,86 @@
|
||||
module HYPRESparseArrays |
||||
|
||||
using HYPRE.LibHYPRE: @check, HYPRE_BigInt, HYPRE_Complex, HYPRE_Int |
||||
using HYPRE: |
||||
HYPRE, HYPREMatrix, HYPRESolver, HYPREVector, HYPRE_IJMatrixSetValues, Internals |
||||
using MPI: MPI |
||||
using SparseArrays: SparseArrays, SparseMatrixCSC, nonzeros, nzrange, rowvals |
||||
|
||||
################################## |
||||
# SparseMatrixCSC -> HYPREMatrix # |
||||
################################## |
||||
|
||||
function Internals.to_hypre_data(A::SparseMatrixCSC, ilower, iupper) |
||||
Internals.check_n_rows(A, ilower, iupper) |
||||
nnz = SparseArrays.nnz(A) |
||||
A_rows = rowvals(A) |
||||
A_vals = nonzeros(A) |
||||
|
||||
# Initialize the data buffers HYPRE wants |
||||
nrows = HYPRE_Int(iupper - ilower + 1) # Total number of rows |
||||
ncols = zeros(HYPRE_Int, nrows) # Number of colums for each row |
||||
rows = collect(HYPRE_BigInt, ilower:iupper) # The row indices |
||||
cols = Vector{HYPRE_BigInt}(undef, nnz) # The column indices |
||||
values = Vector{HYPRE_Complex}(undef, nnz) # The values |
||||
|
||||
# First pass to count nnz per row |
||||
@inbounds for j in 1:size(A, 2) |
||||
for i in nzrange(A, j) |
||||
row = A_rows[i] |
||||
ncols[row] += 1 |
||||
end |
||||
end |
||||
|
||||
# Keep track of the last index used for every row |
||||
lastinds = zeros(Int, nrows) |
||||
cumsum!((@view lastinds[2:end]), (@view ncols[1:end-1])) |
||||
|
||||
# Second pass to populate the output |
||||
@inbounds for j in 1:size(A, 2) |
||||
for i in nzrange(A, j) |
||||
row = A_rows[i] |
||||
k = lastinds[row] += 1 |
||||
val = A_vals[i] |
||||
cols[k] = j |
||||
values[k] = val |
||||
end |
||||
end |
||||
@assert nrows == length(ncols) == length(rows) |
||||
return nrows, ncols, rows, cols, values |
||||
end |
||||
|
||||
# Note: keep in sync with the SparseMatrixCSR method |
||||
function HYPRE.HYPREMatrix(comm::MPI.Comm, B::SparseMatrixCSC, ilower, iupper) |
||||
A = HYPREMatrix(comm, ilower, iupper) |
||||
nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) |
||||
@check HYPRE_IJMatrixSetValues(A, nrows, ncols, rows, cols, values) |
||||
Internals.assemble_matrix(A) |
||||
return A |
||||
end |
||||
|
||||
# Note: keep in sync with the SparseMatrixCSC method |
||||
function HYPRE.HYPREMatrix(B::SparseMatrixCSC, ilower=1, iupper=size(B, 1)) |
||||
return HYPREMatrix(MPI.COMM_SELF, B, ilower, iupper) |
||||
end |
||||
|
||||
|
||||
#################################### |
||||
# SparseMatrixCSC solver interface # |
||||
#################################### |
||||
|
||||
# Note: keep in sync with the SparseMatrixCSR method |
||||
function HYPRE.solve(solver::HYPRESolver, A::SparseMatrixCSC, b::Vector) |
||||
hypre_x = HYPRE.solve(solver, HYPREMatrix(A), HYPREVector(b)) |
||||
x = copy!(similar(b, HYPRE_Complex), hypre_x) |
||||
return x |
||||
end |
||||
|
||||
# Note: keep in sync with the SparseMatrixCSR method |
||||
function HYPRE.solve!(solver::HYPRESolver, x::Vector, A::SparseMatrixCSC, b::Vector) |
||||
hypre_x = HYPREVector(x) |
||||
HYPRE.solve!(solver, hypre_x, HYPREMatrix(A), HYPREVector(b)) |
||||
copy!(x, hypre_x) |
||||
return x |
||||
end |
||||
|
||||
end # module HYPRESparseMatricesCSR |
||||
Loading…
Reference in new issue