26
26
from jax import lax
27
27
from jax import numpy as jnp
28
28
from jax import jvp , linearize , vjp , jit , make_jaxpr
29
- from jax .api_util import flatten_fun_nokwargs
29
+ from jax .api_util import flatten_fun_nokwargs , debug_info
30
30
from jax ._src import config
31
-
32
31
from jax ._src import core
33
32
from jax ._src import linear_util as lu
34
33
from jax ._src import util
@@ -48,14 +47,16 @@ def call(f, *args):
48
47
@util .curry
49
48
def core_call (f , * args ):
50
49
args , in_tree = jax .tree .flatten (args )
51
- f , out_tree = flatten_fun_nokwargs (lu .wrap_init (f ), in_tree )
50
+ dbg = debug_info ("core_call_test" , f , args , {})
51
+ f , out_tree = flatten_fun_nokwargs (lu .wrap_init (f , debug_info = dbg ), in_tree )
52
52
out = core .call_p .bind (f , * args )
53
53
return jax .tree .unflatten (out_tree (), out )
54
54
55
55
@util .curry
56
56
def core_closed_call (f , * args ):
57
57
args , in_tree = jax .tree .flatten (args )
58
- f , out_tree = flatten_fun_nokwargs (lu .wrap_init (f ), in_tree )
58
+ dbg = debug_info ("core_closed_call_test" , f , args , {})
59
+ f , out_tree = flatten_fun_nokwargs (lu .wrap_init (f , debug_info = dbg ), in_tree )
59
60
out = core .closed_call_p .bind (f , * args )
60
61
return jax .tree .unflatten (out_tree (), out )
61
62
@@ -362,7 +363,10 @@ def body(c, _):
362
363
363
364
aval = core .ShapedArray ((), jnp .dtype ('int32' ))
364
365
pval = pe .PartialVal .unknown (aval )
365
- jaxpr , _ , _ = pe .trace_to_jaxpr_nounits (lu .wrap_init (f ), [pval ], False )
366
+ jaxpr , _ , _ = pe .trace_to_jaxpr_nounits (
367
+ lu .wrap_init (f ,
368
+ debug_info = debug_info ("test" , f , (0 ,), {})),
369
+ [pval ], False )
366
370
dropvar , b = jaxpr .eqns [0 ].outvars
367
371
self .assertEqual (dropvar .aval , aval )
368
372
@@ -548,12 +552,13 @@ def test_staging_basic(self):
548
552
a = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
549
553
b = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
550
554
551
- @lu .wrap_init
552
555
def f (x , y ):
553
556
return x , y
554
557
555
558
jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
556
- f , [n , a , b ], keep_inputs = [False , True , True ])
559
+ lu .wrap_init (f ,
560
+ debug_info = debug_info ("test" , f , (1 , 2 ), {})),
561
+ [n , a , b ], keep_inputs = [False , True , True ])
557
562
558
563
self .assertLen (jaxpr .invars , 3 )
559
564
self .assertEqual ((jaxpr .invars [0 ],), jaxpr .invars [1 ].aval .shape )
@@ -569,15 +574,16 @@ def test_staging_nested(self):
569
574
a = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
570
575
b = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
571
576
572
- @lu .wrap_init
573
577
def f (x , y ):
574
578
@jax .jit
575
579
def g (x , y , z , w ):
576
580
return (x , w )
577
581
return g (x , y , x , y )
578
582
579
583
jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
580
- f , [n , a , b ], keep_inputs = [False , True , True ])
584
+ lu .wrap_init (f ,
585
+ debug_info = debug_info ("test" , f , (0 , 1 ), {})),
586
+ [n , a , b ], keep_inputs = [False , True , True ])
581
587
582
588
self .assertLen (jaxpr .invars , 1 + 2 ) # one axis size var, two other inputs
583
589
self .assertEqual ((jaxpr .invars [0 ],), jaxpr .invars [1 ].aval .shape )
@@ -605,15 +611,16 @@ def test_staging_nested_including_shape_arg(self):
605
611
a = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
606
612
b = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
607
613
608
- @lu .wrap_init
609
614
def f (x , y ):
610
615
@jax .jit
611
616
def g (_ , x , y , z , w ):
612
617
return (x , w )
613
618
return g (x .shape [0 ], x , y , x , y )
614
619
615
620
jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
616
- f , [n , a , b ], keep_inputs = [False , True , True ])
621
+ lu .wrap_init (f ,
622
+ debug_info = debug_info ("test" , f , (1 , 2 ), {})),
623
+ [n , a , b ], keep_inputs = [False , True , True ])
617
624
618
625
# { lambda ; a:i32[] b:f32[a] c:f32[a]. let
619
626
# d:f32[a] e:f32[a] = xla_call[
@@ -641,15 +648,16 @@ def test_staging_primitive_applications(self):
641
648
a = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
642
649
b = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
643
650
644
- @lu .wrap_init
645
651
def f (x , y ):
646
652
z = lax .mul (x , y )
647
653
w = lax .sin (z )
648
654
u = lax_internal ._reduce_sum (w , [0 ])
649
655
return (u ,)
650
656
651
657
jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
652
- f , [n , a , b ], keep_inputs = [False , True , True ])
658
+ lu .wrap_init (f ,
659
+ debug_info = debug_info ("test" , f , (1 , 2 ), {})),
660
+ [n , a , b ], keep_inputs = [False , True , True ])
653
661
654
662
self .assertLen (jaxpr .invars , 1 + 2 ) # one axis size var, two other inputs
655
663
self .assertLen (jaxpr .eqns , 3 )
@@ -667,14 +675,15 @@ def test_typecheck_staging_nested(self):
667
675
a = core .DShapedArray ((DBIdx (0 ),), jnp .dtype ('float32' ), weak_type = False )
668
676
b = core .DShapedArray ((DBIdx (1 ),), jnp .dtype ('float32' ), weak_type = False )
669
677
670
- @lu .wrap_init
671
678
def f (a , b ):
672
679
@jax .jit
673
680
def g (x ): return x
674
681
return g (a ),
675
682
676
683
jaxpr , _ , _ , () = pe .trace_to_jaxpr_dynamic (
677
- f , [n , m , a , b ], keep_inputs = [False , False , True , True ])
684
+ lu .wrap_init (f ,
685
+ debug_info = debug_info ("test" , f , (1 , 2 ), {})),
686
+ [n , m , a , b ], keep_inputs = [False , False , True , True ])
678
687
# { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
679
688
# e:f32[a] = xla_call[
680
689
# call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) }
0 commit comments