Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c08b962

Browse files
committedFeb 8, 2025
[better_errors] Continue adding debug info to Jaxprs (step 6)
This follows in a series, starting with jax-ml#26078 and jax-ml#26313, adding debug_info to more calls to lu.wrap_init. Here I changed the `custom_jvp_call` to replace the parameter `jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun` that can carry debug info). Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
1 parent fd1b7cc commit c08b962

20 files changed

+188
-106
lines changed
 

‎jax/_src/api_util.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from jax._src.state.types import AbstractRef
2828
from jax._src.tree_util import (
2929
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
30-
treedef_children, generate_key_paths, keystr, broadcast_prefix,
30+
treedef_children, generate_key_paths, broadcast_prefix,
3131
prefix_errors)
3232
from jax._src.tree_util import _replace_nones
3333
from jax._src import linear_util as lu
@@ -664,12 +664,13 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
664664
except (ValueError, TypeError):
665665
pass
666666
else:
667-
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
667+
return tuple(f'{name}{lu._clean_keystr_arg_names(path)}'
668+
for name, x in ba.arguments.items()
668669
for path, l in generate_key_paths(x) if l is not static)
669-
args_arg_names = tuple(f'args{keystr(path)}'
670+
args_arg_names = tuple(f'args{lu._clean_keystr_arg_names(path)}'
670671
for path, l in generate_key_paths(args_)
671672
if l is not static)
672-
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
673+
kwargs_arg_names = tuple(f'kwargs{lu._clean_keystr_arg_names(path)}'
673674
for path, l in generate_key_paths(kwargs_)
674675
if l is not static)
675676
arg_names = args_arg_names + kwargs_arg_names

‎jax/_src/checkify.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def new_body_f(*c_consts_and_vals):
833833
# This checks if the next cond application will error
834834
_ = cond_f(*c_consts, *out)
835835
return out
836-
new_body_f_ = lu.wrap_init(new_body_f)
836+
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
837837
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
838838
jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
839839
*body_jaxpr.in_avals])
@@ -952,7 +952,8 @@ def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
952952

953953

954954
def shard_map_error_check(
955-
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
955+
error: Error, enabled_errors, *vals_in,
956+
jaxpr: core.Jaxpr, in_names, out_names, **kwargs
956957
):
957958
if (mesh := kwargs.get('mesh')) is None:
958959
raise ValueError('Mesh must be provided for shard_map with checkify.')
@@ -976,7 +977,6 @@ def shard_map_error_check(
976977
)
977978
num_out_error_vals = out_tree.num_leaves - len(out_names)
978979

979-
@lu.wrap_init
980980
def expand_errors_leading_dim(*xs):
981981
outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs)
982982
errs, outs = split_list(outs, [num_out_error_vals])
@@ -985,15 +985,18 @@ def expand_errors_leading_dim(*xs):
985985

986986
with core.extend_axis_env_nd(mesh.shape.items()):
987987
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
988-
expand_errors_leading_dim, checked_jaxpr.in_avals
988+
lu.wrap_init(expand_errors_leading_dim,
989+
debug_info=checked_jaxpr.jaxpr.debug_info),
990+
checked_jaxpr.in_avals
989991
)
990992
checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
991993

992994
# Update shard_map params to account for extra error values.
993995
# Use fully sharded partitioning for out errors.
994996
new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names)
995997
subfun = lu.hashable_partial(
996-
lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts
998+
lu.wrap_init(core.eval_jaxpr, debug_info=checked_jaxpr.jaxpr.debug_info),
999+
checked_jaxpr.jaxpr, checked_jaxpr.consts
9971000
)
9981001
new_params = dict(
9991002
jaxpr=checked_jaxpr.jaxpr,
@@ -1007,8 +1010,10 @@ def expand_errors_leading_dim(*xs):
10071010
return tree_unflatten(out_tree, err_and_out)
10081011
error_checks[shard_map.shard_map_p] = shard_map_error_check
10091012

1010-
def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
1011-
jvp_jaxpr_thunk, call_jaxpr, **params):
1013+
def custom_jvp_call_rule(in_err: Error,
1014+
enabled_errors: set, *in_vals, num_consts,
1015+
jvp_jaxpr_fun: lu.WrappedFun,
1016+
call_jaxpr: core.ClosedJaxpr, **params):
10121017
# The types to have in mind are:
10131018
# jvp : (a -> b) -> (a, T a) -> (b, T b)
10141019
# checkify : (a -> b) -> a -> Err b
@@ -1021,10 +1026,11 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10211026
err_vals, err_tree = jtu.tree_flatten(in_err)
10221027
partial_checkify = lu.wrap_init(
10231028
functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
1024-
call_jaxpr.consts, enabled_errors, err_tree))
1029+
call_jaxpr.consts, enabled_errors, err_tree),
1030+
debug_info=call_jaxpr.jaxpr.debug_info)
10251031
partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
10261032
partial_checkify)
1027-
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
1033+
jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_fun)
10281034
jvp, jvp_out_tree = flatten_fun_output(jvp)
10291035
all_outs = custom_derivatives.custom_jvp_call_p.bind(
10301036
partial_checkify, jvp, *err_vals, *in_vals, **params)
@@ -1041,17 +1047,17 @@ def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
10411047

