DEV Community

Mayank Ketkar
Mayank Ketkar

Posted on

Compiling the Vision Encoder: Squeezing 3% More Throughput from Qwen3-VL on Hopper GPUs

When you run a vision-language model through vLLM, the framework does something clever: it compiles the LLM decoder with torch.compile, fuses operators, and captures CUDA graphs for maximum throughput. But there is a component it quietly leaves behind -- the Vision Transformer (ViT) encoder that processes your images. It runs in plain eager mode, every single time.

We changed that for Qwen3-VL. The result: 3.4% higher throughput on an NVIDIA H200, three previously unknown bugs discovered and fixed, and a one-flag change that any vLLM user can enable today.

This post walks through the engineering story -- why the encoder was left behind, how we ported compilation support from a sibling model, what broke along the way, and what the profiler actually says about where the time goes.


Why Does the Encoder Run Eager?

vLLM's compilation infrastructure is built around the LLM decoder. When you launch an inference server, the startup sequence compiles the decoder's forward pass with torch.compile, traces its graph, and captures CUDA graphs at various batch sizes. This eliminates Python overhead and enables kernel fusion across attention, LayerNorm, and MLP layers.

The multimodal encoder -- the ViT that converts raw image pixels into embedding vectors -- gets none of this treatment. The reason is a single boolean flag in vLLM's compilation config:

compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for Qwen2_5_vl and mLLaMa4
models on selected platforms. Disabled by default until
more models are supported/tested to work."""
Enter fullscreen mode Exit fullscreen mode

The default is False, and for good reason. Vision encoders face a fundamental tension with compilation: variable input shapes. Different requests can carry images at different resolutions, producing different numbers of patches. CUDA graphs require fixed tensor shapes at capture time. A general-purpose serving framework cannot assume that every image will be the same size.

But for batch inference workloads with fixed-size images -- which is common in production pipelines processing standardized camera frames, satellite tiles, or document pages -- this conservatism leaves performance on the table. If your images are all the same resolution, the encoder always receives identically shaped tensors, and torch.compile can fully specialize.

There was a second, more specific problem: Qwen3-VL simply lacked the compilation decorators. Its sibling model, Qwen2.5-VL, already had full torch.compile support for its encoder. Qwen3-VL shared much of the same architecture (including the identical attention implementation), but the compilation wiring was never ported over.


The Pattern: Porting from Qwen2.5-VL

vLLM uses a decorator-based system for selective compilation. Rather than compiling an entire model's forward pass (which would break on Python control flow, NumPy calls, and dynamic branching), it compiles individual submodules whose forward() methods contain only clean tensor operations.

Qwen2.5-VL already had this wired up for three encoder submodules: VisionPatchEmbed, VisionBlock, and VisionPatchMerger. Our task was to replicate the exact same pattern in Qwen3-VL.

The Decorator

Each compilable submodule gets a @support_torch_compile decorator that declares which tensor dimensions are dynamic and provides a gating function:

@support_torch_compile(
    dynamic_arg_dims={"x": 0},
    enable_if=should_torch_compile_mm_vit,
)
class Qwen3_VisionPatchEmbed(nn.Module):
    ...
Enter fullscreen mode Exit fullscreen mode

The dynamic_arg_dims={"x": 0} tells torch.compile that dimension 0 of the input tensor x can vary between calls (different numbers of patches), so it should not bake that shape into the compiled graph. The enable_if callback is a one-liner that checks whether the user opted in:

def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool:
    return vllm_config.compilation_config.compile_mm_encoder
Enter fullscreen mode Exit fullscreen mode

When compile_mm_encoder is False (the default), the decorator sets self.do_not_compile = True and the forward pass runs in eager mode -- zero overhead, zero behavior change. When it is True, the decorator wraps the module in torch.compile on first call and uses compiled execution from then on.

The Model Tags

The second piece of wiring is set_model_tag, a context manager that tells the compilation backend to use separate caches for encoder versus decoder components. Without tags, the encoder and decoder would share a single compile cache, causing shape mismatches when the compiler tries to reuse a graph compiled for decoder weight shapes on encoder weights.

In Qwen3_VisionTransformer.__init__():

# DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag

with set_model_tag("Qwen3_VisionPatchEmbed", is_encoder=True):
    self.patch_embed = Qwen3_VisionPatchEmbed(...)

with set_model_tag("Qwen3_VisionPatchMerger", is_encoder=True):
    self.merger = Qwen3_VisionPatchMerger(...)

# Deepstack mergers need a separate tag (different weight shapes!)
with set_model_tag("Qwen3_VisionPatchMerger_deepstack", is_encoder=True):
    self.deepstack_merger_list = nn.ModuleList([...])

with set_model_tag("Qwen3_VisionBlock", is_encoder=True):
    self.blocks = nn.ModuleList([
        Qwen3_VisionBlock(...) for _ in range(depth)
    ])
Enter fullscreen mode Exit fullscreen mode

That comment about DO NOT MOVE THIS IMPORT is not a joke -- it matches the exact pattern in Qwen2.5-VL and relates to import ordering constraints with the compilation backend (see vllm#27044).

Notice the deepstack mergers get their own tag, separate from the main merger. This was not in the original plan. It was the fix for Bug #2, which we will get to shortly.

What Gets Compiled

The Qwen3-VL vision encoder has three distinct compilable submodules:

Submodule What It Does Dynamic Dims
Qwen3_VisionPatchEmbed Reshape + Conv3D + reshape (pixels to patch embeddings) x: dim 0
Qwen3_VisionBlock (x24) LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual x, cu_seqlens, cos, sin: dim 0
Qwen3_VisionPatchMerger LayerNorm -> Linear -> GELU -> Linear (merge spatial patches) x: dim 0

The outer VisionTransformer.forward() -- which orchestrates these submodules -- is deliberately not compiled. It contains NumPy operations (np.array, np.cumsum), Python control flow (isinstance, list comprehensions), and .tolist() calls that would cause graph breaks. The per-submodule pattern avoids all of this.


Zero Graph Breaks

The first compile attempt was the moment of truth. We enabled TORCH_LOGS=+dynamo and TORCH_COMPILE_DEBUG=1, loaded a handful of test images, and watched TorchDynamo trace through the encoder.

Result: zero graph breaks. The single COMPILING GRAPH event reported:

COMPILING GRAPH due to GraphCompileReason(
    reason='return_value',
    user_stack=[<FrameSummary file qwen3_vl.py, line 1169 in forward>],
    graph_break=False
)
Enter fullscreen mode Exit fullscreen mode

This was expected but still satisfying. The per-submodule compilation pattern is specifically designed to isolate clean tensor operations from Python control flow. Each compiled forward method contains nothing but torch operations -- reshapes, linear projections, attention, LayerNorm, residual additions. No data-dependent control flow, no Python-side data structures, no calls that escape the Dynamo graph.

The key insight: if you tried to compile the entire VisionTransformer.forward() as one graph, you would hit graph breaks immediately on the NumPy calls that compute positional embeddings and cumulative sequence lengths. By compiling only the inner submodules, you get all the fusion benefits with none of the graph break headaches.


Three Bugs Found and Fixed

Zero graph breaks did not mean zero problems. The first full run crashed. Then it crashed differently. Then it crashed a third way. Here is what we found.

Bug 1: AssertionError: Forward context is not set in profile_run()

The crash:

AssertionError: Forward context is not set.
Please use `set_forward_context` to set the forward context.
Enter fullscreen mode Exit fullscreen mode

What happened: When vLLM starts up, it runs a profiling pass (profile_run()) to determine memory usage. This calls self.model.embed_multimodal() to profile the encoder. In eager mode, this works fine -- the encoder's forward methods are just regular PyTorch calls.

But with @support_torch_compile, the compilation backend wraps each submodule in a CUDAGraphWrapper. The wrapper's __call__ method reads forward_context.cudagraph_runtime_mode to decide whether to execute via CUDA graph or fall through to eager. Without a forward context set, it crashes.

The fix: Wrap the profiling call in set_forward_context:

with set_forward_context(attn_metadata=None, vllm_config=self.vllm_config):
    dummy_encoder_outputs = self.model.embed_multimodal(
        **batched_dummy_mm_inputs
    )
Enter fullscreen mode Exit fullscreen mode

Since attn_metadata=None, the wrapper sees CUDAGraphMode.NONE and falls through to eager execution -- exactly the behavior we want during profiling.

Bug 2: AssertionError: expected size 1024==4096

The crash:

AssertionError: expected size 1024==4096, stride 1==1 at dim=0
Enter fullscreen mode Exit fullscreen mode

What happened: Qwen3-VL has two types of patch mergers. The main merger has a LayerNorm over context_dim=1024 (the per-patch hidden size before spatial merging). The deepstack mergers have a LayerNorm over hidden_size=4096 (the full hidden size, via use_postshuffle_norm=True). Both use the Qwen3_VisionPatchMerger class.

In our initial implementation, both mergers shared the same set_model_tag("Qwen3_VisionPatchMerger") context. This meant they shared a single compiled graph cache. When torch.compile traced through the main merger (norm weight shape (1024,)), it cached a graph with that shape baked in. When the deepstack merger tried to reuse the same cached graph with its (4096,) weights -- crash.

The fix: Separate model tags:

with set_model_tag("Qwen3_VisionPatchMerger", is_encoder=True):
    self.merger = ...          # LayerNorm over 1024

with set_model_tag("Qwen3_VisionPatchMerger_deepstack", is_encoder=True):
    self.deepstack_merger_list = ...  # LayerNorm over 4096
Enter fullscreen mode Exit fullscreen mode

Same Python class, different compile caches. The tag system was designed exactly for this -- but you have to remember to use it when two instances of the same class have different weight shapes.

Bug 3: Same as Bug 1, but in _execute_mm_encoder()

The profiling fix (Bug 1) resolved the startup crash, but the same AssertionError appeared during actual inference. The encoder execution path in _execute_mm_encoder() also called embed_multimodal() without setting forward context.

The fix: Same pattern -- wrap the encoder execution loop in set_forward_context(attn_metadata=None, ...).

Defense-in-Depth

After fixing both call sites, we added a belt-and-suspenders guard in CUDAGraphWrapper.__call__ itself:

def __call__(self, *args, **kwargs):
    if not is_forward_context_available():
        return self.runnable(*args, **kwargs)  # Eager fallback
    forward_context = get_forward_context()
    ...
Enter fullscreen mode Exit fullscreen mode

If any future code path calls a compiled encoder submodule without setting forward context, it gracefully falls through to eager execution instead of crashing. This is defense-in-depth -- the primary fix is ensuring all call sites set the context, but the guard protects against regressions.


Profiling: Where the Time Goes

With compilation working, we instrumented the encoder with torch.cuda.Event timing to measure exactly how much each component contributes and how much compilation helps.

The Encoder Is Only 13.5% of Total Inference Time

For Qwen3-VL-2B on our workload, the ViT encoder processes each image once to produce embedding tokens, then the LLM decoder generates the output sequence. The decoder dominates.

Component Baseline (ms) Compiled (ms) Speedup
PatchEmbed 5.2 6.2 -19%
VisionBlocks (24) 352.5 330.2 +6.3%
PatchMerger 3.8 5.3 -39%
Total Encoder 450.3 430.5 +4.4%

VisionBlocks Win, Small Ops Lose

The 24 VisionBlocks are where compilation shines. Each block runs LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual. The Inductor backend fuses these into fewer, more efficient kernels. Blocks 1-23 show a consistent 7-8% per-block speedup, accumulating to a 22.3ms reduction.

PatchEmbed and PatchMerger show the opposite: compilation makes them slower. These are tiny operations (~0.3ms per call). The @support_torch_compile decorator adds Python dispatch overhead on every call, and at this scale, the overhead exceeds the fusion benefit. It is a classic tradeoff -- compilation has a per-call dispatch cost that only pays off when the compiled operation is large enough.

A pragmatic optimization would be to remove the @support_torch_compile decorators from PatchEmbed and PatchMerger, compiling only VisionBlocks. The net encoder speedup would actually be slightly higher without the small-op regressions. But the dispatch overhead is small in absolute terms (a few milliseconds total), and having all submodules wired for compilation maintains consistency with the Qwen2.5-VL pattern.

Why 4.4% Encoder Speedup Becomes 3.4% End-to-End

With the encoder representing 13.5% of total inference time, even a 4.4% encoder speedup translates to only ~0.6% of total wall time through Amdahl's Law. The actual measured end-to-end improvement is larger than that simple calculation suggests, likely because the compilation also reduces Python overhead and improves memory access patterns in ways that benefit the surrounding orchestration code.


End-to-End Benchmark

We ran a full A/B comparison over ~8,000 samples on an NVIDIA H200, with 10-sample warmup excluded from measurements.

Metric Baseline Compiled Delta
Throughput 32.33 samp/s 33.42 samp/s +3.4%
Generate time 266.1s 257.4s -8.7s
Per-sample latency 30.93ms 29.92ms -1.0ms
Model load time 37.3s 50.2s +12.9s

The 3.4% throughput improvement is consistent across scales. We saw similar relative gains at 100 samples (+0.9% -- noisier at smaller scale) and at the full dataset (+3.4%).

The model load time increase (+12.9s) is a one-time cost for Dynamo bytecode transforms and Inductor codegen on the encoder submodules. On subsequent runs, the compilation cache (~/.cache/vllm/torch_compile_cache/) eliminates recompilation entirely -- subsequent startups are only marginally slower than baseline. In a production serving context, this compilation happens once at server startup and all subsequent inference benefits from the speedup.

Break-Even Analysis

Parameter Value
One-time compilation overhead 12.9s
Per-sample time saving ~1.0ms
Break-even point ~12,900 samples

For the first-ever run (cold compilation cache), you need to process approximately 13,000 samples before the compilation overhead is amortized. For any subsequent run with a warm cache, the benefit is immediate.


Output Correctness

An important caveat: compiled and baseline modes produce slightly different outputs on some inputs. This is expected behavior from torch.compile -- the Inductor backend may apply different operator fusion, reduction ordering, and kernel implementations that change floating-point rounding at the bit level. These tiny differences in intermediate activations can cascade through the encoder, shift logits by small amounts, and occasionally flip the argmax for borderline tokens during autoregressive decoding.

Concretely:

  • Both modes are individually deterministic -- the same mode always produces the same output for the same input, run after run.
  • They are not cross-compatible -- baseline and compiled may differ on some samples.
  • The differences are small in magnitude and affect only a fraction of samples.

This is a property of torch.compile itself, not of our changes. If your application requires bitwise reproducibility between compiled and non-compiled modes, this is worth knowing. If you only need consistency within a single mode (the more common requirement), both modes deliver it.


When Would This Matter More?

A 3.4% throughput improvement is real and free (once the cache is warm), but it is bounded by the encoder's share of total inference time. For Qwen3-VL-2B, the ViT encoder is small relative to the LLM decoder. Several scenarios would amplify the benefit:

Larger ViT encoders. Qwen3-VL-72B has a proportionally larger vision encoder. The same 7-8% per-block VisionBlock speedup applied to more expensive encoder blocks would yield a larger end-to-end improvement.

Video workloads. Video inputs require processing many frames, multiplying encoder invocations per request. The encoder's share of total time increases, and the compilation benefit compounds.

High-concurrency serving. When many requests arrive simultaneously, encoder latency adds up across the batch. Shaving 4.4% off each encoder call reduces queuing delay.

Bandwidth-bound GPUs. The H200 is a compute-rich Hopper GPU. On more bandwidth-constrained hardware like the L40S, the operator fusion from torch.compile (which reduces memory traffic by eliminating intermediate tensor materializations) would likely yield a larger percentage improvement.

Higher-resolution images. More patches per image means more work in the VisionBlocks, which are the primary beneficiaries of compilation.


How to Enable It

One flag:

from vllm import LLM

llm = LLM(
    model="Qwen/Qwen3-VL-2B-Instruct",
    compilation_config={"compile_mm_encoder": True},
    # ... other settings
)
Enter fullscreen mode Exit fullscreen mode

Or via the CLI:

vllm serve Qwen/Qwen3-VL-2B-Instruct \
    --compilation-config '{"compile_mm_encoder": true}'
Enter fullscreen mode Exit fullscreen mode

That is it. No model changes, no custom code, no configuration gymnastics. The flag tells vLLM to apply torch.compile to the ViT encoder submodules during model initialization. The first inference call that includes images will trigger compilation (or load from cache), and all subsequent calls use the compiled kernels.

First Run vs. Subsequent Runs

On the very first run with a new model or new vLLM version, you will see a longer model load time (~13s extra) as TorchDynamo traces and Inductor generates code for the encoder submodules. These artifacts are cached to ~/.cache/vllm/torch_compile_cache/.

On all subsequent runs, the cached artifacts load in seconds, and the throughput benefit is immediate.


Conclusion

This was a small change -- six modifications across two files for the core enablement, plus four files touched for bug fixes. The pattern was already established by Qwen2.5-VL; we just ported it to Qwen3-VL. But small changes can have disproportionate engineering value when they uncover latent bugs.

The three bugs we found -- missing set_forward_context in two encoder execution paths, and shared compile caches for mergers with different weight shapes -- are not specific to Qwen3-VL. They would affect any model that enables compile_mm_encoder. The fixes (including the defense-in-depth guard in CUDAGraphWrapper) benefit the entire vLLM multimodal compilation infrastructure.

The profiling results tell an honest story: the ViT encoder is a small fraction of end-to-end time for a 2B parameter model, so even a solid 4.4% encoder speedup translates to a modest 3.4% end-to-end gain. But it is a free 3.4% -- one flag, cached after the first run, no accuracy impact within a single mode. For larger models, video workloads, or bandwidth-constrained hardware, the benefit would be larger.

Sometimes the most useful engineering work is not building something new, but noticing that a capability already exists in the codebase and was never wired up for your model.


Summary of Changes

File Change
vllm/model_executor/models/qwen3_vl.py @support_torch_compile decorators on 3 encoder submodules + set_model_tag wiring
vllm/config/compilation.py Updated compile_mm_encoder docstring to include Qwen3-VL
vllm/v1/worker/gpu_model_runner.py set_forward_context wrapper in _execute_mm_encoder() and profile_run()
vllm/compilation/cuda_graph.py is_forward_context_available() guard in CUDAGraphWrapper.__call__

Hardware and Software

  • GPU: NVIDIA H200 (141 GB HBM3e)
  • vLLM: 0.15.x (main branch)
  • PyTorch: 2.9.x
  • Model: Qwen3-VL-2B-Instruct (fine-tuned checkpoint)
  • Workload: ~8,000 fixed-resolution images, single GPU, temperature=0.0, max_tokens=128

Top comments (0)