Sending Requests#
This notebook provides a quick-start guide to use SGLang in chat completions after installation.
For Vision Language Models, see OpenAI APIs - Vision.
For Embedding Models, see OpenAI APIs - Embedding and Encode (embedding model).
For Reward Models, see Classify (reward model).
Launch A Server#
[1]:
from sglang.test.test_utils import is_in_ci
from sglang.utils import wait_for_server, print_highlight, terminate_process
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
# This is equivalent to running the following command in your terminal
# python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0
server_process, port = launch_server_cmd(
"""
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--host 0.0.0.0
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-02-23 08:36:17] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, skip_tokenizer_init=False, host='0.0.0.0', port=37579, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, stream_output=False, random_seed=812393265, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='sglang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_draft_model_path=None, speculative_algorithm=None, speculative_num_steps=5, speculative_num_draft_tokens=64, speculative_eagle_topk=8, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, return_hidden_states=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False)
[2025-02-23 08:36:35 TP0] Init torch distributed begin.
[2025-02-23 08:36:35 TP0] Load weight begin. avail mem=54.07 GB
[2025-02-23 08:36:37 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:01<00:03, 1.17s/it]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.48it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.06it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.11it/s]
[2025-02-23 08:36:41 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=38.96 GB
[2025-02-23 08:36:41 TP0] KV Cache is allocated. K size: 1.25 GB, V size: 1.25 GB.
[2025-02-23 08:36:41 TP0] Memory pool end. avail mem=36.23 GB
[2025-02-23 08:36:42 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-02-23 08:36:42] INFO: Started server process [1035925]
[2025-02-23 08:36:42] INFO: Waiting for application startup.
[2025-02-23 08:36:42] INFO: Application startup complete.
[2025-02-23 08:36:42] INFO: Uvicorn running on http://0.0.0.0:37579 (Press CTRL+C to quit)
[2025-02-23 08:36:43] INFO: 127.0.0.1:51848 - "GET /v1/models HTTP/1.1" 200 OK
[2025-02-23 08:36:43] INFO: 127.0.0.1:51852 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-02-23 08:36:43 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:36:46] INFO: 127.0.0.1:51854 - "POST /generate HTTP/1.1" 200 OK
[2025-02-23 08:36:46] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Using cURL#
[2]:
import subprocess, json
curl_command = f"""
curl -s http://localhost:{port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{{"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "messages": [{{"role": "user", "content": "What is the capital of France?"}}]}}'
"""
response = json.loads(subprocess.check_output(curl_command, shell=True))
print_highlight(response)
[2025-02-23 08:36:48 TP0] Prefill batch. #new-seq: 1, #new-token: 41, #cached-token: 1, cache hit rate: 2.04%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:36:48] INFO: 127.0.0.1:51858 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': '2bf9a12fe3c84a719ce05e982df34fa0', 'object': 'chat.completion', 'created': 1740299808, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.', 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 42, 'total_tokens': 50, 'completion_tokens': 8, 'prompt_tokens_details': None}}
Using Python Requests#
[3]:
import requests
url = f"http://localhost:{port}/v1/chat/completions"
data = {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print_highlight(response.json())
[2025-02-23 08:36:48 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 41, cache hit rate: 46.15%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:36:48] INFO: 127.0.0.1:51866 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{'id': 'b6c874ae69e74171ae7f025ccb6b7758', 'object': 'chat.completion', 'created': 1740299808, 'model': 'meta-llama/Meta-Llama-3.1-8B-Instruct', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'The capital of France is Paris.', 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 128009}], 'usage': {'prompt_tokens': 42, 'total_tokens': 50, 'completion_tokens': 8, 'prompt_tokens_details': None}}
Using OpenAI Python Client#
[4]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(response)
[2025-02-23 08:36:48 TP0] Prefill batch. #new-seq: 1, #new-token: 13, #cached-token: 30, cache hit rate: 53.73%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:36:48 TP0] Decode batch. #running-req: 1, #token: 60, token usage: 0.00, gen throughput (token/s): 6.52, #queue-req: 0
[2025-02-23 08:36:49] INFO: 127.0.0.1:51882 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ChatCompletion(id='ab4ea13a04224d3995604a6568c74177', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n Capital: Tokyo\n\n2. Country: Australia\n Capital: Canberra\n\n3. Country: Brazil\n Capital: Brasília', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None), matched_stop=128009)], created=1740299809, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=43, total_tokens=86, completion_tokens_details=None, prompt_tokens_details=None))
Streaming#
[5]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
# Use stream=True for streaming responses
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
# Handle the streaming output
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
[2025-02-23 08:36:49] INFO: 127.0.0.1:51890 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-02-23 08:36:49 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 42, cache hit rate: 64.41%, token usage: 0.00, #running-req: 0, #queue-req: 0
Here are 3 countries and their capitals:
1. Country:[2025-02-23 08:36:49 TP0] Decode batch. #running-req: 1, #token: 57, token usage: 0.00, gen throughput (token/s): 58.49, #queue-req: 0
Japan
Capital: Tokyo
2. Country: Australia
Capital: Canberra
3. Country: Brazil
Capital: Brasília
Using Native Generation APIs#
You can also use the native /generate
endpoint with requests, which provides more flexiblity. An API reference is available at Sampling Parameters.
[6]:
import requests
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print_highlight(response.json())
[2025-02-23 08:36:49 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 3, cache hit rate: 63.93%, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-02-23 08:36:50 TP0] Decode batch. #running-req: 1, #token: 17, token usage: 0.00, gen throughput (token/s): 66.25, #queue-req: 0
[2025-02-23 08:36:50] INFO: 127.0.0.1:51894 - "POST /generate HTTP/1.1" 200 OK
{'text': ' a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the', 'meta_info': {'id': 'd70805728eb54c43a01baf80e12b67fa', 'finish_reason': {'type': 'length', 'length': 32}, 'prompt_tokens': 6, 'completion_tokens': 32, 'cached_tokens': 3}}
Streaming#
[7]:
import requests, json
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"]
print(output[prev:], end="", flush=True)
prev = len(output)
[2025-02-23 08:36:50] INFO: 127.0.0.1:51908 - "POST /generate HTTP/1.1" 200 OK
[2025-02-23 08:36:50 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 5, cache hit rate: 64.55%, token usage: 0.00, #running-req: 0, #queue-req: 0
a city of romance, art, fashion, and cuisine. Paris is a must-visit[2025-02-23 08:36:50 TP0] Decode batch. #running-req: 1, #token: 25, token usage: 0.00, gen throughput (token/s): 55.96, #queue-req: 0
destination for anyone who loves history, architecture, and culture. From the
[8]:
terminate_process(server_process)