10421048
# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
10431049
# outputs that checkify adds (just forwarding the error data's primal and
1044-
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
1050+
# tangent components). The jaxpr in jvp_jaxpr_fun doesn't expect those.
10451051
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
10461052
# Adding another layer of lu.transformation was tricky, though maybe doable.
1047-
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
1048-
@lu.wrap_init
1053+
def lift_jvp(num_errs: int, num_consts: int,
1054+
jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
10491055
def jvp(*xs):
10501056
n, ragged = divmod(len(xs), 2)
10511057
assert not ragged
10521058
primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
10531059
zeros = [type(t) is SymbolicZero for t in tangents]
1054-
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
1060+
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
10551061
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
10561062
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
10571063
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
@@ -1063,7 +1069,7 @@ def jvp(*xs):
10631069
primal_errs = xs[num_consts:num_consts+num_errs]
10641070
tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
10651071
return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
1066-
return jvp
1072+
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
10671073

10681074
def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals,
10691075
fun_jaxpr: core.ClosedJaxpr,

‎jax/_src/custom_derivatives.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,19 @@ def get_bind_params(self, params):
378378
new_params = dict(params)
379379
call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
380380
num_consts: int = new_params.pop('num_consts')
381-
jvp_jaxpr_thunk = new_params.pop('jvp_jaxpr_thunk')
381+
jvp_jaxpr_fun = new_params.pop('jvp_jaxpr_fun')
382382
fun = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
383383
debug_info=call_jaxpr.jaxpr.debug_info)
384-
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
384+
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
385385
return [fun, jvp], new_params
386386

387-
def lift_jvp(num_consts: int, jvp_jaxpr_thunk: Callable,
388-
debug_info: core.DebugInfo | None) -> lu.WrappedFun:
387+
def lift_jvp(num_consts: int, jvp_jaxpr_fun: lu.WrappedFun) -> lu.WrappedFun:
389388
def jvp(*xs):
390389
n, ragged = divmod(len(xs), 2)
391390
assert not ragged
392391
primals, tangents = xs[num_consts:n], xs[n+num_consts:]
393392
zeros = [type(t) is SymbolicZero for t in tangents]
394-
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
393+
jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*zeros)
395394
nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
396395
out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
397396
out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
@@ -401,26 +400,26 @@ def jvp(*xs):
401400
for p, z in zip(out_primals, out_zeros)]
402401
assert next(nz_out_tangents_, None) is None
403402
return [*out_primals, *out_tangents]
404-
return lu.wrap_init(jvp, debug_info=debug_info)
403+
return lu.wrap_init(jvp, debug_info=jvp_jaxpr_fun.debug_info)
405404

406405
effects.custom_derivatives_allowed_effects.add_type(lax.InOutFeedEffect)
407406

408407
custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')
409408

410-
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_thunk,
409+
def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
411410
num_consts, symbolic_zeros):
412411
# TODO(mattjj): could do more checking here...
413-
del in_avals, jvp_jaxpr_thunk, num_consts
412+
del in_avals, jvp_jaxpr_fun, num_consts
414413
disallowed_effects = effects.custom_derivatives_allowed_effects.filter_not_in(call_jaxpr.effects)
415414
if disallowed_effects:
416415
raise NotImplementedError(
417416
f'Effects not supported in `custom_jvp`: {disallowed_effects}')
418417
return call_jaxpr.out_avals, call_jaxpr.effects
419418
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
420419

