DeepSeek Usage#

SGLang provides many optimizations specifically designed for the DeepSeek models, making it the inference engine recommended by the official DeepSeek team from Day 0.

This document outlines current optimizations for DeepSeek. For an overview of the implemented features see the completed Roadmap.

Launch DeepSeek V3 with SGLang#

To run DeepSeek V3/R1 models, the requirements are as follows:

Weight Type

Configuration

Full precision FP8
(recommended)

8 x H200

8 x MI300X

2 x 8 x H100/800/20

Full precision BF16

2 x 8 x H200

2 x 8 x MI300X

4 x 8 x H100/800/20

4 x 8 x A100/A800

Quantized weights (AWQ)

8 x H100/800/20

8 x A100/A800

Quantized weights (int8)

16 x A100/800

32 x L40S

Detailed commands for reference:

Download Weights#

If you encounter errors when starting the server, ensure the weights have finished downloading. It’s recommended to download them beforehand or restart multiple times until all weights are downloaded. Please refer to DeepSeek V3 official guide to download the weights.

Caching torch.compile#

The DeepSeek series have huge model weights, it takes some time to compile the model with torch.compile for the first time if you have added the flag --enable-torch-compile. You can refer here to optimize the caching of compilation results, so that the cache can be used to speed up the next startup.

Launch with one node of 8 x H200#

Please refer to the example. **Note that Deepseek V3 is already in FP8. So we should not run it with any quantization arguments like --quantization fp8 --kv-cache-dtype fp8_e5m2.

Running examples on Multi-node#

Optimizations#

Multi-head Latent Attention (MLA) Throughput Optimizations#

Description: MLA is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including:

  • Weight Absorption: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

  • MLA Attention Backends: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, Flashinfer, FlashMLA, CutlassMLA, and Triton backends. The default FA3 provides good performance across wide workloads.

  • FP8 Quantization: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

  • CUDA Graph & Torch.compile: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.

  • Chunked Prefix Cache: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.

Overall, with these optimizations, we have achieved up to 7x acceleration in output throughput compared to the previous version.

Multi-head Latent Attention for DeepSeek Series Models

Usage: MLA optimization is enabled by default.

Reference: Check Blog and Slides for more details.

Data Parallelism Attention#

Description: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. If you do not use DP attention, KV cache will be duplicated among all TP ranks.

Data Parallelism Attention for DeepSeek Series Models

With data parallelism attention enabled, we have achieved up to 1.9x decoding throughput improvement compared to the previous version.

Data Parallelism Attention Performance Comparison

Usage:

  • Append --enable-dp-attention --tp 8 --dp 8 to the server arguments when using 8 H200 GPUs. This optimization improves peak throughput in high batch size scenarios where the server is limited by KV cache capacity. However, it is not recommended for low-latency, small-batch use cases.

  • DP and TP attention can be flexibly combined. For example, to deploy DeepSeek-V3/R1 on 2 nodes with 8 H100 GPUs each, you can specify --enable-dp-attention --tp 16 --dp 2. This configuration runs attention with 2 DP groups, each containing 8 TP GPUs.

Reference: Check Blog.

Multi Node Tensor Parallelism#

Description: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node’s memory.

Usage: Check here for usage examples.

Block-wise FP8#

Description: SGLang implements block-wise FP8 quantization with two key optimizations:

  • Activation: E4M3 format using per-token-per-128-channel sub-vector scales with online casting.

  • Weight: Per-128x128-block quantization for better numerical stability.

  • DeepGEMM: The DeepGEMM kernel library optimized for FP8 matrix multiplications.

Usage: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable SGL_ENABLE_JIT_DEEPGEMM=0.

Before serving the DeepSeek model, precompile the DeepGEMM kernels using:

python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code

The precompilation process typically takes around 10 minutes to complete.

Multi-token Prediction#

Description: SGLang implements DeepSeek V3 Multi-Token Prediction (MTP) based on EAGLE speculative decoding. With this optimization, the decoding speed can be improved by 1.8x for batch size 1 and 1.5x for batch size 32 respectively on H200 TP8 setting.

Usage: Add arguments --speculative-algorithm, --speculative-num-steps, --speculative-eagle-topk and --speculative-num-draft-tokens to enable this feature. For example:

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
  • The best configuration for --speculative-num-steps, --speculative-eagle-topk and --speculative-num-draft-tokens can be searched with bench_speculative.py script for given batch size. The minimum configuration is --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2, which can achieve speedup for larger batch sizes.

  • FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (--attention-backend flashinfer) with speculative decoding,--speculative-eagle-topk parameter should be set to 1. MTP support for the CutlassMLA backend is still under development.

  • To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference this discussion):

    • Adjust --max-running-requests to a larger number. The default value is 32 for MTP. For larger batch sizes, you should increase this value beyond the default value.

    • Set --cuda-graph-bs. It’s a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set here. You can include more batch sizes into it.

Reasoning Content for DeepSeek R1#

See Separate Reasoning.

Function calling for DeepSeek Models#

Add arguments --tool-call-parser deepseekv3 and --chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja(recommended) to enable this feature. For example (running on 1 * H20 node):

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --tool-call-parser deepseekv3 --chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja

Sample Request:

curl "http://127.0.0.1:30000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{"temperature": 0, "max_tokens": 100, "model": "deepseek-ai/DeepSeek-V3-0324", "tools": [{"type": "function", "function": {"name": "query_weather", "description": "Get weather of an city, the user should supply a city first", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city, e.g. Beijing"}}, "required": ["city"]}}}], "messages": [{"role": "user", "content": "Hows the weather like in Qingdao today"}]}'

Expected Response

{"id":"6501ef8e2d874006bf555bc80cddc7c5","object":"chat.completion","created":1745993638,"model":"deepseek-ai/DeepSeek-V3-0324","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"id":"0","index":null,"type":"function","function":{"name":"query_weather","arguments":"{\"city\": \"Qingdao\"}"}}]},"logprobs":null,"finish_reason":"tool_calls","matched_stop":null}],"usage":{"prompt_tokens":116,"total_tokens":138,"completion_tokens":22,"prompt_tokens_details":null}}

Sample Streaming Request:

curl "http://127.0.0.1:30000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{"temperature": 0, "max_tokens": 100, "model": "deepseek-ai/DeepSeek-V3-0324","stream":true,"tools": [{"type": "function", "function": {"name": "query_weather", "description": "Get weather of an city, the user should supply a city first", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city, e.g. Beijing"}}, "required": ["city"]}}}], "messages": [{"role": "user", "content": "Hows the weather like in Qingdao today"}]}'

Expected Streamed Chunks (simplified for clarity):

data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"{\""}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"city"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"\":\""}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"Q"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"ing"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"dao"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"\"}"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":null}}], "finish_reason": "tool_calls"}
data: [DONE]

The client needs to concatenate all arguments fragments to reconstruct the complete tool call:

{"city": "Qingdao"}

Important Notes:

  1. Use a lower "temperature" value for better results.

  2. To receive more consistent tool call results, it is recommended to use --chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja. It provides an improved unified prompt.

FAQ#

Q: Model loading is taking too long, and I’m encountering an NCCL timeout. What should I do?

A: If you’re experiencing extended model loading times and an NCCL timeout, you can try increasing the timeout duration. Add the argument --dist-timeout 3600 when launching your model. This will set the timeout to one hour, which often resolves the issue.