Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c0d1493

Browse files
committedFeb 8, 2025
[better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Fixes in jet, stateful code, key_reuse, ode, tests.
1 parent c08b962 commit c0d1493

13 files changed

+117
-72
lines changed
 

‎jax/_src/api_util.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,17 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
610610
except (ValueError, TypeError):
611611
return None
612612

613-
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable) -> None:
613+
def save_wrapped_fun_sourceinfo(wrapper: Callable,
614+
wrapped: Callable | core.DebugInfo | None) -> None:
614615
# Prefer this to functools.wraps because it does not create a reference to
615616
# the wrapped function.
616-
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
617+
if isinstance(wrapped, core.DebugInfo):
618+
func_src_info = wrapped.func_src_info
619+
elif callable(wrapped):
620+
func_src_info = fun_sourceinfo(wrapped)
621+
else:
622+
return
623+
setattr(wrapper, "__fun_sourceinfo__", func_src_info)
617624

618625
_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
619626

‎jax/_src/core.py

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
141141
self._eqns = list(eqns)
142142
self._effects = effects
143143
self._debug_info = debug_info and debug_info.resolve_result_paths()
144+
if debug_info is None:
145+
assert False # DO_NOT_SUBMIT
144146
# TODO(necula): re-enable these safety checks
145147
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
146148
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

‎jax/_src/interpreters/pxla.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,7 @@ def _move_mutable_consts(
18241824
invars = (*jaxpr.invars, *mutvars)
18251825
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
18261826
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
1827-
effects, None)
1827+
effects, closed_jaxpr.jaxpr.debug_info)
18281828
return core.ClosedJaxpr(jaxpr, consts), in_mut
18291829

18301830
@weakref_lru_cache

‎jax/_src/linear_util.py

+2
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def __init__(self, f: Callable,
166166
self.params = params
167167
self.in_type = in_type
168168
self.debug_info = debug_info
169+
if debug_info is None:
170+
assert False # DO_NOT_SUBMIT
169171

170172
@property
171173
def __name__(self):

‎jax/_src/state/discharge.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree):
460460
return _prepend_scatter(x, indexer, val, add=True)
461461

462462
@weakref_lru_cache
463-
def _cached_closed_jaxpr_discharge(closed_jaxpr):
463+
def _cached_closed_jaxpr_discharge(closed_jaxpr: core.ClosedJaxpr):
464464
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.consts
465465
num_outs = len(jaxpr.outvars)
466466
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
467467
discharged_closed_jaxpr = core.ClosedJaxpr(discharged_jaxpr, discharged_consts)
468-
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr))
468+
fun = lu.wrap_init(core.jaxpr_as_fun(discharged_closed_jaxpr),
469+
debug_info=discharged_jaxpr.debug_info)
469470
return discharged_closed_jaxpr, num_outs, fun
470471

471472
@register_discharge_rule(core.closed_call_p)
@@ -598,7 +599,6 @@ def _convert_outputs_to_writes(
598599
assert not jaxpr.constvars, "Jaxpr shouldn't have constvars."
599600

600601
in_avals = [v.aval for v in jaxpr.invars]
601-
@lu.wrap_init
602602
def eval_jaxpr(*refs):
603603
# We split the refs into the original input refs and the dummy residual
604604
# refs.
@@ -610,14 +610,15 @@ def eval_jaxpr(*refs):
610610
res_ref_avals = [AbstractRef(v.aval) if not isinstance(v.aval, AbstractRef)
611611
else v.aval for v in jaxpr.outvars]
612612
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
613-
eval_jaxpr, [*in_avals, *res_ref_avals])
613+
lu.wrap_init(eval_jaxpr,
614+
debug_info=jaxpr.debug_info),
615+
[*in_avals, *res_ref_avals])
614616
assert not consts
615617
return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals]
616618

617619
def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr:
618620
assert not jaxpr.constvars, "Jaxpr should not have constvars"
619621

