Speculative Decoding#

SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines. Note: Currently, Speculative Decoding in SGLang is compatible with radix cache and chunked prefill.

Performance Highlights#

Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.

Method

Throughput (tokens/s)

SGLang (w/o speculative, 1x H100)

158.34 tokens/s

SGLang + EAGLE-2 (1x H100)

244.10 tokens/s

SGLang + EAGLE-3 (1x H100)

373.25 tokens/s

EAGLE Decoding#

To enable EAGLE speculative decoding the following parameters are relevant:

  • speculative_draft_model_path: Specifies draft model. This parameter is required.

  • speculative_num_steps: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.

  • speculative_eagle_topk: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.

  • speculative_num_draft_tokens: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.

These parameters are the same for EAGLE-2 and EAGLE-3.

EAGLE-2 decoding#

You can enable EAGLE-2 decoding by setting --speculative_algorithm EAGLE and choosing an appropriate model.

[1]:
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process

import openai
[2]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
    --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8
"""
)

wait_for_server(f"http://localhost:{port}")
Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-05-29 23:45:53] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=36444, mem_fraction_static=0.836, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=719151451, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=8, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-29 23:45:53] Infer the chat template name from the model path and obtain the result: llama-2.
[2025-05-29 23:46:04] Attention backend not set. Use flashinfer backend by default.
[2025-05-29 23:46:04] Init torch distributed begin.
[2025-05-29 23:46:04] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:46:04] init_expert_location from trivial
[2025-05-29 23:46:04] Load weight begin. avail mem=78.50 GB
[2025-05-29 23:46:05] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.84s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.29s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.37s/it]

[2025-05-29 23:46:08] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=65.93 GB, mem usage=12.57 GB.
[2025-05-29 23:46:08] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-05-29 23:46:08] Memory pool end. avail mem=55.73 GB
2025-05-29 23:46:08,398 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-29 23:46:08] Init torch distributed begin.
[2025-05-29 23:46:08] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:46:08] Load weight begin. avail mem=55.16 GB
[2025-05-29 23:46:08] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.26s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.26s/it]

[2025-05-29 23:46:10] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=54.23 GB, mem usage=0.93 GB.
[2025-05-29 23:46:10] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-05-29 23:46:10] Memory pool end. avail mem=53.92 GB
[2025-05-29 23:46:10] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-05-29 23:46:10] INFO:     Started server process [2798982]
[2025-05-29 23:46:10] INFO:     Waiting for application startup.
[2025-05-29 23:46:10] INFO:     Application startup complete.
[2025-05-29 23:46:10] INFO:     Uvicorn running on http://127.0.0.1:36444 (Press CTRL+C to quit)
[2025-05-29 23:46:10] INFO:     127.0.0.1:58788 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-29 23:46:11] INFO:     127.0.0.1:58796 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-29 23:46:11] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:46:12,060 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:46:12,074 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:46:12,081 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:46:12,092 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:46:12,495 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:46:12,507 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:46:14,454 - INFO - flashinfer.jit: Loading JIT ops: quantization
2025-05-29 23:46:14,467 - INFO - flashinfer.jit: Finished loading JIT ops: quantization
[2025-05-29 23:46:14] INFO:     127.0.0.1:58800 - "POST /generate HTTP/1.1" 200 OK
[2025-05-29 23:46:14] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[3]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-05-29 23:46:16] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:46:16,167 - INFO - flashinfer.jit: Loading JIT ops: cascade
2025-05-29 23:46:16,180 - INFO - flashinfer.jit: Finished loading JIT ops: cascade
[2025-05-29 23:46:16] INFO:     127.0.0.1:57680 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='0cdf60ecb3da42058be1108f02967353', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1748562376, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[4]:
terminate_process(server_process)
[2025-05-29 23:46:16] Child process unexpectedly failed with an exit code 9. pid=2799263

EAGLE-2 Decoding with torch.compile#

You can also enable torch.compile for further optimizations and optionally set --torch-compile-max-bs:

[5]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
            --enable-torch-compile --torch-compile-max-bs 2
"""
)

