Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Co-Locating vLLM w/ training to achieve higher throughput and GPU utilization #3162

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

toslali-ibm
Copy link

@toslali-ibm toslali-ibm commented Mar 26, 2025

What does this PR do?

Fixes #3064

Addresses:
#3114
#2971
#2922
#2887

Motivation:

Colocating vLLM processes with training workloads enables higher throughput and more efficient GPU utilization. Our test (see section below) shows a ~2× faster GRPO training time with N-1 GPUs i.e., using 7 GPUs for both vLLM and training, compared to 8 GPUs with current TRL (using 7 GPUs for training plus a dedicated GPU for an isolated vLLM server)

Screenshot 2025-03-27 at 1 58 27 PM

Enabler:

vLLM (version >0.7.3) introduced support for an external launcher, allowing vLLM processes to run alongside other workloads on the same GPU.

Benefits:

  • Faster Inference: Speeds up GRPO training by reducing inference latency via parallel prompt processing (each vLLM works on their device's batch)
  • Better GPU Efficiency: Frees up GPU resources by removing the need for a dedicated vLLM server. Multiple vLLM instances can now share GPUs with training jobs (reducing GPU idle time)
  • Supports tensor parallelism (eventually)
  • Ray-less solution

Implementation Notes:

  • Colocation behavior is controlled via the vllm_colocation parameter. If True, vLLM is colocated with the training process. If False, the default setup uses a separate vLLM server.
  • The new get_vllm_client() proxy returns the appropriate client based on the training setup and configuration
  • VLLMColocationClient runs vLLM in-process for faster inference and better GPU utilization
  • VLLMNoOpClient is a no-op fallback for non-main processes in distributed training when default VLLMClient is used.

Notes:

  • To keep this PR focused and simple, we first introduce vLLM colocation. Support for TP to enable larger models (e.g., 72B) will be added in a follow-up PR.

Testing vllm colocation

To run and test the trainer w/ vllm colocation enabled,

  • use grpo config (below)
  • run experiment.sh (below)

(when training deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B model with open-r1/OpenR1-Math-220k dataset) ; execute the following bash script (experiment.sh).

Click to view experiment.sh
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 accelerate launch \
    --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    --num_processes=7 \
    src/open_r1/grpo.py \
    --config config.yaml
Click to view config.yaml
 # Model arguments
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
# chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n'  + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}"
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
dataset_name: open-r1/OpenR1-Math-220k # limo datasset is smaller - open-R1 fork of Fabian (problem key error will occur)
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

# GRPO trainer config
bf16: true

use_vllm: true
vllm_colocation: true
vllm_gpu_memory_utilization: 0.5
vllm_enable_prefix_caching: false
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 1.0e-06
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 10
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
  min_lr_rate: 0.1
max_prompt_length: 512
max_completion_length: 2048
max_steps: 20
num_generations: 16
num_train_epochs: 1
overwrite_output_dir: true
per_device_train_batch_size: 16 # original batch size
push_to_hub: false
reward_funcs:
- length

eval_strategy: "no"
save_strategy: "no"
report_to: none

seed: 42
temperature: 0.7

To run default trainer for comparison:

  • Run vllm server via vllmserver.sh (below)
  • Change CUDA_VISIBLE_DEVICES to 1,2,3,4,5,6,7 and vllm_colocation to False in the experiment.sh script
Click to view vllm_server.sh
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

CC @fabianlim

Requirements

  • vLLM lib needs > 0.7.3 (to use external_launcher) and < 0.8.2 (this released recently breaks current TRL)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@toslali-ibm toslali-ibm marked this pull request as ready for review March 26, 2025 19:05
@binary-husky
Copy link
Contributor

Nice! Can vLLM release GPU memory to prevent GPU OOM in this setting? Does this one work on 32B QwQ models?

@toslali-ibm
Copy link
Author

toslali-ibm commented Mar 28, 2025

Can vLLM release GPU memory to prevent GPU OOM in this setting? Does this one work on 32B QwQ models?
Hello!

  • Memory Release: Not yet, but upcoming PR will introduce a sleep functionality in vLLM to release GPU memory during training phases and help prevent OOM issues.
  • 32B QwQ Models: Not yet. This PR only focuses on vLLM colocation, with TP support planned for the very next PR. We have a PoC where we use TP=8 with collocation to train a 72B model.

@shirinyamani
Copy link
Member

@toslali-ibm Thanks for your contribution! Very Nice!
For clarification purposes, so technically, your addition is saying instead of remote client sending request to the vllm server, we have the client present on the same node/location as the vllm server?

@toslali-ibm
Copy link
Author

toslali-ibm commented Mar 28, 2025

@toslali-ibm Thanks for your contribution! Very Nice! For clarification purposes, so technically, your addition is saying instead of remote client sending request to the vllm server, we have the client present on the same node/location as the vllm server?

Hello @shirinyamani, thank you!

Essentially, there is no centralized vLLM server—instead, vLLM processes run directly on each device (see here), sharing the GPU with training workloads. Each device handles its own batch for generation. The key idea is that you don’t need a separate GPU dedicated solely to vLLM, as it would remain idle when no generation is taking place. As shown in our experiment, this colocation can achieve higher throughput using N-1 GPUs.

@shirinyamani shirinyamani removed the request for review from qgallouedec March 28, 2025 16:59
@shirinyamani
Copy link
Member

Since we are on this topic, I have one more question maybe from both @binary-husky @toslali-ibm
So in my PP branch, I am trying to add support for PP (pipeline_parallalism_size) cuz I believe TP + PP would let up to scale to larger models! However the issue is in vllm PP is only compatible with AsyncLLMEngine and is not supported directly in LLM class of them. But since LLM class wraps the two LLMEngine and the AsyncLLMEngine classes, there should be a way to have this support so that this would smoothly run. for instance;

trl vllm-serve --model Qwen/Qwen2.5-7B --tensor_parallel_size 4 --pipeline-parallel-size 2

Now the issues is that if we wanna use the AsyncLLMEngine, then it do not have collective_rpc like the LLM class that we can init_communicator easily!
So I think if we find a solution for this combined with the colocation (which supports TP), we can scale to very large models, WDYT?

@fabianlim
Copy link
Contributor

fabianlim commented Mar 29, 2025

@shirinyamani hi this is @fabianlim i work with @toslali-ibm. To answer your questiosns

  1. Im not super familar with AsyncLLMEngine but will look it up more. My impression is that in the initialization of LLM, initialize_model_parallel is called that instantiates both TP and PP. So my feeling is that TP and PP should be simultanously possible but need more checking on this
  2. For the external_launcher used in this PR, we actually do not require collective_rpc. This is because using this mode, vllm looks "local" to the training node, i.e., we can directly access self.llm.llm_engine.model_executor.driver_worker.model_runner.model. If you see our changes, we just load the state dict directly in, we save the torch.distributed.broadcast call.

@binary-husky
Copy link
Contributor

Since we are on this topic, I have one more question maybe from both @binary-husky @toslali-ibm So in my PP branch, I am trying to add support for PP (pipeline_parallalism_size) cuz I believe TP + PP would let up to scale to larger models! However the issue is in vllm PP is only compatible with AsyncLLMEngine and is not supported directly in LLM class of them. But since LLM class wraps the two LLMEngine and the AsyncLLMEngine classes, there should be a way to have this support so that this would smoothly run. for instance;

trl vllm-serve --model Qwen/Qwen2.5-7B --tensor_parallel_size 4 --pipeline-parallel-size 2

Now the issues is that if we wanna use the AsyncLLMEngine, then it do not have collective_rpc like the LLM class that we can init_communicator easily! So I think if we find a solution for this combined with the colocation (which supports TP), we can scale to very large models, WDYT?

I have made some progress on AsyncLLMEngine:
#3182

@fabianlim
Copy link
Contributor

fabianlim commented Mar 31, 2025

nice @binary-husky, my understanding is that your async feature works for grad_accum > 1, but currently we compare with your approach when grad_accum = 1. So my feeling is our methods are orthorgonal; they also apply with grad_accum > 1 I believe.

Is is currently true that we require num_generations to divide num_devices * batch_size * grad_accum, i.e., we can use grad accum to distribute generation prompts across time?

cc: @shirinyamani

@toslali-ibm
Copy link
Author

toslali-ibm commented Mar 31, 2025

I recently conducted another experiment with a larger batch size of 24 and observed a 2.5x speedup when using N-1 GPUs -- 7 GPUs (with colocated vLLM and training) compared to 8 GPUs (where 1 GPU was dedicated to remote vLLM and 7 to training).

Screenshot 2025-03-31 at 11 17 22 AM

The next PRs we will focus on:
a) Integrating vLLM’s sleep function to offload model weights and KV cache, allowing a) to free up memory for training after generating completions and b) higher vllm_gpu_mem_util values.
b) Adding TP to the colocated vLLMs.

