Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse rule receiving incorrect values when closure argument captures an iterated variable #2304

Open
danielwe opened this issue Feb 14, 2025 · 3 comments

Comments

@danielwe
Copy link
Contributor

I was getting incorrect gradients and finally found the culprit. See the following MWE:

using Enzyme

call(f::F, x) where {F} = _call(f, x)
_call(f, x) = f(x)

function EnzymeRules.augmented_primal(
    config::EnzymeRules.RevConfig, ::Const{typeof(call)}, ::Type{<:Active}, f, x
)
    println("forward: f.val = $(repr(f.val))")
    fx = call(f.val, x.val)
    primal = EnzymeRules.needs_primal(config) ? fx : nothing
    shadow = EnzymeRules.needs_shadow(config) ? make_zero(fx) : nothing
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
    ::EnzymeRules.RevConfig, ::Const{typeof(call)}, shadow::Active, _, f, x
)
    println("reverse: f.val = $(repr(f.val))")
    fwd, rev = autodiff_thunk(
        ReverseSplitNoPrimal, Const{typeof(_call)}, Active, typeof(f), typeof(x)
    )
    tape, _, _ = fwd(Const(_call), f, x)
    return only(rev(Const(_call), f, x, shadow.val, tape))
end

function mappedclosure(y)
    return map(0.0:0.25:1.0) do z
        call(x -> sum(x .* y .* z), [1.0])
    end
end

yd = Duplicated([1.0], [0.0])
fwd, rev = autodiff_thunk(
    ReverseSplitNoPrimal,
    Const{typeof(mappedclosure)},
    Duplicated,
    typeof(yd),
)

tape, _, shadow = fwd(Const(mappedclosure), yd)
shadow .= 1
rev(Const(mappedclosure), yd, tape)

# output

forward: f.val = var"#2#4"{Float64, Vector{Float64}}(0.0, [1.0])
forward: f.val = var"#2#4"{Float64, Vector{Float64}}(0.25, [1.0])
forward: f.val = var"#2#4"{Float64, Vector{Float64}}(0.5, [1.0])
forward: f.val = var"#2#4"{Float64, Vector{Float64}}(0.75, [1.0])
forward: f.val = var"#2#4"{Float64, Vector{Float64}}(1.0, [1.0])
reverse: f.val = var"#2#4"{Float64, Vector{Float64}}(1.0, [1.0])
reverse: f.val = var"#2#4"{Float64, Vector{Float64}}(1.0, [1.0])
reverse: f.val = var"#2#4"{Float64, Vector{Float64}}(1.0, [1.0])
reverse: f.val = var"#2#4"{Float64, Vector{Float64}}(1.0, [1.0])
reverse: f.val = var"#2#4"{Float64, Vector{Float64}}(0.0, [1.0])

Notice how, in the reverse pass, the same captured value 1.0 is received 4 times in a row. The correct behavior would be for the sequence of values in the reverse pass to mirror those in the forward pass: 1.0, 0.75, 0.5, 0.25, 0.0.

Observations

  • The problem goes away if you remove the type annotation in the call definition, that is,
    -call(f::F, x) where {F} = _call(f, x)
    +call(f, x) = _call(f, x)
    Once again, there seems to be an unfortunate interaction with Julia's specialization heuristics.
  • This also happens when mapping over a vector, like collect(0.0:0.25:1.0), but not when mapping over a tuple like (0.0, 0.25, 0.5, 0.75, 1.0), which you can observe by making the following change:
    -    return map(0.0:0.25:1.0) do z
    +    return map((0.0, 0.25, 0.5, 0.75, 1.0)) do z
             call(x -> sum(x .* y .* z), [1.0])
    -    end
    +    end |> collect  # collect so we still have a Duplicated return
  • The MWE creates MixedDuplicated closures, but the bug also appears for Duplicated closures, which you can observe by making the following change:
    -        call(x -> sum(x .* y .* z), [1.0])
    +        t = [z]
    +        call(x -> sum(x .* y .* t), [1.0])
  • The numbers 0.0, 1.0, or 4 are not special, this example just happens to map over a range of length 5 from 0.0 to 1.0. If I use a range of length 17 from 0.27 to 1.23, I get 16 repetitions of 1.23, and then a final iteration with 0.27.
