Speculative Decoding#
SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines. Note: Currently, Speculative Decoding in SGLang is compatible with radix cache and chunked prefill.
Performance Highlights#
Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be archieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.
Method |
Throughput (tokens/s) |
---|---|
SGLang (w/o speculative, 1x H100) |
158.34 tokens/s |
SGLang + EAGLE-2 (1x H100) |
244.10 tokens/s |
SGLang + EAGLE-3 (1x H100) |
373.25 tokens/s |
EAGLE Decoding#
To enable EAGLE speculative decoding the following parameters are relevant:
speculative_draft_model_path
: Specifies draft model. This parameter is required.speculative_num_steps
: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.speculative_eagle_topk
: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.speculative_num_draft_tokens
: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.
These parameters are the same for EAGLE-2 and EAGLE-3.
EAGLE-2 decoding#
You can enable EAGLE-2 decoding by setting --speculative_algorithm EAGLE
and choosing an appropriate model.
[1]:
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
import openai
[2]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
--speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:28:14] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=34089, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=38140663, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=8, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:28:33 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:28:33 TP0] Init torch distributed begin.
[2025-04-13 23:28:33 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:28:33 TP0] Load weight begin. avail mem=61.47 GB
[2025-04-13 23:28:34 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:28:34 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.56s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.08s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.16s/it]
[2025-04-13 23:28:37 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=48.81 GB, mem usage=12.66 GB.
[2025-04-13 23:28:38 TP0] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-04-13 23:28:38 TP0] Memory pool end. avail mem=38.62 GB
[2025-04-13 23:28:38 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:28:39 TP0] Init torch distributed begin.
[2025-04-13 23:28:39 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:28:39 TP0] Load weight begin. avail mem=38.04 GB
[2025-04-13 23:28:39 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.15s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.15s/it]
[2025-04-13 23:28:40 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=37.11 GB, mem usage=0.93 GB.
[2025-04-13 23:28:40 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-04-13 23:28:40 TP0] Memory pool end. avail mem=36.80 GB
[2025-04-13 23:28:40 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:28:41 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-04-13 23:28:41] INFO: Started server process [2797081]
[2025-04-13 23:28:41] INFO: Waiting for application startup.
[2025-04-13 23:28:41] INFO: Application startup complete.
[2025-04-13 23:28:41] INFO: Uvicorn running on http://127.0.0.1:34089 (Press CTRL+C to quit)
[2025-04-13 23:28:42] INFO: 127.0.0.1:34586 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:28:42] INFO: 127.0.0.1:34592 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:28:42 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:28:46] INFO: 127.0.0.1:34604 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:28:46] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[3]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-04-13 23:28:47 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:28:47] INFO: 127.0.0.1:34620 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[4]:
terminate_process(server_process)
[2025-04-13 23:28:47] Child process unexpectedly failed with an exit code 9. pid=2797836
[2025-04-13 23:28:47] Child process unexpectedly failed with an exit code 9. pid=2797633
EAGLE-2 Decoding with torch.compile
#
You can also enable torch.compile
for further optimizations and optionally set --torch-compile-max-bs
:
[5]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
--enable-torch-compile --torch-compile-max-bs 2
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:28:57] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=35740, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=968882636, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=True, torch_compile_max_bs=2, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:29:07 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:29:07 TP0] Init torch distributed begin.
[2025-04-13 23:29:07 TP0] Init torch distributed ends. mem usage=0.02 GB
[2025-04-13 23:29:07 TP0] Load weight begin. avail mem=59.40 GB
[2025-04-13 23:29:08 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:29:08 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.65s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.12s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.20s/it]
[2025-04-13 23:29:11 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=46.63 GB, mem usage=12.77 GB.
[2025-04-13 23:29:11 TP0] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-04-13 23:29:11 TP0] Memory pool end. avail mem=36.43 GB
[2025-04-13 23:29:12 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:29:12 TP0] Init torch distributed begin.
[2025-04-13 23:29:12 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:29:12 TP0] Load weight begin. avail mem=55.16 GB
[2025-04-13 23:29:12 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.30s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.30s/it]
[2025-04-13 23:29:13 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=54.23 GB, mem usage=0.93 GB.
[2025-04-13 23:29:13 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-04-13 23:29:13 TP0] Memory pool end. avail mem=53.92 GB
[2025-04-13 23:29:13 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:29:14 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-04-13 23:29:14] INFO: Started server process [2799655]
[2025-04-13 23:29:14] INFO: Waiting for application startup.
[2025-04-13 23:29:14] INFO: Application startup complete.
[2025-04-13 23:29:14] INFO: Uvicorn running on http://127.0.0.1:35740 (Press CTRL+C to quit)
[2025-04-13 23:29:14] INFO: 127.0.0.1:53528 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:29:15] INFO: 127.0.0.1:53530 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:29:15 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[6]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-04-13 23:29:20 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 23:29:21] INFO: 127.0.0.1:53534 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:29:21] The server is fired up and ready to roll!
[2025-04-13 23:29:21] INFO: 127.0.0.1:47450 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[7]:
terminate_process(server_process)
[2025-04-13 23:29:21] Child process unexpectedly failed with an exit code 9. pid=2800510
[2025-04-13 23:29:21] Child process unexpectedly failed with an exit code 9. pid=2800286
EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#
By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head
computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.
In our implementation, set --speculative-token-map
to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.
Thanks for the contribution from Weilin Zhao and Zhousx.
[8]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
--mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:29:30] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3-8B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=31635, mem_fraction_static=0.7, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=473383947, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-LLaMA3-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map='thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:29:30] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:29:41 TP0] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:29:41 TP0] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:29:41 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:29:41 TP0] Init torch distributed begin.
[2025-04-13 23:29:42 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:29:42 TP0] Load weight begin. avail mem=63.64 GB
[2025-04-13 23:29:42 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:29:42 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:04<00:12, 4.22s/it]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:07<00:07, 3.81s/it]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:11<00:03, 3.69s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00, 2.74s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00, 3.14s/it]
[2025-04-13 23:29:55 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=46.24 GB, mem usage=17.40 GB.
[2025-04-13 23:29:55 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-04-13 23:29:55 TP0] Memory pool end. avail mem=43.54 GB
[2025-04-13 23:29:56 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:29:56 TP0] Warning: User-specified context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-04-13 23:29:56 TP0] Init torch distributed begin.
[2025-04-13 23:29:56 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:29:56 TP0] Load weight begin. avail mem=42.97 GB
[2025-04-13 23:29:57 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.11s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.11s/it]
[2025-04-13 23:29:58 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=41.27 GB, mem usage=1.70 GB.
[2025-04-13 23:29:58 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-04-13 23:29:58 TP0] Memory pool end. avail mem=41.19 GB
[2025-04-13 23:29:58 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:29:59 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=8192
[2025-04-13 23:29:59] INFO: Started server process [2802239]
[2025-04-13 23:29:59] INFO: Waiting for application startup.
[2025-04-13 23:29:59] INFO: Application startup complete.
[2025-04-13 23:29:59] INFO: Uvicorn running on http://127.0.0.1:31635 (Press CTRL+C to quit)
[2025-04-13 23:30:00] INFO: 127.0.0.1:60142 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:30:00] INFO: 127.0.0.1:60154 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:30:00 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[9]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-04-13 23:30:05 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 23:30:07] INFO: 127.0.0.1:60162 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:30:07] The server is fired up and ready to roll!
[2025-04-13 23:30:07] INFO: 127.0.0.1:60176 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[10]:
terminate_process(server_process)
[2025-04-13 23:30:07] Child process unexpectedly failed with an exit code 9. pid=2803207
[2025-04-13 23:30:07] Child process unexpectedly failed with an exit code 9. pid=2803013
EAGLE-3 Decoding#
You can enable EAGLE-3 decoding by setting --speculative_algorithm EAGLE3
and choosing an appropriate model.
[11]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algorithm EAGLE3 \
--speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
--cuda-graph-max-bs 2 --dtype float16
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:30:16] server_args=ServerArgs(model_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.1-8B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=38565, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=872681703, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm='EAGLE3', speculative_draft_model_path='jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=32, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:30:16] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:30:26 TP0] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:30:27 TP0] Casting torch.bfloat16 to torch.float16.
[2025-04-13 23:30:27 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:30:27 TP0] Init torch distributed begin.
[2025-04-13 23:30:28 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:30:28 TP0] Load weight begin. avail mem=44.25 GB
[2025-04-13 23:30:28 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:30:30 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:04<00:12, 4.22s/it]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:07<00:07, 3.88s/it]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:11<00:03, 3.77s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00, 2.75s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:12<00:00, 3.17s/it]
[2025-04-13 23:30:43 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=46.14 GB, mem usage=-1.89 GB.
[2025-04-13 23:30:43 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-04-13 23:30:43 TP0] Memory pool end. avail mem=43.34 GB
[2025-04-13 23:30:43 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:30:44 TP0] Warning: User-specified context_length (131072) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-04-13 23:30:44 TP0] Init torch distributed begin.
[2025-04-13 23:30:44 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:30:44 TP0] Load weight begin. avail mem=42.77 GB
[2025-04-13 23:30:44 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.99it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.99it/s]
[2025-04-13 23:30:45 TP0] Load weight end. type=LlamaForCausalLMEagle3, dtype=torch.float16, avail mem=41.00 GB, mem usage=1.77 GB.
[2025-04-13 23:30:45 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-04-13 23:30:45 TP0] Memory pool end. avail mem=40.92 GB
[2025-04-13 23:30:45 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:30:45 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-04-13 23:30:46] INFO: Started server process [2805051]
[2025-04-13 23:30:46] INFO: Waiting for application startup.
[2025-04-13 23:30:46] INFO: Application startup complete.
[2025-04-13 23:30:46] INFO: Uvicorn running on http://127.0.0.1:38565 (Press CTRL+C to quit)
[2025-04-13 23:30:46] INFO: 127.0.0.1:54946 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:30:47] INFO: 127.0.0.1:54954 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:30:47 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[12]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-04-13 23:30:51 TP0] Prefill batch. #new-seq: 1, #new-token: 42, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 23:30:53] INFO: 127.0.0.1:54956 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:30:53] The server is fired up and ready to roll!
[2025-04-13 23:30:54] INFO: 127.0.0.1:36182 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[13]:
terminate_process(server_process)
[2025-04-13 23:30:54] Child process unexpectedly failed with an exit code 9. pid=2805671
[2025-04-13 23:30:54] Child process unexpectedly failed with an exit code 9. pid=2805602
References#
EAGLE process is as follows:
Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).
The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the
speculative_eagle_topk
parameter—to ensure a more coherent connection of context, and are given as input again.EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top
speculative_num_draft_tokens
final nodes as draft tokens.EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.
This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionaly to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.
For guidance how to train your own EAGLE model please see the EAGLE repo.