If this PR is successfully merged, we can move forward on (a) and (b). Additionally, @binary-husky's nice work on the async feature for grad_accum > 1 seems orthogonal to colocated vLLM, meaning it can be applied to both the remote vLLM client and the colocated vLLM client to achieve more gains!

Looking forward to your feedback!

@qgallouedec @shirinyamani @fabianlim

@qgallouedec
Copy link
Member

qgallouedec commented Mar 31, 2025

Thanks for this work!
How do you explain the faster throughput with the collocation approach?
The difference I see is that there are 7 vLLM instances instead of one. Perhaps this explains the higher throughput, despite the smaller KV cache size per instance?
So I'm wondering how to scale it up. Let's say we support TP for vLLM, it requires TP to match the number of processes for training, is this your idea?

If this PR is successfully merged, we can move forward on (a) and (b)

As this feature is experimental, I wouldn't wait until it's merged before moving on to the next steps, as it's these steps that will tell us whether it really makes sense to merge. If you want to keep the elements separate, you can always make a PR on this branch.

@toslali-ibm
Copy link
Author

toslali-ibm commented Mar 31, 2025

Thanks for this work! How do you explain the faster throughput with the collocation approach? The difference I see is that there are 7 vLLM instances instead of one. Perhaps this explains the higher throughput, despite the smaller KV cache size per instance? So I'm wondering how to scale it up. Let's say we support TP for vLLM, it requires TP to match the number of processes for training, is this your idea?

