GPT models built with JAX
This project implements the GPT series of models using Jax and Flax's NNX library.
Install the UV python package managment library
curl -LsSf https://astral.sh/uv/install.sh | sh
The main commands available in the Makefile are:
make install
- Install dependencies from lockfilemake dev
- Install all dependencies including dev from lockfilemake clean
- Clean build artifacts and cachemake build
- Build packagemake lint
- Run lintingmake format
- Format codemake lab
- Run Jupyter lab server from the project directory
To see all available commands and their descriptions, run: make help
The training run can be reproduced using notebooks/train_gpt2.ipynb
A machine with 8 x Nvidia A100 80GB GPUs used to train for 1 epoch on a 10bn token sample of the
Fineweb-Edu dataset. Validation was performed on 1% of the dataset.
The trained model was evaluated on the Hellaswag benchmark.
The trained model achieved a score of 0.3025 on the Hellaswag benchmark.