Skip to content

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

License

Notifications You must be signed in to change notification settings

Blealtan/RWKV-LM-LoRA

This branch is 19 commits ahead of, 293 commits behind BlinkDL/RWKV-LM:main.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

4987137 · Jun 19, 2023
Jul 28, 2022
Mar 25, 2022
Jul 5, 2022
Feb 28, 2023
Jun 19, 2023
Apr 13, 2022
Sep 14, 2022
Aug 13, 2021
Sep 14, 2022
Jun 19, 2023
Mar 8, 2023
Mar 14, 2023
Jan 22, 2023
Jan 31, 2023
Feb 13, 2023
Jan 18, 2023
Feb 14, 2023
Aug 17, 2021
Jun 16, 2022
Jun 16, 2022
Mar 28, 2022
Mar 21, 2022
Jul 23, 2022
Jun 27, 2022
Sep 3, 2022
Aug 17, 2021

Repository files navigation

LoRA fork of RWKV-LM

A RWKV-LM fork, added with LoRA finetuning support. Currently only RWKV-v4neo is supported. The LoRA module is self-implemented to work with the TorchScript JIT. Existing RWKV-v4neo models/checkpoints should work out of the box. Now only LoRA-finetuned weights are checkpointed during training: it provides much smaller checkpoints, but you now need to specify the base model to use it. See args.MODEL_LOAD and args.MODEL_LORA in RWKV-v4neo/chat.py.

To finetune an existing model with LoRA, just work like full finetuning but with the LoRA options, in the directory RWKV-v4neo:

python3 train.py \
  --load_model <pretrained base model> \
  --proj_dir <place to save checkpoints> \
  --data_file <data for finetune> \
  --data_type <data type for finetune, recommend binidx> \
  --vocab_size 50277 --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 --micro_bsz 2 --accumulate_grad_batches 4 \
  --n_layer 24 --n_embd 1024 --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 \ # all your familiar options
  --lora --lora_r 8 --lora_alpha 16 --lora_dropout 0.01 \
  --lora_load <lora checkpoint to continue training> \ # optional
  --lora_parts=att,ffn,time,ln # configure which parts to finetune

The r, alpha and dropout options are up to your choice. The att, ffn, time and ln refers to the TimeMix, ChannelMix, time decay/first/mix parameters, and layernorm parameters; DON'T FORGET to add the set of parameters to be finetuned here. I'm still experimenting with different configurations; your experience is also welcomed!

Use json2binidx to convert your data into binidx, which is best suited for this trainer implementation. Once you have the pair of files path/to/foo.bin and path/to/foo.idx, pass --data_file path/to/foo --data_type binidx as arguments. Notice that the .bin and .idx suffix is not there.

To use the finetuned model, use chat.py as usual with the checkpoints in your specified proj_dir, but remember to align the LoRA-corresponded options with what you have specified during training!

args.MODEL_LORA = 'your_lora_checkpoint.pth'
args.lora_r = 8
args.lora_alpha = 32

About

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

Resources

License

Citation

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 91.5%
  • Cuda 6.9%
  • C++ 1.6%