Skip to content

Commit ea86f61

Browse files
hzhwcmhfjklj077
authored andcommittedDec 25, 2023
add run gptq
1 parent 65c7303 commit ea86f61

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed
 

‎README.md

+48
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,54 @@ tokenizer.save_pretrained(new_model_directory)
723723

724724
Note: For multi-GPU training, you need to specify the proper hyperparameters for distributed training based on your machine. Besides, we advise you to specify your maximum sequence length with the argument `--model_max_length`, based on your consideration of data, memory footprint, and training speed.
725725

726+
### Quantize Fine-tuned Models
727+
728+
This section applies to full-parameter/LoRA fine-tuned models. (Note: You do not need to quantize the Q-LoRA fine-tuned model because it is already quantized.)
729+
If you use LoRA, please follow the above instructions to merge your model before quantization.
730+
731+
We recommend using [auto_gptq](https://github.com/PanQiWei/AutoGPTQ) to quantize the finetuned model.
732+
733+
```bash
734+
pip install auto-gptq optimum
735+
```
736+
737+
Note: Currently AutoGPTQ has a bug referred in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/370). Here is a [workaround PR](https://github.com/PanQiWei/AutoGPTQ/pull/495), and you can pull this branch and install from the source.
738+
739+
First, prepare the calibration data. You can reuse the fine-tuning data, or use other data following the same format.
740+
741+
Second, run the following script:
742+
743+
```bash
744+
python run_gptq.py \
745+
--model_name_or_path $YOUR_LORA_MODEL_PATH \
746+
--data_path $DATA \
747+
--out_path $OUTPUT_PATH \
748+
--bits 4 # 4 for int4; 8 for int8
749+
```
750+
751+
This step requires GPUs and may costs a few hours according to your data size and model size.
752+
753+
Then, copy all `*.py`, `*.cu`, `*.cpp` files and `generation_config.json` to the output path. And we recommend you to overwrite `config.json` by copying the file from the coresponding official quantized model
754+
(for example, if you are fine-tuning `Qwen-7B-Chat` and use `--bits 4`, you can find the `config.json` from [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json)).
755+
You should also rename the ``gptq.safetensors`` into ``model.safetensors``.
756+
757+
Finally, test the model by the same method to load the official quantized model. For example,
758+
759+
```python
760+
from transformers import AutoModelForCausalLM, AutoTokenizer
761+
from transformers.generation import GenerationConfig
762+
763+
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
764+
765+
model = AutoModelForCausalLM.from_pretrained(
766+
"/path/to/your/model",
767+
device_map="auto",
768+
trust_remote_code=True
769+
).eval()
770+
771+
response, history = model.chat(tokenizer, "你好", history=None)
772+
print(response)
773+
```
726774

727775
### Profiling of Memory and Speed
728776
We profile the GPU memory and training speed of both LoRA (LoRA (emb) refers to training the embedding and output layer, while LoRA has no trainable embedding and output layer) and Q-LoRA in the setup of single-GPU training. In this test, we experiment on a single A100-SXM4-80G GPU, and we use CUDA 11.8 and Pytorch 2.0. Flash attention 2 is applied. We uniformly use a batch size of 1 and gradient accumulation of 8. We profile the memory (GB) and speed (s/iter) of inputs of different lengths, namely 256, 512, 1024, 2048, 4096, and 8192. We also report the statistics of full-parameter finetuning with Qwen-7B on 2 A100 GPUs. We only report the statistics of 256, 512, and 1024 tokens due to the limitation of GPU memory.

‎README_CN.md

+49
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,55 @@ tokenizer.save_pretrained(new_model_directory)
713713

714714
注意:分布式训练需要根据你的需求和机器指定正确的分布式训练超参数。此外,你需要根据你的数据、显存情况和训练速度预期,使用`--model_max_length`设定你的数据长度。
715715

716+
### 量化微调后模型
717+
718+
这一小节用于量化全参/LoRA微调后的模型。(注意:你不需要量化Q-LoRA模型因为它本身就是量化过的。)
719+
如果你需要量化LoRA微调后的模型,请先根据上方说明去合并你的模型权重。
720+
721+
我们推荐使用[auto_gptq](https://github.com/PanQiWei/AutoGPTQ)去量化你的模型。
722+
723+
```bash
724+
pip install auto-gptq optimum
725+
```
726+
727+
注意: 当前AutoGPTQ有个bug,可以在该[issue](https://github.com/PanQiWei/AutoGPTQ/issues/370)查看。这里有个[修改PR](https://github.com/PanQiWei/AutoGPTQ/pull/495),你可以使用该分支从代码进行安装。
728+
729+
首先,准备校准集。你可以重用微调你的数据,或者按照微调相同的方式准备其他数据。
730+
731+
第二步,运行以下命令:
732+
733+
```bash
734+
python run_gptq.py \
735+
--model_name_or_path $YOUR_LORA_MODEL_PATH \
736+
--data_path $DATA \
737+
--out_path $OUTPUT_PATH \
738+
--bits 4 # 4 for int4; 8 for int8
739+
```
740+
741+
这一步需要使用GPU,根据你的校准集大小和模型大小,可能会消耗数个小时。
742+
743+
接下来, 将原模型中所有 `*.py`, `*.cu`, `*.cpp` 文件和 `generation_config.json` 文件复制到输出模型目录下。同时,使用官方对应版本的量化模型的 `config.json` 文件覆盖输出模型目录下的文件
744+
(例如, 如果你微调了 `Qwen-7B-Chat``--bits 4`, 那么你可以从 [Qwen-7B-Chat-Int4](https://huggingface.co/Qwen/Qwen-7B-Chat-Int4/blob/main/config.json) 仓库中找到对应的`config.json` )。
745+
并且,你需要将 ``gptq.safetensors`` 重命名为 ``model.safetensors``
746+
747+
最后,像官方量化模型一样测试你的模型。例如:
748+
749+
```python
750+
from transformers import AutoModelForCausalLM, AutoTokenizer
751+
from transformers.generation import GenerationConfig
752+
753+
tokenizer = AutoTokenizer.from_pretrained("/path/to/your/model", trust_remote_code=True)
754+
755+
model = AutoModelForCausalLM.from_pretrained(
756+
"/path/to/your/model",
757+
device_map="auto",
758+
trust_remote_code=True
759+
).eval()
760+
761+
response, history = model.chat(tokenizer, "你好", history=None)
762+
print(response)
763+
```
764+
716765
### 显存占用及训练速度
717766
下面记录7B和14B模型在单GPU使用LoRA(LoRA (emb)指的是embedding和输出层参与训练,而LoRA则不优化这部分参数)和QLoRA时处理不同长度输入的显存占用和训练速度的情况。本次评测运行于单张A100-SXM4-80G GPU,使用CUDA 11.8和Pytorch 2.0,并使用了flash attention 2。我们统一使用batch size为1,gradient accumulation为8的训练配置,记录输入长度分别为256、512、1024、2048、4096和8192的显存占用(GB)和训练速度(s/iter)。我们还使用2张A100测了Qwen-7B的全参数微调。受限于显存大小,我们仅测试了256、512和1024token的性能。
718767

‎run_gptq.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import argparse
2+
import json
3+
from typing import Dict
4+
import logging
5+
6+
import torch
7+
import transformers
8+
from transformers import AutoTokenizer
9+
from transformers.trainer_pt_utils import LabelSmoother
10+
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
11+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
12+
13+
def preprocess(
14+
sources,
15+
tokenizer: transformers.PreTrainedTokenizer,
16+
max_len: int,
17+
system_message: str = "You are a helpful assistant."
18+
) -> Dict:
19+
roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}
20+
21+
im_start = tokenizer.im_start_id
22+
im_end = tokenizer.im_end_id
23+
nl_tokens = tokenizer('\n').input_ids
24+
_system = tokenizer('system').input_ids + nl_tokens
25+
_user = tokenizer('user').input_ids + nl_tokens
26+
_assistant = tokenizer('assistant').input_ids + nl_tokens
27+
28+
# Apply prompt templates
29+
data = []
30+
# input_ids, targets = [], []
31+
for i, source in enumerate(sources):
32+
source = source["conversations"]
33+
if roles[source[0]["from"]] != roles["user"]:
34+
source = source[1:]
35+
36+
input_id, target = [], []
37+
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
38+
input_id += system
39+
target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens
40+
assert len(input_id) == len(target)
41+
for j, sentence in enumerate(source):
42+
role = roles[sentence["from"]]
43+
_input_id = tokenizer(role).input_ids + nl_tokens + \
44+
tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
45+
input_id += _input_id
46+
if role == '<|im_start|>user':
47+
_target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens
48+
elif role == '<|im_start|>assistant':
49+
_target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \
50+
_input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens
51+
else:
52+
raise NotImplementedError
53+
target += _target
54+
assert len(input_id) == len(target)
55+
input_id = torch.tensor(input_id[:max_len], dtype=torch.int)
56+
target = torch.tensor(target[:max_len], dtype=torch.int)
57+
data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id)))
58+
59+
return data
60+
61+
62+
if __name__ == "__main__":
63+
parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ")
64+
parser.add_argument("--model_name_or_path", type=str, help="model path")
65+
parser.add_argument("--data_path", type=str, help="calibration data path")
66+
parser.add_argument("--out_path", type=str, help="output path of the quantized model")
67+
parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data")
68+
parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.")
69+
parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model")
70+
args = parser.parse_args()
71+
72+
quantize_config = BaseQuantizeConfig(
73+
bits=args.bits,
74+
group_size=args.group_size,
75+
damp_percent=0.01,
76+
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
77+
static_groups=False,
78+
sym=True,
79+
true_sequential=True,
80+
model_name_or_path=None,
81+
model_file_base_name="model"
82+
)
83+
84+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
85+
tokenizer.pad_token_id = tokenizer.eod_id
86+
data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len)
87+
88+
model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True)
89+
90+
logging.basicConfig(
91+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
92+
)
93+
model.quantize(data, cache_examples_on_gpu=False)
94+
95+
model.save_quantized(args.out_path, use_safetensors=True)
96+
tokenizer.save_pretrained(args.out_path)

0 commit comments

Comments
 (0)
Please sign in to comment.