@@ -460,12 +460,13 @@ def _addupdate_discharge(x, val, idx, tree):
460
460
return _prepend_scatter (x , indexer , val , add = True )
461
461
462
462
@weakref_lru_cache
463
- def _cached_closed_jaxpr_discharge (closed_jaxpr ):
463
+ def _cached_closed_jaxpr_discharge (closed_jaxpr : core . ClosedJaxpr ):
464
464
jaxpr , consts = closed_jaxpr .jaxpr , closed_jaxpr .consts
465
465
num_outs = len (jaxpr .outvars )
466
466
discharged_jaxpr , discharged_consts = discharge_state (jaxpr , consts )
467
467
discharged_closed_jaxpr = core .ClosedJaxpr (discharged_jaxpr , discharged_consts )
468
- fun = lu .wrap_init (core .jaxpr_as_fun (discharged_closed_jaxpr ))
468
+ fun = lu .wrap_init (core .jaxpr_as_fun (discharged_closed_jaxpr ),
469
+ debug_info = discharged_jaxpr .debug_info )
469
470
return discharged_closed_jaxpr , num_outs , fun
470
471
471
472
@register_discharge_rule (core .closed_call_p )
@@ -598,7 +599,6 @@ def _convert_outputs_to_writes(
598
599
assert not jaxpr .constvars , "Jaxpr shouldn't have constvars."
599
600
600
601
in_avals = [v .aval for v in jaxpr .invars ]
601
- @lu .wrap_init
602
602
def eval_jaxpr (* refs ):
603
603
# We split the refs into the original input refs and the dummy residual
604
604
# refs.
@@ -610,14 +610,15 @@ def eval_jaxpr(*refs):
610
610
res_ref_avals = [AbstractRef (v .aval ) if not isinstance (v .aval , AbstractRef )
611
611
else v .aval for v in jaxpr .outvars ]
612
612
jaxpr , _ , consts , () = pe .trace_to_jaxpr_dynamic (
613
- eval_jaxpr , [* in_avals , * res_ref_avals ])
613
+ lu .wrap_init (eval_jaxpr ,
614
+ debug_info = jaxpr .debug_info ),
615
+ [* in_avals , * res_ref_avals ])
614
616
assert not consts
615
617
return jaxpr , [core .ShapedArray (a .shape , a .dtype ) for a in res_ref_avals ]
616
618
617
619
def _convert_inputs_to_reads (num_res : int , jaxpr : core .Jaxpr ) -> core .Jaxpr :
618
620
assert not jaxpr .constvars , "Jaxpr should not have constvars"
619
621
620
- @lu .wrap_init
621
622
def eval_jaxpr (* refs ):
622
623
residual_refs , orig_refs = split_list (refs , [num_res ])
623
624
residual_vals = [r [...] for r in residual_refs ]
@@ -629,7 +630,9 @@ def eval_jaxpr(*refs):
629
630
res_ref_avals = [AbstractRef (aval ) if not isinstance (aval , AbstractRef ) else
630
631
aval for aval in res_val_avals ]
631
632
jaxpr , _ , (), () = pe .trace_to_jaxpr_dynamic (
632
- eval_jaxpr , [* res_ref_avals , * orig_ref_avals ])
633
+ lu .wrap_init (eval_jaxpr ,
634
+ debug_info = jaxpr .debug_info ),
635
+ [* res_ref_avals , * orig_ref_avals ])
633
636
return jaxpr
634
637
635
638
def _run_state_partial_eval (trace : pe .JaxprTrace , * tracers : pe .JaxprTracer ,
@@ -845,12 +848,13 @@ def _run_state_partial_eval_custom(
845
848
* [v .aval for v in res_staged_invars ], ** staged_params )
846
849
_ , staged_outvars = partition_list (in_unknowns , eqn .outvars )
847
850
if num_res :
848
- @ lu . wrap_init
851
+
849
852
def staged (* args ):
850
853
out = run_state_p .bind (* args , ** staged_params )
851
854
return out [num_res :]
852
- staged_call_jaxpr , _ , (), () = pe .trace_to_jaxpr_dynamic (staged ,
853
- [v .aval for v in res_staged_invars ])
855
+ staged_call_jaxpr , _ , (), () = pe .trace_to_jaxpr_dynamic (
856
+ lu .wrap_init (staged , debug_info = jaxpr_staged .debug_info ),
857
+ [v .aval for v in res_staged_invars ])
854
858
eqn_staged = pe .new_jaxpr_eqn (res_staged_invars ,
855
859
staged_outvars ,
856
860
core .closed_call_p ,
@@ -918,7 +922,9 @@ def trans(*args):
918
922
ad .backward_pass (tangent_jaxpr , False , (), (* primals_args , * ct_args ), ())
919
923
return []
920
924
jaxpr_trans , _ , consts , () = pe .trace_to_jaxpr_dynamic (
921
- lu .wrap_init (trans ), [v .aval for v in jaxpr .invars ])
925
+ lu .wrap_init (trans ,
926
+ debug_info = jaxpr .debug_info ),
927
+ [v .aval for v in jaxpr .invars ])
922
928
return jaxpr_trans , consts
923
929
924
930
def _run_state_transpose (in_cts , * args , jaxpr : core .Jaxpr ,
0 commit comments