Skip to content

Commit 5a0d680

Browse files
authoredJan 21, 2025
feat: add flashinfer as 3rdparty and use rmsnorm as example (sgl-project#3033)
1 parent a4331cd commit 5a0d680

File tree

11 files changed

+335
-2
lines changed

11 files changed

+335
-2
lines changed
 

‎.github/workflows/pr-test-sgl-kernel.yml

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
- name: Install
4242
run: |
4343
pip3 install torch==2.5.1
44+
pip3 install pytest
4445
pip3 uninstall sgl-kernel -y || true
4546
cd sgl-kernel
4647
pip3 install .

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,5 @@ compile_commands.json
225225

226226
# VSCode
227227
.vscode
228+
229+
1

‎.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "sgl-kernel/3rdparty/cccl"]
55
path = sgl-kernel/3rdparty/cccl
66
url = https://github.com/NVIDIA/cccl.git
7+
[submodule "sgl-kernel/3rdparty/flashinfer"]
8+
path = sgl-kernel/3rdparty/flashinfer
9+
url = https://github.com/flashinfer-ai/flashinfer.git

‎sgl-kernel/3rdparty/flashinfer

Submodule flashinfer added at a0e99a3

‎sgl-kernel/THIRDPARTYNOTICES.txt

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
Notice for flashinfer-ai/flashinfer
2+
-------------------------------
3+
Apache License
4+
Version 2.0, January 2004
5+
http://www.apache.org/licenses/
6+
7+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8+
9+
1. Definitions.
10+
11+
"License" shall mean the terms and conditions for use, reproduction,
12+
and distribution as defined by Sections 1 through 9 of this document.
13+
14+
"Licensor" shall mean the copyright owner or entity authorized by
15+
the copyright owner that is granting the License.
16+
17+
"Legal Entity" shall mean the union of the acting entity and all
18+
other entities that control, are controlled by, or are under common
19+
control with that entity. For the purposes of this definition,
20+
"control" means (i) the power, direct or indirect, to cause the
21+
direction or management of such entity, whether by contract or
22+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
23+
outstanding shares, or (iii) beneficial ownership of such entity.
24+
25+
"You" (or "Your") shall mean an individual or Legal Entity
26+
exercising permissions granted by this License.
27+
28+
"Source" form shall mean the preferred form for making modifications,
29+
including but not limited to software source code, documentation
30+
source, and configuration files.
31+
32+
"Object" form shall mean any form resulting from mechanical
33+
transformation or translation of a Source form, including but
34+
not limited to compiled object code, generated documentation,
35+
and conversions to other media types.
36+
37+
"Work" shall mean the work of authorship, whether in Source or
38+
Object form, made available under the License, as indicated by a
39+
copyright notice that is included in or attached to the work
40+
(an example is provided in the Appendix below).
41+
42+
"Derivative Works" shall mean any work, whether in Source or Object
43+
form, that is based on (or derived from) the Work and for which the
44+
editorial revisions, annotations, elaborations, or other modifications
45+
represent, as a whole, an original work of authorship. For the purposes
46+
of this License, Derivative Works shall not include works that remain
47+
separable from, or merely link (or bind by name) to the interfaces of,
48+
the Work and Derivative Works thereof.
49+
50+
"Contribution" shall mean any work of authorship, including
51+
the original version of the Work and any modifications or additions
52+
to that Work or Derivative Works thereof, that is intentionally
53+
submitted to Licensor for inclusion in the Work by the copyright owner
54+
or by an individual or Legal Entity authorized to submit on behalf of
55+
the copyright owner. For the purposes of this definition, "submitted"
56+
means any form of electronic, verbal, or written communication sent
57+
to the Licensor or its representatives, including but not limited to
58+
communication on electronic mailing lists, source code control systems,
59+
and issue tracking systems that are managed by, or on behalf of, the
60+
Licensor for the purpose of discussing and improving the Work, but
61+
excluding communication that is conspicuously marked or otherwise
62+
designated in writing by the copyright owner as "Not a Contribution."
63+
64+
"Contributor" shall mean Licensor and any individual or Legal Entity
65+
on behalf of whom a Contribution has been received by Licensor and
66+
subsequently incorporated within the Work.
67+
68+
2. Grant of Copyright License. Subject to the terms and conditions of
69+
this License, each Contributor hereby grants to You a perpetual,
70+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71+
copyright license to reproduce, prepare Derivative Works of,
72+
publicly display, publicly perform, sublicense, and distribute the
73+
Work and such Derivative Works in Source or Object form.
74+
75+
3. Grant of Patent License. Subject to the terms and conditions of
76+
this License, each Contributor hereby grants to You a perpetual,
77+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78+
(except as stated in this section) patent license to make, have made,
79+
use, offer to sell, sell, import, and otherwise transfer the Work,
80+
where such license applies only to those patent claims licensable
81+
by such Contributor that are necessarily infringed by their
82+
Contribution(s) alone or by combination of their Contribution(s)
83+
with the Work to which such Contribution(s) was submitted. If You
84+
institute patent litigation against any entity (including a
85+
cross-claim or counterclaim in a lawsuit) alleging that the Work
86+
or a Contribution incorporated within the Work constitutes direct
87+
or contributory patent infringement, then any patent licenses
88+
granted to You under this License for that Work shall terminate
89+
as of the date such litigation is filed.
90+
91+
4. Redistribution. You may reproduce and distribute copies of the
92+
Work or Derivative Works thereof in any medium, with or without
93+
modifications, and in Source or Object form, provided that You
94+
meet the following conditions:
95+
96+
(a) You must give any other recipients of the Work or
97+
Derivative Works a copy of this License; and
98+
99+
(b) You must cause any modified files to carry prominent notices
100+
stating that You changed the files; and
101+
102+
(c) You must retain, in the Source form of any Derivative Works
103+
that You distribute, all copyright, patent, trademark, and
104+
attribution notices from the Source form of the Work,
105+
excluding those notices that do not pertain to any part of
106+
the Derivative Works; and
107+
108+
(d) If the Work includes a "NOTICE" text file as part of its
109+
distribution, then any Derivative Works that You distribute must
110+
include a readable copy of the attribution notices contained
111+
within such NOTICE file, excluding those notices that do not
112+
pertain to any part of the Derivative Works, in at least one
113+
of the following places: within a NOTICE text file distributed
114+
as part of the Derivative Works; within the Source form or
115+
documentation, if provided along with the Derivative Works; or,
116+
within a display generated by the Derivative Works, if and
117+
wherever such third-party notices normally appear. The contents
118+
of the NOTICE file are for informational purposes only and
119+
do not modify the License. You may add Your own attribution
120+
notices within Derivative Works that You distribute, alongside
121+
or as an addendum to the NOTICE text from the Work, provided
122+
that such additional attribution notices cannot be construed
123+
as modifying the License.
124+
125+
You may add Your own copyright statement to Your modifications and
126+
may provide additional or different license terms and conditions
127+
for use, reproduction, or distribution of Your modifications, or
128+
for any such Derivative Works as a whole, provided Your use,
129+
reproduction, and distribution of the Work otherwise complies with
130+
the conditions stated in this License.
131+
132+
5. Submission of Contributions. Unless You explicitly state otherwise,
133+
any Contribution intentionally submitted for inclusion in the Work
134+
by You to the Licensor shall be under the terms and conditions of
135+
this License, without any additional terms or conditions.
136+
Notwithstanding the above, nothing herein shall supersede or modify
137+
the terms of any separate license agreement you may have executed
138+
with Licensor regarding such Contributions.
139+
140+
6. Trademarks. This License does not grant permission to use the trade
141+
names, trademarks, service marks, or product names of the Licensor,
142+
except as required for reasonable and customary use in describing the
143+
origin of the Work and reproducing the content of the NOTICE file.
144+
145+
7. Disclaimer of Warranty. Unless required by applicable law or
146+
agreed to in writing, Licensor provides the Work (and each
147+
Contributor provides its Contributions) on an "AS IS" BASIS,
148+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149+
implied, including, without limitation, any warranties or conditions
150+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151+
PARTICULAR PURPOSE. You are solely responsible for determining the
152+
appropriateness of using or redistributing the Work and assume any
153+
risks associated with Your exercise of permissions under this License.
154+
155+
8. Limitation of Liability. In no event and under no legal theory,
156+
whether in tort (including negligence), contract, or otherwise,
157+
unless required by applicable law (such as deliberate and grossly
158+
negligent acts) or agreed to in writing, shall any Contributor be
159+
liable to You for damages, including any direct, indirect, special,
160+
incidental, or consequential damages of any character arising as a
161+
result of this License or out of the use or inability to use the
162+
Work (including but not limited to damages for loss of goodwill,
163+
work stoppage, computer failure or malfunction, or any and all
164+
other commercial damages or losses), even if such Contributor
165+
has been advised of the possibility of such damages.
166+
167+
9. Accepting Warranty or Additional Liability. While redistributing
168+
the Work or Derivative Works thereof, You may choose to offer,
169+
and charge a fee for, acceptance of support, warranty, indemnity,
170+
or other liability obligations and/or rights consistent with this
171+
License. However, in accepting such obligations, You may act only
172+
on Your own behalf and on Your sole responsibility, not on behalf
173+
of any other Contributor, and only if You agree to indemnify,
174+
defend, and hold each Contributor harmless for any liability
175+
incurred by, or claims asserted against, such Contributor by reason
176+
of your accepting any such warranty or additional liability.
177+
178+
END OF TERMS AND CONDITIONS
179+
180+
APPENDIX: How to apply the Apache License to your work.
181+
182+
To apply the Apache License to your work, attach the following
183+
boilerplate notice, with the fields enclosed by brackets "[]"
184+
replaced with your own identifying information. (Don't include
185+
the brackets!) The text should be enclosed in the appropriate
186+
comment syntax for the file format. We also recommend that a
187+
file or class name and description of purpose be included on the
188+
same "printed page" as the copyright notice for easier
189+
identification within third-party archives.
190+
191+
Copyright [yyyy] [name of copyright owner]
192+
193+
Licensed under the Apache License, Version 2.0 (the "License");
194+
you may not use this file except in compliance with the License.
195+
You may obtain a copy of the License at
196+
197+
http://www.apache.org/licenses/LICENSE-2.0
198+
199+
Unless required by applicable law or agreed to in writing, software
200+
distributed under the License is distributed on an "AS IS" BASIS,
201+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202+
See the License for the specific language governing permissions and
203+
limitations under the License.
204+
205+
-------------------------------------------------------------------------------------------------
206+
Some of the code in this project are adapted from other open-source projects with different
207+
licenses. This product also bundles some third-party components under other open source licenses.
208+
This section summarizes those components and their licenses.
209+
See licenses/ for text of these licenses.
210+
211+
BSD 3-Clause License
212+
--------------------
213+
214+
include/flashinfer/attention/hopper/epilogue.cuh
215+
include/flashinfer/attention/hopper/mainloop.cuh
216+
include/flashinfer/attention/hopper/kernel_traits.cuh
217+
include/flashinfer/attention/hopper/named_barrier.cuh
218+
include/flashinfer/attention/hopper/tile_scheduler.cuh
219+
include/flashinfer/attention/hopper/utils.cuh
220+
221+
BSD 3-Clause "New" License
222+
--------------------------
223+
224+
3rdparty/cutlass
225+
include/flashinfer/attention/hopper/block_sparse_gather.cuh