@danielwe danielwe changed the title Reverse rule receiving incorrect argument values when primal is mapped over a range/vector Reverse rule receiving incorrect values when primal is mapped over a range/vector and captures iterated value Feb 14, 2025
@danielwe danielwe changed the title Reverse rule receiving incorrect values when primal is mapped over a range/vector and captures iterated value Reverse rule receiving incorrect values when closure argument captures an iterated variable Feb 14, 2025
@danielwe
Copy link
Contributor Author

oof, this is a bit of a showstopper. I'm trying to work around it by placing the required argument on the tape, but to get correct gradients the reverse pass still needs to accumulate derivatives into the shadow of the argument. However, the argument often captures an array with different sizes in different iterations, so now my primal and shadow are incompatible.

Any hints as to how to get to the bottom of this and figure out a fix? The function enzyme_custom_setup_args in customrules.jl looks like a good place to start, but it's a bit daunting.

@danielwe
Copy link
Contributor Author

Partial mea culpa: I've been forgetting about overwritten. Taking that into account fixes the MWE.

However, my real-world examples still run into the issue where corresponding arrays in the tape's primal and the argument's shadow have incompatible sizes. Will make a new MWE shortly.

@danielwe
Copy link
Contributor Author

danielwe commented Feb 15, 2025

Here's a revised MWE that uses overwritten by the book and thus obtains a shadow array smaller than the corresponding primal array. In a handwritten rule, this would usually trigger something like DimensionMismatch: array could not be broadcast to match destination. However, the MWE rule calls back into autodiff_thunk, which seems to skip bounds checks and trigger a non-deterministic segfault instead. (I wonder if this could be the cause of some rare segfaults I've been experiencing, though the more pressing issue is that I'm unable to write correct rules to get correct gradients.)

Happy to do anything I can to help get to the bottom of this and produce a fix.

using Enzyme

call(f::F, x) where {F} = _call(f, x)
_call(f, x) = f(x)

function EnzymeRules.augmented_primal(
    config::EnzymeRules.RevConfig, ::Const{typeof(call)}, ::Type{<:Active}, f, x::Active
)
    fx = call(f.val, x.val)
    primal = EnzymeRules.needs_primal(config) ? fx : nothing
    shadow = EnzymeRules.needs_shadow(config) ? make_zero(fx) : nothing
    tape = EnzymeRules.overwritten(config)[2] ? (deepcopy(f.val),) : nothing
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    config::EnzymeRules.RevConfig, ::Const{typeof(call)}, shadow::Active, tape, f::F, x::X
) where {T,F<:Duplicated{T},X<:Active}
    ff = EnzymeRules.overwritten(config)[2] ? Duplicated(tape[1], f.dval) : f
    checkshadowsize(ff)
    fwd, rev = autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(_call)}, Active, F, X)
    innertape, _, _ = fwd(Const(_call), ff, x)
    return only(rev(Const(_call), ff, x, shadow.val, innertape))
end

function checkshadowsize(x::Duplicated{T}) where {T}
    foreach(fieldnames(T)) do name
        checkshadowsize(Duplicated(getfield.((x.val, x.dval), name)...))
    end
end

function checkshadowsize(x::Duplicated{<:Array})
    psize, ssize = size.((x.val, x.dval))
    if psize != ssize
        @warn """primal-shadow array size mismatch:
            size(primal) = $psize
            size(shadow) = $ssize"""
    end
end

function foo(y)
    return sum(10:-1:1) do j
        t = [j / i for i in 1:j]
        call(x -> sum(x .* t), y)
    end
end

function run(y::AbstractFloat, n)
    for _ in 1:n
        autodiff(Reverse, foo, Active, Active(y))
    end
    return autodiff(Reverse, foo, Active, Active(y))
end

run(1.0, 1000)

Output:

┌ Warning: primal-shadow array size mismatch:
│ size(primal) = (2,)
│ size(shadow) = (1,)
└ @ Main ~/src/scratch/shadowmismatch.jl:35
[...]

[46289] signal (11.2): Segmentation fault: 11
in expression starting at /Users/daniel/src/scratch/shadowmismatch.jl:56
gc_mark_outrefs at /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-honeycrisp-XC9YQX9HH2.0/build/default-honeycrisp-XC9YQX9HH2-0/julialang/julia-release-1-dot-10/src/gc.c:2520 [inlined]
[...]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant