Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 23a3c1a

Browse files
committedFeb 5, 2025
[better_errors] Add debug info to the Jaxprs formed for AD
Following jax-ml#26078 , we add debug info to more calls of lu.wrap_init.
1 parent 414449e commit 23a3c1a

22 files changed

+480
-168
lines changed
 

‎jax/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ pytype_strict_library(
564564
srcs = ["_src/interpreters/mlir.py"],
565565
deps = [
566566
":ad_util",
567+
":api_util",
567568
":config",
568569
":core",
569570
":dtypes",

‎jax/_src/ad_checkpoint.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,13 @@ def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
682682
return _transpose_jaxpr(jaxpr, tuple(in_linear), tuple(out_zeros))
683683

684684
@weakref_lru_cache
685-
def _transpose_jaxpr(jaxpr, in_lin, out_zeros):
685+
def _transpose_jaxpr(jaxpr: core.ClosedJaxpr,
686+
in_lin: Sequence[bool],
687+
out_zeros: Sequence[bool]):
686688
in_avals = ([a for a, lin in zip(jaxpr.in_avals, in_lin ) if not lin] +
687689
[a for a, zero in zip(jaxpr.out_avals, out_zeros) if not zero])
688690
cell = lambda: None
689691

690-
@lu.wrap_init
691692
def transposed(*args_flat):
692693
ins_flat, out_cts_flat = split_list(args_flat, [len(in_lin) - sum(in_lin)])
693694

@@ -715,7 +716,10 @@ def transposed(*args_flat):
715716
in_cts_nz, _ = partition_list(in_zeros, in_cts)
716717
return in_cts_nz
717718

718-
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
719+
transposed_wrapped = lu.wrap_init(transposed,
720+
debug_info=jaxpr.jaxpr.debug_info)
721+
transposed_jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(
722+
transposed_wrapped, in_avals)
719723
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
720724
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error
721725

‎jax/_src/api.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def vmap_f(*args, **kwargs):
983983
"to the positional arguments passed to the function, "
984984
f"but got {len(in_axes)=}, {len(args)=}")
985985
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
986-
f = lu.wrap_init(fun)
986+
f = lu.wrap_init(fun, debug_info=debug_info("vmap", fun, args, kwargs))
987987
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
988988
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
989989
axis_size_ = (axis_size if axis_size is not None else
@@ -1715,15 +1715,15 @@ def jvp(
17151715
0.19900084
17161716
"""
17171717
check_callable(fun)
1718-
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
1719-
1720-
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
1721-
"""Variant of jvp() that takes an lu.WrappedFun."""
17221718
if (not isinstance(primals, (tuple, list)) or
17231719
not isinstance(tangents, (tuple, list))):
17241720
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
17251721
f"found {type(primals).__name__} and {type(tangents).__name__}.")
1722+
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
1723+
primals, tangents, has_aux=has_aux)
17261724

1725+
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
1726+
"""Variant of jvp() that takes an lu.WrappedFun."""
17271727
ps_flat, tree_def = tree_flatten(primals)
17281728
ts_flat, tree_def_2 = tree_flatten(tangents)
17291729
if tree_def != tree_def_2:
@@ -1835,7 +1835,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
18351835
-6.676704
18361836
"""
18371837
check_callable(fun)
1838-
f = lu.wrap_init(fun)
1838+
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
18391839
primals_flat, in_tree = tree_flatten(primals)
18401840
if has_aux:
18411841
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
@@ -1983,8 +1983,9 @@ def vjp(
19831983
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
19841984
del reduce_axes
19851985
check_callable(fun)
1986-
return _vjp(
1987-
lu.wrap_init(fun), *primals, has_aux=has_aux)
1986+
wrapped_fun = lu.wrap_init(fun,
1987+
debug_info=debug_info("vjp", fun, primals, {}))
1988+
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
19881989

19891990
def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
19901991
"""Variant of vjp() that takes an lu.WrappedFun."""
@@ -2049,7 +2050,10 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
20492050
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
20502051
del reduce_axes
20512052
primals_flat, in_tree = tree_flatten(primals)
2052-
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
2053+
flat_fun, out_tree = flatten_fun_nokwargs(
2054+
lu.wrap_init(fun,
2055+
debug_info=debug_info("linear_transpose", fun, primals, {})),
2056+
in_tree)
20532057
in_avals = map(shaped_abstractify, primals_flat)
20542058
in_dtypes = map(dtypes.dtype, in_avals)
20552059

‎jax/_src/api_util.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
6767
return tuple(map(_ensure_str, x))
6868

6969
@lu.transformation_with_aux2
70-
def flatten_fun(f, store, in_tree, *args_flat):
70+
def flatten_fun(f: Callable, store: lu.Store,
71+
in_tree: PyTreeDef, *args_flat):
7172
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
7273
ans = f(*py_args, **py_kwargs)
7374
ans, out_tree = tree_flatten(ans)
@@ -587,8 +588,8 @@ def debug_info(
587588
args: Sequence[Any],
588589
kwargs: dict[str, Any],
589590
*,
590-
static_argnums: tuple[int, ...] = (),
591-
static_argnames: tuple[str, ...] = (),
591+
static_argnums: Sequence[int] = (),
592+
static_argnames: Sequence[str] = (),
592593
result_paths_thunk: Callable[[], tuple[str, ...]] | None = None,
593594
# TODO(necula): check if we really need this, e.g., to speed up tracing?
594595
sourceinfo: str | None = None,

‎jax/_src/checkify.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,9 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
361361
else:
362362
jaxpr, consts = call_jaxpr, ()
363363
consts_ = tuple(HashableWrapper(c) for c in consts)
364-
partial_checkify = lu.hashable_partial(lu.wrap_init(
365-
checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
364+
partial_checkify = lu.hashable_partial(
365+
lu.wrap_init(checkify_jaxpr_flat_hashable, debug_info=jaxpr.debug_info),
366+
jaxpr, consts_, enabled_errors, err_tree)
366367
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
367368
partial_checkify)
368369

@@ -746,7 +747,7 @@ def jaxpr_to_checkify_jaxpr(
746747
checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
747748
jaxpr.consts, enabled_errors,
748749
err_tree)
749-
fun = lu.wrap_init(checkify_jaxpr_partial)
750+
fun = lu.wrap_init(checkify_jaxpr_partial, debug_info=jaxpr.jaxpr.debug_info)
750751
fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
751752

752753
new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)

‎jax/_src/core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2416,8 +2416,9 @@ def call_impl(f: lu.WrappedFun, *args, **params):
24162416
class ClosedCallPrimitive(CallPrimitive):
24172417
def get_bind_params(self, params):
24182418
new_params = dict(params)
2419-
jaxpr = new_params.pop('call_jaxpr')
2420-
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
2419+
jaxpr: ClosedJaxpr = new_params.pop('call_jaxpr')
2420+
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts),
2421+
debug_info=jaxpr.jaxpr.debug_info)
24212422
return [subfun], new_params
24222423

24232424
closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')

0 commit comments

Comments
 (0)
Please sign in to comment.