OpenAI APIs - Completions#
SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models. A complete reference for the API is available in the OpenAI API Reference.
This tutorial covers the following popular APIs:
chat/completions
completions
batches
Check out other tutorials to learn about vision APIs for vision-language models and embedding APIs for embedding models.
Launch A Server#
Launch the server in your terminal and wait for it to initialize.
[1]:
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
print(f"Server started on http://localhost:{port}")
[2025-03-16 09:55:55] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=37300, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=162573626, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=5, speculative_eagle_topk=4, speculative_num_draft_tokens=8, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
[2025-03-16 09:56:16 TP0] Init torch distributed begin.
[2025-03-16 09:56:16 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-03-16 09:56:16 TP0] Load weight begin. avail mem=78.81 GB
[2025-03-16 09:56:16 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-16 09:56:17 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.12it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.75it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.26it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00, 1.32it/s]
[2025-03-16 09:56:20 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.70 GB, mem usage=15.11 GB.
[2025-03-16 09:56:20 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-03-16 09:56:20 TP0] Memory pool end. avail mem=60.91 GB
[2025-03-16 09:56:21 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-03-16 09:56:21] INFO: Started server process [204027]
[2025-03-16 09:56:21] INFO: Waiting for application startup.
[2025-03-16 09:56:21] INFO: Application startup complete.
[2025-03-16 09:56:21] INFO: Uvicorn running on http://0.0.0.0:37300 (Press CTRL+C to quit)
[2025-03-16 09:56:22] INFO: 127.0.0.1:40000 - "GET /v1/models HTTP/1.1" 200 OK
[2025-03-16 09:56:22] INFO: 127.0.0.1:40002 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-16 09:56:22 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:25] INFO: 127.0.0.1:40006 - "POST /generate HTTP/1.1" 200 OK
[2025-03-16 09:56:25] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Server started on http://localhost:37300
Chat Completions#
Usage#
The server fully implements the OpenAI API. It will automatically apply the chat template specified in the Hugging Face tokenizer, if one is available. You can also specify a custom chat template with --chat-template
when launching the server.
[2]:
import openai
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print_highlight(f"Response: {response}")
[2025-03-16 09:56:27 TP0] Prefill batch. #new-seq: 1, #new-token: 42, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:27 TP0] Decode batch. #running-req: 1, #token: 76, token usage: 0.00, gen throughput (token/s): 6.21, #queue-req: 0,
[2025-03-16 09:56:28] INFO: 127.0.0.1:40020 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Parameters#
The chat completions API accepts OpenAI Chat Completions API’s parameters. Refer to OpenAI Chat Completions API for more details.
Here is an example of a detailed chat completion request:
[3]:
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{
"role": "system",
"content": "You are a knowledgeable historian who provides concise responses.",
},
{"role": "user", "content": "Tell me about ancient Rome"},
{
"role": "assistant",
"content": "Ancient Rome was a civilization centered in Italy.",
},
{"role": "user", "content": "What were their major achievements?"},
],
temperature=0.3, # Lower temperature for more focused responses
max_tokens=128, # Reasonable length for a concise response
top_p=0.95, # Slightly higher for better fluency
presence_penalty=0.2, # Mild penalty to avoid repetition
frequency_penalty=0.2, # Mild penalty for more natural language
n=1, # Single response is usually more stable
seed=42, # Keep for reproducibility
)
print_highlight(response.choices[0].message.content)
[2025-03-16 09:56:28 TP0] Prefill batch. #new-seq: 1, #new-token: 51, #cached-token: 25, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:28 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.01, gen throughput (token/s): 87.29, #queue-req: 0,
[2025-03-16 09:56:28 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.01, gen throughput (token/s): 89.70, #queue-req: 0,
[2025-03-16 09:56:29 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.01, gen throughput (token/s): 101.12, #queue-req: 0,
[2025-03-16 09:56:29] INFO: 127.0.0.1:40020 - "POST /v1/chat/completions HTTP/1.1" 200 OK
1. **Engineering and Architecture**: They built impressive structures like the Colosseum, Pantheon, and aqueducts, showcasing their engineering skills.
2. **Law and Governance**: The Romans developed the Twelve Tables, a foundation for modern law, and established the concept of citizenship, which spread throughout their empire.
3. **Military Conquests**: Rome expanded its territories through a series of military campaigns, creating the largest empire in the ancient world.
4. **Language and Literature**: Latin became the language of government, commerce, and literature, influencing modern languages like French, Spanish, and Italian
Streaming mode is also supported.
[4]:
stream = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "Say this is a test"}],
stream=True,
)
for chunk in stream:
if chunk.choices[0].delta.content is not None:
print(chunk.choices[0].delta.content, end="")
[2025-03-16 09:56:29] INFO: 127.0.0.1:40020 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-03-16 09:56:29 TP0] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 30, token usage: 0.00, #running-req: 0, #queue-req: 0,
Let's get started. I'll respond to your test questions or prompts. What would you like to test[2025-03-16 09:56:29 TP0] Decode batch. #running-req: 1, #token: 62, token usage: 0.00, gen throughput (token/s): 94.23, #queue-req: 0,
?
Completions#
Usage#
Completions API is similar to Chat Completions API, but without the messages
parameter or chat templates.
[5]:
response = client.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
prompt="List 3 countries and their capitals.",
temperature=0,
max_tokens=64,
n=1,
stop=None,
)
print_highlight(f"Response: {response}")
[2025-03-16 09:56:29 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:30 TP0] Decode batch. #running-req: 1, #token: 48, token usage: 0.00, gen throughput (token/s): 99.00, #queue-req: 0,
[2025-03-16 09:56:30] INFO: 127.0.0.1:40020 - "POST /v1/completions HTTP/1.1" 200 OK
Parameters#
The completions API accepts OpenAI Completions API’s parameters. Refer to OpenAI Completions API for more details.
Here is an example of a detailed completions request:
[6]:
response = client.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
prompt="Write a short story about a space explorer.",
temperature=0.7, # Moderate temperature for creative writing
max_tokens=150, # Longer response for a story
top_p=0.9, # Balanced diversity in word choice
stop=["\n\n", "THE END"], # Multiple stop sequences
presence_penalty=0.3, # Encourage novel elements
frequency_penalty=0.3, # Reduce repetitive phrases
n=1, # Generate one completion
seed=123, # For reproducible results
)
print_highlight(f"Response: {response}")
[2025-03-16 09:56:30 TP0] Prefill batch. #new-seq: 1, #new-token: 9, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:30 TP0] Decode batch. #running-req: 1, #token: 25, token usage: 0.00, gen throughput (token/s): 85.24, #queue-req: 0,
[2025-03-16 09:56:30 TP0] Decode batch. #running-req: 1, #token: 65, token usage: 0.00, gen throughput (token/s): 105.09, #queue-req: 0,
[2025-03-16 09:56:31 TP0] Decode batch. #running-req: 1, #token: 105, token usage: 0.01, gen throughput (token/s): 100.08, #queue-req: 0,
[2025-03-16 09:56:31 TP0] Decode batch. #running-req: 1, #token: 145, token usage: 0.01, gen throughput (token/s): 93.64, #queue-req: 0,
[2025-03-16 09:56:31] INFO: 127.0.0.1:40020 - "POST /v1/completions HTTP/1.1" 200 OK
Structured Outputs (JSON, Regex, EBNF)#
For OpenAI compatible structed outputs API, refer to Structured Outputs for more details.
Batches#
Batches API for chat completions and completions are also supported. You can upload your requests in jsonl
files, create a batch job, and retrieve the results when the batch job is completed (which takes longer but costs less).
The batches APIs are:
batches
batches/{batch_id}/cancel
batches/{batch_id}
Here is an example of a batch job for chat completions, completions are similar.
[7]:
import json
import time
from openai import OpenAI
client = OpenAI(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
requests = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/chat/completions",
"body": {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [
{"role": "user", "content": "Tell me a joke about programming"}
],
"max_tokens": 50,
},
},
{
"custom_id": "request-2",
"method": "POST",
"url": "/chat/completions",
"body": {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [{"role": "user", "content": "What is Python?"}],
"max_tokens": 50,
},
},
]
input_file_path = "batch_requests.jsonl"
with open(input_file_path, "w") as f:
for req in requests:
f.write(json.dumps(req) + "\n")
with open(input_file_path, "rb") as f:
file_response = client.files.create(file=f, purpose="batch")
batch_response = client.batches.create(
input_file_id=file_response.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
print_highlight(f"Batch job created with ID: {batch_response.id}")
[2025-03-16 09:56:31] INFO: 127.0.0.1:53462 - "POST /v1/files HTTP/1.1" 200 OK
[2025-03-16 09:56:31] INFO: 127.0.0.1:53462 - "POST /v1/batches HTTP/1.1" 200 OK
[2025-03-16 09:56:31 TP0] Prefill batch. #new-seq: 2, #new-token: 20, #cached-token: 60, token usage: 0.00, #running-req: 0, #queue-req: 0,
[8]:
while batch_response.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(f"Batch job status: {batch_response.status}...trying again in 3 seconds...")
batch_response = client.batches.retrieve(batch_response.id)
if batch_response.status == "completed":
print("Batch job completed successfully!")
print(f"Request counts: {batch_response.request_counts}")
result_file_id = batch_response.output_file_id
file_response = client.files.content(result_file_id)
result_content = file_response.read().decode("utf-8")
results = [
json.loads(line) for line in result_content.split("\n") if line.strip() != ""
]
for result in results:
print_highlight(f"Request {result['custom_id']}:")
print_highlight(f"Response: {result['response']}")
print_highlight("Cleaning up files...")
# Only delete the result file ID since file_response is just content
client.files.delete(result_file_id)
else:
print_highlight(f"Batch job failed with status: {batch_response.status}")
if hasattr(batch_response, "errors"):
print_highlight(f"Errors: {batch_response.errors}")
[2025-03-16 09:56:32 TP0] Decode batch. #running-req: 1, #token: 64, token usage: 0.00, gen throughput (token/s): 81.19, #queue-req: 0,
Batch job status: validating...trying again in 3 seconds...
[2025-03-16 09:56:34] INFO: 127.0.0.1:53462 - "GET /v1/batches/batch_8faa3cf7-f27d-4026-babb-d362d1702905 HTTP/1.1" 200 OK
Batch job completed successfully!
Request counts: BatchRequestCounts(completed=2, failed=0, total=2)
[2025-03-16 09:56:34] INFO: 127.0.0.1:53462 - "GET /v1/files/backend_result_file-c96b2621-3885-4a0b-be9e-78b13d57dfd4/content HTTP/1.1" 200 OK
[2025-03-16 09:56:34] INFO: 127.0.0.1:53462 - "DELETE /v1/files/backend_result_file-c96b2621-3885-4a0b-be9e-78b13d57dfd4 HTTP/1.1" 200 OK
It takes a while to complete the batch job. You can use these two APIs to retrieve the batch job status or cancel the batch job.
batches/{batch_id}
: Retrieve the batch job status.batches/{batch_id}/cancel
: Cancel the batch job.
Here is an example to check the batch job status.
[9]:
import json
import time
from openai import OpenAI
client = OpenAI(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
requests = []
for i in range(20):
requests.append(
{
"custom_id": f"request-{i}",
"method": "POST",
"url": "/chat/completions",
"body": {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [
{
"role": "system",
"content": f"{i}: You are a helpful AI assistant",
},
{
"role": "user",
"content": "Write a detailed story about topic. Make it very long.",
},
],
"max_tokens": 64,
},
}
)
input_file_path = "batch_requests.jsonl"
with open(input_file_path, "w") as f:
for req in requests:
f.write(json.dumps(req) + "\n")
with open(input_file_path, "rb") as f:
uploaded_file = client.files.create(file=f, purpose="batch")
batch_job = client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
print_highlight(f"Created batch job with ID: {batch_job.id}")
print_highlight(f"Initial status: {batch_job.status}")
time.sleep(10)
max_checks = 5
for i in range(max_checks):
batch_details = client.batches.retrieve(batch_id=batch_job.id)
print_highlight(
f"Batch job details (check {i+1} / {max_checks}) // ID: {batch_details.id} // Status: {batch_details.status} // Created at: {batch_details.created_at} // Input file ID: {batch_details.input_file_id} // Output file ID: {batch_details.output_file_id}"
)
print_highlight(
f"<strong>Request counts: Total: {batch_details.request_counts.total} // Completed: {batch_details.request_counts.completed} // Failed: {batch_details.request_counts.failed}</strong>"
)
time.sleep(3)
[2025-03-16 09:56:35] INFO: 127.0.0.1:53478 - "POST /v1/files HTTP/1.1" 200 OK
[2025-03-16 09:56:35] INFO: 127.0.0.1:53478 - "POST /v1/batches HTTP/1.1" 200 OK
[2025-03-16 09:56:35 TP0] Prefill batch. #new-seq: 16, #new-token: 480, #cached-token: 400, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:56:35 TP0] Prefill batch. #new-seq: 4, #new-token: 120, #cached-token: 100, token usage: 0.02, #running-req: 16, #queue-req: 0,
[2025-03-16 09:56:35 TP0] Decode batch. #running-req: 20, #token: 925, token usage: 0.05, gen throughput (token/s): 106.77, #queue-req: 0,
[2025-03-16 09:56:35 TP0] Decode batch. #running-req: 20, #token: 1725, token usage: 0.08, gen throughput (token/s): 1921.86, #queue-req: 0,
[2025-03-16 09:56:45] INFO: 127.0.0.1:45234 - "GET /v1/batches/batch_77e71f43-60cd-4f47-a2a2-77906a188ca8 HTTP/1.1" 200 OK
[2025-03-16 09:56:48] INFO: 127.0.0.1:45234 - "GET /v1/batches/batch_77e71f43-60cd-4f47-a2a2-77906a188ca8 HTTP/1.1" 200 OK
[2025-03-16 09:56:51] INFO: 127.0.0.1:45234 - "GET /v1/batches/batch_77e71f43-60cd-4f47-a2a2-77906a188ca8 HTTP/1.1" 200 OK
[2025-03-16 09:56:54] INFO: 127.0.0.1:45234 - "GET /v1/batches/batch_77e71f43-60cd-4f47-a2a2-77906a188ca8 HTTP/1.1" 200 OK
[2025-03-16 09:56:57] INFO: 127.0.0.1:45234 - "GET /v1/batches/batch_77e71f43-60cd-4f47-a2a2-77906a188ca8 HTTP/1.1" 200 OK
Here is an example to cancel a batch job.
[10]:
import json
import time
from openai import OpenAI
import os
client = OpenAI(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
requests = []
for i in range(5000):
requests.append(
{
"custom_id": f"request-{i}",
"method": "POST",
"url": "/chat/completions",
"body": {
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"messages": [
{
"role": "system",
"content": f"{i}: You are a helpful AI assistant",
},
{
"role": "user",
"content": "Write a detailed story about topic. Make it very long.",
},
],
"max_tokens": 128,
},
}
)
input_file_path = "batch_requests.jsonl"
with open(input_file_path, "w") as f:
for req in requests:
f.write(json.dumps(req) + "\n")
with open(input_file_path, "rb") as f:
uploaded_file = client.files.create(file=f, purpose="batch")
batch_job = client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
print_highlight(f"Created batch job with ID: {batch_job.id}")
print_highlight(f"Initial status: {batch_job.status}")
time.sleep(10)
try:
cancelled_job = client.batches.cancel(batch_id=batch_job.id)
print_highlight(f"Cancellation initiated. Status: {cancelled_job.status}")
assert cancelled_job.status == "cancelling"
# Monitor the cancellation process
while cancelled_job.status not in ["failed", "cancelled"]:
time.sleep(3)
cancelled_job = client.batches.retrieve(batch_job.id)
print_highlight(f"Current status: {cancelled_job.status}")
# Verify final status
assert cancelled_job.status == "cancelled"
print_highlight("Batch job successfully cancelled")
except Exception as e:
print_highlight(f"Error during cancellation: {e}")
raise e
finally:
try:
del_response = client.files.delete(uploaded_file.id)
if del_response.deleted:
print_highlight("Successfully cleaned up input file")
if os.path.exists(input_file_path):
os.remove(input_file_path)
print_highlight("Successfully deleted local batch_requests.jsonl file")
except Exception as e:
print_highlight(f"Error cleaning up: {e}")
raise e
[2025-03-16 09:57:00] INFO: 127.0.0.1:41082 - "POST /v1/files HTTP/1.1" 200 OK
[2025-03-16 09:57:00] INFO: 127.0.0.1:41082 - "POST /v1/batches HTTP/1.1" 200 OK
[2025-03-16 09:57:01 TP0] Prefill batch. #new-seq: 21, #new-token: 50, #cached-token: 1105, token usage: 0.03, #running-req: 0, #queue-req: 0,
[2025-03-16 09:57:01 TP0] Prefill batch. #new-seq: 113, #new-token: 3390, #cached-token: 2825, token usage: 0.03, #running-req: 21, #queue-req: 257,
[2025-03-16 09:57:02 TP0] Decode batch. #running-req: 134, #token: 8199, token usage: 0.40, gen throughput (token/s): 158.82, #queue-req: 4866,
[2025-03-16 09:57:02 TP0] Decode batch. #running-req: 134, #token: 13559, token usage: 0.66, gen throughput (token/s): 9395.72, #queue-req: 4866,
[2025-03-16 09:57:03 TP0] Decode batch. #running-req: 134, #token: 18919, token usage: 0.92, gen throughput (token/s): 10890.33, #queue-req: 4866,
[2025-03-16 09:57:03 TP0] Decode out of memory happened. #retracted_reqs: 16, #new_token_ratio: 0.5776 -> 1.0000
[2025-03-16 09:57:03 TP0] Prefill batch. #new-seq: 11, #new-token: 330, #cached-token: 275, token usage: 0.88, #running-req: 118, #queue-req: 4871,
[2025-03-16 09:57:03 TP0] Prefill batch. #new-seq: 118, #new-token: 3540, #cached-token: 2950, token usage: 0.02, #running-req: 11, #queue-req: 4753,
[2025-03-16 09:57:03 TP0] Decode batch. #running-req: 129, #token: 6799, token usage: 0.33, gen throughput (token/s): 8067.53, #queue-req: 4753,
[2025-03-16 09:57:04 TP0] Decode batch. #running-req: 129, #token: 11959, token usage: 0.58, gen throughput (token/s): 11335.13, #queue-req: 4753,
[2025-03-16 09:57:04 TP0] Decode batch. #running-req: 129, #token: 17119, token usage: 0.84, gen throughput (token/s): 9750.46, #queue-req: 4753,
[2025-03-16 09:57:05 TP0] Prefill batch. #new-seq: 11, #new-token: 330, #cached-token: 275, token usage: 0.88, #running-req: 118, #queue-req: 4742,
[2025-03-16 09:57:05 TP0] Prefill batch. #new-seq: 119, #new-token: 3570, #cached-token: 2975, token usage: 0.02, #running-req: 11, #queue-req: 4623,
[2025-03-16 09:57:05 TP0] Decode batch. #running-req: 130, #token: 5681, token usage: 0.28, gen throughput (token/s): 5025.78, #queue-req: 4623,
[2025-03-16 09:57:06 TP0] Decode batch. #running-req: 130, #token: 10881, token usage: 0.53, gen throughput (token/s): 7946.56, #queue-req: 4623,
[2025-03-16 09:57:07 TP0] Decode batch. #running-req: 130, #token: 16081, token usage: 0.79, gen throughput (token/s): 11673.99, #queue-req: 4623,
[2025-03-16 09:57:07 TP0] Prefill batch. #new-seq: 11, #new-token: 330, #cached-token: 275, token usage: 0.89, #running-req: 119, #queue-req: 4612,
[2025-03-16 09:57:07 TP0] Prefill batch. #new-seq: 120, #new-token: 3600, #cached-token: 3000, token usage: 0.02, #running-req: 11, #queue-req: 4492,
[2025-03-16 09:57:07 TP0] Decode batch. #running-req: 131, #token: 4545, token usage: 0.22, gen throughput (token/s): 7937.40, #queue-req: 4492,
[2025-03-16 09:57:08 TP0] Decode batch. #running-req: 131, #token: 9785, token usage: 0.48, gen throughput (token/s): 11730.21, #queue-req: 4492,
[2025-03-16 09:57:08 TP0] Decode batch. #running-req: 131, #token: 15025, token usage: 0.73, gen throughput (token/s): 11398.93, #queue-req: 4492,
[2025-03-16 09:57:09 TP0] Prefill batch. #new-seq: 10, #new-token: 300, #cached-token: 250, token usage: 0.90, #running-req: 120, #queue-req: 4482,
[2025-03-16 09:57:09 TP0] Decode batch. #running-req: 120, #token: 18685, token usage: 0.91, gen throughput (token/s): 10682.03, #queue-req: 4482,
[2025-03-16 09:57:09 TP0] Prefill batch. #new-seq: 122, #new-token: 3660, #cached-token: 3050, token usage: 0.02, #running-req: 10, #queue-req: 4360,
[2025-03-16 09:57:09 TP0] Decode batch. #running-req: 132, #token: 8665, token usage: 0.42, gen throughput (token/s): 7639.36, #queue-req: 4360,
[2025-03-16 09:57:10 TP0] Decode batch. #running-req: 132, #token: 13945, token usage: 0.68, gen throughput (token/s): 11838.75, #queue-req: 4360,
[2025-03-16 09:57:10] INFO: 127.0.0.1:38860 - "POST /v1/batches/batch_5c14b250-a5ac-4514-ba7e-70b15c3f9db7/cancel HTTP/1.1" 200 OK
[2025-03-16 09:57:10 TP0] Prefill batch. #new-seq: 16, #new-token: 2432, #cached-token: 400, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-16 09:57:13] INFO: 127.0.0.1:38860 - "GET /v1/batches/batch_5c14b250-a5ac-4514-ba7e-70b15c3f9db7 HTTP/1.1" 200 OK
[2025-03-16 09:57:13] INFO: 127.0.0.1:38860 - "DELETE /v1/files/backend_input_file-0f1be146-cddf-449b-a8b8-2a288f33a47e HTTP/1.1" 200 OK
[11]:
terminate_process(server_process)