Speculative Decoding#

SGLang now provides an EAGLE2-based speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.

Note: Currently, Speculative Decoding in SGLang is compatible with radix cache and chunked prefill.

Performance Highlights#

  • Official EAGLE code (SafeAILab/EAGLE): ~200 tokens/s

  • Standard SGLang Decoding: ~156 tokens/s

  • EAGLE Decoding in SGLang: ~297 tokens/s

  • EAGLE Decoding in SGLang (w/ torch.compile): ~316 tokens/s

All benchmarks below were run on a single H100.

EAGLE Decoding#

To enable EAGLE speculative decoding the following parameters are relevant:

  • speculative_draft_model_path: Specifies draft model. This parameter is required.

  • speculative_num_steps: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.

  • speculative_eagle_topk: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.

  • speculative_num_draft_tokens: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.

[1]:
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process

server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
    --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-03-16 09:53:19] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=33680, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=976982133, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=8, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
[2025-03-16 09:53:41 TP0] Init torch distributed begin.
[2025-03-16 09:53:41 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:53:41 TP0] Load weight begin. avail mem=60.57 GB
[2025-03-16 09:53:41 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-16 09:53:42 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.22it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.50s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.40s/it]

[2025-03-16 09:53:45 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=32.29 GB, mem usage=28.27 GB.
[2025-03-16 09:53:45 TP0] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-03-16 09:53:45 TP0] Memory pool end. avail mem=22.10 GB
[2025-03-16 09:53:46 TP0] Init torch distributed begin.
[2025-03-16 09:53:46 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:53:46 TP0] Load weight begin. avail mem=21.53 GB
[2025-03-16 09:53:46 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.31s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.31s/it]

[2025-03-16 09:53:48 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=17.23 GB, mem usage=4.29 GB.
[2025-03-16 09:53:48 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-03-16 09:53:48 TP0] Memory pool end. avail mem=16.92 GB
[2025-03-16 09:53:48 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-03-16 09:53:48] INFO:     Started server process [194665]
[2025-03-16 09:53:48] INFO:     Waiting for application startup.
[2025-03-16 09:53:48] INFO:     Application startup complete.
[2025-03-16 09:53:48] INFO:     Uvicorn running on http://127.0.0.1:33680 (Press CTRL+C to quit)
[2025-03-16 09:53:48] INFO:     127.0.0.1:39888 - "GET /v1/models HTTP/1.1" 200 OK
[2025-03-16 09:53:49] INFO:     127.0.0.1:39892 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-16 09:53:49 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[2]:
import openai

client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-03-16 09:53:56 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-03-16 09:54:00] INFO:     127.0.0.1:39900 - "POST /generate HTTP/1.1" 200 OK
[2025-03-16 09:54:00] The server is fired up and ready to roll!
[2025-03-16 09:54:00] INFO:     127.0.0.1:51882 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='fce77eac125440b8be49c3476b66cd34', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1742118840, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[3]:
terminate_process(server_process)

EAGLE Decoding with torch.compile#

You can also enable torch.compile for further optimizations and optionally set --cuda-graph-max-bs:

[4]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
            --enable-torch-compile --cuda-graph-max-bs 2
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-03-16 09:54:29] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=37153, mem_fraction_static=0.6, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=475128165, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=True, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
[2025-03-16 09:54:49 TP0] Init torch distributed begin.
[2025-03-16 09:54:49 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:54:49 TP0] Load weight begin. avail mem=78.81 GB
[2025-03-16 09:54:49 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-16 09:54:50 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.74it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.33s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.22s/it]

[2025-03-16 09:54:53 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=66.15 GB, mem usage=12.66 GB.
[2025-03-16 09:54:53 TP0] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-03-16 09:54:53 TP0] Memory pool end. avail mem=55.96 GB
[2025-03-16 09:54:53 TP0] Init torch distributed begin.
[2025-03-16 09:54:53 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:54:53 TP0] Load weight begin. avail mem=55.39 GB
[2025-03-16 09:54:53 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.38s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.38s/it]

