Structured Outputs#
You can specify a JSON schema, regular expression or EBNF to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (json_schema
, regex
, or ebnf
) can be specified for a request.
SGLang supports two grammar backends:
Outlines (default): Supports JSON schema and regular expression constraints.
XGrammar: Supports JSON schema and EBNF constraints and currently uses the GGML BNF format.
We suggest using XGrammar whenever possible for its better performance. For more details, see XGrammar technical overview.
To use Xgrammar, simply add --grammar-backend
xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.
OpenAI Compatible API#
[1]:
from sglang.utils import (
execute_shell_command,
wait_for_server,
terminate_process,
print_highlight,
)
import openai
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
server_process = execute_shell_command(
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar"
)
wait_for_server("http://localhost:30000")
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
[2025-01-13 13:08:20] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, random_seed=36724584, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, dump_requests_folder=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2025-01-13 13:08:38 TP0] Init torch distributed begin.
[2025-01-13 13:08:38 TP0] Load weight begin. avail mem=78.81 GB
[2025-01-13 13:08:40 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.29it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.20it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.15it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.57it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.41it/s]
[2025-01-13 13:08:43 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.72 GB
[2025-01-13 13:08:43 TP0] KV Cache is allocated. K size: 27.13 GB, V size: 27.13 GB.
[2025-01-13 13:08:43 TP0] Memory pool end. avail mem=8.34 GB
[2025-01-13 13:08:43 TP0] Capture cuda graph begin. This can take up to several minutes.
100%|██████████| 23/23 [00:05<00:00, 3.89it/s]
[2025-01-13 13:08:49 TP0] Capture cuda graph end. Time elapsed: 5.93 s
[2025-01-13 13:08:49 TP0] max_total_num_tokens=444500, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-01-13 13:08:50] INFO: Started server process [210389]
[2025-01-13 13:08:50] INFO: Waiting for application startup.
[2025-01-13 13:08:50] INFO: Application startup complete.
[2025-01-13 13:08:50] INFO: Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2025-01-13 13:08:51] INFO: 127.0.0.1:46660 - "GET /v1/models HTTP/1.1" 200 OK
[2025-01-13 13:08:51] INFO: 127.0.0.1:46674 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-01-13 13:08:51 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:08:53] INFO: 127.0.0.1:46682 - "POST /generate HTTP/1.1" 200 OK
[2025-01-13 13:08:53] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
JSON#
you can directly define a JSON schema or use Pydantic to define and validate the response.
Using Pydantic
[2]:
from pydantic import BaseModel, Field
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "json_schema",
"json_schema": {
"name": "foo",
# convert the pydantic model to json schema
"schema": CapitalInfo.model_json_schema(),
},
},
)
response_content = response.choices[0].message.content
# validate the JSON response by the pydantic model
capital_info = CapitalInfo.model_validate_json(response_content)
print_highlight(f"Validated response: {capital_info.model_dump_json()}")
[2025-01-13 13:08:56 TP0] Prefill batch. #new-seq: 1, #new-token: 48, #cached-token: 1, cache hit rate: 1.79%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:08:56] INFO: 127.0.0.1:42770 - "POST /v1/chat/completions HTTP/1.1" 200 OK
JSON Schema Directly
[3]:
import json
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Give me the information of the capital of France in the JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "json_schema",
"json_schema": {"name": "foo", "schema": json.loads(json_schema)},
},
)
print_highlight(response.choices[0].message.content)
[2025-01-13 13:08:56 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 48, cache hit rate: 46.67%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:08:56] INFO: 127.0.0.1:42770 - "POST /v1/chat/completions HTTP/1.1" 200 OK
EBNF#
[4]:
ebnf_grammar = """
root ::= city | description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful geography bot."},
{
"role": "user",
"content": "Give me the information of the capital of France.",
},
],
temperature=0,
max_tokens=32,
extra_body={"ebnf": ebnf_grammar},
)
print_highlight(response.choices[0].message.content)
[2025-01-13 13:08:56 TP0] Prefill batch. #new-seq: 1, #new-token: 27, #cached-token: 25, cache hit rate: 47.13%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:08:56 TP0] Decode batch. #running-req: 1, #token: 55, token usage: 0.00, gen throughput (token/s): 5.91, #queue-req: 0
[2025-01-13 13:08:56] INFO: 127.0.0.1:42770 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[5]:
terminate_process(server_process)
server_process = execute_shell_command(
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0"
)
wait_for_server("http://localhost:30000")
[2025-01-13 13:09:11] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, random_seed=45664428, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, dump_requests_folder=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2025-01-13 13:09:29 TP0] Init torch distributed begin.
[2025-01-13 13:09:29 TP0] Load weight begin. avail mem=78.81 GB
[2025-01-13 13:09:30 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.11it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.04it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.02it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.39it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.24it/s]
[2025-01-13 13:09:34 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.72 GB
[2025-01-13 13:09:34 TP0] KV Cache is allocated. K size: 27.13 GB, V size: 27.13 GB.
[2025-01-13 13:09:34 TP0] Memory pool end. avail mem=8.34 GB
[2025-01-13 13:09:34 TP0] Capture cuda graph begin. This can take up to several minutes.
100%|██████████| 23/23 [00:05<00:00, 4.51it/s]
[2025-01-13 13:09:39 TP0] Capture cuda graph end. Time elapsed: 5.11 s
[2025-01-13 13:09:40 TP0] max_total_num_tokens=444500, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072
[2025-01-13 13:09:40] INFO: Started server process [211139]
[2025-01-13 13:09:40] INFO: Waiting for application startup.
[2025-01-13 13:09:40] INFO: Application startup complete.
[2025-01-13 13:09:40] INFO: Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)
[2025-01-13 13:09:40] INFO: 127.0.0.1:47130 - "GET /v1/models HTTP/1.1" 200 OK
[2025-01-13 13:09:41] INFO: 127.0.0.1:47132 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-01-13 13:09:41 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:09:43] INFO: 127.0.0.1:47134 - "POST /generate HTTP/1.1" 200 OK
[2025-01-13 13:09:43] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
Regular expression#
[6]:
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=128,
extra_body={"regex": "(Paris|London)"},
)
print_highlight(response.choices[0].message.content)
[2025-01-13 13:09:45 TP0] Prefill batch. #new-seq: 1, #new-token: 41, #cached-token: 1, cache hit rate: 2.04%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:09:45] INFO: 127.0.0.1:59958 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[7]:
terminate_process(server_process)
Native API and SGLang Runtime (SRT)#
[8]:
server_process = execute_shell_command(
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --grammar-backend xgrammar
"""
)
wait_for_server("http://localhost:30010")
[2025-01-13 13:09:58] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30010, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, random_seed=459153817, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, dump_requests_folder=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2025-01-13 13:10:15 TP0] Init torch distributed begin.
[2025-01-13 13:10:15 TP0] Load weight begin. avail mem=78.81 GB
[2025-01-13 13:10:16 TP0] Using model weights format ['*.safetensors']
[2025-01-13 13:10:17 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.44it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.44it/s]
[2025-01-13 13:10:17 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=76.39 GB
[2025-01-13 13:10:17 TP0] KV Cache is allocated. K size: 33.47 GB, V size: 33.47 GB.
[2025-01-13 13:10:17 TP0] Memory pool end. avail mem=7.45 GB
[2025-01-13 13:10:17 TP0] Capture cuda graph begin. This can take up to several minutes.
100%|██████████| 23/23 [00:05<00:00, 4.20it/s]
[2025-01-13 13:10:23 TP0] Capture cuda graph end. Time elapsed: 5.48 s
[2025-01-13 13:10:23 TP0] max_total_num_tokens=2193171, max_prefill_tokens=16384, max_running_requests=4097, context_len=131072
[2025-01-13 13:10:24] INFO: Started server process [211865]
[2025-01-13 13:10:24] INFO: Waiting for application startup.
[2025-01-13 13:10:24] INFO: Application startup complete.
[2025-01-13 13:10:24] INFO: Uvicorn running on http://127.0.0.1:30010 (Press CTRL+C to quit)
[2025-01-13 13:10:24] INFO: 127.0.0.1:58260 - "GET /v1/models HTTP/1.1" 200 OK
[2025-01-13 13:10:25] INFO: 127.0.0.1:58264 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-01-13 13:10:25 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:10:27] INFO: 127.0.0.1:58272 - "POST /generate HTTP/1.1" 200 OK
[2025-01-13 13:10:27] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
JSON#
Using Pydantic
[9]:
import requests
import json
from pydantic import BaseModel, Field
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
# Make API request
response = requests.post(
"http://localhost:30010/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json.dumps(CapitalInfo.model_json_schema()),
},
},
)
print_highlight(response.json())
response_data = json.loads(response.json()["text"])
# validate the response by the pydantic model
capital_info = CapitalInfo.model_validate(response_data)
print_highlight(f"Validated response: {capital_info.model_dump_json()}")
[2025-01-13 13:10:29 TP0] Prefill batch. #new-seq: 1, #new-token: 14, #cached-token: 1, cache hit rate: 4.55%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:10:30] INFO: 127.0.0.1:58288 - "POST /generate HTTP/1.1" 200 OK
JSON Schema Directly
[10]:
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
"http://localhost:30010/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print_highlight(response.json())
[2025-01-13 13:10:30 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 14, cache hit rate: 40.54%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:10:30 TP0] Decode batch. #running-req: 1, #token: 29, token usage: 0.00, gen throughput (token/s): 6.13, #queue-req: 0
[2025-01-13 13:10:30] INFO: 127.0.0.1:58292 - "POST /generate HTTP/1.1" 200 OK
EBNF#
[11]:
import requests
response = requests.post(
"http://localhost:30010/generate",
json={
"text": "Give me the information of the capital of France.",
"sampling_params": {
"max_new_tokens": 128,
"temperature": 0,
"n": 3,
"ebnf": (
"root ::= city | description\n"
'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
'description ::= city " is " status\n'
'status ::= "the capital of " country\n'
'country ::= "England" | "France" | "Germany" | "Italy"'
),
},
"stream": False,
"return_logprob": False,
},
)
print_highlight(response.json())
[2025-01-13 13:10:30 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 1, cache hit rate: 33.33%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:10:30 TP0] Prefill batch. #new-seq: 3, #new-token: 3, #cached-token: 30, cache hit rate: 56.79%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:10:30] INFO: 127.0.0.1:58300 - "POST /generate HTTP/1.1" 200 OK
[12]:
terminate_process(server_process)
server_process = execute_shell_command(
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010
"""
)
wait_for_server("http://localhost:30010")
[2025-01-13 13:10:42] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='127.0.0.1', port=30010, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, random_seed=1037872448, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, dump_requests_folder=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)
[2025-01-13 13:10:59 TP0] Init torch distributed begin.
[2025-01-13 13:11:00 TP0] Load weight begin. avail mem=78.81 GB
[2025-01-13 13:11:01 TP0] Using model weights format ['*.safetensors']
[2025-01-13 13:11:01 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.22it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.22it/s]
[2025-01-13 13:11:02 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=76.39 GB
[2025-01-13 13:11:02 TP0] KV Cache is allocated. K size: 33.47 GB, V size: 33.47 GB.
[2025-01-13 13:11:02 TP0] Memory pool end. avail mem=7.45 GB
[2025-01-13 13:11:02 TP0] Capture cuda graph begin. This can take up to several minutes.
100%|██████████| 23/23 [00:04<00:00, 4.95it/s]
[2025-01-13 13:11:06 TP0] Capture cuda graph end. Time elapsed: 4.66 s
[2025-01-13 13:11:07 TP0] max_total_num_tokens=2193171, max_prefill_tokens=16384, max_running_requests=4097, context_len=131072
[2025-01-13 13:11:07] INFO: Started server process [212616]
[2025-01-13 13:11:07] INFO: Waiting for application startup.
[2025-01-13 13:11:07] INFO: Application startup complete.
[2025-01-13 13:11:07] INFO: Uvicorn running on http://127.0.0.1:30010 (Press CTRL+C to quit)
[2025-01-13 13:11:07] INFO: 127.0.0.1:44206 - "GET /v1/models HTTP/1.1" 200 OK
[2025-01-13 13:11:08] INFO: 127.0.0.1:44220 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-01-13 13:11:08 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:11:10] INFO: 127.0.0.1:44222 - "POST /generate HTTP/1.1" 200 OK
[2025-01-13 13:11:10] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
Regular expression#
[13]:
response = requests.post(
"http://localhost:30010/generate",
json={
"text": "Paris is the capital of",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"regex": "(France|England)",
},
},
)
print_highlight(response.json())
[2025-01-13 13:11:12 TP0] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 1, cache hit rate: 7.69%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-01-13 13:11:12] INFO: 127.0.0.1:44226 - "POST /generate HTTP/1.1" 200 OK
[14]:
terminate_process(server_process)
Offline Engine API#
[15]:
import sglang as sgl
llm_xgrammar = sgl.Engine(
model_path="meta-llama/Meta-Llama-3.1-8B-Instruct", grammar_backend="xgrammar"
)
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.16it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.09it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.08it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.30it/s]
100%|██████████| 23/23 [00:04<00:00, 4.73it/s]
JSON#
Using Pydantic
[16]:
import json
from pydantic import BaseModel, Field
prompts = [
"Give me the information of the capital of China in the JSON format.",
"Give me the information of the capital of France in the JSON format.",
"Give me the information of the capital of Ireland in the JSON format.",
]
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
sampling_params = {
"temperature": 0.1,
"top_p": 0.95,
"json_schema": json.dumps(CapitalInfo.model_json_schema()),
}
outputs = llm_xgrammar.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print_highlight("===============================")
print_highlight(f"Prompt: {prompt}") # validate the output by the pydantic model
capital_info = CapitalInfo.model_validate_json(output["text"])
print_highlight(f"Validated output: {capital_info.model_dump_json()}")
JSON Schema Directly
[17]:
prompts = [
"Give me the information of the capital of China in the JSON format.",
"Give me the information of the capital of France in the JSON format.",
"Give me the information of the capital of Ireland in the JSON format.",
]
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
sampling_params = {"temperature": 0.1, "top_p": 0.95, "json_schema": json_schema}
outputs = llm_xgrammar.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print_highlight("===============================")
print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
Generated text: {"name": "Beijing", "population": 21500000}
Generated text: {"name": "Paris", "population": 2141000}
Generated text: {"name": "Dublin", "population": 527617}
EBNF#
[18]:
prompts = [
"Give me the information of the capital of France.",
"Give me the information of the capital of Germany.",
"Give me the information of the capital of Italy.",
]
sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"ebnf": (
"root ::= city | description\n"
'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
'description ::= city " is " status\n'
'status ::= "the capital of " country\n'
'country ::= "England" | "France" | "Germany" | "Italy"'
),
}
outputs = llm_xgrammar.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print_highlight("===============================")
print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
Generated text: Paris is the capital of France
Generated text: Berlin is the capital of Germany
Generated text: Paris is the capital of France
[19]:
llm_xgrammar.shutdown()
llm_outlines = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
/public_sglang_ci/runner-a-gpu-45/_work/_tool/Python/3.9.21/x64/lib/python3.9/multiprocessing/resource_tracker.py:96: UserWarning: resource_tracker: process died unexpectedly, relaunching. Some resources might leak.
warnings.warn('resource_tracker: process died unexpectedly, '
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.16it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.09it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.08it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.30it/s]
100%|██████████| 23/23 [00:04<00:00, 4.62it/s]
Regular expression#
[20]:
prompts = [
"Please provide information about London as a major global city:",
"Please provide information about Paris as a major global city:",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95, "regex": "(France|England)"}
outputs = llm_outlines.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print_highlight("===============================")
print_highlight(f"Prompt: {prompt}\nGenerated text: {output['text']}")
Generated text: England
Generated text: France
[21]:
llm_outlines.shutdown()