diff --git a/src/EnumX.jl b/src/EnumX.jl index 2886e94..1c54447 100644 --- a/src/EnumX.jl +++ b/src/EnumX.jl @@ -7,11 +7,21 @@ export @enumx abstract type Enum{T} <: Base.Enum{T} end macro enumx(args...) - return enumx(args...) + return enumx(__module__, args...) end -function enumx(name, args...) - namemap = Dict{Int32,Symbol}() +function enumx(_module_, name, args...) + if name isa Symbol + modname = name + baseT = Int32 + elseif name isa Expr && name.head == :(::) && name.args[1] isa Symbol && length(name.args) == 2 + modname = name.args[1] + baseT = Core.eval(_module_, name.args[2]) + else + throw(ArgumentError("invalid EnumX.@enumx type specification: $(name)")) + end + name = modname + namemap = Dict{baseT,Symbol}() next = 0 for arg in args @assert arg isa Symbol # TODO @@ -19,13 +29,13 @@ function enumx(name, args...) next += 1 end module_block = quote - primitive type Type <: Enum{Int32} 32 end + primitive type Type <: Enum{$(baseT)} $(sizeof(baseT) * 8) end let namemap = $(namemap) check_valid(x) = x in keys(namemap) || - throw(ArgumentError("invalid value $(x) for Enum $($(QuoteNode(name)))")) + throw(ArgumentError("invalid value $(x) for Enum $($(QuoteNode(modname)))")) global function $(esc(:Type))(x::Integer) check_valid(x) - return Base.bitcast($(esc(:Type)), convert(Int32, x)) + return Base.bitcast($(esc(:Type)), convert($(baseT), x)) end Base.Enums.namemap(::Base.Type{$(esc(:Type))}) = namemap Base.Enums.instances(::Base.Type{$(esc(:Type))}) = @@ -37,7 +47,7 @@ function enumx(name, args...) Expr(:const, Expr(:(=), esc(v), Expr(:call, esc(:Type), k))) ) end - return Expr(:toplevel, Expr(:module, false, esc(name), module_block), nothing) + return Expr(:toplevel, Expr(:module, false, esc(modname), module_block), nothing) end function Base.show(io::IO, ::MIME"text/plain", x::E) where E <: Enum diff --git a/test/runtests.jl b/test/runtests.jl index f78ba98..6f92726 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,9 @@ using EnumX, Test +const T16 = Int16 +getInt64() = Int64 + @testset "EnumX" begin # Basic @@ -68,4 +71,29 @@ let io = IOBuffer() @test str == "Fruit.Banana = 1" end + +# Base type specification +@enumx Fruit8::Int8 Apple +@test Fruit8.Type <: EnumX.Enum{Int8} <: Base.Enum{Int8} +@test Base.Enums.basetype(Fruit8.Type) === Int8 +@test Integer(Fruit8.Apple) === Int8(0) + +@enumx Fruit16::T16 Apple +@test Fruit16.Type <: EnumX.Enum{Int16} <: Base.Enum{Int16} +@test Base.Enums.basetype(Fruit16.Type) === Int16 +@test Integer(Fruit16.Apple) === Int16(0) + +@enumx Fruit64::getInt64() Apple +@test Fruit64.Type <: EnumX.Enum{Int64} <: Base.Enum{Int64} +@test Base.Enums.basetype(Fruit64.Type) === Int64 +@test Integer(Fruit64.Apple) == Int64(0) + +try + @macroexpand @enumx (Fr + uit) Apple +catch err + err isa LoadError && (err = err.error) + @test err isa ArgumentError + @test err.msg == "invalid EnumX.@enumx type specification: Fr + uit" +end + end # testset