wait_for_server(f"http://localhost:{port}")
Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-05-29 23:46:29] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=32855, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=720784246, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=True, torch_compile_max_bs=2, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-29 23:46:29] Infer the chat template name from the model path and obtain the result: llama-2.
[2025-05-29 23:46:40] Attention backend not set. Use flashinfer backend by default.
[2025-05-29 23:46:40] Init torch distributed begin.
[2025-05-29 23:46:40] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:46:40] init_expert_location from trivial
[2025-05-29 23:46:40] Load weight begin. avail mem=78.50 GB
[2025-05-29 23:46:40] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.84s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.30s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.38s/it]

[2025-05-29 23:46:43] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=65.93 GB, mem usage=12.57 GB.
[2025-05-29 23:46:43] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-05-29 23:46:43] Memory pool end. avail mem=55.73 GB
2025-05-29 23:46:44,043 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-29 23:46:44] Init torch distributed begin.
[2025-05-29 23:46:44] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:46:44] Load weight begin. avail mem=55.16 GB
[2025-05-29 23:46:44] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.28s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.28s/it]

[2025-05-29 23:46:45] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=54.23 GB, mem usage=0.93 GB.
[2025-05-29 23:46:45] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-05-29 23:46:45] Memory pool end. avail mem=53.92 GB
[2025-05-29 23:46:46] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-05-29 23:46:46] INFO:     Started server process [2800005]
[2025-05-29 23:46:46] INFO:     Waiting for application startup.
[2025-05-29 23:46:46] INFO:     Application startup complete.
[2025-05-29 23:46:46] INFO:     Uvicorn running on http://127.0.0.1:32855 (Press CTRL+C to quit)
[2025-05-29 23:46:46] INFO:     127.0.0.1:49628 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-29 23:46:47] INFO:     127.0.0.1:49642 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-29 23:46:47] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:46:47,752 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:46:47,766 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:46:47,772 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:46:47,782 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:46:48,196 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:46:48,212 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:46:50,152 - INFO - flashinfer.jit: Loading JIT ops: quantization
2025-05-29 23:46:50,164 - INFO - flashinfer.jit: Finished loading JIT ops: quantization
[2025-05-29 23:46:50] INFO:     127.0.0.1:49648 - "POST /generate HTTP/1.1" 200 OK
[2025-05-29 23:46:50] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[6]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-05-29 23:46:51] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:46:51,630 - INFO - flashinfer.jit: Loading JIT ops: cascade
2025-05-29 23:46:51,642 - INFO - flashinfer.jit: Finished loading JIT ops: cascade
[2025-05-29 23:46:51] INFO:     127.0.0.1:49654 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='a0a73f08005b462e9074f364577f99fb', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1748562411, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[7]:
terminate_process(server_process)
[2025-05-29 23:46:51] Child process unexpectedly failed with an exit code 9. pid=2800285

EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#

By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.

In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.

Thanks for the contribution from Weilin Zhao and Zhousx.

[8]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
    --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16
