|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +from typing import Dict |
| 4 | +import logging |
| 5 | + |
| 6 | +import torch |
| 7 | +import transformers |
| 8 | +from transformers import AutoTokenizer |
| 9 | +from transformers.trainer_pt_utils import LabelSmoother |
| 10 | +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig |
| 11 | +IGNORE_TOKEN_ID = LabelSmoother.ignore_index |
| 12 | + |
| 13 | +def preprocess( |
| 14 | + sources, |
| 15 | + tokenizer: transformers.PreTrainedTokenizer, |
| 16 | + max_len: int, |
| 17 | + system_message: str = "You are a helpful assistant." |
| 18 | +) -> Dict: |
| 19 | + roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} |
| 20 | + |
| 21 | + im_start = tokenizer.im_start_id |
| 22 | + im_end = tokenizer.im_end_id |
| 23 | + nl_tokens = tokenizer('\n').input_ids |
| 24 | + _system = tokenizer('system').input_ids + nl_tokens |
| 25 | + _user = tokenizer('user').input_ids + nl_tokens |
| 26 | + _assistant = tokenizer('assistant').input_ids + nl_tokens |
| 27 | + |
| 28 | + # Apply prompt templates |
| 29 | + data = [] |
| 30 | + # input_ids, targets = [], [] |
| 31 | + for i, source in enumerate(sources): |
| 32 | + source = source["conversations"] |
| 33 | + if roles[source[0]["from"]] != roles["user"]: |
| 34 | + source = source[1:] |
| 35 | + |
| 36 | + input_id, target = [], [] |
| 37 | + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens |
| 38 | + input_id += system |
| 39 | + target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens |
| 40 | + assert len(input_id) == len(target) |
| 41 | + for j, sentence in enumerate(source): |
| 42 | + role = roles[sentence["from"]] |
| 43 | + _input_id = tokenizer(role).input_ids + nl_tokens + \ |
| 44 | + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens |
| 45 | + input_id += _input_id |
| 46 | + if role == '<|im_start|>user': |
| 47 | + _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens |
| 48 | + elif role == '<|im_start|>assistant': |
| 49 | + _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ |
| 50 | + _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens |
| 51 | + else: |
| 52 | + raise NotImplementedError |
| 53 | + target += _target |
| 54 | + assert len(input_id) == len(target) |
| 55 | + input_id = torch.tensor(input_id[:max_len], dtype=torch.int) |
| 56 | + target = torch.tensor(target[:max_len], dtype=torch.int) |
| 57 | + data.append(dict(input_ids=input_id, attention_mask=input_id.ne(tokenizer.pad_token_id))) |
| 58 | + |
| 59 | + return data |
| 60 | + |
| 61 | + |
| 62 | +if __name__ == "__main__": |
| 63 | + parser = argparse.ArgumentParser("Model Quantization using AutoGPTQ") |
| 64 | + parser.add_argument("--model_name_or_path", type=str, help="model path") |
| 65 | + parser.add_argument("--data_path", type=str, help="calibration data path") |
| 66 | + parser.add_argument("--out_path", type=str, help="output path of the quantized model") |
| 67 | + parser.add_argument("--max_len", type=int, default=8192, help="max length of calibration data") |
| 68 | + parser.add_argument("--bits", type=int, default=4, help="the bits of quantized model. 4 indicates int4 models.") |
| 69 | + parser.add_argument("--group-size", type=int, default=128, help="the group size of quantized model") |
| 70 | + args = parser.parse_args() |
| 71 | + |
| 72 | + quantize_config = BaseQuantizeConfig( |
| 73 | + bits=args.bits, |
| 74 | + group_size=args.group_size, |
| 75 | + damp_percent=0.01, |
| 76 | + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad |
| 77 | + static_groups=False, |
| 78 | + sym=True, |
| 79 | + true_sequential=True, |
| 80 | + model_name_or_path=None, |
| 81 | + model_file_base_name="model" |
| 82 | + ) |
| 83 | + |
| 84 | + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) |
| 85 | + tokenizer.pad_token_id = tokenizer.eod_id |
| 86 | + data = preprocess(json.load(open(args.data_path)), tokenizer, args.max_len) |
| 87 | + |
| 88 | + model = AutoGPTQForCausalLM.from_pretrained(args.model_name_or_path, quantize_config, device_map="auto", trust_remote_code=True) |
| 89 | + |
| 90 | + logging.basicConfig( |
| 91 | + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" |
| 92 | + ) |
| 93 | + model.quantize(data, cache_examples_on_gpu=False) |
| 94 | + |
| 95 | + model.save_quantized(args.out_path, use_safetensors=True) |
| 96 | + tokenizer.save_pretrained(args.out_path) |
0 commit comments