421-
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
420+
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_fun,
422421
num_consts, symbolic_zeros):
423-
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
422+
del jvp_jaxpr_fun, num_consts, symbolic_zeros
424423
consts = mlir._ir_consts(call_jaxpr.consts)
425424
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
426425
ctx.name_stack, ctx.tokens_in, consts,
@@ -452,7 +451,7 @@ def _custom_jvp_call_dce(
452451
return [False] * len(eqn.invars), None
453452

454453
call_jaxpr = eqn.params["call_jaxpr"]
455-
jvp_jaxpr_thunk = eqn.params["jvp_jaxpr_thunk"]
454+
jvp_jaxpr_fun = eqn.params["jvp_jaxpr_fun"]
456455
# We must set instantiate=True because some inputs that are unused by the
457456
# DCE'ed primal might be used in the JVP rule.
458457
dce_call_jaxpr, used_ins = _cached_closed_call_dce_instantiate(
@@ -461,7 +460,7 @@ def _custom_jvp_call_dce(
461460

462461
@pe._memoize
463462
def dce_jvp_jaxpr_thunk(*in_zeros):
464-
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_thunk(*in_zeros)
463+
jvp_jaxpr, consts, out_zeros = jvp_jaxpr_fun.call_wrapped(*in_zeros)
465464
dce_jvp_jaxpr, _ = pe.dce_jaxpr(jvp_jaxpr, [*used_outs, *used_outs], True)
466465
dce_out_zeros = [v for used, v in zip(used_outs, out_zeros) if used]
467466
return dce_jvp_jaxpr, consts, dce_out_zeros
@@ -470,7 +469,8 @@ def dce_jvp_jaxpr_thunk(*in_zeros):
470469
new_params = dict(
471470
eqn.params,
472471
call_jaxpr=dce_call_jaxpr,
473-
jvp_jaxpr_thunk=dce_jvp_jaxpr_thunk,
472+
jvp_jaxpr_fun=lu.wrap_init(dce_jvp_jaxpr_thunk,
473+
debug_info=jvp_jaxpr_fun.debug_info)
474474
)
475475
new_eqn = pe.new_jaxpr_eqn(
476476
eqn.invars, outvars, eqn.primitive, new_params, dce_call_jaxpr.effects,

‎jax/_src/interpreters/ad.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,9 @@ def f_tangent(*args):
725725

726726
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
727727
nz_tangents_out = call_primitive.bind_with_trace(
728-
self.tangent_trace, (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), new_params)
728+
self.tangent_trace, (lu.wrap_init(f_tangent,
729+
debug_info=lin_jaxpr.debug_info),
730+
*residuals, *nz_tangents_in), new_params)
729731
nz_tangents_out_iter = iter(nz_tangents_out)
730732
tangents_out = [next(nz_tangents_out_iter) if nz else Zero.from_primal_value(primal)
731733
for nz, primal in zip(nzs_out, primals_out)]

‎jax/_src/interpreters/partial_eval.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def _closed_call_param_updater(params, _, __):
502502
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
503503
call_param_updaters[core.closed_call_p] = _closed_call_param_updater
504504

505-
def abstract_eval_fun(fun, *avals, debug_info=None, **params):
505+
def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
506506
_, avals_out, _, () = trace_to_jaxpr_dynamic(
507507
lu.wrap_init(fun, params, debug_info=debug_info), avals)
508508
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
@@ -1992,7 +1992,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun,
19921992
self.frame.add_eqn(eqn)
19931993
return out_tracers
19941994

1995-
def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros):
1995+
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
1996+
jvp: lu.WrappedFun, tracers,
1997+
symbolic_zeros: bool):
19961998
tracers = map(self.to_jaxpr_tracer, tracers)
19971999
in_avals = [t.aval for t in tracers]
19982000
in_tangent_avals = [t.to_tangent_aval() for t in in_avals]
@@ -2014,7 +2016,8 @@ def jvp_jaxpr_thunk(*in_zeros):
20142016
outvars = map(self.makevar, out_tracers)
20152017
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
20162018
dict(call_jaxpr=closed_fun_jaxpr,
2017-
jvp_jaxpr_thunk=jvp_jaxpr_thunk,
2019+
jvp_jaxpr_fun=lu.wrap_init(jvp_jaxpr_thunk,
2020+
debug_info=jvp.debug_info),
20182021
num_consts=len(consts),
20192022
symbolic_zeros=symbolic_zeros),
20202023
fun_jaxpr.effects,

‎jax/_src/lax/control_flow/loops.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,9 @@ def arrange_jaxpr_args_for_wrapped(args):
12551255
)
12561256
# TODO(cperivol): avoid tracing the jaxpr twice. When doing so don't
12571257
# forget to manage the effects.
1258-
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
1258+
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(
1259+
lu.wrap_init(wrapped, debug_info=discharged_jaxpr.debug_info),
1260+
avals_for_wrapped_no_refs)
12591261
all_out = scan_p.bind(*args_for_wrapped,
12601262
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
12611263
length=length,
@@ -1922,9 +1924,9 @@ def new_body(*consts_refs_carry):
19221924
carry, refs_out = split_list(carry_refs, [num_carry])
19231925
return [*refs_out, *carry]
19241926
new_body_jaxpr, _, new_body_consts, () = pe.trace_to_jaxpr_dynamic(
1925-
lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a
1926-
in ref_avals],
1927-
*carry_avals])
1927+
lu.wrap_init(new_body, debug_info=discharged_body_jaxpr.debug_info),
1928+
[*remaining_body_const_avals, *[a.inner_aval for a in ref_avals],
1929+
*carry_avals])
19281930
if new_body_consts: raise NotImplementedError
19291931

19301932
# Since some `Ref`s that were previously consts are now carries, we need to
@@ -1936,9 +1938,8 @@ def new_cond(*consts_refs_carry):
19361938
del refs # We don't use them here!
19371939
return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry)
19381940
new_cond_jaxpr, _, new_cond_consts, () = pe.trace_to_jaxpr_dynamic(
1939-
lu.wrap_init(new_cond), [*cond_consts_avals,
1940-
*[a.inner_aval for a in ref_avals],
1941-
*carry_avals])
1941+
lu.wrap_init(new_cond, debug_info=cond_jaxpr.debug_info),
1942+
[*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals])
19421943
if new_cond_consts: raise NotImplementedError
19431944

19441945
out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry,

‎jax/_src/lax/control_flow/solves.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def _root_jvp(const_lengths, jaxprs, primals, tangents):
159159
linearize_and_solve = partial(
160160
core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s)
161161
f_at_solution = lambda *params: f(*params, *solution)
162-
_, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
162+
_, rhs = ad.jvp(lu.wrap_init(f_at_solution,
163+
debug_info=jaxprs.f.jaxpr.debug_info)).call_wrapped(
163164
params.f, params_dot.f)
164165
solution_dot = _map(
165166
operator.neg, linearize_and_solve(*solution, *rhs))