"""
)

wait_for_server(f"http://localhost:{port}")
Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-05-29 23:47:02] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3-8B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=31345, mem_fraction_static=0.7, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=187492853, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-LLaMA3-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map='thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-29 23:47:02] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:47:12] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:47:13] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:47:13] Attention backend not set. Use flashinfer backend by default.
[2025-05-29 23:47:13] Init torch distributed begin.
[2025-05-29 23:47:13] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:47:13] init_expert_location from trivial
[2025-05-29 23:47:14] Load weight begin. avail mem=78.50 GB
[2025-05-29 23:47:15] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:13,  4.59s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:08,  4.18s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:12<00:03,  3.95s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00,  2.85s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00,  3.33s/it]

[2025-05-29 23:47:28] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=63.52 GB, mem usage=14.98 GB.
[2025-05-29 23:47:28] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-05-29 23:47:28] Memory pool end. avail mem=60.83 GB
2025-05-29 23:47:28,979 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-29 23:47:29] Warning: User-specified context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-05-29 23:47:29] Init torch distributed begin.
[2025-05-29 23:47:29] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:47:29] Load weight begin. avail mem=60.25 GB
[2025-05-29 23:47:29] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.11s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.11s/it]

[2025-05-29 23:47:30] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=58.55 GB, mem usage=1.70 GB.
[2025-05-29 23:47:30] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-05-29 23:47:30] Memory pool end. avail mem=58.47 GB
[2025-05-29 23:47:31] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=8192
[2025-05-29 23:47:32] INFO:     Started server process [2801027]
[2025-05-29 23:47:32] INFO:     Waiting for application startup.
[2025-05-29 23:47:32] INFO:     Application startup complete.
[2025-05-29 23:47:32] INFO:     Uvicorn running on http://127.0.0.1:31345 (Press CTRL+C to quit)
[2025-05-29 23:47:33] INFO:     127.0.0.1:35400 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-29 23:47:33] INFO:     127.0.0.1:35406 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-29 23:47:33] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:47:33,640 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:47:33,654 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:47:33,660 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:47:33,671 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:47:34,078 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:47:34,090 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:47:36,127 - INFO - flashinfer.jit: Loading JIT ops: quantization
2025-05-29 23:47:36,140 - INFO - flashinfer.jit: Finished loading JIT ops: quantization
[2025-05-29 23:47:36] INFO:     127.0.0.1:35410 - "POST /generate HTTP/1.1" 200 OK
[2025-05-29 23:47:36] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[9]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-05-29 23:47:38] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:47:38,099 - INFO - flashinfer.jit: Loading JIT ops: cascade
2025-05-29 23:47:38,111 - INFO - flashinfer.jit: Finished loading JIT ops: cascade
[2025-05-29 23:47:38] INFO:     127.0.0.1:35422 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='9ecc18c22cbf4b1dbf05f7fb0c6de784', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. **France** - **Paris**\n2. **Japan** - **Tokyo**\n3. **Australia** - **Canberra**', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1748562458, model='meta-llama/Meta-Llama-3-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=39, prompt_tokens=18, total_tokens=57, completion_tokens_details=None, prompt_tokens_details=None))
[10]:
terminate_process(server_process)
[2025-05-29 23:47:38] Child process unexpectedly failed with an exit code 9. pid=2801307
[2025-05-29 23:47:38] Child process unexpectedly failed with an exit code 9. pid=2801173

EAGLE-3 Decoding#

You can enable EAGLE-3 decoding by setting --speculative_algorithm EAGLE3 and choosing an appropriate model.

[11]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct  --speculative-algorithm EAGLE3 \
    --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
        --cuda-graph-max-bs 2 --dtype float16
"""
)

wait_for_server(f"http://localhost:{port}")
Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-05-29 23:47:50] server_args=ServerArgs(model_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.1-8B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=39060, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=872261174, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE3', speculative_draft_model_path='jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=32, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-29 23:47:50] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:48:01] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:48:01] Casting torch.bfloat16 to torch.float16.
[2025-05-29 23:48:01] Attention backend not set. Use flashinfer backend by default.
[2025-05-29 23:48:01] Init torch distributed begin.
[2025-05-29 23:48:01] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:48:01] init_expert_location from trivial
[2025-05-29 23:48:02] Load weight begin. avail mem=78.50 GB
[2025-05-29 23:48:02] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:13,  4.62s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:09<00:09,  4.54s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:13<00:04,  4.29s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.15s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.62s/it]

[2025-05-29 23:48:17] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=63.48 GB, mem usage=15.02 GB.
[2025-05-29 23:48:17] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-05-29 23:48:17] Memory pool end. avail mem=60.68 GB
2025-05-29 23:48:17,499 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-29 23:48:17] Warning: User-specified context_length (131072) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-05-29 23:48:18] Init torch distributed begin.
[2025-05-29 23:48:18] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:48:18] Load weight begin. avail mem=60.11 GB
[2025-05-29 23:48:18] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.76it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.76it/s]