620-
@lu.wrap_init
621622
def eval_jaxpr(*refs):
622623
residual_refs, orig_refs = split_list(refs, [num_res])
623624
residual_vals = [r[...] for r in residual_refs]
@@ -629,7 +630,9 @@ def eval_jaxpr(*refs):
629630
res_ref_avals = [AbstractRef(aval) if not isinstance(aval, AbstractRef) else
630631
aval for aval in res_val_avals]
631632
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
632-
eval_jaxpr, [*res_ref_avals, *orig_ref_avals])
633+
lu.wrap_init(eval_jaxpr,
634+
debug_info=jaxpr.debug_info),
635+
[*res_ref_avals, *orig_ref_avals])
633636
return jaxpr
634637

635638
def _run_state_partial_eval(trace: pe.JaxprTrace, *tracers: pe.JaxprTracer,
@@ -845,12 +848,13 @@ def _run_state_partial_eval_custom(
845848
*[v.aval for v in res_staged_invars], **staged_params)
846849
_, staged_outvars = partition_list(in_unknowns, eqn.outvars)
847850
if num_res:
848-
@lu.wrap_init
851+
849852
def staged(*args):
850853
out = run_state_p.bind(*args, **staged_params)
851854
return out[num_res:]
852-
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(staged,
853-
[v.aval for v in res_staged_invars])
855+
staged_call_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
856+
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info),
857+
[v.aval for v in res_staged_invars])
854858
eqn_staged = pe.new_jaxpr_eqn(res_staged_invars,
855859
staged_outvars,
856860
core.closed_call_p,
@@ -918,7 +922,9 @@ def trans(*args):
918922
ad.backward_pass(tangent_jaxpr, False, (), (*primals_args, *ct_args), ())
919923
return []
920924
jaxpr_trans, _, consts, () = pe.trace_to_jaxpr_dynamic(
921-
lu.wrap_init(trans), [v.aval for v in jaxpr.invars])
925+
lu.wrap_init(trans,
926+
debug_info=jaxpr.debug_info),
927+
[v.aval for v in jaxpr.invars])
922928
return jaxpr_trans, consts
923929