‎sgl-kernel/setup.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import torch
34
from setuptools import find_packages, setup
45
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
56

@@ -24,10 +25,13 @@ def update_wheel_platform_tag():
2425

2526

2627
cutlass = root / "3rdparty" / "cutlass"
28+
flashinfer = root / "3rdparty" / "flashinfer"
2729
include_dirs = [
2830
cutlass.resolve() / "include",
2931
cutlass.resolve() / "tools" / "util" / "include",
3032
root / "src" / "sgl-kernel" / "csrc",
33+
flashinfer.resolve() / "include",
34+
flashinfer.resolve() / "csrc",
3135
]
3236
nvcc_flags = [
3337
"-DNDEBUG",
@@ -39,9 +43,21 @@ def update_wheel_platform_tag():
3943
"-gencode=arch=compute_89,code=sm_89",
4044
"-gencode=arch=compute_90,code=sm_90",
4145
"-gencode=arch=compute_90a,code=sm_90a",
42-
"-U__CUDA_NO_HALF_OPERATORS__",
43-
"-U__CUDA_NO_HALF2_OPERATORS__",
46+
"-std=c++17",
47+
"-use_fast_math",
48+
"-DFLASHINFER_ENABLE_F16",
49+
"-DFLASHINFER_ENABLE_BF16",
4450
]
51+
for flag in [
52+
"-D__CUDA_NO_HALF_OPERATORS__",
53+
"-D__CUDA_NO_HALF_CONVERSIONS__",
54+
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
55+
"-D__CUDA_NO_HALF2_OPERATORS__",
56+
]:
57+
try:
58+
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
59+
except ValueError:
60+
pass
4561
cxx_flags = ["-O3"]
4662
libraries = ["c10", "torch", "torch_python", "cuda"]
4763
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
@@ -56,6 +72,7 @@ def update_wheel_platform_tag():
5672
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
5773
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
5874
"src/sgl-kernel/csrc/rotary_embedding.cu",
75+
"src/sgl-kernel/csrc/norm.cu",
5976
],
6077
include_dirs=include_dirs,
6178
extra_compile_args={

‎sgl-kernel/src/sgl-kernel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
int8_scaled_mm,
77
moe_align_block_size,
88
register_graph_buffers,
9+
rmsnorm,
910
rotary_embedding,
1011
sampling_scaling_penalties,
1112
)
@@ -20,4 +21,5 @@
2021
"get_graph_buffer_ipc_meta",
2122
"register_graph_buffers",
2223
"rotary_embedding",
24+
"rmsnorm",
2325
]
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <cstdint>
2+
#include <flashinfer/norm.cuh>
3+
4+
#include "pytorch_extension_utils.h"
5+
6+
using namespace flashinfer;
7+
8+
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) {
9+
CHECK_INPUT(input);
10+
CHECK_INPUT(weight);
11+
auto device = input.device();
12+
CHECK_EQ(weight.device(), device);
13+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
14+
CHECK_DIM(1, weight); // weight: (hidden_size)
15+
CHECK_EQ(input.size(1), weight.size(0));
16+
unsigned int batch_size = input.size(0);
17+
unsigned int hidden_size = input.size(1);
18+
CHECK_EQ(output.size(0), batch_size);
19+
CHECK_EQ(output.size(1), hidden_size);
20+
21+
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
22+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
23+
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
24+
static_cast<c_type*>(output.data_ptr()), batch_size, hidden_size, eps, stream);
25+
TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
26+
return true;
27+
});
28+
}