[2025-05-29 23:48:19] Load weight end. type=LlamaForCausalLMEagle3, dtype=torch.float16, avail mem=58.34 GB, mem usage=1.77 GB.
[2025-05-29 23:48:19] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-05-29 23:48:19] Memory pool end. avail mem=58.26 GB
[2025-05-29 23:48:19] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-05-29 23:48:20] INFO:     Started server process [2802050]
[2025-05-29 23:48:20] INFO:     Waiting for application startup.
[2025-05-29 23:48:20] INFO:     Application startup complete.
[2025-05-29 23:48:20] INFO:     Uvicorn running on http://127.0.0.1:39060 (Press CTRL+C to quit)
[2025-05-29 23:48:20] INFO:     127.0.0.1:60612 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-29 23:48:21] INFO:     127.0.0.1:60614 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-29 23:48:21] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:48:21,709 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:48:21,724 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:48:21,731 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:48:21,741 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:48:22,134 - INFO - flashinfer.jit: Loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:48:22,146 - INFO - flashinfer.jit: Finished loading JIT ops: batch_decode_with_kv_cache_dtype_q_f16_dtype_kv_f16_dtype_o_f16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False
2025-05-29 23:48:24,093 - INFO - flashinfer.jit: Loading JIT ops: quantization
2025-05-29 23:48:24,106 - INFO - flashinfer.jit: Finished loading JIT ops: quantization
[2025-05-29 23:48:24] INFO:     127.0.0.1:60624 - "POST /generate HTTP/1.1" 200 OK
[2025-05-29 23:48:24] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[12]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-05-29 23:48:25] Prefill batch. #new-seq: 1, #new-token: 42, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:48:25,620 - INFO - flashinfer.jit: Loading JIT ops: cascade
2025-05-29 23:48:25,632 - INFO - flashinfer.jit: Finished loading JIT ops: cascade
[2025-05-29 23:48:25] INFO:     127.0.0.1:54062 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='07c23ea7001d4569aba74858dad66712', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n Capital: Tokyo\n\n2. Country: Australia\n Capital: Canberra\n\n3. Country: Brazil\n Capital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1748562505, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=43, total_tokens=86, completion_tokens_details=None, prompt_tokens_details=None))
[13]:
terminate_process(server_process)
[2025-05-29 23:48:25] Child process unexpectedly failed with an exit code 9. pid=2802330

Multi Token Prediction#

We support MTP(Multi-Token Prediction) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to deepseek doc)

[14]:
server_process, port = launch_server_cmd(
    """
    python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \
    --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \
    --mem-fraction 0.5
"""
)

wait_for_server(f"http://localhost:{port}")
Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-05-29 23:48:36] server_args=ServerArgs(model_path='XiaomiMiMo/MiMo-7B-RL', tokenizer_path='XiaomiMiMo/MiMo-7B-RL', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=True, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='XiaomiMiMo/MiMo-7B-RL', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='0.0.0.0', port=36268, mem_fraction_static=0.5, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=990595073, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path=None, speculative_num_steps=1, speculative_eagle_topk=1, speculative_num_draft_tokens=2, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-29 23:48:49] Attention backend not set. Use flashinfer backend by default.
[2025-05-29 23:48:49] Init torch distributed begin.
[2025-05-29 23:48:49] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:48:49] init_expert_location from trivial
[2025-05-29 23:48:49] Load weight begin. avail mem=78.50 GB
[2025-05-29 23:48:50] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.37it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.26it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.22it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.26it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.26it/s]

[2025-05-29 23:48:53] Load weight end. type=MiMoForCausalLM, dtype=torch.bfloat16, avail mem=64.28 GB, mem usage=14.21 GB.
[2025-05-29 23:48:53] KV Cache is allocated. #tokens: 20480, K size: 1.41 GB, V size: 1.41 GB
[2025-05-29 23:48:53] Memory pool end. avail mem=61.23 GB
2025-05-29 23:48:53,655 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-29 23:48:54] Init torch distributed begin.
[2025-05-29 23:48:54] Init torch distributed ends. mem usage=0.00 GB
[2025-05-29 23:48:54] Load weight begin. avail mem=60.54 GB
[2025-05-29 23:48:54] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  4.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  7.30it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  7.01it/s]

