diff --git a/Project.toml b/Project.toml index fa6d2c1..6dfcf67 100644 --- a/Project.toml +++ b/Project.toml @@ -7,3 +7,5 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" HYPRE_jll = "0a602bbd-b08b-5d75-8d32-0de6eef44785" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1" diff --git a/src/HYPRE.jl b/src/HYPRE.jl index f6049bc..abc7741 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -61,4 +61,153 @@ module LibHYPRE end end +using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, rowvals, nonzeros, nzrange +using SparseMatricesCSR: SparseMatrixCSR, colvals +using MPI: MPI +using .LibHYPRE + +module Internals + function check_n_rows end + function to_hypre_data end +end + +function Internals.check_n_rows(A, ilower, iupper) + if size(A, 1) != (iupper - ilower + 1) + throw(ArgumentError("number of rows in matrix does not match global start/end rows ilower and iupper")) + end +end + +function Internals.to_hypre_data(A::SparseMatrixCSC, ilower, iupper) + Internals.check_n_rows(A, ilower, iupper) + nnz = SparseArrays.nnz(A) + A_rows = rowvals(A) + A_vals = nonzeros(A) + + # Initialize data as HYPRE expects + nrows = HYPRE_Int(iupper - ilower + 1) # Total number of rows + ncols = zeros(HYPRE_Int, nrows) # Number of colums for each row + rows = collect(HYPRE_BigInt, ilower:iupper) # The row indices + cols = Vector{HYPRE_BigInt}(undef, nnz) # The column indices + values = Vector{HYPRE_Complex}(undef, nnz) # The values + + # First pass to count nnz per row + @inbounds for j in 1:size(A, 2) + for i in nzrange(A, j) + row = A_rows[i] + ncols[row] += 1 + end + end + + # Keep track of the last index used for every row + lastinds = zeros(Int, nrows) + cumsum!((@view lastinds[2:end]), (@view ncols[1:end-1])) + + # Second pass to populate the output + @inbounds for j in 1:size(A, 2) + for i in nzrange(A, j) + row = A_rows[i] + k = lastinds[row] += 1 + val = A_vals[i] + cols[k] = j + values[k] = val + end + end + return nrows, ncols, rows, cols, values +end + +function Internals.to_hypre_data(A::SparseMatrixCSR, ilower, iupper) + Internals.check_n_rows(A, ilower, iupper) + nnz = SparseArrays.nnz(A) + A_cols = colvals(A) + A_vals = nonzeros(A) + + # Initialize data as HYPRE expects + nrows = HYPRE_Int(iupper - ilower + 1) # Total number of rows + ncols = Vector{HYPRE_Int}(undef, nrows) # Number of colums for each row + rows = collect(HYPRE_BigInt, ilower:iupper) # The row indices + cols = Vector{HYPRE_BigInt}(undef, nnz) # The column indices + values = Vector{HYPRE_Complex}(undef, nnz) # The values + + # Loop over the rows and collect all values + k = 0 + @inbounds for i in 1:size(A, 1) + nzr = nzrange(A, i) + ncols[i] = length(nzr) + for j in nzr + k += 1 + col = A_cols[j] + val = A_vals[j] + cols[k] = col + values[k] = val + end + end + @assert nnz == k + return nrows, ncols, rows, cols, values +end + +mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex} + IJMatrix::HYPRE_IJMatrix + ParCSRMatrix::HYPRE_ParCSRMatrix + HYPREMatrix() = new(C_NULL, C_NULL) +end + +function HYPREMatrix(B::Union{SparseMatrixCSC,SparseMatrixCSR}, ilower, iupper, comm::MPI.Comm=MPI.COMM_WORLD) + # Compute indices/values in the format SetValues expect + nrows, ncols, rows, cols, values = Internals.to_hypre_data(B, ilower, iupper) + # Create the IJ matrix + A = HYPREMatrix() + IJMatrixRef = Ref{HYPRE_IJMatrix}(C_NULL) + HYPRE_IJMatrixCreate(comm, ilower, iupper, ilower, iupper, IJMatrixRef) + A.IJMatrix = IJMatrixRef[] + # Attach a finalizer + finalizer(x -> HYPRE_IJMatrixDestroy(x.IJMatrix), A) + # Set storage type + HYPRE_IJMatrixSetObjectType(A.IJMatrix, HYPRE_PARCSR) + # Initialize to make ready for setting values + HYPRE_IJMatrixInitialize(A.IJMatrix) + # Set all the values + HYPRE_IJMatrixSetValues(A.IJMatrix, nrows, ncols, rows, cols, values) + # Finalize + HYPRE_IJMatrixAssemble(A.IJMatrix) + # Fetch the assembled CSR matrix + ParCSRMatrixRef = Ref{Ptr{Cvoid}}(C_NULL) + HYPRE_IJMatrixGetObject(A.IJMatrix, ParCSRMatrixRef) + A.ParCSRMatrix = convert(Ptr{HYPRE_ParCSRMatrix}, ParCSRMatrixRef[]) + return A +end + +mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex} + IJVector::HYPRE_IJVector + ParVector::HYPRE_ParVector + HYPREVector() = new(C_NULL, C_NULL) +end + +function Internals.to_hypre_data(x::Vector, ilower, iupper) + Internals.check_n_rows(x, ilower, iupper) + indices = collect(HYPRE_BigInt, ilower:iupper) + values = convert(Vector{HYPRE_Complex}, x) + return HYPRE_Int(length(indices)), indices, values +end + +function HYPREVector(x::Vector, ilower, iupper, comm=MPI.COMM_WORLD) + nvalues, indices, values = Internals.to_hypre_data(x, ilower, iupper) + b = HYPREVector() + b_ref = Ref{HYPRE_IJVector}(C_NULL) + HYPRE_IJVectorCreate(comm, ilower, iupper, b_ref) + b.IJVector = b_ref[] + finalizer(x -> HYPRE_IJVectorDestroy(x.IJVector), b) # Set storage type + HYPRE_IJVectorSetObjectType(b.IJVector, HYPRE_PARCSR) + # Initialize to make ready for setting values + HYPRE_IJVectorInitialize(b.IJVector) + # Set the values + HYPRE_IJVectorSetValues(b.IJVector, nvalues, indices, values) + # Finalize + HYPRE_IJVectorAssemble(b.IJVector) + # Fetch the assembled object + par_b_ref = Ref{Ptr{Cvoid}}(C_NULL) + HYPRE_IJVectorGetObject(b.IJVector, par_b_ref) + b.ParVector = convert(Ptr{HYPRE_ParVector}, par_b_ref[]) + return b +end + end # module HYPRE