‎jax/_src/linear_util.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,14 @@ def trans1(static_arg, *dynamic_args, **kwargs):
6565

6666
from collections.abc import Callable, Sequence
6767
from functools import partial
68+
import re
6869
from typing import Any, NamedTuple
6970
import weakref
7071

7172
from jax._src import config
7273
from jax._src import core
7374
from jax._src import traceback_util
74-
from jax._src.tree_util import keystr, generate_key_paths
75+
from jax._src.tree_util import keystr, KeyPath, generate_key_paths
7576
from jax._src.util import curry, cache_clearing_funs, HashableFunction
7677

7778

@@ -329,10 +330,16 @@ def wrap_init(f: Callable, params=None, *,
329330
return fun
330331

331332

333+
# We replace <flat index 0> with 0
334+
_re_clean_keystr_arg_names = re.compile(r"<flat index ([^>]+)>")
335+
def _clean_keystr_arg_names(k: KeyPath) -> str:
336+
res = keystr(k)
337+
return _re_clean_keystr_arg_names.sub(r"\1", res)
338+
332339
@transformation_with_aux2
333340
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
334341
ans = _fun(*args, **kwargs)
335-
result_paths = [keystr(path) for path, _ in generate_key_paths(ans)]
342+
result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)]
336343
if _store:
337344
# In some instances a lu.WrappedFun is called multiple times, e.g.,
338345
# the bwd function in a custom_vjp

‎jax/_src/pallas/mosaic/lowering.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3026,11 +3026,11 @@ def _custom_jvp_call_lowering_rule(
30263026
ctx: LoweringRuleContext,
30273027
*args,
30283028
call_jaxpr: jax_core.Jaxpr,
3029-
jvp_jaxpr_thunk: Callable,
3029+
jvp_jaxpr_fun: lu.WrappedFun,
30303030
num_consts: int,
30313031
symbolic_zeros: bool,
30323032
):
3033-
del jvp_jaxpr_thunk
3033+
del jvp_jaxpr_fun
30343034
if symbolic_zeros: raise NotImplementedError
30353035
if num_consts: raise NotImplementedError
30363036
if call_jaxpr.consts: raise NotImplementedError

‎jax/_src/pallas/pallas_call.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def _block_map_function(new_idx, *args):
242242

243243
with grid_mapping.trace_env():
244244
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
245-
lu.wrap_init(_block_map_function), idx_avals)
245+
lu.wrap_init(_block_map_function,
246+
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
247+
idx_avals)
246248
shape = block_mapping.block_shape
247249
if dim is batching.not_mapped:
248250
new_block_shape = shape

‎jax/_src/state/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _hoist(*consts_args):
7373
return core.eval_jaxpr(jaxpr, all_consts, *args0, *args1)
7474

7575
hoisted_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
76-
lu.wrap_init(_hoist), in_avals)
76+
lu.wrap_init(_hoist, debug_info=jaxpr.debug_info), in_avals)
7777
assert not consts, "All consts should have been converted to refs"
7878
return hoisted_jaxpr
7979

‎jax/experimental/attrs.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
17+
from typing import Any, Callable
1818

1919
from jax._src import core
2020
from jax._src import source_info_util
@@ -90,7 +90,9 @@ def jvp(f, primals, tangents, attr_tangents):
9090
primals_flat, in_tree = tree_flatten((attr_primals, *primals))
9191
tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents))
9292
if in_tree != in_tree_: raise Exception
93-
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree)
93+
dbg = api_util.debug_info("attrs_jvp", f, primals, {})
94+
f_, out_tree = flatten_fun_nokwargs(
95+
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), in_tree)
9496
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
9597
primals_flat, tangents_flat)
9698
out_primals = tree_unflatten(out_tree(), out_primals_flat)
@@ -151,12 +153,14 @@ def _getattr_jvp(trace, obj, attr):
151153
ad.LinearizeTrace.process_setattr = _setattr_jvp
152154
ad.LinearizeTrace.process_getattr = _getattr_jvp
153155

154-
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
156+
def linearize(f: Callable, *primals, attrs: list[tuple[Any, str]] = []):
155157
attr_primals = [jax_getattr(o, a) for o, a in attrs]
156158
attr_avals = [core.get_aval(p) for p in attr_primals]
157159
primals_flat, in_tree = tree_flatten(primals)
158160
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
159-
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
161+
dbg = api_util.debug_info("attrs linearize", f, primals, {})
162+
f_, out_tree = flatten_fun_nokwargs(
163+
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
160164
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
161165
f_, *attr_primals, *primals_flat)
162166
f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
@@ -206,7 +210,9 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
206210
attr_primals = [jax_getattr(o, a) for o, a in attrs]
207211
primals_flat, in_tree = tree_flatten(primals)
208212
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
209-
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
213+
dbg = api_util.debug_info("attrs vjp", f, primals, {})
214+
f_, out_tree = flatten_fun_nokwargs(
215+
_set_attrs(lu.wrap_init(f, debug_info=dbg), attrs), tree)
210216
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
211217
f_, *attr_primals, *primals_flat)
212218
attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval()

