Skip to content

Commit 2ac189e

Browse files
authoredMar 10, 2025
Amd test fp8 (sgl-project#4261)
1 parent 5a6400e commit 2ac189e

File tree

6 files changed

+84
-0
lines changed

6 files changed

+84
-0
lines changed
 

‎.github/workflows/pr-test-amd.yml

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
timeout-minutes: 20
5656
run: |
5757
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
58+
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_fp8_accuracy.py
5859
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
5960
6061
mla-test-1-gpu-amd:

‎python/sglang/srt/configs/model_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _verify_quantization(self) -> None:
237237
"compressed_tensors",
238238
"compressed-tensors",
239239
"fbgemm_fp8",
240+
"w8a8_fp8",
240241
]
241242
optimized_quantization_methods = [
242243
"fp8",

‎python/sglang/srt/layers/quantization/fp8_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
else:
3333
from sgl_kernel import fp8_scaled_mm
3434

35+
# Input scaling factors are no longer optional in _scaled_mm starting
36+
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
37+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
38+
3539

3640
def cutlass_fp8_supported():
3741
if not _is_cuda:

‎python/sglang/test/test_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from sglang.utils import get_exception_traceback
2929

3030
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
31+
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
32+
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
33+
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
34+
)
3135
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
3236
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
3337
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"

‎test/srt/run_suite.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class TestFile:
6969
TestFile("test_vision_llm.py", 18.4),
7070
TestFile("test_vision_openai_server.py", 344),
7171
TestFile("test_w8a8_quantization.py", 46),
72+
TestFile("test_eval_fp8_accuracy.py", 172),
7273
],
7374
"nightly": [
7475
TestFile("test_nightly_gsm8k_eval.py"),

‎test/srt/test_eval_fp8_accuracy.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import unittest
2+
from types import SimpleNamespace
3+
4+
from sglang.srt.utils import kill_process_tree
5+
from sglang.test.run_eval import run_eval
6+
from sglang.test.test_utils import (
7+
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST,
8+
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST,
9+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
10+
DEFAULT_URL_FOR_TEST,
11+
popen_launch_server,
12+
)
13+
14+
15+
class TestEvalFP8Accuracy(unittest.TestCase):
16+
@classmethod
17+
def setUpClass(cls):
18+
cls.model = DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
19+
cls.base_url = DEFAULT_URL_FOR_TEST
20+
cls.process = popen_launch_server(
21+
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
22+
)
23+
24+
@classmethod
25+
def tearDownClass(cls):
26+
kill_process_tree(cls.process.pid)
27+
28+
def test_mmlu(self):
29+
args = SimpleNamespace(
30+
base_url=self.base_url,
31+
model=self.model,
32+
eval_name="mmlu",
33+
num_examples=64,
34+
num_threads=32,
35+
temperature=0.1,
36+
)
37+
38+
metrics = run_eval(args)
39+
self.assertGreaterEqual(metrics["score"], 0.62)
40+
41+
42+
class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase):
43+
@classmethod
44+
def setUpClass(cls):
45+
cls.model = DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
46+
cls.base_url = DEFAULT_URL_FOR_TEST
47+
cls.process = popen_launch_server(
48+
cls.model,
49+
cls.base_url,
50+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
51+
other_args=["--quantization", "w8a8_fp8"],
52+
)
53+
54+
@classmethod
55+
def tearDownClass(cls):
56+
kill_process_tree(cls.process.pid)
57+
58+
def test_mmlu(self):
59+
args = SimpleNamespace(
60+
base_url=self.base_url,
61+
model=self.model,
62+
eval_name="mmlu",
63+
num_examples=64,
64+
num_threads=32,
65+
temperature=0.1,
66+
)
67+
68+
metrics = run_eval(args)
69+
self.assertGreaterEqual(metrics["score"], 0.70)
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)
Please sign in to comment.