You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been running a speed test in my repository (https://github.com/danielpmorton/cbfpy/blob/main/test/test_speed.py) and I've noticed a significant reduction in speed with newer versions of Jax/Jaxlib. Running this with newer versions of Jax after 0.4.32 results in this script being approximately 1/3 the speed as with 0.4.31 and previous versions
I know 0.4.32 is a yanked release, but this slowdown still exists in 0.4.33 and later
I'll try to come up with a minimal working example, as I know this is hard to reproduce without installing my repo. In general, I'm solving a lot of QPs on CPU using qpax (https://github.com/kevin-tracy/qpax)
(cbfpy_release) dmorton@asl-nuc:~/cbfpy_release$ pip install "jaxlib==0.4.32" "jax==0.4.32"
(... various pip installation logging)
(cbfpy_release) dmorton@asl-nuc:~/cbfpy_release$ python test/test_speed.py
pybullet build time: Nov 28 2023 23:48:36
pygame 2.6.1 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
ACC average Hz: 24240.550542352157
.Point robot average Hz: 24371.811879894532
.
----------------------------------------------------------------------
Ran 2 tests in 2.409s
OK
(cbfpy_release) dmorton@asl-nuc:~/cbfpy_release$ pip install "jaxlib==0.4.31" "jax==0.4.31"
(... various pip installation logging)
(cbfpy_release) dmorton@asl-nuc:~/cbfpy_release$ python test/test_speed.py
pybullet build time: Nov 28 2023 23:48:36
pygame 2.6.1 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html
**ACC average Hz: 57967.98398949308
.Point robot average Hz: 62571.14449011536**
.
----------------------------------------------------------------------
Ran 2 tests in 2.073s
OK
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.10.8 (main, May 30 2024, 10:58:14) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asl-nuc', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May 7 09:00:52 UTC 2', machine='x86_64')
The text was updated successfully, but these errors were encountered:
Another note: the following code does not fix the problem, and in fact makes the issue much worse. I noticed this on a few other threads here: clemisch/jaxtomo#3#23590
import os
XLA_flag = "--xla_cpu_use_thunk_runtime=false "
os.environ["XLA_FLAGS"] = XLA_flag
Thanks for the report! There's quite a lot going on in your example, and I'm not very familiar with the dependencies, so I don't immediately have intuition about what to try. Although, I'm also surprised that setting that environment variable didn't improve the situation! So it would be useful if you could try to dig a little deeper and isolate the issue with a smaller self contained reproducer so that we have something to work with.
Description
I've been running a speed test in my repository (https://github.com/danielpmorton/cbfpy/blob/main/test/test_speed.py) and I've noticed a significant reduction in speed with newer versions of Jax/Jaxlib. Running this with newer versions of Jax after 0.4.32 results in this script being approximately 1/3 the speed as with 0.4.31 and previous versions
I know 0.4.32 is a yanked release, but this slowdown still exists in 0.4.33 and later
I'll try to come up with a minimal working example, as I know this is hard to reproduce without installing my repo. In general, I'm solving a lot of QPs on CPU using qpax (https://github.com/kevin-tracy/qpax)
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: