From ea1588c085b923f279f4d1765d41fe0022d199f0 Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Wed, 1 Feb 2023 16:36:06 +0100 Subject: [PATCH] Make HYPRE(Vector|Matrix) subtype of AbstractArray. --- src/HYPRE.jl | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/HYPRE.jl b/src/HYPRE.jl index dcaa12e..5cf34af 100644 --- a/src/HYPRE.jl +++ b/src/HYPRE.jl @@ -53,12 +53,13 @@ end # HYPREMatrix # ############### -mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex} +mutable struct HYPREMatrix <: AbstractMatrix{HYPRE_Complex} #= const =# comm::MPI.Comm #= const =# ilower::HYPRE_BigInt #= const =# iupper::HYPRE_BigInt #= const =# jlower::HYPRE_BigInt #= const =# jupper::HYPRE_BigInt + #= const =# n::HYPRE_BigInt ijmatrix::HYPRE_IJMatrix parmatrix::HYPRE_ParCSRMatrix end @@ -71,8 +72,10 @@ Base.unsafe_convert(::Type{HYPRE_ParCSRMatrix}, A::HYPREMatrix) = A.parmatrix function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, jlower::Integer=ilower, jupper::Integer=iupper) + # Compute total size (assumes square matrix) + n = MPI.Allreduce(HYPRE_BigInt(iupper - ilower + 1), +, comm) # Create the IJ matrix - A = HYPREMatrix(comm, ilower, iupper, jlower, jupper, C_NULL, C_NULL) + A = HYPREMatrix(comm, ilower, iupper, jlower, jupper, n, C_NULL, C_NULL) ijmatrix_ref = Ref{HYPRE_IJMatrix}(C_NULL) @check HYPRE_IJMatrixCreate(comm, ilower, iupper, ilower, iupper, ijmatrix_ref) A.ijmatrix = ijmatrix_ref[] @@ -107,10 +110,11 @@ end # HYPREVector # ############### -mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex} +mutable struct HYPREVector <: AbstractVector{HYPRE_Complex} #= const =# comm::MPI.Comm #= const =# ilower::HYPRE_BigInt #= const =# iupper::HYPRE_BigInt + #= const =# n::HYPRE_BigInt ijvector::HYPRE_IJVector parvector::HYPRE_ParVector end @@ -121,8 +125,10 @@ Base.unsafe_convert(::Type{HYPRE_IJVector}, b::HYPREVector) = b.ijvector Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) + # Compute total length + n = MPI.Allreduce(HYPRE_BigInt(iupper - ilower + 1), +, comm) # Create the IJ vector - b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL) + b = HYPREVector(comm, ilower, iupper, n, C_NULL, C_NULL) ijvector_ref = Ref{HYPRE_IJVector}(C_NULL) @check HYPRE_IJVectorCreate(comm, ilower, iupper, ijvector_ref) b.ijvector = ijvector_ref[] @@ -186,6 +192,30 @@ function Base.zero(b::HYPREVector) return x end +########################### +# AbstractArray interface # +########################### + +# Need to define hash since HYPRE(Matrix|Vector) are stored in the internal WeakKeyDict +function Base.hash(A::Union{HYPREVector, HYPREMatrix}, h::UInt) + return invoke(hash, Tuple{Any, UInt}, A, h) +end + +# Define some useful methods to pretend to be <: AbstractArray +function Base.eltype(::Union{HYPREMatrix, Type{HYPREMatrix}, HYPREVector, Type{HYPREVector}}) + return HYPRE_Complex +end +Base.size(b::HYPREVector) = (b.n,) +Base.size(A::HYPREMatrix) = (A.n, A.n) + +Base.show(io::IO, ::MIME"text/plain", A::Union{HYPREVector,HYPREMatrix}) = show(io, A) +function Base.show(io::IO, b::HYPREVector) + print(io, "$(b.n)-element HYPREVector{$(eltype(b))} with local indices $(b.ilower):$(b.iupper)") +end +function Base.show(io::IO, b::HYPREMatrix) + print(io, "$(b.n)×$(b.n) HYPREMatrix{$(eltype(b))} with local rows $(b.ilower):$(b.iupper)") +end + ###################################### # SparseMatrixCS(C|R) -> HYPREMatrix # ######################################