Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-gpu vllm inference with tensor parallelism, colocating policy model + ref model + vllm engine on the same node #514

Open
nhannguyen2709 opened this issue Mar 17, 2025 · 3 comments

Comments

@nhannguyen2709
Copy link

Hello @lewtun @edbeeching,

I've created a custom fork based on the faster GRPO trainer PR with some nice improvements to allow large-scale training using just 1 single node. To summarize, I've done the following things:

(1) Policy model + reference model + vllm engines are now living on the same node
(2) All gpus can be used to generate rollouts, and vllm tensor_parallel_size can be set to values > 1
(3) Policy model and optimizer states are offloaded to cpu and reloaded to gpu prior to and after rollout generation. I've tested the offloading strategies with both deepspeed zero2 and zero3.
(4) Training with num_iterations > 1

I've been able to do full-finetuning with qwen 7b and lora-finetuning with qwen 14b on single 8xh100 node.

If you're are interested, I'm willing to open a PR and share more detailed training logs + evaluation on AIME 24-25.

@edbeeching
Copy link
Collaborator

Hi @nhannguyen2709 that sounds promising, it would be great if you could share more details

@lewtun
Copy link
Member

lewtun commented Mar 17, 2025

This is very cool @nhannguyen2709 ! Just FYI there's also a PR on TRL to enable multi-node support with vllm: huggingface/trl#3094

Ultimately, we'd like to settle on a single implementation (easier to maintain), so perhaps you can also take a look at that PR and comment on whether it differs substantially to your approach?

@nhannguyen2709
Copy link
Author

nhannguyen2709 commented Mar 18, 2025

@edbeeching @lewtun
I took a look at that PR and unfortunately it differed quite substantially to my approach. It launches a main training script on 1 node, and inside additionally launches a vLLM engine on another node.
My approach keeps all components on a single node.

I attached here a diagram depicting my implementation. The settings are 8 accelerate processes, vLLM with tensor parallel size of 4.
Co-locating vllm process and accelerate process is enabled thanks to this PR (released in vLLM 0.7.3.).

┌────────────────────────────────── Node with 8x H100 GPUs ─────────────────────────────────┐
│                                                                                           │
│  ┌──────┐  ┌──────┐  ┌──────┐  ┌──────┐          ┌──────┐  ┌──────┐  ┌──────┐  ┌──────┐   │
│  │GPU 0 │  │GPU 1 │  │GPU 2 │  │GPU 3 │          │GPU 4 │  │GPU 5 │  │GPU 6 │  │GPU 7 │   │
│  └──┬───┘  └──┬───┘  └──┬───┘  └──┬───┘          └──┬───┘  └──┬───┘  └──┬───┘  └──┬───┘   │
│     │         │         │         │                 │         │         │         │       │
│  ┌──┴───┐  ┌──┴───┐  ┌──┴───┐  ┌──┴───┐          ┌──┴───┐  ┌──┴───┐  ┌──┴───┐  ┌──┴───┐   │
│  │Proc 0│  │Proc 1│  │Proc 2│  │Proc 3│          │Proc 4│  │Proc 5│  │Proc 6│  │Proc 7│   │
│  │      │  │      │  │      │  │      │          │      │  │      │  │      │  │      │   │
│  │Policy│  │Policy│  │Policy│  │Policy│          │Policy│  │Policy│  │Policy│  │Policy│   │
│  │Model │  │Model │  │Model │  │Model │          │Model │  │Model │  │Model │  │Model │   │
│  │+     │  │+     │  │+     │  │+     │          │+     │  │+     │  │+     │  │+     │   │
│  │Ref.  │  │Ref.  │  │Ref.  │  │Ref.  │          │Ref.  │  │Ref.  │  │Ref.  │  │Ref.  │   │
│  │Model │  │Model │  │Model │  │Model │          │Model │  │Model │  │Model │  │Model │   │
│  └──┬───┘  └──┬───┘  └──┬───┘  └──┬───┘          └──┬───┘  └──┬───┘  └──┬───┘  └──┬───┘   │
│     │         │         │         │                 │         │         │         │       │
│     ▼         ▼         ▼         ▼                 ▼         ▼         ▼         ▼       │
│  ┌────────────────────────────────────────────────────────────────────────────────────┐   │
│  │                          VLLMShardingManager.__enter__()                           │   │
│  │                                                                                    │   │
│  │  1. All-gather policy model weights from 8 accelerate processes                    │   │
│  │  2. DeepSpeed module offloaded to CPU to free GPU memory                           │   │
│  │  3. vLLM engines woken up (one for GPUs 0-3, one for GPUs 4-7)                     │   │
│  │  4. Sync policy model weights to vLLM engines                                      │   │
│  └────────────────────────────────────────┬───────────────────────────────────────────┘   │
│                                           │                                               │
│                                           ▼                                               │
│  ┌────────────────────────────┐                   ┌────────────────────────────┐          │
│  │     vLLM Engine 1 (TP=4)   │                   │     vLLM Engine 2 (TP=4)   │          │
│  │                            │                   │                            │          │
│  │  5. Preprocess: All-gather │                   │ 5. Preprocess: All-gather  │          │
│  │     prompt texts within TP │                   │    prompt texts within TP  │          │
│  │     group (GPUs 0-3)       │                   │    group (GPUs 4-7)        │          │
│  │                            │                   │                            │          │
│  │  6. Generate completions   │                   │ 6. Generate completions    │          │
│  │     with tensor parallel   │                   │    with tensor parallel    │          │
│  │                            │                   │                            │          │
│  │  7. Postprocess: Each rank │                   │ 7. Postprocess: Each rank  │          │
│  │     receives its slice     │                   │    receives its slice      │          │
│  │     of completions         │                   │    of completions          │          │
│  └────────────┬───────────────┘                   └────────────┬───────────────┘          │
│               │                                                │                          │
│               └──────────────────────────┬─────────────────────┘                          │
│                                          │                                                │
│                                          ▼                                                │
│  ┌────────────────────────────────────────────────────────────────────────────────────┐   │
│  │                          VLLMShardingManager.__exit__()                            │   │
│  │                                                                                    │   │
│  │  8. vLLM engines put to sleep to free GPU memory                                   │   │
│  │  9. DeepSpeed module loaded back to GPU for training                               │   │
│  └───────────────────────────────────────┬────────────────────────────────────────────┘   │
│                                          │                                                │
│                                          ▼                                                │
│  ┌──────┐  ┌──────┐  ┌──────┐  ┌──────┐          ┌──────┐  ┌──────┐  ┌──────┐  ┌──────┐   │
│  │GPU 0 │  │GPU 1 │  │GPU 2 │  │GPU 3 │          │GPU 4 │  │GPU 5 │  │GPU 6 │  │GPU 7 │   │
│  │Optim │  │Optim │  │Optim │  │Optim │          │Optim │  │Optim │  │Optim │  │Optim │   │
│  │Step  │  │Step  │  │Step  │  │Step  │          │Step  │  │Step  │  │Step  │  │Step  │   │
│  └──────┘  └──────┘  └──────┘  └──────┘          └──────┘  └──────┘  └──────┘  └──────┘   │
│                                                                                           │
└───────────────────────────────────────────────────────────────────────────────────────────┘

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants