Skip to content

0️⃣1️⃣🤗 BitNet-Transformers: Huggingface Transformers Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch with Llama(2) Architecture

Notifications You must be signed in to change notification settings

DewEfresh/BitNet-Transformers

This branch is 13 commits ahead of Beomi/BitNet-Transformers:main.

Folders and files

NameName
Last commit message
Last commit date
Oct 19, 2023
Oct 24, 2023
Oct 20, 2023
Oct 24, 2023
Oct 20, 2023
Oct 20, 2023
Oct 19, 2023
Oct 25, 2023
Oct 25, 2023
Oct 20, 2023
Oct 20, 2023
Oct 20, 2023
Oct 24, 2023

Repository files navigation

0️⃣1️⃣🤗 BitNet-Transformers: Huggingface Transformers Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch with Mistral Architecture

BitNet Architecture

BitNet

Prepare Dev env

# Clone this repo
git clone https://github.com/DewEfresh/bitnet-transformers
cd bitnet-transformers

# Install requirements
pip install -r clm_requirements.txt

# Clone transformers repo
git clone https://github.com/huggingface/transformers
pip install -e transformers

# Update Llama(2) model
rm ./transformers/src/transformers/models/llama/modeling_llama.py
ln -s $(pwd)/bitnet_mistral/modeling_llama.py ./transformers/src/transformers/models/llama/modeling_llama.py

# Update Llama(2) model
rm ./transformers/src/transformers/models/mistral/modeling_mistral.py
ln -s $(pwd)/bitnet_mistral/modeling_mistral.py ./transformers/src/transformers/models/mistral/modeling_mistral.py

We'll overwrite bitnet_llama/modeling_llama.py into transformers. Since the file is linked, any changes made to the file will be reflected in the transformers repo.

Train Wikitext-103

Train Loss Graph when train BitLLAMA using Wikitext-103

You can track metrics via wandb

./train_wikitext.sh

GPU Mem Usage Comparison

Train Config

  • Batch size: 1
  • Gradient accumulation: 1
  • Seq length: 2048
  • Model: LLamaForCausalLM with BitLinear layer
  • Model size: 47,452,672 (47.5M)

Original LLAMA - 16bit

  • Uses 250MB GPU memory for Model weights

BitLLAMA - Mixed 16bit

  • Uses 200MB GPU memory for Model weights
  • Use bf16(or fp16) to store model weights
  • Use int8 to store -1/1 1-bit weights
  • Use more memory when training than original LLAMA: It saves 1-bit weight and 16bit weight together

BitLLAMA - 8bit

  • Uses 100MB GPU memory for Model weights
  • Use bf16(or fp16) on-the-fly when needed
  • Use 8bit to save 1-bit BitLinear weight & other weights

BitLLAMA - 1bit

  • Use bf16(or fp16) on-the-fly when needed
  • Use 1bit to save 1-bit weight
TBD

Todo

  • Add BitLinear layer
  • Add LLamaForCausalLM model with BitLinear layer
    • Update .save_pretrained method (for 1-bit weight saving)
  • Add sample code for LM training
  • Update BitLinear layer to use 1-bit weight
    • Use uint8 instead of bfloat16
    • Use custom cuda kernel for 1-bit weight

About

0️⃣1️⃣🤗 BitNet-Transformers: Huggingface Transformers Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch with Llama(2) Architecture

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 63.3%
  • Jupyter Notebook 36.1%
  • Shell 0.6%