diff --git a/src/Runic.jl b/src/Runic.jl index 8961221..ce88357 100644 --- a/src/Runic.jl +++ b/src/Runic.jl @@ -23,6 +23,11 @@ mutable struct Context # User settings verbose::Bool debug::Bool + # Current state + # node::Union{JuliaSyntax.GreenNode, Nothing} + prev_sibling::Union{JuliaSyntax.GreenNode, Nothing} + next_sibling::Union{JuliaSyntax.GreenNode, Nothing} + # parent::Union{JuliaSyntax.GreenNode, Nothing} end function Context(src_str; debug::Bool = false, verbose::Bool = debug) @@ -30,7 +35,15 @@ function Context(src_str; debug::Bool = false, verbose::Bool = debug) src_tree = JuliaSyntax.parseall(JuliaSyntax.GreenNode, src_str; ignore_warnings=true) fmt_io = IOBuffer() fmt_tree = nothing - return Context(src_str, src_tree, src_io, fmt_io, fmt_tree, verbose, debug) + return Context( + src_str, src_tree, src_io, fmt_io, fmt_tree, verbose, debug, + nothing, nothing, + ) +end + +function next_sibling_kind(ctx::Context)::Union{JuliaSyntax.Kind, Nothing} + next = ctx.next_sibling + return next === nothing ? nothing : JuliaSyntax.kind(next) end # Read the bytes of the current node from the output io @@ -66,16 +79,24 @@ function format_node_with_children!(ctx::Context, node::JuliaSyntax.GreenNode) if !JuliaSyntax.haschildren(node) return node end + # Keep track of the siblings on this stack + prev_sibling = ctx.prev_sibling + next_sibling = ctx.next_sibling + ctx.prev_sibling = nothing + ctx.next_sibling = nothing # @assert JuliaSyntax.haschildren(node) span_sum = 0 original_bytes = node_bytes(ctx, node) # TODO: Read into reusable buffer children = JuliaSyntax.children(node) # The new node parts head′ = JuliaSyntax.head(node) - children′ = () + children′ = children # This aliases until the need to copy below # Keep track of changes; if no child changes the original node can be returned any_child_changed = false for (i, child) in pairs(children) + # Set the siblings: previous from children′, next from children + ctx.prev_sibling = get(children′, i - 1, nothing) + ctx.next_sibling = get(children, i + 1, nothing) child′ = child span_sum += JuliaSyntax.span(child) this_child_changed = false @@ -86,6 +107,8 @@ function format_node_with_children!(ctx::Context, node::JuliaSyntax.GreenNode) child′′ = format_node!(ctx, child′) if child′′ === nullnode this_child_changed = true + # TODO: When this is fixed the sibling setting above needs to handle this + # too error("TODO: handle removed children") elseif child′′ === child′ child′ = child′′ @@ -113,13 +136,17 @@ function format_node_with_children!(ctx::Context, node::JuliaSyntax.GreenNode) end any_child_changed |= this_child_changed if any_child_changed - # Promote children from tuple to array and copy older siblings into it - if children′ === () + # De-alias the children if needed + if children′ === children children′ = eltype(children)[children[j] for j in 1:(i-1)] end push!(children′, child′) end end + # Reset the siblings + ctx.prev_sibling = prev_sibling + ctx.next_sibling = next_sibling + # Return a new node if any of the children changed if any_child_changed span′ = mapreduce(JuliaSyntax.span, +, children′; init=0) return JuliaSyntax.GreenNode(head′, span′, children′) @@ -137,6 +164,12 @@ function format_node!(ctx::Context, node::JuliaSyntax.GreenNode) @assert !JuliaSyntax.haschildren(node) str = String(node_bytes(ctx, node)) str′ = replace(str, r"\h*(\r\n|\r|\n)" => '\n') + # If the next sibling is also a NewlineWs we can trim trailing + # whitespace from this node too + next_kind = next_sibling_kind(ctx) + if next_kind === K"NewlineWs" + str′ = replace(str′, r"(\r\n|\r|\n)\h*" => '\n') + end if str != str′ # Write new bytes and reset the stream nb = write_and_reset(ctx, str′) @@ -393,7 +426,7 @@ function format_tree!(ctx::Context) # Reset IOs so that the offsets are correct seek(ctx.src_io, src_pos) seek(ctx.fmt_io, fmt_pos) - # Keep track of the depth to break out of infinite loops + # Set the root to the current node root′ = root itr = 0 while true diff --git a/test/runtests.jl b/test/runtests.jl index 105e074..374849e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using Test: println(io, " ") # Trailing space on consecutive lines println(io, " ") str = String(take!(io)) - @test_broken format_string(str) == "a = 1\nb = 2\n\n\n" + @test format_string(str) == "a = 1\nb = 2\n\n\n" end @testset "Hex/oct/bin literal integers" begin