Speculative Decoding#
SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.
Performance Highlights#
Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.
Method |
Throughput (tokens/s) |
---|---|
SGLang (w/o speculative, 1x H100) |
158.34 tokens/s |
SGLang + EAGLE-2 (1x H100) |
244.10 tokens/s |
SGLang + EAGLE-3 (1x H100) |
373.25 tokens/s |
EAGLE Decoding#
To enable EAGLE speculative decoding the following parameters are relevant:
speculative_draft_model_path
: Specifies draft model. This parameter is required.speculative_num_steps
: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.speculative_eagle_topk
: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.speculative_num_draft_tokens
: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.
These parameters are the same for EAGLE-2 and EAGLE-3.
You can find the best combinations of these parameters with bench_speculative.py.
In the documentation below, we set --cuda-graph-max-bs
to be a small value for faster engine startup. For your own workloads, please tune the above parameters together with --cuda-graph-max-bs
, --max-running-requests
, --mem-fraction-static
for the best performance.
EAGLE-2 decoding#
You can enable EAGLE-2 decoding by setting --speculative_algorithm EAGLE
and choosing an appropriate model.
[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
import openai
[2]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
--speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:30:20.176000 1234945 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:30:20.176000 1234945 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-08-14 06:30:21] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=38842, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.849, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=919763188, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:30:22] Using default HuggingFace chat template with detected content format: string
W0814 06:30:28.999000 1235160 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:30:28.999000 1235160 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:30:29.000000 1235161 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:30:29.000000 1235161 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:30:30] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 06:30:30] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:30:30] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:30:31] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:30:31] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:30:32] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.70s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.22s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.29s/it]
[2025-08-14 06:30:34] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=65.93 GB, mem usage=12.66 GB.
[2025-08-14 06:30:34] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-08-14 06:30:34] Memory pool end. avail mem=55.73 GB
[2025-08-14 06:30:35] Capture cuda graph begin. This can take up to several minutes. avail mem=55.16 GB
[2025-08-14 06:30:35] Capture cuda graph bs [1, 2, 3, 4]
Capturing batches (bs=1 avail_mem=54.86 GB): 100%|██████████| 4/4 [00:00<00:00, 5.57it/s]
[2025-08-14 06:30:36] Capture cuda graph end. Time elapsed: 1.27 s. mem usage=0.37 GB. avail mem=54.79 GB.
[2025-08-14 06:30:36] Init torch distributed begin.
[2025-08-14 06:30:36] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:30:36] Load weight begin. avail mem=54.79 GB
[2025-08-14 06:30:36] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.24s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.24s/it]
[2025-08-14 06:30:38] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=53.86 GB, mem usage=0.93 GB.
[2025-08-14 06:30:38] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-08-14 06:30:38] Memory pool end. avail mem=53.55 GB
[2025-08-14 06:30:38] Capture draft cuda graph begin. This can take up to several minutes. avail mem=53.77 GB
Capturing batches (bs=1 avail_mem=53.65 GB): 100%|██████████| 4/4 [00:04<00:00, 1.17s/it]
[2025-08-14 06:30:44] Capture draft cuda graph end. Time elapsed: 5.50 s. mem usage=0.14 GB. avail mem=53.64 GB.
[2025-08-14 06:30:44] Capture draft extend cuda graph begin. This can take up to several minutes. avail mem=53.64 GB
Capturing batches (bs=1 avail_mem=53.56 GB): 100%|██████████| 4/4 [00:00<00:00, 118.61it/s]
[2025-08-14 06:30:44] Capture draft extend cuda graph end. Time elapsed: 0.58 s. mem usage=0.08 GB. avail mem=53.56 GB.
[2025-08-14 06:30:44] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=4096, available_gpu_mem=53.56 GB
[2025-08-14 06:30:44] INFO: Started server process [1234945]
[2025-08-14 06:30:44] INFO: Waiting for application startup.
[2025-08-14 06:30:44] INFO: Application startup complete.
[2025-08-14 06:30:44] INFO: Uvicorn running on http://127.0.0.1:38842 (Press CTRL+C to quit)
[2025-08-14 06:30:45] INFO: 127.0.0.1:52410 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:30:45] INFO: 127.0.0.1:52426 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:30:45] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:30:45] INFO: 127.0.0.1:52432 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:30:45] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[3]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-08-14 06:30:50] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:30:50] INFO: 127.0.0.1:52436 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[4]:
terminate_process(server_process)
EAGLE-2 Decoding with torch.compile
#
You can also enable torch.compile
for further optimizations and optionally set --torch-compile-max-bs
:
[5]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
--enable-torch-compile --torch-compile-max-bs 2
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:30:56.711000 1235966 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:30:56.711000 1235966 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-08-14 06:30:58] server_args=ServerArgs(model_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_path='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=39225, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.6, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=474346026, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='meta-llama/Llama-2-7b-chat-hf', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-llama2-chat-7B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=True, torch_compile_max_bs=2, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:30:58] Using default HuggingFace chat template with detected content format: string
W0814 06:31:05.258000 1236181 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:31:05.258000 1236181 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:31:05.265000 1236182 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:31:05.265000 1236182 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:31:06] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 06:31:06] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:31:06] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:31:07] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:31:07] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:31:07] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.65s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.15s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.22s/it]
[2025-08-14 06:31:10] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=65.93 GB, mem usage=12.66 GB.
[2025-08-14 06:31:10] KV Cache is allocated. #tokens: 20480, K size: 5.00 GB, V size: 5.00 GB
[2025-08-14 06:31:10] Memory pool end. avail mem=55.73 GB
[2025-08-14 06:31:10] Capture cuda graph begin. This can take up to several minutes. avail mem=55.16 GB
[2025-08-14 06:31:10] Capture cuda graph bs [1, 2, 3, 4]
Capturing batches (bs=2 avail_mem=54.89 GB): 25%|██▌ | 1/4 [00:00<00:01, 2.11it/s]/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
torch._dynamo.utils.warn_once(msg)
AUTOTUNE mm(128x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
mm 0.0484 ms 100.0%
triton_mm_18 0.0495 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_12 0.0525 ms 92.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_8 0.0537 ms 90.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_7 0.0562 ms 86.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_11 0.0563 ms 86.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_17 0.0580 ms 83.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_10 0.0675 ms 71.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
triton_mm_14 0.0682 ms 71.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
triton_mm_4 0.0713 ms 68.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.3443 seconds and 0.6818 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
mm 0.0222 ms 100.0%
triton_mm_27 0.0232 ms 95.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_31 0.0270 ms 82.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_23 0.0304 ms 73.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_37 0.0321 ms 69.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_26 0.0403 ms 55.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_30 0.0409 ms 54.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_22 0.0420 ms 52.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_20 0.0424 ms 52.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
triton_mm_36 0.0431 ms 51.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.2822 seconds and 0.4825 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
mm 0.0759 ms 100.0%
triton_mm_49 0.0761 ms 99.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_55 0.0778 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_50 0.0804 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_56 0.0837 ms 90.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_46 0.0925 ms 82.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_45 0.0935 ms 81.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_47 0.0967 ms 78.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_48 0.1007 ms 75.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
triton_mm_54 0.1011 ms 75.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4237 seconds and 0.2079 seconds precompiling for 20 choices
AUTOTUNE mm(128x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
mm 0.0499 ms 100.0%
triton_mm_65 0.0517 ms 96.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_69 0.0577 ms 86.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_61 0.0650 ms 76.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_75 0.0719 ms 69.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_64 0.0929 ms 53.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_68 0.0939 ms 53.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_60 0.0970 ms 51.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_74 0.0997 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_58 0.1030 ms 48.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4279 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
triton_mm_93 0.1026 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_94 0.1045 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
mm 0.1054 ms 97.3%
triton_mm_88 0.1083 ms 94.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_83 0.1170 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_87 0.1171 ms 87.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_84 0.1259 ms 81.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_92 0.1278 ms 80.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_85 0.1365 ms 75.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_89 0.1458 ms 70.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5027 seconds and 0.3809 seconds precompiling for 20 choices
Capturing batches (bs=1 avail_mem=54.80 GB): 75%|███████▌ | 3/4 [00:20<00:07, 7.42s/it]AUTOTUNE mm(64x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
mm 0.0477 ms 100.0%
triton_mm_107 0.0480 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_111 0.0486 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_103 0.0493 ms 96.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_99 0.0511 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_102 0.0556 ms 85.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_106 0.0559 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_96 0.0570 ms 83.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
triton_mm_98 0.0612 ms 78.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_97 0.0639 ms 74.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2838 seconds and 0.3524 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
mm 0.0220 ms 100.0%
triton_mm_116 0.0226 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_120 0.0232 ms 95.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_124 0.0267 ms 82.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_128 0.0307 ms 71.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_113 0.0400 ms 55.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
triton_mm_119 0.0401 ms 54.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_123 0.0408 ms 53.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_115 0.0411 ms 53.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_114 0.0426 ms 51.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2538 seconds and 0.3580 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
triton_mm_136 0.0752 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_137 0.0752 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
mm 0.0755 ms 99.7%
triton_mm_140 0.0756 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_141 0.0767 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_145 0.0797 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_133 0.0834 ms 90.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_139 0.0918 ms 81.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
triton_mm_130 0.0958 ms 78.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
triton_mm_143 0.0959 ms 78.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.3439 seconds and 0.1900 seconds precompiling for 18 choices
AUTOTUNE mm(64x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
triton_mm_150 0.0472 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_154 0.0485 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
mm 0.0488 ms 96.7%
triton_mm_158 0.0554 ms 85.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_162 0.0646 ms 73.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_149 0.0880 ms 53.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_153 0.0897 ms 52.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_157 0.0915 ms 51.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_148 0.0923 ms 51.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_147 0.0940 ms 50.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.3809 seconds and 0.0002 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
triton_mm_170 0.0998 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
triton_mm_174 0.0998 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_179 0.1004 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
triton_mm_175 0.1007 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
triton_mm_171 0.1016 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
mm 0.1029 ms 97.0%
triton_mm_167 0.1115 ms 89.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
triton_mm_172 0.1188 ms 84.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_176 0.1236 ms 80.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
triton_mm_173 0.1308 ms 76.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4644 seconds and 0.3392 seconds precompiling for 18 choices
Capturing batches (bs=1 avail_mem=54.80 GB): 100%|██████████| 4/4 [00:36<00:00, 9.14s/it]
[2025-08-14 06:31:47] Capture cuda graph end. Time elapsed: 37.05 s. mem usage=0.43 GB. avail mem=54.73 GB.
[2025-08-14 06:31:48] Init torch distributed begin.
[2025-08-14 06:31:48] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:31:48] Load weight begin. avail mem=54.73 GB
[2025-08-14 06:31:48] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.17s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.17s/it]
[2025-08-14 06:31:49] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=53.80 GB, mem usage=0.93 GB.
[2025-08-14 06:31:49] KV Cache is allocated. #tokens: 20480, K size: 0.16 GB, V size: 0.16 GB
[2025-08-14 06:31:49] Memory pool end. avail mem=53.49 GB
[2025-08-14 06:31:49] Capture draft cuda graph begin. This can take up to several minutes. avail mem=53.68 GB
Capturing batches (bs=1 avail_mem=53.49 GB): 100%|██████████| 4/4 [00:14<00:00, 3.52s/it]
[2025-08-14 06:32:04] Capture draft cuda graph end. Time elapsed: 14.67 s. mem usage=0.23 GB. avail mem=53.45 GB.
[2025-08-14 06:32:04] Capture draft extend cuda graph begin. This can take up to several minutes. avail mem=53.45 GB
Capturing batches (bs=1 avail_mem=53.35 GB): 100%|██████████| 4/4 [00:00<00:00, 33.08it/s]
[2025-08-14 06:32:05] Capture draft extend cuda graph end. Time elapsed: 0.88 s. mem usage=0.10 GB. avail mem=53.35 GB.
[2025-08-14 06:32:05] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=4096, available_gpu_mem=53.35 GB
[2025-08-14 06:32:05] INFO: Started server process [1235966]
[2025-08-14 06:32:05] INFO: Waiting for application startup.
[2025-08-14 06:32:05] INFO: Application startup complete.
[2025-08-14 06:32:05] INFO: Uvicorn running on http://127.0.0.1:39225 (Press CTRL+C to quit)
[2025-08-14 06:32:05] INFO: 127.0.0.1:60042 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:32:06] INFO: 127.0.0.1:60044 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:32:06] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:32:08] INFO: 127.0.0.1:60052 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:32:08] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[6]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-08-14 06:32:10] Prefill batch. #new-seq: 1, #new-token: 16, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:32:10] INFO: 127.0.0.1:60060 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[7]:
terminate_process(server_process)
EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#
By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head
computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.
In our implementation, set --speculative-token-map
to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.
Thanks for the contribution from Weilin Zhao and Zhousx.
[8]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
--mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:32:17.165000 1241394 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:32:17.165000 1241394 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-08-14 06:32:18] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=34408, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='float16', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.7, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=326952710, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='meta-llama/Meta-Llama-3-8B-Instruct', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm='EAGLE', speculative_draft_model_path='lmsys/sglang-EAGLE-LLaMA3-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=64, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map='thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt', ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:32:18] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:32:19] Using default HuggingFace chat template with detected content format: string
W0814 06:32:25.476000 1241611 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:32:25.476000 1241611 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:32:25.476000 1241610 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:32:25.476000 1241610 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:32:28] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:32:28] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:32:28] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 06:32:28] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:32:28] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:32:29] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:32:29] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:32:30] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:04<00:12, 4.07s/it]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:08<00:08, 4.17s/it]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:12<00:04, 4.21s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00, 2.99s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00, 3.42s/it]
[2025-08-14 06:32:43] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=63.52 GB, mem usage=15.06 GB.
[2025-08-14 06:32:43] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-08-14 06:32:43] Memory pool end. avail mem=60.83 GB
[2025-08-14 06:32:44] Capture cuda graph begin. This can take up to several minutes. avail mem=60.25 GB
[2025-08-14 06:32:44] Capture cuda graph bs [1, 2, 3, 4]
Capturing batches (bs=1 avail_mem=59.74 GB): 100%|██████████| 4/4 [00:01<00:00, 3.58it/s]
[2025-08-14 06:32:45] Capture cuda graph end. Time elapsed: 1.67 s. mem usage=0.59 GB. avail mem=59.67 GB.
[2025-08-14 06:32:46] Warning: User-specified context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-08-14 06:32:46] Init torch distributed begin.
[2025-08-14 06:32:46] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:32:46] Load weight begin. avail mem=59.67 GB
[2025-08-14 06:32:46] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.04s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00, 1.04s/it]
[2025-08-14 06:32:47] Load weight end. type=LlamaForCausalLMEagle, dtype=torch.float16, avail mem=57.97 GB, mem usage=1.70 GB.
[2025-08-14 06:32:47] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-08-14 06:32:47] Memory pool end. avail mem=57.89 GB
[2025-08-14 06:32:48] Capture draft cuda graph begin. This can take up to several minutes. avail mem=58.56 GB
Capturing batches (bs=1 avail_mem=58.39 GB): 100%|██████████| 4/4 [00:06<00:00, 1.53s/it]
[2025-08-14 06:32:54] Capture draft cuda graph end. Time elapsed: 6.60 s. mem usage=0.22 GB. avail mem=58.35 GB.
[2025-08-14 06:32:54] Capture draft extend cuda graph begin. This can take up to several minutes. avail mem=58.35 GB
Capturing batches (bs=1 avail_mem=58.25 GB): 100%|██████████| 4/4 [00:00<00:00, 93.90it/s]
[2025-08-14 06:32:55] Capture draft extend cuda graph end. Time elapsed: 0.52 s. mem usage=0.10 GB. avail mem=58.25 GB.
[2025-08-14 06:32:55] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=8192, available_gpu_mem=58.25 GB
[2025-08-14 06:32:55] INFO: Started server process [1241394]
[2025-08-14 06:32:55] INFO: Waiting for application startup.
[2025-08-14 06:32:55] INFO: Application startup complete.
[2025-08-14 06:32:55] INFO: Uvicorn running on http://127.0.0.1:34408 (Press CTRL+C to quit)
[2025-08-14 06:32:56] INFO: 127.0.0.1:58092 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:32:56] INFO: 127.0.0.1:58098 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:32:56] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:32:56] INFO: 127.0.0.1:58104 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:32:56] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[9]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-08-14 06:33:01] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:33:01] INFO: 127.0.0.1:58116 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[10]:
terminate_process(server_process)
EAGLE-3 Decoding#
You can enable EAGLE-3 decoding by setting --speculative_algorithm EAGLE3
and choosing an appropriate model.
[11]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algorithm EAGLE3 \
--speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
--speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
--cuda-graph-max-bs 2 --dtype float16
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:33:07.618000 1242456 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:33:07.618000 1242456 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-08-14 06:33:08] server_args=ServerArgs(model_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=37782, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='float16', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.6, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=366771840, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='meta-llama/Llama-3.1-8B-Instruct', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm='EAGLE3', speculative_draft_model_path='jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B', speculative_num_steps=5, speculative_eagle_topk=8, speculative_num_draft_tokens=32, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:33:09] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:33:09] Using default HuggingFace chat template with detected content format: string
W0814 06:33:15.782000 1242673 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:33:15.782000 1242673 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:33:15.873000 1242672 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:33:15.873000 1242672 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:33:16] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:33:17] Casting torch.bfloat16 to torch.float16.
[2025-08-14 06:33:17] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 06:33:17] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:33:18] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:33:18] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:33:19] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:33:19] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:04<00:12, 4.09s/it]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:08<00:08, 4.19s/it]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:12<00:04, 4.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00, 2.96s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00, 3.40s/it]
[2025-08-14 06:33:33] Load weight end. type=LlamaForCausalLM, dtype=torch.float16, avail mem=63.48 GB, mem usage=15.11 GB.
[2025-08-14 06:33:33] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-08-14 06:33:33] Memory pool end. avail mem=60.72 GB
[2025-08-14 06:33:33] Capture cuda graph begin. This can take up to several minutes. avail mem=60.15 GB
[2025-08-14 06:33:33] Capture cuda graph bs [1, 2, 3, 4]
Capturing batches (bs=1 avail_mem=59.67 GB): 100%|██████████| 4/4 [00:00<00:00, 6.14it/s]
[2025-08-14 06:33:34] Capture cuda graph end. Time elapsed: 1.23 s. mem usage=0.54 GB. avail mem=59.60 GB.
[2025-08-14 06:33:35] Warning: User-specified context_length (131072) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors.
[2025-08-14 06:33:35] Init torch distributed begin.
[2025-08-14 06:33:35] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:33:35] Load weight begin. avail mem=59.60 GB
[2025-08-14 06:33:35] Using model weights format ['*.bin']
Loading pt checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.02it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.02it/s]
[2025-08-14 06:33:35] Load weight end. type=LlamaForCausalLMEagle3, dtype=torch.float16, avail mem=57.46 GB, mem usage=2.15 GB.
[2025-08-14 06:33:35] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-08-14 06:33:35] Memory pool end. avail mem=57.38 GB
[2025-08-14 06:33:36] Capture draft cuda graph begin. This can take up to several minutes. avail mem=58.05 GB
Capturing batches (bs=1 avail_mem=58.24 GB): 100%|██████████| 4/4 [00:04<00:00, 1.17s/it]
[2025-08-14 06:33:41] Capture draft cuda graph end. Time elapsed: 5.18 s. mem usage=-0.14 GB. avail mem=58.20 GB.
[2025-08-14 06:33:41] Capture draft extend cuda graph begin. This can take up to several minutes. avail mem=58.20 GB
Capturing batches (bs=1 avail_mem=58.09 GB): 100%|██████████| 4/4 [00:00<00:00, 90.79it/s]
[2025-08-14 06:33:41] Capture draft extend cuda graph end. Time elapsed: 0.57 s. mem usage=0.11 GB. avail mem=58.09 GB.
[2025-08-14 06:33:41] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=131072, available_gpu_mem=58.09 GB
[2025-08-14 06:33:42] INFO: Started server process [1242456]
[2025-08-14 06:33:42] INFO: Waiting for application startup.
[2025-08-14 06:33:42] INFO: Application startup complete.
[2025-08-14 06:33:42] INFO: Uvicorn running on http://127.0.0.1:37782 (Press CTRL+C to quit)
[2025-08-14 06:33:43] INFO: 127.0.0.1:53794 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:33:43] INFO: 127.0.0.1:53806 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:33:43] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:33:43] INFO: 127.0.0.1:53814 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:33:43] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[12]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-08-14 06:33:48] Prefill batch. #new-seq: 1, #new-token: 42, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:33:48] INFO: 127.0.0.1:53818 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[13]:
terminate_process(server_process)
Multi Token Prediction#
We support MTP(Multi-Token Prediction) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to deepseek doc)
[14]:
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \
--speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \
--mem-fraction 0.5
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:33:55.259000 1243477 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:33:55.259000 1243477 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
[2025-08-14 06:33:56] server_args=ServerArgs(model_path='XiaomiMiMo/MiMo-7B-RL', tokenizer_path='XiaomiMiMo/MiMo-7B-RL', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=True, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=31404, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.5, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=793073465, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='XiaomiMiMo/MiMo-7B-RL', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm='EAGLE', speculative_draft_model_path=None, speculative_num_steps=1, speculative_eagle_topk=1, speculative_num_draft_tokens=2, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=True, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:33:57] Using default HuggingFace chat template with detected content format: string
W0814 06:34:05.325000 1243693 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:34:05.325000 1243693 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:34:06.009000 1243694 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:34:06.009000 1243694 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:34:06] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 06:34:06] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:34:10] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:34:11] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:34:11] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:34:11] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.55it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.33it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.35it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.34it/s]
[2025-08-14 06:34:14] Load weight end. type=MiMoForCausalLM, dtype=torch.bfloat16, avail mem=64.28 GB, mem usage=14.30 GB.
[2025-08-14 06:34:15] KV Cache is allocated. #tokens: 20480, K size: 1.41 GB, V size: 1.41 GB
[2025-08-14 06:34:15] Memory pool end. avail mem=61.24 GB
[2025-08-14 06:34:15] Capture cuda graph begin. This can take up to several minutes. avail mem=60.55 GB
[2025-08-14 06:34:16] Capture cuda graph bs [1, 2, 3, 4]
Capturing batches (bs=1 avail_mem=60.22 GB): 100%|██████████| 4/4 [00:00<00:00, 5.27it/s]
[2025-08-14 06:34:17] Capture cuda graph end. Time elapsed: 1.96 s. mem usage=0.39 GB. avail mem=60.15 GB.
[2025-08-14 06:34:17] Init torch distributed begin.
[2025-08-14 06:34:17] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:34:17] Load weight begin. avail mem=60.15 GB
[2025-08-14 06:34:17] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:00, 5.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00, 9.48it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00, 9.00it/s]
[2025-08-14 06:34:18] Load weight end. type=MiMoMTP, dtype=torch.bfloat16, avail mem=57.44 GB, mem usage=2.71 GB.
[2025-08-14 06:34:18] KV Cache is allocated. #tokens: 20480, K size: 0.04 GB, V size: 0.04 GB
[2025-08-14 06:34:18] Memory pool end. avail mem=57.36 GB
[2025-08-14 06:34:18] Capture draft cuda graph begin. This can take up to several minutes. avail mem=59.45 GB
Capturing batches (bs=1 avail_mem=59.45 GB): 100%|██████████| 4/4 [00:01<00:00, 2.06it/s]
[2025-08-14 06:34:21] Capture draft cuda graph end. Time elapsed: 2.50 s. mem usage=0.01 GB. avail mem=59.45 GB.
[2025-08-14 06:34:21] Capture draft extend cuda graph begin. This can take up to several minutes. avail mem=59.45 GB
Capturing batches (bs=1 avail_mem=59.35 GB): 100%|██████████| 4/4 [00:00<00:00, 61.35it/s]
[2025-08-14 06:34:21] Capture draft extend cuda graph end. Time elapsed: 0.52 s. mem usage=0.09 GB. avail mem=59.35 GB.
[2025-08-14 06:34:21] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=32768, available_gpu_mem=59.35 GB
[2025-08-14 06:34:22] INFO: Started server process [1243477]
[2025-08-14 06:34:22] INFO: Waiting for application startup.
[2025-08-14 06:34:22] INFO: Application startup complete.
[2025-08-14 06:34:22] INFO: Uvicorn running on http://0.0.0.0:31404 (Press CTRL+C to quit)
[2025-08-14 06:34:22] INFO: 127.0.0.1:33788 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:34:23] INFO: 127.0.0.1:33794 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:34:23] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:34:23] INFO: 127.0.0.1:33802 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:34:23] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[15]:
import requests
url = f"http://localhost:{port}/v1/chat/completions"
data = {
"model": "XiaomiMiMo/MiMo-7B-RL",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print_highlight(response.json())
[2025-08-14 06:34:27] Prefill batch. #new-seq: 1, #new-token: 26, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:34:28] Decode batch. #running-req: 1, #token: 91, token usage: 0.00, accept len: 1.76, cuda graph: True, gen throughput (token/s): 11.00, #queue-req: 0,
[2025-08-14 06:34:28] Decode batch. #running-req: 1, #token: 161, token usage: 0.01, accept len: 1.75, cuda graph: True, gen throughput (token/s): 155.79, #queue-req: 0,
[2025-08-14 06:34:29] Decode batch. #running-req: 1, #token: 225, token usage: 0.01, accept len: 1.60, cuda graph: True, gen throughput (token/s): 145.97, #queue-req: 0,
[2025-08-14 06:34:29] INFO: 127.0.0.1:41566 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[16]:
terminate_process(server_process)
References#
EAGLE process is as follows:
Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).
The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the
speculative_eagle_topk
parameter—to ensure a more coherent connection of context, and are given as input again.EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top
speculative_num_draft_tokens
final nodes as draft tokens.EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.
This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.
For guidance how to train your own EAGLE model please see the EAGLE repo.