‎sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
3030
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
3131
torch::Tensor& cos_sin_cache, bool is_neox);
3232

33+
// rms norm
34+
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
35+
3336
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3437
// trt_reduce
3538
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -45,4 +48,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4548
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
4649
// rotary embedding
4750
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
51+
// rms norm
52+
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
4853
}

‎sgl-kernel/src/sgl-kernel/ops/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Optional
2+
3+
import torch
14
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
25
from sgl_kernel.ops._kernels import dispose as _dispose
36
from sgl_kernel.ops._kernels import (
@@ -7,6 +10,7 @@
710
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
811
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
912
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
13+
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
1014
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
1115
from sgl_kernel.ops._kernels import (
1216
sampling_scaling_penalties as _sampling_scaling_penalties,
@@ -76,3 +80,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
7680

7781
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
7882
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
83+
84+
85+
def rmsnorm(
86+
input: torch.Tensor,
87+
weight: torch.Tensor,
88+
eps: float = 1e-6,
89+
out: Optional[torch.Tensor] = None,
90+
) -> torch.Tensor:
91+
if out is None:
92+
out = torch.empty_like(input)
93+
stream = torch.cuda.current_stream().cuda_stream
94+
stream_int = int(stream)
95+
_rmsnorm(out, input, weight, eps, stream_int)
96+
return out

‎sgl-kernel/tests/test_rmsnorm.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
from sgl_kernel import rmsnorm
4+
5+
6+
def llama_rms_norm(x, w, eps=1e-6):
7+
orig_dtype = x.dtype
8+
x = x.float()
9+
variance = x.pow(2).mean(dim=-1, keepdim=True)
10+
x = x * torch.rsqrt(variance + eps)
11+
x = x * w.float()
12+
x = x.to(orig_dtype)
13+
return x
14+
15+
16+
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
17+
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
18+
@pytest.mark.parametrize("dtype", [torch.float16])
19+
@pytest.mark.parametrize("specify_out", [True, False])
20+
def test_norm(batch_size, hidden_size, dtype, specify_out):
21+
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
22+
w = torch.randn(hidden_size).to(0).to(dtype)
23+
24+
y_ref = llama_rms_norm(x, w)
25+
if specify_out:
26+
y = torch.empty_like(x)
27+
rmsnorm(x, w, out=y)
28+
else:
29+
y = rmsnorm(x, w)
30+
31+
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)
Please sign in to comment.