If this PR is successfully merged, we can move forward on (a) and (b)

As this feature is experimental, I wouldn't wait until it's merged before moving on to the next steps, as it's these steps that will tell us whether it really makes sense to merge. If you want to keep the elements separate, you can always make a PR on this branch.

Thanks @qgallouedec
We're seeing improved throughput because each rank now runs a vLLM worker, processing batches of its own device. Consider the comparison:

  • Case 1: 7 prompts with num_generations=16 for centralized vllm → 112 generations handled by a single vLLM server.
  • Case 2: 16 prompts with num_generations=1 per vllm instance across 7 GPUs → also totaling 112 generations, but distributed across devices for better parallelism.

To scale, we're incorporating TP. The key idea is that each TP device initializes an LLM() with tp=N and distributed_executor_backend="external_launcher". Each process holds 1/N of the model weights, and work is sharded such that all processes receive the same input and generate outputs collectively. There's an example of this setup in the official vLLM repo.

I'm currently building a small PoC to demonstrate TP combined with vllm_colocation.

it requires TP to match the number of processes for training, is this your idea?

I think this can be the initial setup. But it is possible to create mini shards like say: You have 4 processes and Each mini shard consists of 2 processes. So, you get 2 mini shards: [rank 0, rank 1] and [rank 2, rank 3]. Each mini shard is responsible for running 1 instance of vLLM with TP=2. @fabianlim has a PoC of this, which we will incorporate on top of vllm_coloc here.

@toslali-ibm
Copy link
Author

toslali-ibm commented Apr 1, 2025

Hello @qgallouedec

I have a working PoC for TP using external_launcher for 32B model (Qwen2.5-32B-Instruct).

The script sets TP=4, and each process initializes its own LLM() instance with tensor_parallel_size=4 and distributed_executor_backend="external_launcher". Each process holds 1/N of the model weights and participates in .generate()—collaborating to produce the output. The crucial part is ensuring all shards receive the same input.

Please let me know your comments/questions.

Note: TP w/ external launcher works on vllm==0.7.3 ---> 0.8.0 and onward has a bug associated with it, and I created a bug report.

CC @fabianlim

Click to see `poc.py` - run it via `accelerate launch --num_processes=4 poc.py` -- use vllm==0.7.3
"""
Each process instantiates a LLM() w/ tp=N and distributed_executor_backend="external_launcher". 
- Each process holds 1/N of the model weights
- Each process does .generate() — work together to generate the output
- The key part is ensuring that all processes receive the same input, because vLLM expects to generate jointly using participating shards
"""
from vllm import LLM, SamplingParams
from accelerate import Accelerator
from accelerate.utils import gather_object

