Sending Requests#
This notebook provides a quick-start guide to use SGLang in chat completions after installation.
For Vision Language Models, see OpenAI APIs - Vision.
For Embedding Models, see OpenAI APIs - Embedding and Encode (embedding model).
For Reward Models, see Classify (reward model).
Launch A Server#
[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
# This is equivalent to running the following command in your terminal
# python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0
server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct \
--host 0.0.0.0
"""
)
wait_for_server(f"http://localhost:{port}")
W0814 06:24:08.936000 1231605 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:24:08.936000 1231605 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:24:11] server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=35526, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.874, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=44487854, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:24:11] Using default HuggingFace chat template with detected content format: string
W0814 06:24:18.006000 1231982 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:24:18.006000 1231982 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:24:18.021000 1231983 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:24:18.021000 1231983 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:24:21] Attention backend not explicitly specified. Use fa3 backend by default.
[2025-08-14 06:24:21] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:24:22] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:24:23] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:24:23] Load weight begin. avail mem=78.58 GB
[2025-08-14 06:24:23] Using model weights format ['*.safetensors']
[2025-08-14 06:24:23] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 6.39it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 6.38it/s]
[2025-08-14 06:24:23] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=77.52 GB, mem usage=1.06 GB.
[2025-08-14 06:24:23] KV Cache is allocated. #tokens: 20480, K size: 0.12 GB, V size: 0.12 GB
[2025-08-14 06:24:23] Memory pool end. avail mem=77.12 GB
[2025-08-14 06:24:24] Capture cuda graph begin. This can take up to several minutes. avail mem=77.03 GB
[2025-08-14 06:24:24] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=76.96 GB): 100%|██████████| 3/3 [00:00<00:00, 10.37it/s]
[2025-08-14 06:24:24] Capture cuda graph end. Time elapsed: 0.88 s. mem usage=0.07 GB. avail mem=76.95 GB.
[2025-08-14 06:24:25] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=32768, available_gpu_mem=76.95 GB
[2025-08-14 06:24:25] INFO: Started server process [1231605]
[2025-08-14 06:24:25] INFO: Waiting for application startup.
[2025-08-14 06:24:25] INFO: Application startup complete.
[2025-08-14 06:24:25] INFO: Uvicorn running on http://0.0.0.0:35526 (Press CTRL+C to quit)
[2025-08-14 06:24:26] INFO: 127.0.0.1:48714 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:24:26] INFO: 127.0.0.1:48720 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:24:26] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:24:27] INFO: 127.0.0.1:48726 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:24:27] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
Using cURL#
[2]:
import subprocess, json
curl_command = f"""
curl -s http://localhost:{port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{{"model": "qwen/qwen2.5-0.5b-instruct", "messages": [{{"role": "user", "content": "What is the capital of France?"}}]}}'
"""
response = json.loads(subprocess.check_output(curl_command, shell=True))
print_highlight(response)
[2025-08-14 06:24:31] Prefill batch. #new-seq: 1, #new-token: 36, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:24:31] INFO: 127.0.0.1:48728 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': '5717536f009742f8aa7f4e9240de00c0', 'object': 'chat.completion', 'created': 1755152671, 'model': 'qwen/qwen2.5-0.5b-instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.', 'reasoning_content': None, 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 151645}], 'usage': {'prompt_tokens': 36, 'total_tokens': 44, 'completion_tokens': 8, 'prompt_tokens_details': None, 'reasoning_tokens': 0}}
Using Python Requests#
[3]:
import requests
url = f"http://localhost:{port}/v1/chat/completions"
data = {
"model": "qwen/qwen2.5-0.5b-instruct",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print_highlight(response.json())
[2025-08-14 06:24:31] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 35, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:24:31] INFO: 127.0.0.1:48742 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': 'd99cfdcb26454d63b56486d7bec76d6f', 'object': 'chat.completion', 'created': 1755152671, 'model': 'qwen/qwen2.5-0.5b-instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.', 'reasoning_content': None, 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 151645}], 'usage': {'prompt_tokens': 36, 'total_tokens': 44, 'completion_tokens': 8, 'prompt_tokens_details': None, 'reasoning_tokens': 0}}
Using OpenAI Python Client#
[4]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(response)
[2025-08-14 06:24:32] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:24:32] Decode batch. #running-req: 1, #token: 54, token usage: 0.00, cuda graph: True, gen throughput (token/s): 5.99, #queue-req: 0,
[2025-08-14 06:24:32] INFO: 127.0.0.1:48758 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ChatCompletion(id='6d006c6cc0f5477c907e86daf2660a48', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Sure, here are three countries and their respective capitals:\n\n1. **United States** - Washington, D.C.\n2. **Canada** - Ottawa\n3. **Australia** - Canberra', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=151645)], created=1755152672, model='qwen/qwen2.5-0.5b-instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=39, prompt_tokens=37, total_tokens=76, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0))
Streaming#
[5]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
# Use stream=True for streaming responses
response = client.chat.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
# Handle the streaming output
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
[2025-08-14 06:24:32] INFO: 127.0.0.1:48772 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-08-14 06:24:32] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 36, token usage: 0.00, #running-req: 0, #queue-req: 0,
Sure, here are three countries and their respective capitals:
1. **United States**[2025-08-14 06:24:32] Decode batch. #running-req: 1, #token: 55, token usage: 0.00, cuda graph: True, gen throughput (token/s): 285.01, #queue-req: 0,
- Washington, D.C.
2. **Canada** - Ottawa
3. **Australia** - Canberra
Using Native Generation APIs#
You can also use the native /generate
endpoint with requests, which provides more flexibility. An API reference is available at Sampling Parameters.
[6]:
import requests
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print_highlight(response.json())
[2025-08-14 06:24:32] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 2, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:24:32] Decode batch. #running-req: 1, #token: 24, token usage: 0.00, cuda graph: True, gen throughput (token/s): 525.97, #queue-req: 0,
[2025-08-14 06:24:32] INFO: 127.0.0.1:48782 - "POST /generate HTTP/1.1" 200 OK
{'text': ' Paris. It is the largest city in Europe and the second largest city in the world. It is located in the south of France, on the banks of the', 'output_ids': [12095, 13, 1084, 374, 279, 7772, 3283, 304, 4505, 323, 279, 2086, 7772, 3283, 304, 279, 1879, 13, 1084, 374, 7407, 304, 279, 9806, 315, 9625, 11, 389, 279, 13959, 315, 279], 'meta_info': {'id': '57e6c0ae3615405b9bc53662d8994b88', 'finish_reason': {'type': 'length', 'length': 32}, 'prompt_tokens': 5, 'completion_tokens': 32, 'cached_tokens': 2, 'e2e_latency': 0.05570030212402344}}
Streaming#
[7]:
import requests, json
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"]
print(output[prev:], end="", flush=True)
prev = len(output)
[2025-08-14 06:24:32] INFO: 127.0.0.1:48794 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:24:32] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 4, token usage: 0.00, #running-req: 0, #queue-req: 0,
Paris. It is the largest city in Europe and the second largest city in the world. It is located in the south of France[2025-08-14 06:24:32] Decode batch. #running-req: 1, #token: 32, token usage: 0.00, cuda graph: True, gen throughput (token/s): 538.37, #queue-req: 0,
, on the banks of the
[8]:
terminate_process(server_process)