[2025-03-16 09:54:55 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=54.46 GB, mem usage=0.93 GB.
[2025-03-16 09:54:55 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-03-16 09:54:55 TP0] Memory pool end. avail mem=54.14 GB
[2025-03-16 09:54:55 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=4096
[2025-03-16 09:54:55] INFO:     Started server process [198053]
[2025-03-16 09:54:55] INFO:     Waiting for application startup.
[2025-03-16 09:54:55] INFO:     Application startup complete.
[2025-03-16 09:54:55] INFO:     Uvicorn running on http://127.0.0.1:37153 (Press CTRL+C to quit)
[2025-03-16 09:54:56] INFO:     127.0.0.1:49364 - "GET /v1/models HTTP/1.1" 200 OK
[2025-03-16 09:54:56] INFO:     127.0.0.1:49374 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-16 09:54:56 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[5]:
import openai

client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-03-16 09:55:02 TP0] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-03-16 09:55:04] INFO:     127.0.0.1:49376 - "POST /generate HTTP/1.1" 200 OK
[2025-03-16 09:55:04] The server is fired up and ready to roll!
[2025-03-16 09:55:04] INFO:     127.0.0.1:51350 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='ff9d28bef372415db580bf0767034499', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1742118904, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None))
[6]:
terminate_process(server_process)

EAGLE Decoding via Frequency-Ranked Speculative Sampling#

By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.

In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.

Thanks for the contribution from Weilin Zhao and Zhousx.

[7]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
    --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-03-16 09:55:17] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='float16', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3-8B-Instruct', chat_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=36590, mem_fraction_static=0.7, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=720756809, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-LLaMA3-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map='thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=2, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
[2025-03-16 09:55:23] Casting torch.bfloat16 to torch.float16.
[2025-03-16 09:55:39 TP0] Casting torch.bfloat16 to torch.float16.
[2025-03-16 09:55:39 TP0] Casting torch.bfloat16 to torch.float16.
[2025-03-16 09:55:39 TP0] Init torch distributed begin.
[2025-03-16 09:55:40 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:55:40 TP0] Load weight begin. avail mem=78.81 GB
[2025-03-16 09:55:40 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-16 09:55:40 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:05<00:15,  5.26s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:06<00:05,  2.91s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:11<00:03,  3.65s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:15<00:00,  4.08s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:15<00:00,  3.95s/it]

[2025-03-16 09:55:56 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=63.74 GB, mem usage=15.06 GB.
[2025-03-16 09:55:56 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-03-16 09:55:56 TP0] Memory pool end. avail mem=61.05 GB
/public_sglang_ci/runner-c-gpu-0/_work/sglang/sglang/python/sglang/srt/speculative/eagle_worker.py:568: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  hot_token_id = torch.load(token_map_path)
[2025-03-16 09:55:57 TP0] Warning: User-specified context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-03-16 09:55:57 TP0] Init torch distributed begin.
[2025-03-16 09:55:57 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:55:57 TP0] Load weight begin. avail mem=60.48 GB
[2025-03-16 09:55:57 TP0] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.20s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.21s/it]

[2025-03-16 09:55:59 TP0] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=58.78 GB, mem usage=1.70 GB.
[2025-03-16 09:55:59 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-03-16 09:55:59 TP0] Memory pool end. avail mem=58.70 GB
[2025-03-16 09:55:59 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=8192
[2025-03-16 09:55:59] INFO:     Started server process [201518]
[2025-03-16 09:55:59] INFO:     Waiting for application startup.
[2025-03-16 09:55:59] INFO:     Application startup complete.
[2025-03-16 09:55:59] INFO:     Uvicorn running on http://127.0.0.1:36590 (Press CTRL+C to quit)
[2025-03-16 09:56:00] INFO:     127.0.0.1:38362 - "GET /v1/models HTTP/1.1" 200 OK
[2025-03-16 09:56:00] INFO:     127.0.0.1:38366 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-16 09:56:00 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[8]:
import openai

client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
[2025-03-16 09:56:06 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 1, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-03-16 09:56:08] INFO:     127.0.0.1:38370 - "POST /generate HTTP/1.1" 200 OK
[2025-03-16 09:56:08] The server is fired up and ready to roll!
[2025-03-16 09:56:09] INFO:     127.0.0.1:55178 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response: ChatCompletion(id='4af7085e38cd49bfa9cd2db19b911d71', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. **France** - **Paris**\n2. **Japan** - **Tokyo**\n3. **Australia** - **Canberra**', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1742118969, model='meta-llama/Meta-Llama-3-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=39, prompt_tokens=18, total_tokens=57, completion_tokens_details=None, prompt_tokens_details=None))
[9]:
terminate_process(server_process)

References#

EAGLE process is as follows:

  • Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).

  • The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the speculative_eagle_topk parameter—to ensure a more coherent connection of context, and are given as input again.

  • EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top speculative_num_draft_tokens final nodes as draft tokens.

This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionaly to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see the paper.