‎jax/experimental/jax2tf/jax2tf.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ def __init__(self, fun_jax, *,
599599
name_stack = util.wrap_name(fun_name, "jax2tf")
600600
self.name_stack = name_stack
601601
self.args_flat_tf = args_flat_tf
602+
self.debug = api_util.debug_info("jax2tf", fun_jax,
603+
args_specs, kwargs_specs)
602604

603605
def before_conversion(self):
604606
prev_enable_xla = _thread_local_state.enable_xla
@@ -623,7 +625,10 @@ def _restore_context():
623625
dim_values, _ = _interpret_fun_jax(
624626
partial(shape_poly.compute_dim_vars_from_arg_shapes,
625627
self.args_avals_flat, args_kwargs_tree=self.in_tree),
626-
self.args_flat_tf, self.args_avals_flat, self.name_stack)
628+
self.args_flat_tf, self.args_avals_flat, self.name_stack,
629+
debug_info=api_util.debug_info("jax2tf dim_vars",
630+
shape_poly.compute_dim_vars_from_arg_shapes,
631+
self.args_specs, self.kwargs_specs))
627632

628633
_thread_local_state.shape_env = zip(dim_vars, dim_values)
629634

@@ -639,7 +644,8 @@ def run_fun_tf(self,
639644
fun_flat_jax,
640645
args_flat_tf, self.args_avals_flat,
641646
self.name_stack,
642-
fresh_constant_cache=True)
647+
fresh_constant_cache=True,
648+
debug_info=self.debug)
643649
return outs_tf, self.outs_avals, out_tree_thunk()
644650

645651
def get_vjp_fun(self) -> tuple[Callable,
@@ -849,10 +855,12 @@ def _interpret_fun_jax(
849855
fun_jax: Callable,
850856
args_tf: Sequence[TfVal],
851857
args_avals: Sequence[core.ShapedArray],
852-
extra_name_stack: str | None,
858+
extra_name_stack: str | None, *,
853859
fresh_constant_cache: bool = False,
860+
debug_info: core.DebugInfo,
854861
) -> tuple[tuple[TfVal, ...], tuple[core.ShapedArray, ...]]:
855-
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), args_avals)
862+
subtrace_fun = _interpret_subtrace(
863+
lu.wrap_init(fun_jax, debug_info=debug_info), args_avals)
856864
with _extended_name_stack(extra_name_stack):
857865
out_vals: Sequence[tuple[TfVal, core.ShapedArray]] = \
858866
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
@@ -1033,7 +1041,9 @@ def impl_multiple_results_jax(*args_jax):
10331041

10341042
results_tf, _ = _interpret_fun_jax(
10351043
impl_multiple_results_jax, args_tf, _in_avals,
1036-
extra_name_stack)
1044+
extra_name_stack,
1045+
debug_info=api_util.debug_info("jax2tf", impl_jax,
1046+
args_tf, kwargs))
10371047
return results_tf if multiple_results else results_tf[0]
10381048