924930
def _run_state_transpose(in_cts, *args, jaxpr: core.Jaxpr,

‎jax/experimental/jet.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import numpy as np
6161

6262
from jax import lax
63+
from jax import api_util
6364
import jax.numpy as jnp
6465
from jax.experimental import pjit
6566
from jax.tree_util import (register_pytree_node, tree_structure,
@@ -147,7 +148,9 @@ def flatten_fun_output(f, store, *args):
147148
store.store(tree)
148149
return ans
149150

150-
f, out_tree = flatten_fun_output(lu.wrap_init(fun))
151+
f, out_tree = flatten_fun_output(
152+
lu.wrap_init(fun,
153+
debug_info=api_util.debug_info("jet", fun, primals, {})))
151154
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
152155
return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
153156

@@ -723,7 +726,7 @@ def _scatter_add_rule(primals_in, series_in, *, update_jaxpr, update_consts,
723726
def _jet_jaxpr(
724727
jaxpr: core.ClosedJaxpr, order: int, primals_and_series_avals, in_tree_def
725728
) -> tuple[core.ClosedJaxpr, Any]:
726-
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
729+
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)
727730
f_jet, out_tree_def = traceable(jet_fun(jet_subtrace(f), order), in_tree_def)
728731
jaxpr_jet, _, consts, () = pe.trace_to_jaxpr_dynamic(
729732
f_jet, primals_and_series_avals)

‎jax/experimental/key_reuse/_core.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,10 @@ def is_consumed(var: core.Atom):
397397
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
398398
args_flat, in_tree = tree_util.tree_flatten(args)
399399
in_avals_flat = [core.get_aval(arg) for arg in args_flat]
400-
wrapped_fun, _ = api_util.flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
400+
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
401+
lu.wrap_init(fun,
402+
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
403+
in_tree)
401404
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
402405
return jaxpr_type_signature(jaxpr)
403406

‎jax/experimental/ode.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828

2929
from functools import partial
3030
import operator as op
31+
from typing import Callable
3132

3233
import jax
34+
from jax import api_util
3335
import jax.numpy as jnp
3436
from jax._src import core
3537
from jax import custom_derivatives
@@ -44,8 +46,9 @@
4446
zip = safe_zip
4547

4648

47-
def ravel_first_arg(f, unravel):
48-
return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped
49+
def ravel_first_arg(f: Callable, unravel, debug_info: core.DebugInfo):
50+
return ravel_first_arg_(lu.wrap_init(f, debug_info=debug_info),
51+
unravel).call_wrapped
4952

5053
@lu.transformation2
5154
def ravel_first_arg_(f, unravel, y_flat, *args):
@@ -179,9 +182,10 @@ def odeint(func, y0, t, *args, rtol=1.4e-8, atol=1.4e-8, mxstep=jnp.inf, hmax=jn
179182
return _odeint_wrapper(converted, rtol, atol, mxstep, hmax, y0, t, *args, *consts)
180183

181184
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4))
182-
def _odeint_wrapper(func, rtol, atol, mxstep, hmax, y0, ts, *args):
185+
def _odeint_wrapper(func: Callable, rtol, atol, mxstep, hmax, y0, ts, *args):
183186
y0, unravel = ravel_pytree(y0)
184-
func = ravel_first_arg(func, unravel)
187+
debug = api_util.debug_info("odeint", func, args, {})
188+
func = ravel_first_arg(func, unravel, debug)
185189
out = _odeint(func, rtol, atol, mxstep, hmax, y0, ts, *args)
186190
return jax.vmap(unravel)(out)
187191

‎tests/debug_info_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def to_remat(x):
881881
tracer_spy=tracer_spy,
882882
expected_jaxpr_debug_infos=[
883883
# TODO(necula): what are these flat_index components?
884-
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
884+
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
885885
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
886886
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
887887
],

‎tests/jaxpr_effects_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from absl.testing import absltest
1919
import jax
20+
from jax import api_util
2021
import jax.numpy as jnp
2122
from jax import lax
2223
from jax.experimental import pjit
@@ -178,12 +179,13 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
178179
def test_core_call_primitive_inherits_effects(self):
179180

180181
def f(x):
181-
@lu.wrap_init
182182
def f_(x):
183183
effect_p.bind(effect=foo_effect)
184184
effect_p.bind(effect=bar_effect)
185185
return [x]
186-
return core.call(f_, x)[0]
186+
dbg = api_util.debug_info("test", f_, (2.,), {})
187+
return core.call(
188+
lu.wrap_init(f_, debug_info=dbg), x)[0]
187189
jaxpr = jax.make_jaxpr(f)(2.)
188190
self.assertIn(foo_effect, jaxpr.jaxpr.effects)
189191
self.assertIn(bar_effect, jaxpr.jaxpr.effects)

‎tests/name_stack_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from absl.testing import absltest
1717
import jax
18+
from jax import api_util
1819
import jax.numpy as jnp
1920
from jax._src import core
2021
from jax import lax
@@ -85,11 +86,12 @@ def f(x):
8586
def test_call_primitive_jaxpr_should_not_store_outer_name_stack(self):
8687
@jax.named_scope('foo')
8788
def f(x):
88-
@lu.wrap_init
8989
@jax.named_scope('bar')
9090
def _f(x):
9191
return [x + 1]
92-
return core.call(_f, x)[0]
92+
return core.call(lu.wrap_init(
93+
_f,
94+
debug_info=api_util.debug_info("test", _f, (0,), {})), x)[0]
9395

9496
jaxpr = jax.make_jaxpr(f)(2).jaxpr
9597
self.assertEqual(str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.name_stack), 'bar')

‎tests/state_test.py

+56-46
Large diffs are not rendered by default.

‎tests/util_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818

1919
import jax
20+
from jax import api_util
2021
from jax._src import linear_util as lu
2122
from jax._src import test_util as jtu
2223
from jax._src import util
@@ -62,7 +63,10 @@ def kw_to_positional(f, store, factor, *args, **kwargs):
6263
store.store(aux_output)
6364
return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):])))
6465

65-
wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
66+
# Wraps `f` as a `WrappedFun`.
67+
wf = lu.wrap_init(
68+
f,
69+
debug_info=api_util.debug_info("test", f, (1, 2), dict(three=3, four=4)))
6670
wf, out_thunk = kw_to_positional(wf, 2)
6771
# Call the transformed function.
6872
scaled_positional, scaled_kwargs = wf.call_wrapped(1, 2, three=3, four=4)

0 commit comments

Comments
 (0)
Please sign in to comment.