# === Setup distributed environment ===
accelerator = Accelerator()
device = accelerator.device
tp_size = accelerator.num_processes
rank = accelerator.process_index
print(f"\n----------\nDevice: {device} | Tensor Parallel Size: {tp_size} | Process Rank: {rank}\n----------\n")

# === Each worker has local prompts ===
local_prompts = [
    f"Prompt 1 from worker {rank}, How is it going for you today?",
    f"Prompt 2 from worker {rank}, What is the weather like in Boston usually?"
]

# === Gather all prompts across workers ===
all_prompts = gather_object(local_prompts)

# === Initialize vLLM ===
llm = LLM(
    model="Qwen/Qwen2.5-32B-Instruct",
    tensor_parallel_size=tp_size,
    distributed_executor_backend="external_launcher",
    device="cuda",
)

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# === Generate outputs collectively — all ranks must call this with same inputs ===
outputs = llm.generate(
    prompts=all_prompts,
    sampling_params=sampling_params,
    use_tqdm=(rank == 0)
)

# === Print results from all ranks ===
for i, output in enumerate(outputs):
    prompt = all_prompts[i]
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r} --- Generated: {generated_text!r}\n")

# === Print results from rank 0 — shows all outputs ===
if rank == 0:
    print(f"\n==== Final Output (TP={tp_size}) ====\n")
    for i, output in enumerate(outputs):
        prompt = all_prompts[i]
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r} --- Generated: {generated_text!r}\n")

# === Slice outputs to keep only local shard's portion if TP > 1 ===
if tp_size > 1:
    process_index = rank
    tp_slice = slice(
        process_index * len(local_prompts),
        (process_index + 1) * len(local_prompts)
    )
    local_outputs = outputs[tp_slice]
else:
    local_outputs = outputs

# === Print generations for this rank's original prompts ===
for i, output in enumerate(local_outputs):
    prompt = local_prompts[i]
    generated_text = output.outputs[0].text
    print(f"\n\n\n----Local generations --- Rank {rank} -- Prompt: {prompt!r} --> Generated: {generated_text!r}\n")

@toslali-ibm
Copy link
Author

I’ve created a draft PR showcasing TP and vLLM sleep with vllm_colocation.

When training a 14B model (Qwen/Qwen2.5-14B-Instruct) across 8 GPUs (w/ TP=8), I observed a 1.7× speedup in colocation mode (see image for reference).

Screenshot 2025-04-02 at 1 25 55 PM

Below is the details of the experiment.

Install TRL from PR

  • pip uninstall -y trl
  • git clone -b tpcoloc https://github.com/toslali-ibm/trl.git
  • pip install -e trl

Run GRPO trainer in coloc mode:

  • CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/grpo.py --config config.yaml

Run GRPO trainer in server mode

  • simply remove vllm_tp and vllm_gpu_memory_utilization in config.yaml, then:
  • boot vllm server CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-14B-Instruct
  • start GRPO trainer CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes=4 src/open_r1/grpo.py --config config.yaml

Config

Click to see `config.yaml`
# Model arguments
model_name_or_path: Qwen/Qwen2.5-14B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: open-r1/OpenR1-Math-220k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
dataset_prompt_column: "problem"

# GRPO trainer config
bf16: true
use_vllm: true
vllm_tp: true # remove this in server mode
vllm_gpu_memory_utilization: 0.4 # remove this in server mode
vllm_enable_prefix_caching: false
vllm_max_model_len: 1024

do_eval: false
eval_strategy: "no"
use_vllm: true
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false

learning_rate: 2.0e-05
log_completions: false
log_level: info
logging_first_step: true
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_grad_norm: 0.2
max_prompt_length: 512
max_completion_length: 512
max_steps: 10
num_generations: 4
num_train_epochs: 1

overwrite_output_dir: true
per_device_train_batch_size: 1

reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
eval_strategy: "no"
save_strategy: "no"
report_to: none

seed: 42
temperature: 0.7
warmup_ratio: 0.1

CC @qgallouedec @fabianlim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable External Launcher Support for vLLM in TRL for Efficient GRPO Training
5 participants