Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support cos_sin_cache in all cases. #3020

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

yuxianq
Copy link
Collaborator

@yuxianq yuxianq commented Mar 24, 2025

This MR contains the following updates:

  1. Handle fuse_pos_embd=True/False and create RotaryEmbedding inside attention module, so that the users don't need to handle it in the modeling files.
  2. Cache cos_sin for unfused rope implementation. If flashinfer is available, use apply_rope_with_cos_sin_cache_inplace instead of apply_rope_inplace. Otherwise, we fallback to pure pytorch implementation, which can support any rope now.
  3. We use create_rope_const_params to create and cache cos_sin_cache for all rope types, including Deepseek yarn rope.

@yuxianq yuxianq requested review from hlu1, BestJuly, QiJune and kaiyux March 24, 2025 09:44
@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 24, 2025

/bot run --add-multi-gpu-test

@niukuo
Copy link
Collaborator

niukuo commented Mar 24, 2025

PR_Github #283 [ run ] triggered by Bot

@niukuo
Copy link
Collaborator

niukuo commented Mar 24, 2025

PR_Github #283 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #272 completed with status: 'FAILURE'

@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 25, 2025

/bot run --add-multi-gpu-test

@niukuo
Copy link
Collaborator

niukuo commented Mar 25, 2025

PR_Github #387 [ run ] triggered by Bot

@niukuo
Copy link
Collaborator

niukuo commented Mar 25, 2025

PR_Github #387 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #345 completed with status: 'FAILURE'

@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 25, 2025

/bot run --add-multi-gpu-test

@niukuo
Copy link
Collaborator

niukuo commented Mar 25, 2025

PR_Github #430 [ run ] triggered by Bot

@niukuo
Copy link
Collaborator

niukuo commented Mar 25, 2025

PR_Github #430 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #369 completed with status: 'FAILURE'

@yuxianq yuxianq force-pushed the user/yuxianq/cos-sin-cache branch from 86c5593 to 95840f8 Compare March 26, 2025 03:57
Signed-off-by: Yuxian Qiu <[email protected]>
@yuxianq yuxianq force-pushed the user/yuxianq/cos-sin-cache branch from 95840f8 to 54d797b Compare March 26, 2025 04:00
@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 26, 2025

/bot run --add-multi-gpu-test

@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #510 [ run ] triggered by Bot

@BestJuly BestJuly removed their request for review March 26, 2025 05:20
@BestJuly
Copy link
Collaborator

I think I am pinged by mistake, is the review request actually pointed to @litaotju ?

@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #510 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #437 completed with status: 'FAILURE'

@QiJune
Copy link
Collaborator

QiJune commented Mar 26, 2025

@yuxianq Can we split this PR to several small PRs? For example, the first item can be a single PR.

  1. Handle fuse_pos_embd=True/False and create RotaryEmbedding inside attention module, so that the users don't need to handle it in the modeling files.

@yuxianq yuxianq requested a review from litaotju March 26, 2025 07:01
@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 26, 2025

Can we split this PR to several small PRs? For example, the first item can be a single PR.

@QiJune I will have a try. Let me pass the CI first to validate that these features work correctly.

Signed-off-by: Yuxian Qiu <[email protected]>
@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 26, 2025

/bot run --add-multi-gpu-test

@yuxianq yuxianq changed the title Support cos_sin_cache in all cases. feat: Support cos_sin_cache in all cases. Mar 26, 2025
@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #550 [ run ] triggered by Bot

@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #550 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #469 completed with status: 'FAILURE'

@yuxianq
Copy link
Collaborator Author

yuxianq commented Mar 26, 2025

/bot run --disable-fail-fast --add-multi-gpu-test

@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #584 [ run ] triggered by Bot

@niukuo
Copy link
Collaborator

niukuo commented Mar 26, 2025

PR_Github #584 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #497 completed with status: 'FAILURE'

@@ -5,9 +5,8 @@
from tensorrt_llm.quantization import (quantize_and_export,
quantize_nemo_and_export)

mp.set_start_method("spawn", force=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, are someone importing this file?

It should be used a cli command only.

@@ -327,57 +338,77 @@ def from_config(config) -> "RopeParams":
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
rope_params.mscale = rope_scaling.get("mscale", 1.0)
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
if config.model_type == "deepseek_v3":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks somewaht ad-hoc to me. Is it possible to not relying on the model type hard code string here in a general interface?

assert self.scale_type != RotaryScalingType.longrope, "Long RoPE is not yet supported."
if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_deepseek_attention_plugin(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Nit.

Can we say create*positions_for_yarn or something similar w/o to be too specific to DeepSeek easily?

@@ -102,7 +102,17 @@ def __init__(
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
self.pos_embd_params = pos_embd_params
self.rotary_emb = rotary_emb

self.support_rope = self.attn_backend == "TRTLLM"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be the custom op will do rope fusion right?

Maybe renaming to something like? self.rope_fused_in_custom_op = True

self.rotary_emb = rotary_emb

self.support_rope = self.attn_backend == "TRTLLM"
self.support_fused_qkv = self.attn_backend == "TRTLLM"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"support" is a vague word.
support means which one?

  1. configuable, both fused qkv and unfused qkv can run. or
  2. requires fused qkv?

@@ -249,6 +249,8 @@ def __init__(
attn_backend=attn_backend,
load_format=pytorch_backend_config.load_format,
)
if not hasattr(self.model, 'extra_attrs'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will this be true?
Can we always attach the extra_attrs in _load_model such that this won't be needed?

gather_ids)
else:
return self._forward_step(inputs, gather_ids)
with model_extra_attrs(self.model.extra_attrs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we wrap the whole forward function be inside this context manager?

@@ -122,7 +122,7 @@ def submit_sync(self, task: Callable[..., T], *args, **kwargs) -> List[T]:

def shutdown(self):
if self.mpi_pool is not None:
self.mpi_pool.shutdown(wait=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Superjomn do you remember if we have some reason to make this "wait=False"?

@@ -36,6 +36,7 @@ def test_llm_api(self, import_oot_code: bool):
llm = LLM(model=model_dir,
kv_cache_config=kv_cache_config,
max_num_tokens=2048)
del llm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will not this llm object be automatically destoryed when the function return? its just a local var.

@@ -216,3 +216,4 @@ async def test():
1.0), f"Expected '{expected}' but get '{result}'"

asyncio.run(test())
del llm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, is there some reason that the object not deleted by python when function returns?

Signed-off-by: Yuxian Qiu <[email protected]>
@yuxianq
Copy link
Collaborator Author

yuxianq commented Apr 2, 2025

/bot run --disable-fail-fast --stage-list "A30-7"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants