Browse Source

Make HYPRE(Vector|Matrix) subtype of AbstractArray.

fe/abstractarray
Fredrik Ekre 3 years ago
parent
commit
ea1588c085
  1. 38
      src/HYPRE.jl

38
src/HYPRE.jl

@ -53,12 +53,13 @@ end
# HYPREMatrix # # HYPREMatrix #
############### ###############
mutable struct HYPREMatrix # <: AbstractMatrix{HYPRE_Complex} mutable struct HYPREMatrix <: AbstractMatrix{HYPRE_Complex}
#= const =# comm::MPI.Comm #= const =# comm::MPI.Comm
#= const =# ilower::HYPRE_BigInt #= const =# ilower::HYPRE_BigInt
#= const =# iupper::HYPRE_BigInt #= const =# iupper::HYPRE_BigInt
#= const =# jlower::HYPRE_BigInt #= const =# jlower::HYPRE_BigInt
#= const =# jupper::HYPRE_BigInt #= const =# jupper::HYPRE_BigInt
#= const =# n::HYPRE_BigInt
ijmatrix::HYPRE_IJMatrix ijmatrix::HYPRE_IJMatrix
parmatrix::HYPRE_ParCSRMatrix parmatrix::HYPRE_ParCSRMatrix
end end
@ -71,8 +72,10 @@ Base.unsafe_convert(::Type{HYPRE_ParCSRMatrix}, A::HYPREMatrix) = A.parmatrix
function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer, function HYPREMatrix(comm::MPI.Comm, ilower::Integer, iupper::Integer,
jlower::Integer=ilower, jupper::Integer=iupper) jlower::Integer=ilower, jupper::Integer=iupper)
# Compute total size (assumes square matrix)
n = MPI.Allreduce(HYPRE_BigInt(iupper - ilower + 1), +, comm)
# Create the IJ matrix # Create the IJ matrix
A = HYPREMatrix(comm, ilower, iupper, jlower, jupper, C_NULL, C_NULL) A = HYPREMatrix(comm, ilower, iupper, jlower, jupper, n, C_NULL, C_NULL)
ijmatrix_ref = Ref{HYPRE_IJMatrix}(C_NULL) ijmatrix_ref = Ref{HYPRE_IJMatrix}(C_NULL)
@check HYPRE_IJMatrixCreate(comm, ilower, iupper, ilower, iupper, ijmatrix_ref) @check HYPRE_IJMatrixCreate(comm, ilower, iupper, ilower, iupper, ijmatrix_ref)
A.ijmatrix = ijmatrix_ref[] A.ijmatrix = ijmatrix_ref[]
@ -107,10 +110,11 @@ end
# HYPREVector # # HYPREVector #
############### ###############
mutable struct HYPREVector # <: AbstractVector{HYPRE_Complex} mutable struct HYPREVector <: AbstractVector{HYPRE_Complex}
#= const =# comm::MPI.Comm #= const =# comm::MPI.Comm
#= const =# ilower::HYPRE_BigInt #= const =# ilower::HYPRE_BigInt
#= const =# iupper::HYPRE_BigInt #= const =# iupper::HYPRE_BigInt
#= const =# n::HYPRE_BigInt
ijvector::HYPRE_IJVector ijvector::HYPRE_IJVector
parvector::HYPRE_ParVector parvector::HYPRE_ParVector
end end
@ -121,8 +125,10 @@ Base.unsafe_convert(::Type{HYPRE_IJVector}, b::HYPREVector) = b.ijvector
Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector Base.unsafe_convert(::Type{HYPRE_ParVector}, b::HYPREVector) = b.parvector
function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer) function HYPREVector(comm::MPI.Comm, ilower::Integer, iupper::Integer)
# Compute total length
n = MPI.Allreduce(HYPRE_BigInt(iupper - ilower + 1), +, comm)
# Create the IJ vector # Create the IJ vector
b = HYPREVector(comm, ilower, iupper, C_NULL, C_NULL) b = HYPREVector(comm, ilower, iupper, n, C_NULL, C_NULL)
ijvector_ref = Ref{HYPRE_IJVector}(C_NULL) ijvector_ref = Ref{HYPRE_IJVector}(C_NULL)
@check HYPRE_IJVectorCreate(comm, ilower, iupper, ijvector_ref) @check HYPRE_IJVectorCreate(comm, ilower, iupper, ijvector_ref)
b.ijvector = ijvector_ref[] b.ijvector = ijvector_ref[]
@ -186,6 +192,30 @@ function Base.zero(b::HYPREVector)
return x return x
end end
###########################
# AbstractArray interface #
###########################
# Need to define hash since HYPRE(Matrix|Vector) are stored in the internal WeakKeyDict
function Base.hash(A::Union{HYPREVector, HYPREMatrix}, h::UInt)
return invoke(hash, Tuple{Any, UInt}, A, h)
end
# Define some useful methods to pretend to be <: AbstractArray
function Base.eltype(::Union{HYPREMatrix, Type{HYPREMatrix}, HYPREVector, Type{HYPREVector}})
return HYPRE_Complex
end
Base.size(b::HYPREVector) = (b.n,)
Base.size(A::HYPREMatrix) = (A.n, A.n)
Base.show(io::IO, ::MIME"text/plain", A::Union{HYPREVector,HYPREMatrix}) = show(io, A)
function Base.show(io::IO, b::HYPREVector)
print(io, "$(b.n)-element HYPREVector{$(eltype(b))} with local indices $(b.ilower):$(b.iupper)")
end
function Base.show(io::IO, b::HYPREMatrix)
print(io, "$(b.n)×$(b.n) HYPREMatrix{$(eltype(b))} with local rows $(b.ilower):$(b.iupper)")
end
###################################### ######################################
# SparseMatrixCS(C|R) -> HYPREMatrix # # SparseMatrixCS(C|R) -> HYPREMatrix #
###################################### ######################################

Loading…
Cancel
Save