Speculative Decoding#

SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.

Performance Highlights#

Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.

Method

Throughput (tokens/s)

SGLang (w/o speculative, 1x H100)

158.34 tokens/s

SGLang + EAGLE-2 (1x H100)

244.10 tokens/s

SGLang + EAGLE-3 (1x H100)

373.25 tokens/s

EAGLE Decoding#

To enable EAGLE speculative decoding the following parameters are relevant:

  • speculative_draft_model_path: Specifies draft model. This parameter is required.

  • speculative_num_steps: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.

  • speculative_eagle_topk: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.

  • speculative_num_draft_tokens: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.

These parameters are the same for EAGLE-2 and EAGLE-3.

You can find the best combinations of these parameters with bench_speculative.py.

In the documentation below, we set --cuda-graph-max-bs to be a small value for faster engine startup. For your own workloads, please tune the above parameters together with --cuda-graph-max-bs, --max-running-requests, --mem-fraction-static for the best performance.

EAGLE-2 decoding#

You can enable EAGLE-2 decoding by setting --speculative-algorithm EAGLE and choosing an appropriate model.

[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

import openai
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
[2]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
    --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:transformers.configuration_utils:`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
WARNING:sglang.srt.server_args:
########################################################################
# For contributors and developers:                                    #
# Please move environment variable definitions to sglang.srt.environ  #
# using the following pattern:                                        #
#     SGLANG_XXX = EnvBool(False)                                     #
#                                                                     #
########################################################################

All deep_gemm operations loaded successfully!
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:23:05] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-28 18:23:08] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.74s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.21s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.29s/it]

Capturing batches (bs=1 avail_mem=38.17 GB): 100%|██████████| 4/4 [00:44<00:00, 11.06s/it]
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.29s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.29s/it]

Capturing batches (bs=1 avail_mem=34.47 GB): 100%|██████████| 4/4 [00:23<00:00,  5.99s/it]
Capturing batches (bs=1 avail_mem=34.37 GB): 100%|██████████| 4/4 [00:00<00:00, 19.50it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[3]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='d58f5cdabe574afea0fcbe177472d2d4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1759083933, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[4]:
terminate_process(server_process)

EAGLE-2 Decoding with torch.compile#

You can also enable torch.compile for further optimizations and optionally set --torch-compile-max-bs:

[5]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
            --enable-torch-compile --torch-compile-max-bs 2 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
WARNING:sglang.srt.server_args:
########################################################################
# For contributors and developers:                                    #
# Please move environment variable definitions to sglang.srt.environ  #
# using the following pattern:                                        #
#     SGLANG_XXX = EnvBool(False)                                     #
#                                                                     #
########################################################################

All deep_gemm operations loaded successfully!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:25:42] `torch_dtype` is deprecated! Use `dtype` instead!
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:25:50] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-28 18:25:51] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.79s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.24s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.32s/it]

Capturing batches (bs=2 avail_mem=38.46 GB):  25%|██▌       | 1/4 [00:44<02:12, 44.18s/it]/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
AUTOTUNE mm(128x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0484 ms 100.0%
  triton_mm_18 0.0498 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_12 0.0530 ms 91.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_8 0.0537 ms 90.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_7 0.0562 ms 86.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_11 0.0564 ms 85.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_17 0.0587 ms 82.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_10 0.0679 ms 71.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_14 0.0686 ms 70.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_4 0.0712 ms 68.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.3830 seconds and 0.7167 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0221 ms 100.0%
  triton_mm_27 0.0231 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_31 0.0268 ms 82.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_23 0.0304 ms 72.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_37 0.0317 ms 69.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_26 0.0399 ms 55.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_30 0.0407 ms 54.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_22 0.0422 ms 52.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_20 0.0425 ms 52.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_36 0.0434 ms 51.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.2954 seconds and 0.4864 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_49 0.0746 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.0761 ms 98.1%
  triton_mm_55 0.0767 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_50 0.0806 ms 92.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_56 0.0840 ms 88.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_46 0.0917 ms 81.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_45 0.0940 ms 79.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_47 0.0974 ms 76.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_48 0.1009 ms 73.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_54 0.1022 ms 73.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4815 seconds and 0.1723 seconds precompiling for 20 choices
AUTOTUNE mm(128x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
  mm 0.0504 ms 100.0%
  triton_mm_65 0.0512 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_69 0.0577 ms 87.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_61 0.0657 ms 76.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_75 0.0722 ms 69.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_64 0.0929 ms 54.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_68 0.0941 ms 53.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_60 0.0992 ms 50.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_74 0.1001 ms 50.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_58 0.1030 ms 49.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4304 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE mm(128x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_93 0.1027 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_94 0.1044 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  mm 0.1045 ms 98.2%
  triton_mm_88 0.1086 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_83 0.1166 ms 88.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_87 0.1177 ms 87.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_84 0.1260 ms 81.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_92 0.1268 ms 81.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_85 0.1349 ms 76.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_89 0.1412 ms 72.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5267 seconds and 0.4115 seconds precompiling for 20 choices
Capturing batches (bs=1 avail_mem=37.75 GB):  75%|███████▌  | 3/4 [01:04<00:19, 19.12s/it]AUTOTUNE mm(64x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_103 0.0429 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  mm 0.0477 ms 90.1%
  triton_mm_111 0.0481 ms 89.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_107 0.0488 ms 88.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_99 0.0509 ms 84.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_102 0.0556 ms 77.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_106 0.0559 ms 76.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_96 0.0569 ms 75.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_98 0.0618 ms 69.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_97 0.0631 ms 68.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.6818 seconds and 0.3625 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0211 ms 100.0%
  triton_mm_116 0.0220 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_120 0.0227 ms 92.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_124 0.0264 ms 79.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_128 0.0298 ms 70.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_119 0.0345 ms 61.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_113 0.0396 ms 53.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_123 0.0404 ms 52.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_115 0.0406 ms 51.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_114 0.0421 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.6025 seconds and 0.3533 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0725 ms 100.0%
  triton_mm_136 0.0748 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_137 0.0748 ms 96.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_140 0.0752 ms 96.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_141 0.0767 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_145 0.0796 ms 91.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_133 0.0836 ms 86.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_139 0.0890 ms 81.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_138 0.0899 ms 80.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_143 0.0915 ms 79.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.8403 seconds and 0.1925 seconds precompiling for 18 choices
AUTOTUNE mm(64x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
  triton_mm_150 0.0470 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_154 0.0480 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  mm 0.0483 ms 97.5%
  triton_mm_158 0.0557 ms 84.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_162 0.0645 ms 73.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_153 0.0817 ms 57.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_157 0.0824 ms 57.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_149 0.0885 ms 53.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_148 0.0907 ms 51.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_147 0.0927 ms 50.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8755 seconds and 0.0003 seconds precompiling for 18 choices
AUTOTUNE mm(64x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_170 0.0941 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_174 0.1002 ms 93.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_175 0.1004 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_179 0.1004 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_171 0.1014 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  mm 0.1034 ms 91.0%
  triton_mm_167 0.1113 ms 84.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_172 0.1193 ms 78.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_176 0.1226 ms 76.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_177 0.1316 ms 71.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9355 seconds and 0.3468 seconds precompiling for 18 choices
Capturing batches (bs=1 avail_mem=37.75 GB): 100%|██████████| 4/4 [01:23<00:00, 20.96s/it]
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.14s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.14s/it]

Capturing batches (bs=1 avail_mem=52.98 GB): 100%|██████████| 4/4 [00:29<00:00,  7.38s/it]
Capturing batches (bs=1 avail_mem=36.85 GB): 100%|██████████| 4/4 [00:00<00:00, 37.27it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[6]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='d53fccbe972c4cf7b8d364a9a2de3c8a', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1759084144, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[7]:
terminate_process(server_process)

EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#

By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.

In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.

Thanks for the contribution from Weilin Zhao and Zhousx.

[8]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
    --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16  --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
WARNING:sglang.srt.server_args:
########################################################################
# For contributors and developers:                                    #
# Please move environment variable definitions to sglang.srt.environ  #
# using the following pattern:                                        #
#     SGLANG_XXX = EnvBool(False)                                     #
#                                                                     #
########################################################################

All deep_gemm operations loaded successfully!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:29:12] `torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:29:12] Casting torch.bfloat16 to torch.float16.
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:29:20] `torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:29:20] Casting torch.bfloat16 to torch.float16.
[2025-09-28 18:29:20] Casting torch.bfloat16 to torch.float16.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-28 18:29:21] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:12,  4.21s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:08,  4.30s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:12<00:04,  4.26s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00,  3.01s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:13<00:00,  3.47s/it]

Capturing batches (bs=1 avail_mem=57.63 GB): 100%|██████████| 4/4 [00:45<00:00, 11.50s/it]
[2025-09-28 18:30:25] Warning: Target model's context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config.
[2025-09-28 18:30:25] Overriding the draft model's max_position_embeddings to 8192.
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.12s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.12s/it]

Capturing batches (bs=1 avail_mem=56.28 GB): 100%|██████████| 4/4 [00:08<00:00,  2.00s/it]
Capturing batches (bs=1 avail_mem=56.14 GB): 100%|██████████| 4/4 [00:00<00:00, 19.90it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[9]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='c45d29a0c5a445bd9862e52a6d52fae4', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. **France** - **Paris**\n2. **Japan** - **Tokyo**\n3. **Australia** - **Canberra**', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1759084305, model='meta-llama/Meta-Llama-3-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=39, prompt_tokens=18, total_tokens=57, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[10]:
terminate_process(server_process)

EAGLE-3 Decoding#

You can enable EAGLE-3 decoding by setting --speculative-algorithm EAGLE3 and choosing an appropriate model.

[11]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct  --speculative-algorithm EAGLE3 \
    --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
        --cuda-graph-max-bs 2 --dtype float16 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
WARNING:sglang.srt.server_args:
########################################################################
# For contributors and developers:                                    #
# Please move environment variable definitions to sglang.srt.environ  #
# using the following pattern:                                        #
#     SGLANG_XXX = EnvBool(False)                                     #
#                                                                     #
########################################################################

All deep_gemm operations loaded successfully!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:31:53] `torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:31:53] Casting torch.bfloat16 to torch.float16.
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:32:01] `torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:32:01] Casting torch.bfloat16 to torch.float16.
[2025-09-28 18:32:02] Casting torch.bfloat16 to torch.float16.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-28 18:32:02] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:13,  4.36s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:09,  4.52s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:13<00:04,  4.43s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.13s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.61s/it]

Capturing batches (bs=1 avail_mem=40.13 GB): 100%|██████████| 4/4 [00:45<00:00, 11.41s/it]
[2025-09-28 18:33:06] Warning: Target model's context_length (131072) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config.
[2025-09-28 18:33:06] Overriding the draft model's max_position_embeddings to 131072.
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.74it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.73it/s]

Capturing batches (bs=1 avail_mem=38.10 GB): 100%|██████████| 4/4 [00:05<00:00,  1.29s/it]
Capturing batches (bs=1 avail_mem=37.95 GB): 100%|██████████| 4/4 [00:00<00:00, 28.44it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[12]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='1db20a6881ea4a93943eda546cc88e75', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n Capital: Tokyo\n\n2. Country: Australia\n Capital: Canberra\n\n3. Country: Brazil\n Capital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1759084465, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=43, total_tokens=86, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[13]:
terminate_process(server_process)

Multi Token Prediction#

We support MTP(Multi-Token Prediction) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to deepseek doc)

[14]:
server_process, port = launch_server_cmd(
    """
    python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \
    --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \
    --mem-fraction 0.5 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
WARNING:sglang.srt.server_args:Overlap scheduler is disabled because of using eagle speculative decoding.
WARNING:sglang.srt.server_args:
########################################################################
# For contributors and developers:                                    #
# Please move environment variable definitions to sglang.srt.environ  #
# using the following pattern:                                        #
#     SGLANG_XXX = EnvBool(False)                                     #
#                                                                     #
########################################################################

All deep_gemm operations loaded successfully!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:34:32] `torch_dtype` is deprecated! Use `dtype` instead!
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-28 18:34:41] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-28 18:34:42] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.15it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.07it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.05it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.12it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.10it/s]

Capturing batches (bs=1 avail_mem=9.09 GB): 100%|██████████| 4/4 [01:02<00:00, 15.70s/it]
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  4.16it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  7.14it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  6.77it/s]

Capturing batches (bs=1 avail_mem=8.31 GB): 100%|██████████| 4/4 [00:02<00:00,  1.85it/s]
Capturing batches (bs=1 avail_mem=8.06 GB): 100%|██████████| 4/4 [00:00<00:00, 50.49it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[15]:
import requests

url = f"http://localhost:{port}/v1/chat/completions"

data = {
    "model": "XiaomiMiMo/MiMo-7B-RL",
    "messages": [{"role": "user", "content": "What is the capital of France?"}],
}

response = requests.post(url, json=data)
print_highlight(response.json())
{'id': '6c1d5b772d4e4ae787de2b29c4928f47', 'object': 'chat.completion', 'created': 1759084625, 'model': 'XiaomiMiMo/MiMo-7B-RL', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': "\nOkay, the user is asking for the capital of France. Let me start by recalling what I know about France's geography. France is a country in Western Europe. I remember that the capital is a major city, probably one of the well-known ones. Paris comes to mind immediately because I've heard of it being the capital. But wait, maybe there's another city? Let me think. Some countries have their capitals different from their largest cities, but in France, isn't Paris both the capital and the largest city? Yes, that's right. Paris is the capital. To be sure, I can cross-verify. The Eiffel Tower is in Paris, which is a major landmark. The French government is headquartered there as well. So, putting it all together, the answer should be Paris. No, there's no confusion here. Other major cities in France include Lyon, Marseille, and Toulouse, but none of those are the capital. So the correct answer is Paris.\n\nThe capital of France is **Paris**. This iconic city is home to major landmarks like the Eiffel Tower, the Louvre Museum, and the Palace of Versailles. As France's largest city and political center, Paris serves as the heart of the nation's culture, economy, and diplomacy. 🇫🇷", 'reasoning_content': None, 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 151645}], 'usage': {'prompt_tokens': 26, 'total_tokens': 298, 'completion_tokens': 272, 'prompt_tokens_details': None, 'reasoning_tokens': 0}, 'metadata': {'weight_version': 'default'}}
[16]:
terminate_process(server_process)

References#

EAGLE process is as follows:

  • Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).

  • The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the speculative_eagle_topk parameter—to ensure a more coherent connection of context, and are given as input again.

  • EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top speculative_num_draft_tokens final nodes as draft tokens.

  • EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.

This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.

For guidance how to train your own EAGLE model please see the EAGLE repo.