Speculative Decoding#

SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.

Note: Currently, Speculative Decoding in SGLang does not support radix cache.

Performance Highlights#

  • Official EAGLE code (SafeAILab/EAGLE): ~200 tokens/s

  • Standard SGLang Decoding: ~156 tokens/s

  • EAGLE Decoding in SGLang: ~297 tokens/s

  • EAGLE Decoding in SGLang (w/ torch.compile): ~316 tokens/s

All benchmarks below were run on a single H100.

EAGLE Decoding#

To enable EAGLE-based speculative decoding, specify the draft model (--speculative-draft) and the relevant EAGLE parameters:

[1]:
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process

server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE \
    --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-02-23 08:34:33] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=39443, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=-1, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=True, tp_size=1, stream_interval=1, stream_output=False, random_seed=402206106, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path='lmzheng/sglang-EAGLE-llama2-chat-7B', speculative_algorithm='EAGLE', speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=True, disable_jump_forward=False, disable_cuda_graph=True, disable_cuda_graph_padding=True, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
[2025-02-23 08:34:53 TP0] Init torch distributed begin.
[2025-02-23 08:34:54 TP0] Load weight begin. avail mem=59.32 GB
[2025-02-23 08:34:55 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.35s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.24s/it]

[2025-02-23 08:34:58 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=32.11 GB
[2025-02-23 08:34:58 TP0] KV Cache is allocated. K size: 5.00 GB, V size: 5.00 GB.
[2025-02-23 08:34:58 TP0] Memory pool end. avail mem=21.99 GB
[2025-02-23 08:34:59 TP0] Init torch distributed begin.
[2025-02-23 08:34:59 TP0] Load weight begin. avail mem=21.42 GB
[2025-02-23 08:34:59 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
/public_sglang_ci/runner-c-gpu-1/_work/sglang/sglang/python/sglang/srt/model_loader/weight_utils.py:447: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.28s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.28s/it]

[2025-02-23 08:35:00 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=20.49 GB
[2025-02-23 08:35:00 TP0] KV Cache is allocated. K size: 0.16 GB, V size: 0.16 GB.
[2025-02-23 08:35:00 TP0] Memory pool end. avail mem=20.17 GB
[2025-02-23 08:35:00 TP0] max_total_num_tokens=20480, chunked_prefill_size=-1, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-02-23 08:35:00] INFO:     Started server process [1030648]
[2025-02-23 08:35:00] INFO:     Waiting for application startup.
[2025-02-23 08:35:00] INFO:     Application startup complete.
[2025-02-23 08:35:00] INFO:     Uvicorn running on http://127.0.0.1:39443 (Press CTRL+C to quit)
[2025-02-23 08:35:01] INFO:     127.0.0.1:35618 - "GET /v1/models HTTP/1.1" 200 OK
[2025-02-23 08:35:01] INFO:     127.0.0.1:35628 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-02-23 08:35:01 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:35:05] INFO:     127.0.0.1:35640 - "POST /generate HTTP/1.1" 200 OK
[2025-02-23 08:35:05] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[2]:
import openai

client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-02-23 08:35:06 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:35:06] INFO:     127.0.0.1:35656 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='da6119ed7e3942538eaf6e262621abea', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=2)], created=1740299706, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[3]:
terminate_process(server_process)

EAGLE Decoding with torch.compile#

You can also enable torch.compile for further optimizations and optionally set --cuda-graph-max-bs:

[4]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algo EAGLE \
    --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
            --enable-torch-compile --cuda-graph-max-bs 2
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-02-23 08:35:20] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=38402, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=-1, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=True, tp_size=1, stream_interval=1, stream_output=False, random_seed=622916098, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path='lmzheng/sglang-EAGLE-llama2-chat-7B', speculative_algorithm='EAGLE', speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=True, disable_jump_forward=False, disable_cuda_graph=True, disable_cuda_graph_padding=True, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=True, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
[2025-02-23 08:35:38 TP0] Init torch distributed begin.
[2025-02-23 08:35:38 TP0] Load weight begin. avail mem=59.83 GB
[2025-02-23 08:35:40 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.62it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.26s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.17s/it]

[2025-02-23 08:35:42 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=66.15 GB
[2025-02-23 08:35:42 TP0] KV Cache is allocated. K size: 5.00 GB, V size: 5.00 GB.
[2025-02-23 08:35:42 TP0] Memory pool end. avail mem=56.03 GB
[2025-02-23 08:35:43 TP0] Init torch distributed begin.
[2025-02-23 08:35:43 TP0] Load weight begin. avail mem=55.45 GB
[2025-02-23 08:35:43 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
/public_sglang_ci/runner-c-gpu-1/_work/sglang/sglang/python/sglang/srt/model_loader/weight_utils.py:447: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state = torch.load(bin_file, map_location="cpu")
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.30s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.30s/it]

[2025-02-23 08:35:44 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=54.53 GB
[2025-02-23 08:35:44 TP0] KV Cache is allocated. K size: 0.16 GB, V size: 0.16 GB.
[2025-02-23 08:35:44 TP0] Memory pool end. avail mem=54.21 GB
[2025-02-23 08:35:45 TP0] max_total_num_tokens=20480, chunked_prefill_size=-1, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-02-23 08:35:45] INFO:     Started server process [1033198]
[2025-02-23 08:35:45] INFO:     Waiting for application startup.
[2025-02-23 08:35:45] INFO:     Application startup complete.
[2025-02-23 08:35:45] INFO:     Uvicorn running on http://127.0.0.1:38402 (Press CTRL+C to quit)
[2025-02-23 08:35:45] INFO:     127.0.0.1:55980 - "GET /v1/models HTTP/1.1" 200 OK
[2025-02-23 08:35:46] INFO:     127.0.0.1:55992 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-02-23 08:35:46 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:35:49] INFO:     127.0.0.1:55998 - "POST /generate HTTP/1.1" 200 OK
[2025-02-23 08:35:49] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[5]:
import openai

client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-02-23 08:35:50 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:35:51] INFO:     127.0.0.1:56014 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='76abc0ca3ed94497b9e8c2b21600849d', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=2)], created=1740299751, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[6]:
terminate_process(server_process)