[2025-05-29 23:48:55] Load weight end. type=MiMoMTP, dtype=torch.bfloat16, avail mem=57.83 GB, mem usage=2.71 GB.
[2025-05-29 23:48:55] KV Cache is allocated. #tokens: 20480, K size: 1.41 GB, V size: 1.41 GB
[2025-05-29 23:48:55] Memory pool end. avail mem=54.88 GB
[2025-05-29 23:48:55] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=32768
[2025-05-29 23:48:56] INFO:     Started server process [2803071]
[2025-05-29 23:48:56] INFO:     Waiting for application startup.
[2025-05-29 23:48:56] INFO:     Application startup complete.
[2025-05-29 23:48:56] INFO:     Uvicorn running on http://0.0.0.0:36268 (Press CTRL+C to quit)
[2025-05-29 23:48:56] INFO:     127.0.0.1:37694 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-29 23:48:57] INFO:     127.0.0.1:37704 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-29 23:48:57] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-29 23:48:57,549 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:48:57,568 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False_sm90
2025-05-29 23:48:57,575 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:48:57,586 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-29 23:48:58,576 - INFO - flashinfer.jit: Loading JIT ops: quantization
2025-05-29 23:48:58,589 - INFO - flashinfer.jit: Finished loading JIT ops: quantization
[2025-05-29 23:48:58] INFO:     127.0.0.1:37714 - "POST /generate HTTP/1.1" 200 OK
[2025-05-29 23:48:58] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[15]:
import requests

url = f"http://localhost:{port}/v1/chat/completions"

data = {
    "model": "XiaomiMiMo/MiMo-7B-RL",
    "messages": [{"role": "user", "content": "What is the capital of France?"}],
}

response = requests.post(url, json=data)
print_highlight(response.json())
[2025-05-29 23:49:01] Prefill batch. #new-seq: 1, #new-token: 26, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-29 23:49:02] Decode batch. #running-req: 1, #token: 91, token usage: 0.00, accept len: 1.76, cuda graph: False, gen throughput (token/s): 10.13, #queue-req: 0
[2025-05-29 23:49:03] Decode batch. #running-req: 1, #token: 157, token usage: 0.01, accept len: 1.65, cuda graph: False, gen throughput (token/s): 100.05, #queue-req: 0
[2025-05-29 23:49:03] Decode batch. #running-req: 1, #token: 227, token usage: 0.01, accept len: 1.75, cuda graph: False, gen throughput (token/s): 108.34, #queue-req: 0
[2025-05-29 23:49:04] Decode batch. #running-req: 1, #token: 299, token usage: 0.01, accept len: 1.80, cuda graph: False, gen throughput (token/s): 109.81, #queue-req: 0
[2025-05-29 23:49:05] INFO:     127.0.0.1:37722 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': 'b9461f3c5ab94874baf44f1847e5ef68', 'object': 'chat.completion', 'created': 1748562541, 'model': 'XiaomiMiMo/MiMo-7B-RL', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': '\nOkay, so the user is asking, "What is the capital of France?" Let me start by recalling what I know about France. France is a country in Europe, right? I remember that Paris is a major city there. But wait, is Paris actually the capital? Let me make sure I\'m not confusing it with another city. Maybe some people think it\'s another place, but I\'m pretty sure Paris is the capital.\n\nWait, maybe there\'s a historical reason for that. Before Paris, was there another capital? No, I think Paris has been the capital for a long time. Let me think about regions in France. There\'s Provence, which is in the south, but that\'s not the capital. Lyon is a big city, but I don\'t think it\'s the capital. Then there\'s Bretagne in the west, but again, not the capital. So Paris must be the correct answer. \n\nAlso, cultural landmarks. The Eiffel Tower, the Louvre Museum, those are all in Paris. The French government seat is in Paris too, right? The parliament, the presidential palace, all those institutions. So yes, Paris is definitely the capital. I don\'t think there\'s any other city that serves as the capital of France. Maybe some people might confuse it with other cities, but no, Paris is it. So the answer should be straightforward.\n\n\nThe capital of France is **Paris**. This vibrant city is renowned for its iconic landmarks like the Eiffel Tower, Louvre Museum, and charming neighborhoods such as the Left Bank. It serves as the political, cultural, and economic heart of the country.', 'reasoning_content': None, 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 151645}], 'usage': {'prompt_tokens': 26, 'total_tokens': 367, 'completion_tokens': 341, 'prompt_tokens_details': None}}
[16]:
terminate_process(server_process)
[2025-05-29 23:49:05] Child process unexpectedly failed with an exit code 9. pid=2803352

References#

EAGLE process is as follows:

  • Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).

  • The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the speculative_eagle_topk parameter—to ensure a more coherent connection of context, and are given as input again.

  • EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top speculative_num_draft_tokens final nodes as draft tokens.

  • EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.

This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.

For guidance how to train your own EAGLE model please see the EAGLE repo.