10391049
return wrapped_tf
@@ -1066,7 +1076,8 @@ def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
10661076
"""
10671077
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
10681078
args_tf, jaxpr.in_avals, extra_name_stack,
1069-
fresh_constant_cache=fresh_constant_cache)
1079+
fresh_constant_cache=fresh_constant_cache,
1080+
debug_info=jaxpr.jaxpr.debug_info)
10701081
return outs_tf
10711082

10721083

@@ -1197,7 +1208,9 @@ def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfV
11971208
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
11981209
shape_values_tf, _ = _interpret_fun_jax(
11991210
partial(core.evaluate_shape, shape, dim_vars),
1200-
dim_values, [core.dim_value_aval()] * len(dim_values), "") # type: ignore
1211+
dim_values, [core.dim_value_aval()] * len(dim_values), "", # type: ignore
1212+
debug_info=api_util.debug_info("jax2tf evaluate_shape", core.evaluate_shape,
1213+
(0, 0, *dim_values), {}))
12011214
# Keep only the non-constant dimensions
12021215
return tuple(operator.index(d) if core.is_constant_dim(d) else d_tf # type: ignore
12031216
for d, d_tf in zip(shape, shape_values_tf))
@@ -3431,10 +3444,10 @@ def _tridiagonal_solve(*args: TfVal, _in_avals, _out_aval, **params):
34313444
tf_impl_with_avals[lax.linalg.tridiagonal_solve_p] = _tridiagonal_solve
34323445

34333446
def _custom_jvp_call(*args: TfVal, call_jaxpr: core.ClosedJaxpr,
3434-
jvp_jaxpr_thunk: Callable,
3447+
jvp_jaxpr_fun: Callable,
34353448
num_consts: int) -> Sequence[TfVal]:
34363449
# TODO(necula): ensure that there is no AD transformation in scope
3437-
del jvp_jaxpr_thunk, num_consts
3450+
del jvp_jaxpr_fun, num_consts
34383451
return _interpret_jaxpr(call_jaxpr, *args, extra_name_stack="custom_jvp",
34393452
fresh_constant_cache=False)
34403453

‎jax/experimental/jax2tf/tests/shape_poly_test.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import jax.numpy as jnp
3939
from jax import random
4040
from jax import tree_util
41+
from jax._src import api_util
4142
from jax._src import config
4243
from jax._src import core
4344
from jax._src import test_util as jtu
@@ -442,7 +443,10 @@ def f_tf(*args_tf):
442443
partial(shape_poly.compute_dim_vars_from_arg_shapes,
443444
avals,
444445
args_kwargs_tree=tree_util.tree_flatten((avals, {}))[1]),
445-
args_tf, avals, "")
446+
args_tf, avals, "",
447+
debug_info=api_util.debug_info("jax2tf dim_vars",
448+
shape_poly.compute_dim_vars_from_arg_shapes,
449+
avals, {}))
446450
if expected_shapes is not None:
447451
expected_avals = tree_util.tree_map(
448452
lambda shape_str: core.ShapedArray(

‎jax/experimental/shard_map.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,7 @@ def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
13871387

13881388

13891389
@register_check(custom_derivatives.custom_jvp_call_p)
1390-
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk,
1390+
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_fun,
13911391
num_consts, symbolic_zeros):
13921392
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
13931393

@@ -1780,31 +1780,33 @@ def _partial_eval_jaxpr_custom_rule(
17801780
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
17811781
_partial_eval_jaxpr_custom_rule
17821782

1783-
def _add_reshapes(which, jaxpr_known, jaxpr_staged):
1783+
def _add_reshapes(which: Sequence[bool],
1784+
jaxpr_known: core.Jaxpr,
1785+
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
17841786
# add singleton axes to residuals which are from jaxpr_known and are scalars
17851787
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
17861788
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
17871789
if not any(which_): return jaxpr_known, jaxpr_staged
17881790
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
17891791

1790-
@lu.wrap_init
17911792
def known(*args):
17921793
out = core.eval_jaxpr(jaxpr_known, (), *args)
17931794
out_known, res = split_list(out, [len(out) - sum(which)])
17941795
res = [_add_singleton(x) if not x.shape else x for x in res]
17951796
return [*out_known, *res]
17961797
avals_in = [v.aval for v in jaxpr_known.invars]
1797-
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(known, avals_in)
1798+
jaxpr_known, _, (), () = pe.trace_to_jaxpr_dynamic(
1799+
lu.wrap_init(known, debug_info=jaxpr_known.debug_info), avals_in)
17981800

1799-
@lu.wrap_init
18001801
def staged(*args):
18011802
res_, ins = split_list(args, [len(which)])
18021803
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
18031804
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
18041805
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
18051806
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
18061807
avals_in = [*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]]
1807-
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(staged, avals_in)
1808+
jaxpr_staged, _, (), () = pe.trace_to_jaxpr_dynamic(
1809+
lu.wrap_init(staged, debug_info=jaxpr_staged.debug_info), avals_in)
18081810

18091811
return jaxpr_known, jaxpr_staged
18101812

@@ -2070,7 +2072,8 @@ def _replication_rewrite_match(
20702072
in_rep: Sequence[set[AxisName]],
20712073
out_rep_dst: Sequence[set[AxisName]],
20722074
) -> core.ClosedJaxpr:
2073-
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
2075+
f = lu.wrap_init(partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
2076+
debug_info=jaxpr.jaxpr.debug_info)
20742077
f, out_rep = _efficient_transpose_rewrite_nomatch(f, mesh, in_rep)
20752078
f = _match_rep(f, mesh, out_rep, out_rep_dst)
20762079
jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)

‎jax/experimental/sparse/bcoo.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1042,8 +1042,14 @@ def _bcoo_dot_general_sampled_impl(A, B, indices, *, dimension_numbers):
10421042

10431043
@bcoo_dot_general_sampled_p.def_abstract_eval
10441044
def _bcoo_dot_general_sampled_abstract_eval(A, B, indices, *, dimension_numbers):
1045-
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B)
1046-
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result)
1045+
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
1046+
lax.dot_general, (A, B), dict(dimension_numbers=dimension_numbers))
1047+
dense_result, = pe.abstract_eval_fun(lambda *args: [lax.dot_general(*args, dimension_numbers=dimension_numbers)], A, B,
1048+
debug_info=dbg)
1049+
dbg = api_util.debug_info("bcoo_dot_general_sampled_abstract_eval",
1050+
_bcoo_extract, (indices, dense_result), {})
1051+
sparse_result, = pe.abstract_eval_fun(lambda *args: [_bcoo_extract(*args)], indices, dense_result,
1052+
debug_info=dbg)
10471053
return sparse_result
10481054

10491055
def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers):

‎jax/experimental/sparse/transform.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
import jax
5757
from jax import lax
58+
from jax._src import api_util
5859
from jax._src import config
5960
from jax._src import core
6061
from jax._src.custom_derivatives import lift_jvp
@@ -365,12 +366,15 @@ def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]):
365366
spenv = SparsifyEnv(out_bufs)
366367
return spvalues_to_arrays(spenv, out_spvalues())
367368

368-
def _sparsify_with_tracer(fun):
369+
def _sparsify_with_tracer(fun: Callable):
369370
"""Implementation of sparsify() using tracers."""
370371
@functools.wraps(fun)
371372
def _wrapped(*args):
372373
args_flat, in_tree = tree_flatten(args, is_leaf=_is_sparse_obj)
373-
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
374+
wrapped_fun, out_tree = flatten_fun_nokwargs(
375+
lu.wrap_init(fun,
376+
debug_info=api_util.debug_info("sparsify", fun, args, {})),
377+
in_tree)
374378
out = sparsify_fun(wrapped_fun, args_flat)
375379
return tree_unflatten(out_tree(), out)
376380
return _wrapped
@@ -439,7 +443,12 @@ def wrapped(
439443
) -> tuple[Sequence[SparsifyValue], pytree.PyTreeDef]:
440444
spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
441445
in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
442-
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
446+
wrapped_fun, out_tree = flatten_fun_nokwargs(
447+
lu.wrap_init(
448+
f, params,
449+
debug_info=api_util.debug_info("sparsify", f,
450+
spvalues_to_arrays(spenv, spvalues), {})),
451+
in_tree)
443452
jaxpr, out_avals_flat, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
444453
result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
445454
if len(out_avals_flat) != len(result):
@@ -716,14 +725,14 @@ def _gather_sparse_rule(spenv, *args, dimension_numbers, slice_sizes, unique_ind
716725

717726
sparse_rules_bcoo[lax.gather_p] = _gather_sparse_rule
718727

719-
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
728+
def _sparsify_jaxpr(spenv: SparsifyEnv,
729+
jaxpr: core.ClosedJaxpr, *spvalues):
720730
# TODO(jakevdp): currently this approach discards all information about
721731
# shared data & indices when generating the sparsified jaxpr. The
722732
# current approach produces valid sparsified while loops, but they
723733
# don't work in corner cases (see associated TODO in sparsify_test.py)
724734
out_tree: pytree.PyTreeDef | None = None
725735

726-
@lu.wrap_init
727736
def wrapped(*args_flat):
728737
# TODO(frostig,jakevdp): This closes over `spenv`, which can bring
729738
# in buffers from the "outer scope" as constants. Is this a
@@ -740,7 +749,8 @@ def wrapped(*args_flat):
740749
args = spvalues_to_arrays(spenv, spvalues)
741750
args_flat, in_tree = tree_flatten(args)
742751
avals_flat = [core.get_aval(arg) for arg in args_flat]
743-
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
752+
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
753+
lu.wrap_init(wrapped, debug_info=jaxpr.jaxpr.debug_info), avals_flat)
744754
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
745755
assert out_tree is not None
746756
return sp_jaxpr, out_tree
@@ -866,17 +876,18 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
866876

867877
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
868878
call_jaxpr: core.ClosedJaxpr = params.pop('call_jaxpr')
869-
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
870-
num_consts = params.pop('num_consts')
879+
jvp_jaxpr_fun: lu.WrappedFun = params.pop('jvp_jaxpr_fun')
880+
num_consts: int = params.pop('num_consts')
871881
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
872-
@lu.wrap_init
873882
def fun(*arrs):
874883
sparrs = arrays_to_spvalues(spenv, arrs)
875884
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
876885
return spvalues_to_arrays(spenv, out)
877-
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk, call_jaxpr.jaxpr.debug_info)
886+
jvp = lift_jvp(num_consts, jvp_jaxpr_fun)
878887
invals = spvalues_to_arrays(spenv, spvalues)
879-
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
888+
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(
889+
lu.wrap_init(fun, debug_info=call_jaxpr.jaxpr.debug_info),
890+
jvp, *invals, **params)
880891
return arrays_to_spvalues(spenv, outvals)
881892

882893
sparse_rules_bcoo[jax.custom_derivatives.custom_jvp_call_p] = _custom_jvp_sparse_rule

‎tests/core_test.py

+24-15
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
from jax import lax
2727
from jax import numpy as jnp
2828
from jax import jvp, linearize, vjp, jit, make_jaxpr
29-
from jax.api_util import flatten_fun_nokwargs
29+
from jax.api_util import flatten_fun_nokwargs, debug_info
3030
from jax._src import config
31-
3231
from jax._src import core
3332
from jax._src import linear_util as lu
3433
from jax._src import util
@@ -48,14 +47,16 @@ def call(f, *args):
4847
@util.curry
4948
def core_call(f, *args):
5049
args, in_tree = jax.tree.flatten(args)
51-
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
50+
dbg = debug_info("core_call_test", f, args, {})
51+
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
5252
out = core.call_p.bind(f, *args)
5353
return jax.tree.unflatten(out_tree(), out)
5454

5555
@util.curry
5656
def core_closed_call(f, *args):
5757
args, in_tree = jax.tree.flatten(args)
58-
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
58+
dbg = debug_info("core_closed_call_test", f, args, {})
59+
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, debug_info=dbg), in_tree)
5960
out = core.closed_call_p.bind(f, *args)
6061
return jax.tree.unflatten(out_tree(), out)
6162

@@ -362,7 +363,10 @@ def body(c, _):
362363

363364
aval = core.ShapedArray((), jnp.dtype('int32'))
364365
pval = pe.PartialVal.unknown(aval)
365-
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(lu.wrap_init(f), [pval], False)
366+
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(
367+
lu.wrap_init(f,
368+
debug_info=debug_info("test", f, (0,), {})),
369+
[pval], False)
366370
dropvar, b = jaxpr.eqns[0].outvars
367371
self.assertEqual(dropvar.aval, aval)
368372

@@ -548,12 +552,13 @@ def test_staging_basic(self):
548552
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
549553
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
550554

551-
@lu.wrap_init
552555
def f(x, y):
553556
return x, y
554557

555558
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
556-
f, [n, a, b], keep_inputs=[False, True, True])
559+
lu.wrap_init(f,
560+
debug_info=debug_info("test", f, (1, 2), {})),
561+
[n, a, b], keep_inputs=[False, True, True])
557562

558563
self.assertLen(jaxpr.invars, 3)
559564
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@@ -569,15 +574,16 @@ def test_staging_nested(self):
569574
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
570575
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
571576

572-
@lu.wrap_init
573577
def f(x, y):
574578
@jax.jit
575579
def g(x, y, z, w):
576580
return (x, w)
577581
return g(x, y, x, y)
578582

579583
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
580-
f, [n, a, b], keep_inputs=[False, True, True])
584+
lu.wrap_init(f,
585+
debug_info=debug_info("test", f, (0, 1), {})),
586+
[n, a, b], keep_inputs=[False, True, True])
581587

582588
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
583589
self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
@@ -605,15 +611,16 @@ def test_staging_nested_including_shape_arg(self):
605611
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
606612
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
607613

608-
@lu.wrap_init
609614
def f(x, y):
610615
@jax.jit
611616
def g(_, x, y, z, w):
612617
return (x, w)
613618
return g(x.shape[0], x, y, x, y)
614619

615620
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
616-
f, [n, a, b], keep_inputs=[False, True, True])
621+
lu.wrap_init(f,
622+
debug_info=debug_info("test", f, (1, 2), {})),
623+
[n, a, b], keep_inputs=[False, True, True])
617624

618625
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
619626
# d:f32[a] e:f32[a] = xla_call[
@@ -641,15 +648,16 @@ def test_staging_primitive_applications(self):
641648
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
642649
b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
643650

644-
@lu.wrap_init
645651
def f(x, y):
646652
z = lax.mul(x, y)
647653
w = lax.sin(z)
648654
u = lax_internal._reduce_sum(w, [0])
649655
return (u,)
650656

651657
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
652-
f, [n, a, b], keep_inputs=[False, True, True])
658+
lu.wrap_init(f,
659+
debug_info=debug_info("test", f, (1, 2), {})),
660+
[n, a, b], keep_inputs=[False, True, True])
653661

654662
self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs
655663
self.assertLen(jaxpr.eqns, 3)
@@ -667,14 +675,15 @@ def test_typecheck_staging_nested(self):
667675
a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False)
668676
b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False)
669677

670-
@lu.wrap_init
671678
def f(a, b):
672679
@jax.jit
673680
def g(x): return x
674681
return g(a),
675682

676683
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
677-
f, [n, m, a, b], keep_inputs=[False, False, True, True])
684+
lu.wrap_init(f,
685+
debug_info=debug_info("test", f, (1, 2), {})),
686+
[n, m, a, b], keep_inputs=[False, False, True, True])
678687
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
679688
# e:f32[a] = xla_call[
680689
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }

‎tests/debug_info_test.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -881,9 +881,9 @@ def to_remat(x):
881881
tracer_spy=tracer_spy,
882882
expected_jaxpr_debug_infos=[
883883
# TODO(necula): what are these flat_index components?
884-
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][<flat index 0>][0][<flat index 0>][0][0]",
885-
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
886-
re.compile(r"traced_for=jit, fun=relu at .*/nn/functions.py:.*, arg_names=x, result_paths="),
884+
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
885+
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
886+
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
887887
],
888888
check_tracer_arg_name=True,
889889
expected_tracer_debug_infos=[
@@ -1071,8 +1071,7 @@ def fn_tp(r, t):
10711071
expected_jaxpr_debug_infos=[
10721072
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=",
10731073
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=",
1074-
# TODO(necula): flat_index?
1075-
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[<flat index 0>][0][0],[<flat index 0>][0][1]",
1074+
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0][0][0],[0][0][1]",
10761075
],
10771076
check_tracer_arg_name=True,
10781077
expected_tracer_debug_infos=[

‎tests/shard_map_test.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import jax
3131
import jax.ad_checkpoint
32+
from jax import api_util
3233
from jax import lax
3334
from jax.sharding import Mesh, NamedSharding
3435
from jax.sharding import PartitionSpec as P
@@ -1329,7 +1330,11 @@ def foo(x):
13291330

13301331
def test_rewrite_process_call(self):
13311332
def f(x):
1332-
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
1333+
return core.call_p.bind(
1334+
lu.wrap_init(lambda x: [2. * x],
1335+
debug_info=api_util.debug_info("test", lambda x: [2. * x],
1336+
(x,), {})),
1337+
x)[0] * x
13331338

13341339
mesh = jtu.create_mesh((4,), ('x',))
13351340
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
@@ -1345,7 +1350,10 @@ def test_rewrite_post_process_call(self):
13451350
@jax.jit
13461351
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
13471352
def f(x):
1348-
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
1353+
return core.call_p.bind(
1354+
lu.wrap_init(lambda: [2. * x],
1355+
debug_info=api_util.debug_info("test", lambda: [2. * x],
1356+
(), {})))[0] * x
13491357

13501358
x = jnp.arange(4.)
13511359
y = f(x)

0 commit comments

Comments
 (0)
Please sign in to comment.