SGLang Frontend Language#
SGLang frontend language can be used to define simple and easy prompts in a convenient, structured way.
Launch A Server#
Launch the server in your terminal and wait for it to initialize.
[1]:
import requests
import os
from sglang import assistant_begin, assistant_end
from sglang import assistant, function, gen, system, user
from sglang import image
from sglang import RuntimeEndpoint, set_default_backend
from sglang.srt.utils import load_image
from sglang.test.test_utils import is_in_ci
from sglang.utils import print_highlight, terminate_process, wait_for_server
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
print(f"Server started on http://localhost:{port}")
[2025-04-13 03:11:52] server_args=ServerArgs(model_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=39093, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=296844672, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disable_fast_image_processor=False)
[2025-04-13 03:12:01 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 03:12:01 TP0] Init torch distributed begin.
[2025-04-13 03:12:02 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 03:12:02 TP0] Load weight begin. avail mem=41.41 GB
[2025-04-13 03:12:02 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 03:12:03 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.70it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.65it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:01<00:00, 1.59it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.65it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.64it/s]
[2025-04-13 03:12:06 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=46.91 GB, mem usage=-5.51 GB.
[2025-04-13 03:12:06 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-04-13 03:12:06 TP0] Memory pool end. avail mem=45.62 GB
[2025-04-13 03:12:06 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 03:12:06 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=32768
[2025-04-13 03:12:07] INFO: Started server process [2657496]
[2025-04-13 03:12:07] INFO: Waiting for application startup.
[2025-04-13 03:12:07] INFO: Application startup complete.
[2025-04-13 03:12:07] INFO: Uvicorn running on http://0.0.0.0:39093 (Press CTRL+C to quit)
[2025-04-13 03:12:07] INFO: 127.0.0.1:50274 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 03:12:08] INFO: 127.0.0.1:50286 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 03:12:08 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Server started on http://localhost:39093
Set the default backend. Note: Besides the local server, you may use also OpenAI
or other API endpoints.
[2]:
set_default_backend(RuntimeEndpoint(f"http://localhost:{port}"))
[2025-04-13 03:12:12] INFO: 127.0.0.1:53726 - "GET /get_model_info HTTP/1.1" 200 OK
Basic Usage#
The most simple way of using SGLang frontend language is a simple question answer dialog between a user and an assistant.
[3]:
@function
def basic_qa(s, question):
s += system(f"You are a helpful assistant than can answer questions.")
s += user(question)
s += assistant(gen("answer", max_tokens=512))
[4]:
state = basic_qa("List 3 countries and their capitals.")
print_highlight(state["answer"])
[2025-04-13 03:12:13 TP0] Prefill batch. #new-seq: 1, #new-token: 31, #cached-token: 0, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 03:12:14] INFO: 127.0.0.1:50298 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:14] The server is fired up and ready to roll!
[2025-04-13 03:12:14] INFO: 127.0.0.1:53742 - "POST /generate HTTP/1.1" 200 OK
1. France - Paris
2. Japan - Tokyo
3. Brazil - Brasília
Multi-turn Dialog#
SGLang frontend language can also be used to define multi-turn dialogs.
[5]:
@function
def multi_turn_qa(s):
s += system(f"You are a helpful assistant than can answer questions.")
s += user("Please give me a list of 3 countries and their capitals.")
s += assistant(gen("first_answer", max_tokens=512))
s += user("Please give me another list of 3 countries and their capitals.")
s += assistant(gen("second_answer", max_tokens=512))
return s
state = multi_turn_qa()
print_highlight(state["first_answer"])
print_highlight(state["second_answer"])
[2025-04-13 03:12:14 TP0] Prefill batch. #new-seq: 1, #new-token: 18, #cached-token: 18, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:14 TP0] Decode batch. #running-req: 1, #token: 46, token usage: 0.00, gen throughput (token/s): 5.98, #queue-req: 0,
[2025-04-13 03:12:15] INFO: 127.0.0.1:53750 - "POST /generate HTTP/1.1" 200 OK
1. **France** - Paris
2. **Canada** - Ottawa
3. **Japan** - Tokyo
[2025-04-13 03:12:15 TP0] Prefill batch. #new-seq: 1, #new-token: 23, #cached-token: 72, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:15 TP0] Decode batch. #running-req: 1, #token: 108, token usage: 0.01, gen throughput (token/s): 67.35, #queue-req: 0,
[2025-04-13 03:12:15] INFO: 127.0.0.1:53754 - "POST /generate HTTP/1.1" 200 OK
1. **Italy** - Rome
2. **Australia** - Canberra
3. **Brazil** - Brasília
Control flow#
You may use any Python code within the function to define more complex control flows.
[6]:
@function
def tool_use(s, question):
s += assistant(
"To answer this question: "
+ question
+ ". I need to use a "
+ gen("tool", choices=["calculator", "search engine"])
+ ". "
)
if s["tool"] == "calculator":
s += assistant("The math expression is: " + gen("expression"))
elif s["tool"] == "search engine":
s += assistant("The key word to search is: " + gen("word"))
state = tool_use("What is 2 * 2?")
print_highlight(state["tool"])
print_highlight(state["expression"])
[2025-04-13 03:12:15 TP0] Prefill batch. #new-seq: 1, #new-token: 25, #cached-token: 8, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:15] INFO: 127.0.0.1:53764 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:15 TP0] Prefill batch. #new-seq: 1, #new-token: 2, #cached-token: 31, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:15 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 31, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 03:12:15] INFO: 127.0.0.1:53774 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:15 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 33, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:15 TP0] Decode batch. #running-req: 1, #token: 58, token usage: 0.00, gen throughput (token/s): 79.89, #queue-req: 0,
[2025-04-13 03:12:16 TP0] Decode batch. #running-req: 1, #token: 98, token usage: 0.00, gen throughput (token/s): 64.11, #queue-req: 0,
[2025-04-13 03:12:17 TP0] Decode batch. #running-req: 1, #token: 138, token usage: 0.01, gen throughput (token/s): 63.50, #queue-req: 0,
[2025-04-13 03:12:17] INFO: 127.0.0.1:53788 - "POST /generate HTTP/1.1" 200 OK
You don't necessarily need a calculator for this, as you can easily compute this in your head or on paper. But if we were to use a calculator, you would do the following:
1. Enter the number 2.
2. Press the multiplication (*) button.
3. Enter the number 2.
4. Press the equals (=) button.
When you do this, the calculator will display the result: 4.
So, 2 * 2 = 4.
Parallelism#
Use fork
to launch parallel prompts. Because sgl.gen
is non-blocking, the for loop below issues two generation calls in parallel.
[7]:
@function
def tip_suggestion(s):
s += assistant(
"Here are two tips for staying healthy: "
"1. Balanced Diet. 2. Regular Exercise.\n\n"
)
forks = s.fork(2)
for i, f in enumerate(forks):
f += assistant(
f"Now, expand tip {i+1} into a paragraph:\n"
+ gen("detailed_tip", max_tokens=256, stop="\n\n")
)
s += assistant("Tip 1:" + forks[0]["detailed_tip"] + "\n")
s += assistant("Tip 2:" + forks[1]["detailed_tip"] + "\n")
s += assistant(
"To summarize the above two tips, I can say:\n" + gen("summary", max_tokens=512)
)
state = tip_suggestion()
print_highlight(state["summary"])
[2025-04-13 03:12:17 TP0] Prefill batch. #new-seq: 1, #new-token: 35, #cached-token: 14, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:17 TP0] Prefill batch. #new-seq: 1, #new-token: 35, #cached-token: 14, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 03:12:17 TP0] Decode batch. #running-req: 2, #token: 110, token usage: 0.01, gen throughput (token/s): 97.52, #queue-req: 0,
[2025-04-13 03:12:17] INFO: 127.0.0.1:53800 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:18 TP0] Decode batch. #running-req: 1, #token: 117, token usage: 0.01, gen throughput (token/s): 80.53, #queue-req: 0,
[2025-04-13 03:12:19 TP0] Decode batch. #running-req: 1, #token: 157, token usage: 0.01, gen throughput (token/s): 62.42, #queue-req: 0,
[2025-04-13 03:12:19 TP0] Decode batch. #running-req: 1, #token: 197, token usage: 0.01, gen throughput (token/s): 62.49, #queue-req: 0,
[2025-04-13 03:12:20 TP0] Decode batch. #running-req: 1, #token: 237, token usage: 0.01, gen throughput (token/s): 62.40, #queue-req: 0,
[2025-04-13 03:12:20] INFO: 127.0.0.1:53796 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:20 TP0] Prefill batch. #new-seq: 1, #new-token: 264, #cached-token: 39, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:20 TP0] Decode batch. #running-req: 1, #token: 333, token usage: 0.02, gen throughput (token/s): 58.67, #queue-req: 0,
[2025-04-13 03:12:21 TP0] Decode batch. #running-req: 1, #token: 373, token usage: 0.02, gen throughput (token/s): 61.36, #queue-req: 0,
[2025-04-13 03:12:22 TP0] Decode batch. #running-req: 1, #token: 413, token usage: 0.02, gen throughput (token/s): 61.74, #queue-req: 0,
[2025-04-13 03:12:22 TP0] Decode batch. #running-req: 1, #token: 453, token usage: 0.02, gen throughput (token/s): 60.90, #queue-req: 0,
[2025-04-13 03:12:23 TP0] Decode batch. #running-req: 1, #token: 493, token usage: 0.02, gen throughput (token/s): 62.10, #queue-req: 0,
[2025-04-13 03:12:24] INFO: 127.0.0.1:33892 - "POST /generate HTTP/1.1" 200 OK
2. **Regular Exercise**: Engaging in regular exercise is crucial for maintaining good health and well-being. Exercise helps to strengthen your muscles and bones, improve cardiovascular health, enhance mood, and manage weight. Aim for at least 150 minutes of moderate-intensity aerobic activity or 75 minutes of vigorous-intensity aerobic activity each week, along with muscle-strengthening exercises on two or more days. Activities like walking, jogging, cycling, swimming, and weightlifting are excellent choices. Consistency is key, so find exercises you enjoy to stick with them long-term.
Together, these tips can significantly contribute to a healthier lifestyle.
Constrained Decoding#
Use regex
to specify a regular expression as a decoding constraint. This is only supported for local models.
[8]:
@function
def regular_expression_gen(s):
s += user("What is the IP address of the Google DNS servers?")
s += assistant(
gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
)
state = regular_expression_gen()
print_highlight(state["answer"])
[2025-04-13 03:12:24 TP0] Decode batch. #running-req: 0, #token: 0, token usage: 0.00, gen throughput (token/s): 62.36, #queue-req: 0,
[2025-04-13 03:12:24 TP0] Prefill batch. #new-seq: 1, #new-token: 18, #cached-token: 12, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:27] INFO: 127.0.0.1:33906 - "POST /generate HTTP/1.1" 200 OK
Use regex
to define a JSON
decoding schema.
[9]:
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
@function
def character_gen(s, name):
s += user(
f"{name} is a character in Harry Potter. Please fill in the following information about this character."
)
s += assistant(gen("json_output", max_tokens=256, regex=character_regex))
state = character_gen("Harry Potter")
print_highlight(state["json_output"])
[2025-04-13 03:12:27 TP0] Prefill batch. #new-seq: 1, #new-token: 24, #cached-token: 14, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:28 TP0] Decode batch. #running-req: 1, #token: 64, token usage: 0.00, gen throughput (token/s): 10.54, #queue-req: 0,
[2025-04-13 03:12:28 TP0] Decode batch. #running-req: 1, #token: 104, token usage: 0.01, gen throughput (token/s): 61.90, #queue-req: 0,
[2025-04-13 03:12:29 TP0] Decode batch. #running-req: 1, #token: 144, token usage: 0.01, gen throughput (token/s): 61.27, #queue-req: 0,
[2025-04-13 03:12:29] INFO: 127.0.0.1:33914 - "POST /generate HTTP/1.1" 200 OK
"name": "Harry Potter",
"house": "Gryffindor",
"blood status": "Half-blood",
"occupation": "student",
"wand": {
"wood": "Willow",
"core": " Phoenix feather",
"length": 10.5
},
"alive": "Alive",
"patronus": "Stag",
"bogart": "Frank Bryce"
}
Batching#
Use run_batch
to run a batch of prompts.
[10]:
@function
def text_qa(s, question):
s += user(question)
s += assistant(gen("answer", stop="\n"))
states = text_qa.run_batch(
[
{"question": "What is the capital of the United Kingdom?"},
{"question": "What is the capital of France?"},
{"question": "What is the capital of Japan?"},
],
progress_bar=True,
)
for i, state in enumerate(states):
print_highlight(f"Answer {i+1}: {states[i]['answer']}")
[2025-04-13 03:12:29 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 13, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:29] INFO: 127.0.0.1:49942 - "POST /generate HTTP/1.1" 200 OK
0%| | 0/3 [00:00<?, ?it/s]
[2025-04-13 03:12:29 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 17, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:29 TP0] Prefill batch. #new-seq: 2, #new-token: 18, #cached-token: 34, token usage: 0.00, #running-req: 1, #queue-req: 0,
[2025-04-13 03:12:29] INFO: 127.0.0.1:49962 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:29] INFO: 127.0.0.1:49968 - "POST /generate HTTP/1.1" 200 OK
100%|██████████| 3/3 [00:00<00:00, 15.18it/s]
[2025-04-13 03:12:29] INFO: 127.0.0.1:49948 - "POST /generate HTTP/1.1" 200 OK
Streaming#
Use stream
to stream the output to the user.
[11]:
@function
def text_qa(s, question):
s += user(question)
s += assistant(gen("answer", stop="\n"))
state = text_qa.run(
question="What is the capital of France?", temperature=0.1, stream=True
)
for out in state.text_iter():
print(out, end="", flush=True)
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant
[2025-04-13 03:12:29] INFO: 127.0.0.1:49982 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:12:29 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 25, token usage: 0.00, #running-req: 0, #queue-req: 0,
The capital of France is Paris.<|im_end|>
Complex Prompts#
You may use {system|user|assistant}_{begin|end}
to define complex prompts.
[12]:
@function
def chat_example(s):
s += system("You are a helpful assistant.")
# Same as: s += s.system("You are a helpful assistant.")
with s.user():
s += "Question: What is the capital of France?"
s += assistant_begin()
s += "Answer: " + gen("answer", max_tokens=100, stop="\n")
s += assistant_end()
state = chat_example()
print_highlight(state["answer"])
[2025-04-13 03:12:29 TP0] Prefill batch. #new-seq: 1, #new-token: 17, #cached-token: 14, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:12:29] INFO: 127.0.0.1:49998 - "POST /generate HTTP/1.1" 200 OK
[13]:
terminate_process(server_process)
[2025-04-13 03:12:29] Child process unexpectedly failed with an exit code 9. pid=2657706
[2025-04-13 03:12:29] Child process unexpectedly failed with an exit code 9. pid=2657634
Multi-modal Generation#
You may use SGLang frontend language to define multi-modal prompts. See here for supported models.
[14]:
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
print(f"Server started on http://localhost:{port}")
[2025-04-13 03:12:39] server_args=ServerArgs(model_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-VL-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=33739, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=303220435, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disable_fast_image_processor=False)
[2025-04-13 03:12:49 TP0] Overlap scheduler is disabled for multimodal models.
[2025-04-13 03:12:49 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 03:12:49 TP0] Automatically reduce --mem-fraction-static to 0.792 because this is a multimodal model.
[2025-04-13 03:12:49 TP0] Automatically turn off --chunked-prefill-size for multimodal model.
[2025-04-13 03:12:49 TP0] Automatically disable radix cache for qwen-vl series.
[2025-04-13 03:12:49 TP0] Init torch distributed begin.
[2025-04-13 03:12:49 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 03:12:49 TP0] Load weight begin. avail mem=61.24 GB
[2025-04-13 03:12:50 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 03:12:52 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 20% Completed | 1/5 [00:00<00:03, 1.26it/s]
Loading safetensors checkpoint shards: 40% Completed | 2/5 [00:01<00:02, 1.34it/s]
Loading safetensors checkpoint shards: 60% Completed | 3/5 [00:02<00:01, 1.36it/s]
Loading safetensors checkpoint shards: 80% Completed | 4/5 [00:02<00:00, 1.38it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00, 1.81it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:03<00:00, 1.57it/s]
[2025-04-13 03:12:55 TP0] Load weight end. type=Qwen2_5_VLForConditionalGeneration, dtype=torch.bfloat16, avail mem=45.45 GB, mem usage=15.79 GB.
[2025-04-13 03:12:55 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-04-13 03:12:55 TP0] Memory pool end. avail mem=44.08 GB
[2025-04-13 03:12:56 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 03:12:57 TP0] max_total_num_tokens=20480, chunked_prefill_size=-1, max_prefill_tokens=16384, max_running_requests=200, context_len=128000
[2025-04-13 03:12:58] INFO: Started server process [2659496]
[2025-04-13 03:12:58] INFO: Waiting for application startup.
[2025-04-13 03:12:58] INFO: Application startup complete.
[2025-04-13 03:12:58] INFO: Uvicorn running on http://0.0.0.0:33739 (Press CTRL+C to quit)
[2025-04-13 03:12:59] INFO: 127.0.0.1:58992 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 03:12:59] INFO: 127.0.0.1:59002 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 03:12:59 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Server started on http://localhost:33739
[15]:
set_default_backend(RuntimeEndpoint(f"http://localhost:{port}"))
[2025-04-13 03:13:04] INFO: 127.0.0.1:59020 - "GET /get_model_info HTTP/1.1" 200 OK
Ask a question about an image.
[16]:
@function
def image_qa(s, image_file, question):
s += user(image(image_file) + question)
s += assistant(gen("answer", max_tokens=256))
image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
image_bytes, _ = load_image(image_url)
state = image_qa(image_bytes, "What is in the image?")
print_highlight(state["answer"])
[2025-04-13 03:13:04] INFO: 127.0.0.1:59010 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 03:13:04] The server is fired up and ready to roll!
[2025-04-13 03:13:05 TP0] Prefill batch. #new-seq: 1, #new-token: 307, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 03:13:06 TP0] Decode batch. #running-req: 1, #token: 340, token usage: 0.02, gen throughput (token/s): 4.60, #queue-req: 0,
[2025-04-13 03:13:07 TP0] Decode batch. #running-req: 1, #token: 380, token usage: 0.02, gen throughput (token/s): 44.17, #queue-req: 0,
[2025-04-13 03:13:07] INFO: 127.0.0.1:59022 - "POST /generate HTTP/1.1" 200 OK
In front when the crutches phenomena interact beats parked ornate substrata public transport ownership.
[17]:
terminate_process(server_process)
[2025-04-13 03:13:08] Child process unexpectedly failed with an exit code 9. pid=2659811