OpenAI APIs - Completions#
SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models. A complete reference for the API is available in the OpenAI API Reference.
This tutorial covers the following popular APIs:
chat/completions
completions
Check out other tutorials to learn about vision APIs for vision-language models and embedding APIs for embedding models.
Launch A Server#
Launch the server in your terminal and wait for it to initialize.
[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
server_process, port = launch_server_cmd(
"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
print(f"Server started on http://localhost:{port}")
W0814 09:47:05.185000 1085583 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:47:05.185000 1085583 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 09:47:06] server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=39185, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.874, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=731626249, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 09:47:07] Using default HuggingFace chat template with detected content format: string
W0814 09:47:14.124000 1086151 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:47:14.124000 1086151 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 09:47:14.175000 1086152 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:47:14.175000 1086152 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 09:47:15] Attention backend not explicitly specified. Use fa3 backend by default.
[2025-08-14 09:47:15] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 09:47:17] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 09:47:18] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 09:47:18] Load weight begin. avail mem=78.58 GB
[2025-08-14 09:47:18] Using model weights format ['*.safetensors']
[2025-08-14 09:47:19] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.88it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.87it/s]
[2025-08-14 09:47:19] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=77.52 GB, mem usage=1.06 GB.
[2025-08-14 09:47:19] KV Cache is allocated. #tokens: 20480, K size: 0.12 GB, V size: 0.12 GB
[2025-08-14 09:47:19] Memory pool end. avail mem=77.12 GB
[2025-08-14 09:47:19] Capture cuda graph begin. This can take up to several minutes. avail mem=77.03 GB
[2025-08-14 09:47:19] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=76.96 GB): 100%|██████████| 3/3 [00:00<00:00, 10.99it/s]
[2025-08-14 09:47:20] Capture cuda graph end. Time elapsed: 0.85 s. mem usage=0.07 GB. avail mem=76.95 GB.
[2025-08-14 09:47:20] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=32768, available_gpu_mem=76.95 GB
[2025-08-14 09:47:21] INFO: Started server process [1085583]
[2025-08-14 09:47:21] INFO: Waiting for application startup.
[2025-08-14 09:47:21] INFO: Application startup complete.
[2025-08-14 09:47:21] INFO: Uvicorn running on http://0.0.0.0:39185 (Press CTRL+C to quit)
[2025-08-14 09:47:21] INFO: 127.0.0.1:54738 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 09:47:22] INFO: 127.0.0.1:54740 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 09:47:22] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:47:22] INFO: 127.0.0.1:54756 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 09:47:22] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
Server started on http://localhost:39185
Chat Completions#
Usage#
The server fully implements the OpenAI API. It will automatically apply the chat template specified in the Hugging Face tokenizer, if one is available. You can also specify a custom chat template with --chat-template
when launching the server.
[2]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-08-14 09:47:27] Prefill batch. #new-seq: 1, #new-token: 37, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 70, token usage: 0.00, cuda graph: True, gen throughput (token/s): 5.89, #queue-req: 0,
[2025-08-14 09:47:27] INFO: 127.0.0.1:54762 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Parameters#
The chat completions API accepts OpenAI Chat Completions API’s parameters. Refer to OpenAI Chat Completions API for more details.
SGLang extends the standard API with the extra_body
parameter, allowing for additional customization. One key option within extra_body
is chat_template_kwargs
, which can be used to pass arguments to the chat template processor.
[3]:
response = client.chat.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
messages=[
{
"role": "system",
"content": "You are a knowledgeable historian who provides concise responses.",
},
{"role": "user", "content": "Tell me about ancient Rome"},
{
"role": "assistant",
"content": "Ancient Rome was a civilization centered in Italy.",
},
{"role": "user", "content": "What were their major achievements?"},
],
temperature=0.3, # Lower temperature for more focused responses
max_tokens=128, # Reasonable length for a concise response
top_p=0.95, # Slightly higher for better fluency
presence_penalty=0.2, # Mild penalty to avoid repetition
frequency_penalty=0.2, # Mild penalty for more natural language
n=1, # Single response is usually more stable
seed=42, # Keep for reproducibility
)
print_highlight(response.choices[0].message.content)
[2025-08-14 09:47:27] Prefill batch. #new-seq: 1, #new-token: 49, #cached-token: 5, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 88, token usage: 0.00, cuda graph: True, gen throughput (token/s): 242.67, #queue-req: 0,
[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 128, token usage: 0.01, cuda graph: True, gen throughput (token/s): 620.86, #queue-req: 0,
[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 168, token usage: 0.01, cuda graph: True, gen throughput (token/s): 625.91, #queue-req: 0,
[2025-08-14 09:47:27] INFO: 127.0.0.1:54762 - "POST /v1/chat/completions HTTP/1.1" 200 OK
1. The construction of the Colosseum, one of the largest amphitheaters in the world.
2. The construction of the Pantheon, a basilica that served as a temple and church.
3. The development of Roman law, which established principles of justice and legal procedures.
4. The invention of the wheel, which revolutionized transportation.
5. The construction of aqueducts to supply water
Streaming mode is also supported.
[4]:
stream = client.chat.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
messages=[{"role": "user", "content": "Say this is a test"}],
stream=True,
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="")
[2025-08-14 09:47:27] INFO: 127.0.0.1:54762 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-08-14 09:47:27] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 24, token usage: 0.00, #running-req: 0, #queue-req: 0,
Yes, I am Qwen, an AI created by Alibaba Cloud. My purpose is to assist you with any questions or tasks[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 60, token usage: 0.00, cuda graph: True, gen throughput (token/s): 432.95, #queue-req: 0,
you may have to the best of my abilities. If you have any specific questions or areas you'd like to discuss, please let me know, and I'll do my best to provide the information or[2025-08-14 09:47:27] Decode batch. #running-req: 1, #token: 100, token usage: 0.00, cuda graph: True, gen throughput (token/s): 664.12, #queue-req: 0,
answers you're looking for.
Enabling Model Thinking/Reasoning#
You can use chat_template_kwargs
to enable or disable the model’s internal thinking or reasoning process output. Set "enable_thinking": True
within chat_template_kwargs
to include the reasoning steps in the response. This requires launching the server with a compatible reasoning parser.
Reasoning Parser Options:
--reasoning-parser deepseek-r1
: For DeepSeek-R1 family models (R1, R1-0528, R1-Distill)--reasoning-parser qwen3
: For both standard Qwen3 models that supportenable_thinking
parameter and Qwen3-Thinking models--reasoning-parser qwen3-thinking
: For Qwen3-Thinking models, force reasoning version of qwen3 parser--reasoning-parser kimi
: For Kimi thinking models
Here’s an example demonstrating how to enable thinking and retrieve the reasoning content separately (using separate_reasoning: True
):
# For Qwen3 models with enable_thinking support:
# python3 -m sglang.launch_server --model-path QwQ/Qwen3-32B-250415 --reasoning-parser qwen3 ...
from openai import OpenAI
# Modify OpenAI's API key and API base to use SGLang's API server.
openai_api_key = "EMPTY"
openai_api_base = f"http://127.0.0.1:{port}/v1" # Use the correct port
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
model = "QwQ/Qwen3-32B-250415" # Use the model loaded by the server
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
response = client.chat.completions.create(
model=model,
messages=messages,
extra_body={
"chat_template_kwargs": {"enable_thinking": True},
"separate_reasoning": True
}
)
print("response.choices[0].message.reasoning_content: \n", response.choices[0].message.reasoning_content)
print("response.choices[0].message.content: \n", response.choices[0].message.content)
Example Output:
response.choices[0].message.reasoning_content:
Okay, so I need to figure out which number is greater between 9.11 and 9.8. Hmm, let me think. Both numbers start with 9, right? So the whole number part is the same. That means I need to look at the decimal parts to determine which one is bigger.
...
Therefore, after checking multiple methods—aligning decimals, subtracting, converting to fractions, and using a real-world analogy—it's clear that 9.8 is greater than 9.11.
response.choices[0].message.content:
To determine which number is greater between **9.11** and **9.8**, follow these steps:
...
**Answer**:
9.8 is greater than 9.11.
Setting "enable_thinking": False
(or omitting it) will result in reasoning_content
being None
.
Note for Qwen3-Thinking models: These models always generate thinking content and do not support the enable_thinking
parameter. Use --reasoning-parser qwen3-thinking
or --reasoning-parser qwen3
to parse the thinking content.
Here is an example of a detailed chat completion request using standard OpenAI parameters:
Completions#
Usage#
Completions API is similar to Chat Completions API, but without the messages
parameter or chat templates.
[5]:
response = client.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
prompt="List 3 countries and their capitals.",
temperature=0,
max_tokens=64,
n=1,
stop=None,
)
print_highlight(f"Response: {response}")
[2025-08-14 09:47:28] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:47:28] Decode batch. #running-req: 1, #token: 42, token usage: 0.00, cuda graph: True, gen throughput (token/s): 318.43, #queue-req: 0,
[2025-08-14 09:47:28] INFO: 127.0.0.1:54762 - "POST /v1/completions HTTP/1.1" 200 OK
Parameters#
The completions API accepts OpenAI Completions API’s parameters. Refer to OpenAI Completions API for more details.
Here is an example of a detailed completions request:
[6]:
response = client.completions.create(
model="qwen/qwen2.5-0.5b-instruct",
prompt="Write a short story about a space explorer.",
temperature=0.7, # Moderate temperature for creative writing
max_tokens=150, # Longer response for a story
top_p=0.9, # Balanced diversity in word choice
stop=["\n\n", "THE END"], # Multiple stop sequences
presence_penalty=0.3, # Encourage novel elements
frequency_penalty=0.3, # Reduce repetitive phrases
n=1, # Generate one completion
seed=123, # For reproducible results
)
print_highlight(f"Response: {response}")
[2025-08-14 09:47:28] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:47:28] Decode batch. #running-req: 1, #token: 19, token usage: 0.00, cuda graph: True, gen throughput (token/s): 496.44, #queue-req: 0,
[2025-08-14 09:47:28] INFO: 127.0.0.1:54762 - "POST /v1/completions HTTP/1.1" 200 OK
Structured Outputs (JSON, Regex, EBNF)#
For OpenAI compatible structured outputs API, refer to Structured Outputs for more details.
[7]:
terminate_process(server_process)