Structured Outputs For Reasoning Models#
When working with reasoning models that use special tokens like <think>...</think>
to denote reasoning sections, you might want to allow free-form text within these sections while still enforcing grammar constraints on the rest of the output.
SGLang provides a feature to disable grammar restrictions within reasoning sections. This is particularly useful for models that need to perform complex reasoning steps before providing a structured output.
To enable this feature, use the --reasoning-parser
flag which decide the think_end_token, such as </think>
, when launching the server. You can also specify the reasoning parser using the --reasoning-parser
flag.
Supported Models#
Currently, SGLang supports the following reasoning models:
DeepSeek R1 series: The reasoning content is wrapped with
<think>
and</think>
tags.QwQ: The reasoning content is wrapped with
<think>
and</think>
tags.
Usage#
OpenAI Compatible API#
Specify the --grammar-backend
, --reasoning-parser
option.
[1]:
import openai
import os
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
os.environ["TOKENIZERS_PARALLELISM"] = "false"
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1"
)
wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
W0814 06:18:54.501000 1214090 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:18:54.501000 1214090 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:18:57] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=39862, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.874, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=676351897, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='deepseek-r1', tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:18:58] Using default HuggingFace chat template with detected content format: string
W0814 06:19:06.552000 1214996 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:19:06.552000 1214996 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:19:06.569000 1214995 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:19:06.569000 1214995 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:19:08] Attention backend not explicitly specified. Use fa3 backend by default.
[2025-08-14 06:19:08] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:19:08] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 06:19:09] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:19:09] Load weight begin. avail mem=52.54 GB
[2025-08-14 06:19:09] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.43s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.36s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.37s/it]
[2025-08-14 06:19:12] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=38.10 GB, mem usage=14.44 GB.
[2025-08-14 06:19:12] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-08-14 06:19:12] Memory pool end. avail mem=36.77 GB
[2025-08-14 06:19:12] Capture cuda graph begin. This can take up to several minutes. avail mem=36.67 GB
[2025-08-14 06:19:13] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=36.59 GB): 100%|██████████| 3/3 [00:00<00:00, 11.13it/s]
[2025-08-14 06:19:13] Capture cuda graph end. Time elapsed: 0.77 s. mem usage=0.09 GB. avail mem=36.58 GB.
[2025-08-14 06:19:14] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=131072, available_gpu_mem=36.58 GB
[2025-08-14 06:19:14] INFO: Started server process [1214090]
[2025-08-14 06:19:14] INFO: Waiting for application startup.
[2025-08-14 06:19:14] INFO: Application startup complete.
[2025-08-14 06:19:14] INFO: Uvicorn running on http://0.0.0.0:39862 (Press CTRL+C to quit)
[2025-08-14 06:19:14] INFO: 127.0.0.1:38994 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:19:15] INFO: 127.0.0.1:39008 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:19:15] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:16] INFO: 127.0.0.1:39012 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:19:16] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
JSON#
you can directly define a JSON schema or use Pydantic to define and validate the response.
Using Pydantic
[2]:
from pydantic import BaseModel, Field
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
messages=[
{
"role": "assistant",
"content": "Give me the information and population of the capital of France in the JSON format.",
},
],
temperature=0,
max_tokens=2048,
response_format={
"type": "json_schema",
"json_schema": {
"name": "foo",
# convert the pydantic model to json schema
"schema": CapitalInfo.model_json_schema(),
},
},
)
print_highlight(
f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-08-14 06:19:19] Prefill batch. #new-seq: 1, #new-token: 21, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:20] Decode batch. #running-req: 1, #token: 55, token usage: 0.00, cuda graph: True, gen throughput (token/s): 5.92, #queue-req: 0,
[2025-08-14 06:19:21] Decode batch. #running-req: 1, #token: 95, token usage: 0.00, cuda graph: True, gen throughput (token/s): 169.26, #queue-req: 0,
[2025-08-14 06:19:21] Decode batch. #running-req: 1, #token: 135, token usage: 0.01, cuda graph: True, gen throughput (token/s): 149.67, #queue-req: 0,
[2025-08-14 06:19:21] Decode batch. #running-req: 1, #token: 175, token usage: 0.01, cuda graph: True, gen throughput (token/s): 144.31, #queue-req: 0,
[2025-08-14 06:19:21] Decode batch. #running-req: 1, #token: 215, token usage: 0.01, cuda graph: True, gen throughput (token/s): 167.25, #queue-req: 0,
[2025-08-14 06:19:22] Decode batch. #running-req: 1, #token: 255, token usage: 0.01, cuda graph: True, gen throughput (token/s): 166.69, #queue-req: 0,
[2025-08-14 06:19:22] Decode batch. #running-req: 1, #token: 295, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.72, #queue-req: 0,
[2025-08-14 06:19:22] Decode batch. #running-req: 1, #token: 335, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.70, #queue-req: 0,
[2025-08-14 06:19:22] Decode batch. #running-req: 1, #token: 375, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:23] Decode batch. #running-req: 1, #token: 415, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.71, #queue-req: 0,
[2025-08-14 06:19:23] Decode batch. #running-req: 1, #token: 455, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.67, #queue-req: 0,
[2025-08-14 06:19:23] INFO: 127.0.0.1:39024 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Next, I considered the population. I remember that Paris has a large population, but I'm not exactly sure of the current number. I think it's around 2 million, but I'm not 100% certain. I should double-check that to make sure I provide accurate information.
I also need to structure this in JSON. JSON requires key-value pairs, so I'll need to define the keys appropriately. Maybe "city" for the name, "country" for the capital, and "population" for the number. I should make sure the syntax is correct, with proper commas and quotation marks.
Wait, the user might be using this data for a project or a presentation. They probably need it in a structured format that's easy to parse. JSON is a good choice because it's widely used in web applications and data interchange. I should ensure the JSON is valid and well-formatted to avoid any issues when they use it.
I also wonder if they need more details, like the administrative region or some notable landmarks. But since they specifically asked for population, I'll stick to that. Maybe I should include a note about the population figure being approximate, just in case.
Putting it all together, I'll format the JSON with the city name, country, and population. I'll make sure the keys are in English to keep it clear. Double-checking the population number is crucial to maintain credibility. I think 2,147,000 is a commonly cited figure, so I'll go with that.
Finally, I'll present the JSON in a code block so it's easy to read and copy. I should also offer further assistance in case they need more data. That way, I cover all bases and ensure the user gets exactly what they're looking for.
content: {
"name": "Paris",
"population": 2147000
}
JSON Schema Directly
[3]:
import json
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
messages=[
{
"role": "assistant",
"content": "Give me the information and population of the capital of France in the JSON format.",
},
],
temperature=0,
max_tokens=2048,
response_format={
"type": "json_schema",
"json_schema": {"name": "foo", "schema": json.loads(json_schema)},
},
)
print_highlight(
f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-08-14 06:19:23] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 21, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:23] Decode batch. #running-req: 1, #token: 56, token usage: 0.00, cuda graph: True, gen throughput (token/s): 143.40, #queue-req: 0,
[2025-08-14 06:19:23] Decode batch. #running-req: 1, #token: 96, token usage: 0.00, cuda graph: True, gen throughput (token/s): 169.24, #queue-req: 0,
[2025-08-14 06:19:24] Decode batch. #running-req: 1, #token: 136, token usage: 0.01, cuda graph: True, gen throughput (token/s): 159.45, #queue-req: 0,
[2025-08-14 06:19:24] Decode batch. #running-req: 1, #token: 176, token usage: 0.01, cuda graph: True, gen throughput (token/s): 131.06, #queue-req: 0,
[2025-08-14 06:19:24] Decode batch. #running-req: 1, #token: 216, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.82, #queue-req: 0,
[2025-08-14 06:19:24] Decode batch. #running-req: 1, #token: 256, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.77, #queue-req: 0,
[2025-08-14 06:19:25] Decode batch. #running-req: 1, #token: 296, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:25] Decode batch. #running-req: 1, #token: 336, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:25] Decode batch. #running-req: 1, #token: 376, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:25] INFO: 127.0.0.1:39024 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Next, I thought about the structure. The user wants JSON, so I need to format it correctly with keys like "city", "population", and maybe "country". I should make sure the syntax is correct—no typos, proper commas, and brackets.
I also considered the user's possible needs. They might be doing a project or a presentation, so providing accurate data is crucial. Maybe they're a student learning about France's capitals or someone compiling demographic data. Either way, precision is key.
I decided to look up the latest population figure. A quick search shows that as of 2023, Paris has a population of around 3,600,000. That seems right, but I should note that populations can fluctuate due to births, deaths, and migration.
Putting it all together, I structured the JSON with the city name, population, and country. I made sure the numbers were accurate and the format was correct. I also added a comment to explain the population figure, just in case the user needed more context.
Finally, I sent the response, ensuring it was clear and met the user's request. I hope this helps them with whatever they're working on!
content: {
"name": "Paris",
"population": 3600000
}
EBNF#
[4]:
ebnf_grammar = """
root ::= city | description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
messages=[
{"role": "system", "content": "You are a helpful geography bot."},
{
"role": "assistant",
"content": "Give me the information and population of the capital of France in the JSON format.",
},
],
temperature=0,
max_tokens=2048,
extra_body={"ebnf": ebnf_grammar},
)
print_highlight(
f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-08-14 06:19:25] Prefill batch. #new-seq: 1, #new-token: 28, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:25] Decode batch. #running-req: 1, #token: 61, token usage: 0.00, cuda graph: True, gen throughput (token/s): 146.72, #queue-req: 0,
[2025-08-14 06:19:26] Decode batch. #running-req: 1, #token: 101, token usage: 0.00, cuda graph: True, gen throughput (token/s): 165.57, #queue-req: 0,
[2025-08-14 06:19:26] Decode batch. #running-req: 1, #token: 141, token usage: 0.01, cuda graph: True, gen throughput (token/s): 165.11, #queue-req: 0,
[2025-08-14 06:19:26] Decode batch. #running-req: 1, #token: 181, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.83, #queue-req: 0,
[2025-08-14 06:19:26] Decode batch. #running-req: 1, #token: 221, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.90, #queue-req: 0,
[2025-08-14 06:19:27] Decode batch. #running-req: 1, #token: 261, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.79, #queue-req: 0,
[2025-08-14 06:19:27] Decode batch. #running-req: 1, #token: 301, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.71, #queue-req: 0,
[2025-08-14 06:19:27] Decode batch. #running-req: 1, #token: 341, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:27] Decode batch. #running-req: 1, #token: 381, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.70, #queue-req: 0,
[2025-08-14 06:19:28] Decode batch. #running-req: 1, #token: 421, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:28] Decode batch. #running-req: 1, #token: 461, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.69, #queue-req: 0,
[2025-08-14 06:19:28] Decode batch. #running-req: 1, #token: 501, token usage: 0.02, cuda graph: True, gen throughput (token/s): 153.51, #queue-req: 0,
[2025-08-14 06:19:28] Decode batch. #running-req: 1, #token: 541, token usage: 0.03, cuda graph: True, gen throughput (token/s): 141.57, #queue-req: 0,
[2025-08-14 06:19:29] Decode batch. #running-req: 1, #token: 581, token usage: 0.03, cuda graph: True, gen throughput (token/s): 143.23, #queue-req: 0,
[2025-08-14 06:19:29] Decode batch. #running-req: 1, #token: 621, token usage: 0.03, cuda graph: True, gen throughput (token/s): 142.86, #queue-req: 0,
[2025-08-14 06:19:29] Decode batch. #running-req: 1, #token: 661, token usage: 0.03, cuda graph: True, gen throughput (token/s): 166.67, #queue-req: 0,
[2025-08-14 06:19:29] Decode batch. #running-req: 1, #token: 701, token usage: 0.03, cuda graph: True, gen throughput (token/s): 164.17, #queue-req: 0,
[2025-08-14 06:19:30] Decode batch. #running-req: 1, #token: 741, token usage: 0.04, cuda graph: True, gen throughput (token/s): 121.16, #queue-req: 0,
[2025-08-14 06:19:30] INFO: 127.0.0.1:39024 - "POST /v1/chat/completions HTTP/1.1" 200 OK
First, I need to figure out what the user is really looking for. They might be creating a dataset or a project that requires the capitals of various nations. Maybe they're a student working on a geography assignment or a developer building a mapping application. Either way, they need accurate and reliable data.
Looking at the list they provided: Albania, Algeria, Australia, Austria, Brazil, Canada, China, Colombia, Denmark, Egypt, France, Germany, Greece, Hungary, Iceland, India, Indonesia, Ireland, Italy, Japan, Mexico, Netherlands, Nigeria, Poland, Portugal, Russia, Saudi Arabia, Spain, Sweden, Switzerland, Thailand, Turkey, UK, USA. That's quite a comprehensive list, covering multiple continents.
I should make sure each country's capital is correct. I know some capitals off the top of my head, like Albania's Tirana, Algeria's Algiers, Australia's Canberra. But I should double-check the rest to avoid mistakes. For example, I'm pretty sure Germany's capital is Berlin, but I should confirm that. Same with countries like Mexico and Spain, where the capital is in a different state.
The user wants the information in JSON format, so each country will have a key with its name and capital, along with the population. I'll need to look up the most recent population estimates for each. Population numbers can change, so it's important to use the latest data. For example, Germany's population is around 83 million, while the UK's is about 67 million.
I should structure the JSON array correctly, ensuring each object has the same keys. Also, I'll need to format the numbers properly, using commas for readability. For instance, 2.1 million should be written as 2.1 million, not 2100000.
I wonder if the user needs the population data to be dynamic or if they prefer static numbers. Since they didn't specify, I'll go with the most recent estimates available. Also, I should present the data clearly, maybe adding comments or organizing it in a way that's easy to read, even though it's just a list.
Another thing to consider is the possibility of the user needing this data for a presentation or a report. If that's the case, accuracy is crucial. I should verify each piece of information to ensure reliability. For example, confirming that Japan's capital is Tokyo and not another city.
I should also think about the user's deeper needs. They might be building a quiz app where users can test their knowledge of country capitals, or perhaps they're creating a mapping tool that requires capital locations. Either way, providing accurate and up-to-date information is key.
I'll start by listing each country with its capital and population, making sure each entry is a separate object in the JSON array. I'll also add a comment at the top to explain the structure, so the user knows what each key represents. This will make the JSON more understandable for them.
Finally, I'll review the JSON to ensure there are no syntax errors, like missing commas or brackets. It's important that the JSON is valid so that the user can easily parse and use the data without issues.
content: London is the capital of France
Regular expression#
[5]:
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
messages=[
{"role": "assistant", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=2048,
extra_body={"regex": "(Paris|London)"},
)
print_highlight(
f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-08-14 06:19:30] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 2, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:30] Decode batch. #running-req: 1, #token: 31, token usage: 0.00, cuda graph: True, gen throughput (token/s): 83.34, #queue-req: 0,
[2025-08-14 06:19:30] Decode batch. #running-req: 1, #token: 71, token usage: 0.00, cuda graph: True, gen throughput (token/s): 140.63, #queue-req: 0,
[2025-08-14 06:19:31] Decode batch. #running-req: 1, #token: 111, token usage: 0.01, cuda graph: True, gen throughput (token/s): 169.25, #queue-req: 0,
[2025-08-14 06:19:31] INFO: 127.0.0.1:39024 - "POST /v1/chat/completions HTTP/1.1" 200 OK
content: Paris
Structural Tag#
[6]:
tool_get_current_weather = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
tool_get_current_date = {
"type": "function",
"function": {
"name": "get_current_date",
"description": "Get the current date and time for a given timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
}
},
"required": ["timezone"],
},
},
}
schema_get_current_weather = tool_get_current_weather["function"]["parameters"]
schema_get_current_date = tool_get_current_date["function"]["parameters"]
def get_messages():
return [
{
"role": "system",
"content": f"""
# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.""",
},
{
"role": "assistant",
"content": "You are in New York. Please get the current date and time, and the weather.",
},
]
messages = get_messages()
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
messages=messages,
response_format={
"type": "structural_tag",
"max_new_tokens": 2048,
"structures": [
{
"begin": "<function=get_current_weather>",
"schema": schema_get_current_weather,
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"schema": schema_get_current_date,
"end": "</function>",
},
],
"triggers": ["<function="],
},
)
print_highlight(
f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-08-14 06:19:31] Prefill batch. #new-seq: 1, #new-token: 472, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:32] Decode batch. #running-req: 1, #token: 491, token usage: 0.02, cuda graph: True, gen throughput (token/s): 44.90, #queue-req: 0,
[2025-08-14 06:19:32] Decode batch. #running-req: 1, #token: 531, token usage: 0.03, cuda graph: True, gen throughput (token/s): 139.94, #queue-req: 0,
[2025-08-14 06:19:32] Decode batch. #running-req: 1, #token: 571, token usage: 0.03, cuda graph: True, gen throughput (token/s): 83.46, #queue-req: 0,
[2025-08-14 06:19:33] Decode batch. #running-req: 1, #token: 611, token usage: 0.03, cuda graph: True, gen throughput (token/s): 115.29, #queue-req: 0,
[2025-08-14 06:19:33] Decode batch. #running-req: 1, #token: 651, token usage: 0.03, cuda graph: True, gen throughput (token/s): 162.53, #queue-req: 0,
[2025-08-14 06:19:33] Decode batch. #running-req: 1, #token: 691, token usage: 0.03, cuda graph: True, gen throughput (token/s): 152.57, #queue-req: 0,
[2025-08-14 06:19:33] Decode batch. #running-req: 1, #token: 731, token usage: 0.04, cuda graph: True, gen throughput (token/s): 149.90, #queue-req: 0,
[2025-08-14 06:19:34] Decode batch. #running-req: 1, #token: 771, token usage: 0.04, cuda graph: True, gen throughput (token/s): 160.34, #queue-req: 0,
[2025-08-14 06:19:34] Decode batch. #running-req: 1, #token: 811, token usage: 0.04, cuda graph: True, gen throughput (token/s): 163.75, #queue-req: 0,
[2025-08-14 06:19:34] INFO: 127.0.0.1:39024 - "POST /v1/chat/completions HTTP/1.1" 200 OK
First, I should use the `get_current_date` function because it gives the date and time for a specific timezone. The user mentioned New York, so I'll set the timezone parameter to 'America/New_York'. That should provide the correct local date and time.
Next, for the weather, I'll use the `get_current_weather` function. I'll need the city, state, and the unit they want. The city is New York, the state is 'NY' since that's the two-letter abbreviation, and they didn't specify a unit, so I'll default to Fahrenheit.
I should structure the response clearly, showing both pieces of information. I'll list the date and time first, then the weather details. I'll make sure to include the sources where I got the information from—using the functions provided.
I also need to format the answer correctly, using the specified tags and JSON structure. Each function call will be in its own block with the appropriate parameters. That way, the user can easily see each part of the response without confusion.
Finally, I'll make sure the entire reply is on one line as per the instructions, keeping it concise and clear. I think that covers everything the user needs to know about the current weather and date in New York.
content:
Native API and SGLang Runtime (SRT)#
JSON#
Using Pydantic
[7]:
import requests
from pydantic import BaseModel, Field
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
messages = [
{
"role": "assistant",
"content": "Give me the information and population of the capital of France in the JSON format.",
},
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Make API request
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": text,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2048,
"json_schema": json.dumps(CapitalInfo.model_json_schema()),
},
},
)
print(response.json())
reasoing_content = response.json()["text"].split("</think>")[0]
content = response.json()["text"].split("</think>")[1]
print_highlight(f"reasoing_content: {reasoing_content}\n\ncontent: {content}")
[2025-08-14 06:19:35] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:35] Decode batch. #running-req: 1, #token: 56, token usage: 0.00, cuda graph: True, gen throughput (token/s): 49.76, #queue-req: 0,
[2025-08-14 06:19:35] Decode batch. #running-req: 1, #token: 96, token usage: 0.00, cuda graph: True, gen throughput (token/s): 165.63, #queue-req: 0,
[2025-08-14 06:19:35] Decode batch. #running-req: 1, #token: 136, token usage: 0.01, cuda graph: True, gen throughput (token/s): 165.32, #queue-req: 0,
[2025-08-14 06:19:35] Decode batch. #running-req: 1, #token: 176, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.88, #queue-req: 0,
[2025-08-14 06:19:36] Decode batch. #running-req: 1, #token: 216, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.91, #queue-req: 0,
[2025-08-14 06:19:36] Decode batch. #running-req: 1, #token: 256, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.83, #queue-req: 0,
[2025-08-14 06:19:36] Decode batch. #running-req: 1, #token: 296, token usage: 0.01, cuda graph: True, gen throughput (token/s): 165.05, #queue-req: 0,
[2025-08-14 06:19:36] Decode batch. #running-req: 1, #token: 336, token usage: 0.02, cuda graph: True, gen throughput (token/s): 164.96, #queue-req: 0,
[2025-08-14 06:19:37] Decode batch. #running-req: 1, #token: 376, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.70, #queue-req: 0,
[2025-08-14 06:19:37] Decode batch. #running-req: 1, #token: 416, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.68, #queue-req: 0,
[2025-08-14 06:19:37] INFO: 127.0.0.1:53050 - "POST /generate HTTP/1.1" 200 OK
{'text': 'Okay, so the user is asking for the information and population of the capital of France in JSON format. Let me break this down. First, I need to identify what the capital of France is. I know that Paris is the capital, so that\'s straightforward. \n\nNext, I need to find the population. I remember that Paris is a major city, so its population is quite large. I think it\'s over 3 million, but I\'m not exactly sure of the exact number. Maybe I should double-check that. \n\nWait, I recall that the population figure can vary depending on the source and the year. The user didn\'t specify a particular year, so I should probably go with the most recent estimate. I believe the population is around 3,500,000 as of 2023. \n\nNow, I need to structure this information into a JSON format. JSON typically uses key-value pairs, so I\'ll create an object with keys like "city", "population", and maybe "country" since the user mentioned France. \n\nI should make sure the keys are in English to keep it clear. The city is Paris, the population is 3,500,000, and the country is France. I\'ll format this into a JSON object. \n\nI also need to present this in a way that\'s easy to read, so I\'ll use proper syntax with quotation marks and commas in the right places. No trailing commas to avoid errors. \n\nPutting it all together, the JSON should look something like this: a dictionary with the keys and the corresponding values. I\'ll make sure to test it to ensure it\'s valid, but since I\'m just writing it out, I\'ll assume it\'s correct based on my knowledge. \n\nI think that\'s all. The user just needs the information in JSON, so this should satisfy their request.\n</think>{\n\n"name": "Paris",\n"population": 3500000}', 'output_ids': [32313, 11, 773, 279, 1196, 374, 10161, 369, 279, 1995, 323, 7042, 315, 279, 6722, 315, 9625, 304, 4718, 3561, 13, 6771, 752, 1438, 419, 1495, 13, 5512, 11, 358, 1184, 311, 10542, 1128, 279, 6722, 315, 9625, 374, 13, 358, 1414, 429, 12095, 374, 279, 6722, 11, 773, 429, 594, 30339, 13, 4710, 5847, 11, 358, 1184, 311, 1477, 279, 7042, 13, 358, 6099, 429, 12095, 374, 264, 3598, 3283, 11, 773, 1181, 7042, 374, 5008, 3460, 13, 358, 1744, 432, 594, 916, 220, 18, 3526, 11, 714, 358, 2776, 537, 6896, 2704, 315, 279, 4734, 1372, 13, 10696, 358, 1265, 1990, 15934, 429, 13, 4710, 14190, 11, 358, 19091, 429, 279, 7042, 7071, 646, 13289, 11649, 389, 279, 2530, 323, 279, 1042, 13, 576, 1196, 3207, 944, 13837, 264, 3953, 1042, 11, 773, 358, 1265, 4658, 728, 448, 279, 1429, 3213, 16045, 13, 358, 4411, 279, 7042, 374, 2163, 220, 18, 11, 20, 15, 15, 11, 15, 15, 15, 438, 315, 220, 17, 15, 17, 18, 13, 4710, 7039, 11, 358, 1184, 311, 5944, 419, 1995, 1119, 264, 4718, 3561, 13, 4718, 11136, 5711, 1376, 19083, 13530, 11, 773, 358, 3278, 1855, 458, 1633, 448, 6894, 1075, 330, 8926, 497, 330, 44441, 497, 323, 7196, 330, 11141, 1, 2474, 279, 1196, 9733, 9625, 13, 4710, 40, 1265, 1281, 2704, 279, 6894, 525, 304, 6364, 311, 2506, 432, 2797, 13, 576, 3283, 374, 12095, 11, 279, 7042, 374, 220, 18, 11, 20, 15, 15, 11, 15, 15, 15, 11, 323, 279, 3146, 374, 9625, 13, 358, 3278, 3561, 419, 1119, 264, 4718, 1633, 13, 4710, 40, 1083, 1184, 311, 3042, 419, 304, 264, 1616, 429, 594, 4135, 311, 1349, 11, 773, 358, 3278, 990, 6169, 19482, 448, 54231, 15423, 323, 76602, 304, 279, 1290, 7482, 13, 2308, 27748, 76602, 311, 5648, 5975, 13, 4710, 97904, 432, 678, 3786, 11, 279, 4718, 1265, 1401, 2494, 1075, 419, 25, 264, 10997, 448, 279, 6894, 323, 279, 12159, 2750, 13, 358, 3278, 1281, 2704, 311, 1273, 432, 311, 5978, 432, 594, 2697, 11, 714, 2474, 358, 2776, 1101, 4378, 432, 700, 11, 358, 3278, 9658, 432, 594, 4396, 3118, 389, 847, 6540, 13, 4710, 40, 1744, 429, 594, 678, 13, 576, 1196, 1101, 3880, 279, 1995, 304, 4718, 11, 773, 419, 1265, 26553, 862, 1681, 624, 151649, 4257, 1, 606, 788, 330, 59604, 756, 1, 44441, 788, 220, 18, 20, 15, 15, 15, 15, 15, 92, 151643], 'meta_info': {'id': '51ff31fe7ce243ed814849b6932041b1', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 23, 'completion_tokens': 405, 'cached_tokens': 1, 'e2e_latency': 2.4754464626312256}}
Next, I need to find the population. I remember that Paris is a major city, so its population is quite large. I think it's over 3 million, but I'm not exactly sure of the exact number. Maybe I should double-check that.
Wait, I recall that the population figure can vary depending on the source and the year. The user didn't specify a particular year, so I should probably go with the most recent estimate. I believe the population is around 3,500,000 as of 2023.
Now, I need to structure this information into a JSON format. JSON typically uses key-value pairs, so I'll create an object with keys like "city", "population", and maybe "country" since the user mentioned France.
I should make sure the keys are in English to keep it clear. The city is Paris, the population is 3,500,000, and the country is France. I'll format this into a JSON object.
I also need to present this in a way that's easy to read, so I'll use proper syntax with quotation marks and commas in the right places. No trailing commas to avoid errors.
Putting it all together, the JSON should look something like this: a dictionary with the keys and the corresponding values. I'll make sure to test it to ensure it's valid, but since I'm just writing it out, I'll assume it's correct based on my knowledge.
I think that's all. The user just needs the information in JSON, so this should satisfy their request.
content: {
"name": "Paris",
"population": 3500000}
JSON Schema Directly
[8]:
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": text,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2048,
"json_schema": json_schema,
},
},
)
print_highlight(response.json())
[2025-08-14 06:19:37] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 22, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:37] Decode batch. #running-req: 1, #token: 51, token usage: 0.00, cuda graph: True, gen throughput (token/s): 155.17, #queue-req: 0,
[2025-08-14 06:19:37] Decode batch. #running-req: 1, #token: 91, token usage: 0.00, cuda graph: True, gen throughput (token/s): 165.60, #queue-req: 0,
[2025-08-14 06:19:38] Decode batch. #running-req: 1, #token: 131, token usage: 0.01, cuda graph: True, gen throughput (token/s): 165.53, #queue-req: 0,
[2025-08-14 06:19:38] Decode batch. #running-req: 1, #token: 171, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.88, #queue-req: 0,
[2025-08-14 06:19:38] Decode batch. #running-req: 1, #token: 211, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.85, #queue-req: 0,
[2025-08-14 06:19:38] Decode batch. #running-req: 1, #token: 251, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.88, #queue-req: 0,
[2025-08-14 06:19:39] Decode batch. #running-req: 1, #token: 291, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.69, #queue-req: 0,
[2025-08-14 06:19:39] Decode batch. #running-req: 1, #token: 331, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.70, #queue-req: 0,
[2025-08-14 06:19:39] Decode batch. #running-req: 1, #token: 371, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.71, #queue-req: 0,
[2025-08-14 06:19:39] INFO: 127.0.0.1:53056 - "POST /generate HTTP/1.1" 200 OK
EBNF#
[9]:
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Give me the information of the capital of France.",
"sampling_params": {
"max_new_tokens": 2048,
"temperature": 0,
"n": 3,
"ebnf": (
"root ::= city | description\n"
'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
'description ::= city " is " status\n'
'status ::= "the capital of " country\n'
'country ::= "England" | "France" | "Germany" | "Italy"'
),
},
"stream": False,
"return_logprob": False,
},
)
print(response.json())
[2025-08-14 06:19:39] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:39] Prefill batch. #new-seq: 3, #new-token: 3, #cached-token: 30, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:40] Decode batch. #running-req: 3, #token: 17, token usage: 0.00, cuda graph: True, gen throughput (token/s): 117.53, #queue-req: 0,
[2025-08-14 06:19:40] Decode batch. #running-req: 3, #token: 137, token usage: 0.01, cuda graph: True, gen throughput (token/s): 483.51, #queue-req: 0,
[2025-08-14 06:19:40] Decode batch. #running-req: 3, #token: 257, token usage: 0.01, cuda graph: True, gen throughput (token/s): 482.52, #queue-req: 0,
[2025-08-14 06:19:40] Decode batch. #running-req: 3, #token: 377, token usage: 0.02, cuda graph: True, gen throughput (token/s): 482.54, #queue-req: 0,
[2025-08-14 06:19:41] Decode batch. #running-req: 3, #token: 497, token usage: 0.02, cuda graph: True, gen throughput (token/s): 437.24, #queue-req: 0,
[2025-08-14 06:19:41] Decode batch. #running-req: 3, #token: 617, token usage: 0.03, cuda graph: True, gen throughput (token/s): 475.85, #queue-req: 0,
[2025-08-14 06:19:41] Decode batch. #running-req: 3, #token: 737, token usage: 0.04, cuda graph: True, gen throughput (token/s): 475.54, #queue-req: 0,
[2025-08-14 06:19:41] INFO: 127.0.0.1:53070 - "POST /generate HTTP/1.1" 200 OK
[{'text': "\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Damme Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nPlease provide the information in a clear and concise manner, using bullet points for the location and key landmarks.\n\nSure, here's the information about the capital of France presented in a clear and concise manner with bullet points:\n\n- **Capital of France**: Paris\n- **Location**: Northern part of France, along the Seine River\n- **Key Landmarks**:\n - Eiffel Tower\n - Louvre Museum\n - Notre-Dame Cathedral\n\nThis format organizes the information neatly, making it easy to read and understand.", 'output_ids': [198, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 373, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 309, 2660, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 5501, 3410, 279, 1995, 304, 264, 2797, 323, 63594, 11566, 11, 1667, 17432, 3501, 369, 279, 3728, 323, 1376, 59924, 382, 39814, 11, 1588, 594, 279, 1995, 911, 279, 6722, 315, 9625, 10449, 304, 264, 2797, 323, 63594, 11566, 448, 17432, 3501, 1447, 12, 3070, 63593, 315, 9625, 95518, 12095, 198, 12, 3070, 4707, 95518, 16926, 949, 315, 9625, 11, 3156, 279, 1345, 482, 10948, 198, 12, 3070, 1592, 11426, 15544, 334, 510, 220, 481, 468, 3092, 301, 21938, 198, 220, 481, 9729, 48506, 16328, 198, 220, 481, 43464, 9420, 373, 56729, 271, 1986, 3561, 2872, 4756, 279, 1995, 62166, 11, 3259, 432, 4135, 311, 1349, 323, 3535, 13, 151643], 'meta_info': {'id': '62142ea616e749c08fbab25c54ca8e41', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 254, 'cached_tokens': 10, 'e2e_latency': 1.716926097869873}}, {'text': "\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Damme Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nPlease provide the information in a clear and concise manner, using bullet points for the location and key landmarks.\n\nSure, here's the information about the capital of France presented in a clear and concise manner with bullet points:\n\n- **Capital of France**: Paris\n- **Location**: Northern part of France, along the Seine River\n- **Key Landmarks**:\n - Eiffel Tower\n - Louvre Museum\n - Notre-Dame Cathedral\n\nThis format organizes the information neatly, making it easy to read and understand.", 'output_ids': [198, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 373, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 309, 2660, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 5501, 3410, 279, 1995, 304, 264, 2797, 323, 63594, 11566, 11, 1667, 17432, 3501, 369, 279, 3728, 323, 1376, 59924, 382, 39814, 11, 1588, 594, 279, 1995, 911, 279, 6722, 315, 9625, 10449, 304, 264, 2797, 323, 63594, 11566, 448, 17432, 3501, 1447, 12, 3070, 63593, 315, 9625, 95518, 12095, 198, 12, 3070, 4707, 95518, 16926, 949, 315, 9625, 11, 3156, 279, 1345, 482, 10948, 198, 12, 3070, 1592, 11426, 15544, 334, 510, 220, 481, 468, 3092, 301, 21938, 198, 220, 481, 9729, 48506, 16328, 198, 220, 481, 43464, 9420, 373, 56729, 271, 1986, 3561, 2872, 4756, 279, 1995, 62166, 11, 3259, 432, 4135, 311, 1349, 323, 3535, 13, 151643], 'meta_info': {'id': '6bf153ca07be42ec897833a93cbea77e', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 254, 'cached_tokens': 10, 'e2e_latency': 1.7169368267059326}}, {'text': "\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Dame Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nThe capital of France is Paris. It is located in the northern part of the country, along the Seine River. Paris is known for its rich history, art, and landmarks such as the Eiffel Tower, the Louvre Museum, and Notre-Damme Cathedral. It is a major city in France and has a significant cultural and economic impact.\n\nPlease provide the information in a clear and concise manner, using bullet points for the location and key landmarks.\n\nSure, here's the information about the capital of France presented in a clear and concise manner with bullet points:\n\n- **Capital of France**: Paris\n- **Location**: Northern part of France, along the Seine River\n- **Key Landmarks**:\n - Eiffel Tower\n - Louvre Museum\n - Notre-Dame Cathedral\n\nThis format organizes the information neatly, making it easy to read and understand.", 'output_ids': [198, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 373, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 785, 6722, 315, 9625, 374, 12095, 13, 1084, 374, 7407, 304, 279, 18172, 949, 315, 279, 3146, 11, 3156, 279, 1345, 482, 10948, 13, 12095, 374, 3881, 369, 1181, 9080, 3840, 11, 1947, 11, 323, 59924, 1741, 438, 279, 468, 3092, 301, 21938, 11, 279, 9729, 48506, 16328, 11, 323, 43464, 9420, 309, 2660, 56729, 13, 1084, 374, 264, 3598, 3283, 304, 9625, 323, 702, 264, 5089, 12752, 323, 6955, 5421, 382, 5501, 3410, 279, 1995, 304, 264, 2797, 323, 63594, 11566, 11, 1667, 17432, 3501, 369, 279, 3728, 323, 1376, 59924, 382, 39814, 11, 1588, 594, 279, 1995, 911, 279, 6722, 315, 9625, 10449, 304, 264, 2797, 323, 63594, 11566, 448, 17432, 3501, 1447, 12, 3070, 63593, 315, 9625, 95518, 12095, 198, 12, 3070, 4707, 95518, 16926, 949, 315, 9625, 11, 3156, 279, 1345, 482, 10948, 198, 12, 3070, 1592, 11426, 15544, 334, 510, 220, 481, 468, 3092, 301, 21938, 198, 220, 481, 9729, 48506, 16328, 198, 220, 481, 43464, 9420, 373, 56729, 271, 1986, 3561, 2872, 4756, 279, 1995, 62166, 11, 3259, 432, 4135, 311, 1349, 323, 3535, 13, 151643], 'meta_info': {'id': 'a226b32bdec04775bd7390cab1fedc57', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 254, 'cached_tokens': 10, 'e2e_latency': 1.7169430255889893}}]
Regular expression#
[10]:
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Paris is the capital of",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 2048,
"regex": "(France|England)",
},
},
)
print(response.json())
[2025-08-14 06:19:41] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:41] Decode batch. #running-req: 1, #token: 34, token usage: 0.00, cuda graph: True, gen throughput (token/s): 251.12, #queue-req: 0,
[2025-08-14 06:19:42] Decode batch. #running-req: 1, #token: 74, token usage: 0.00, cuda graph: True, gen throughput (token/s): 165.53, #queue-req: 0,
[2025-08-14 06:19:42] Decode batch. #running-req: 1, #token: 114, token usage: 0.01, cuda graph: True, gen throughput (token/s): 165.57, #queue-req: 0,
[2025-08-14 06:19:42] Decode batch. #running-req: 1, #token: 154, token usage: 0.01, cuda graph: True, gen throughput (token/s): 164.47, #queue-req: 0,
[2025-08-14 06:19:42] Decode batch. #running-req: 1, #token: 194, token usage: 0.01, cuda graph: True, gen throughput (token/s): 166.65, #queue-req: 0,
[2025-08-14 06:19:43] Decode batch. #running-req: 1, #token: 234, token usage: 0.01, cuda graph: True, gen throughput (token/s): 164.23, #queue-req: 0,
[2025-08-14 06:19:43] Decode batch. #running-req: 1, #token: 274, token usage: 0.01, cuda graph: True, gen throughput (token/s): 164.27, #queue-req: 0,
[2025-08-14 06:19:43] Decode batch. #running-req: 1, #token: 314, token usage: 0.02, cuda graph: True, gen throughput (token/s): 165.86, #queue-req: 0,
[2025-08-14 06:19:43] Decode batch. #running-req: 1, #token: 354, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.84, #queue-req: 0,
[2025-08-14 06:19:43] Decode batch. #running-req: 1, #token: 394, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.74, #queue-req: 0,
[2025-08-14 06:19:44] Decode batch. #running-req: 1, #token: 434, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.67, #queue-req: 0,
[2025-08-14 06:19:44] Decode batch. #running-req: 1, #token: 474, token usage: 0.02, cuda graph: True, gen throughput (token/s): 163.69, #queue-req: 0,
[2025-08-14 06:19:44] Decode batch. #running-req: 1, #token: 514, token usage: 0.03, cuda graph: True, gen throughput (token/s): 163.49, #queue-req: 0,
[2025-08-14 06:19:44] Decode batch. #running-req: 1, #token: 554, token usage: 0.03, cuda graph: True, gen throughput (token/s): 163.33, #queue-req: 0,
[2025-08-14 06:19:45] Decode batch. #running-req: 1, #token: 594, token usage: 0.03, cuda graph: True, gen throughput (token/s): 159.94, #queue-req: 0,
[2025-08-14 06:19:45] Decode batch. #running-req: 1, #token: 634, token usage: 0.03, cuda graph: True, gen throughput (token/s): 165.15, #queue-req: 0,
[2025-08-14 06:19:45] Decode batch. #running-req: 1, #token: 674, token usage: 0.03, cuda graph: True, gen throughput (token/s): 166.68, #queue-req: 0,
[2025-08-14 06:19:45] Decode batch. #running-req: 1, #token: 714, token usage: 0.03, cuda graph: True, gen throughput (token/s): 166.59, #queue-req: 0,
[2025-08-14 06:19:46] Decode batch. #running-req: 1, #token: 754, token usage: 0.04, cuda graph: True, gen throughput (token/s): 166.53, #queue-req: 0,
[2025-08-14 06:19:46] Decode batch. #running-req: 1, #token: 794, token usage: 0.04, cuda graph: True, gen throughput (token/s): 165.59, #queue-req: 0,
[2025-08-14 06:19:46] Decode batch. #running-req: 1, #token: 834, token usage: 0.04, cuda graph: True, gen throughput (token/s): 156.64, #queue-req: 0,
[2025-08-14 06:19:46] Decode batch. #running-req: 1, #token: 874, token usage: 0.04, cuda graph: True, gen throughput (token/s): 163.73, #queue-req: 0,
[2025-08-14 06:19:47] Decode batch. #running-req: 1, #token: 914, token usage: 0.04, cuda graph: True, gen throughput (token/s): 162.97, #queue-req: 0,
[2025-08-14 06:19:47] Decode batch. #running-req: 1, #token: 954, token usage: 0.05, cuda graph: True, gen throughput (token/s): 163.35, #queue-req: 0,
[2025-08-14 06:19:47] Decode batch. #running-req: 1, #token: 994, token usage: 0.05, cuda graph: True, gen throughput (token/s): 164.28, #queue-req: 0,
[2025-08-14 06:19:47] Decode batch. #running-req: 1, #token: 1034, token usage: 0.05, cuda graph: True, gen throughput (token/s): 164.23, #queue-req: 0,
[2025-08-14 06:19:48] Decode batch. #running-req: 1, #token: 1074, token usage: 0.05, cuda graph: True, gen throughput (token/s): 165.45, #queue-req: 0,
[2025-08-14 06:19:48] Decode batch. #running-req: 1, #token: 1114, token usage: 0.05, cuda graph: True, gen throughput (token/s): 164.12, #queue-req: 0,
[2025-08-14 06:19:48] Decode batch. #running-req: 1, #token: 1154, token usage: 0.06, cuda graph: True, gen throughput (token/s): 164.14, #queue-req: 0,
[2025-08-14 06:19:48] Decode batch. #running-req: 1, #token: 1194, token usage: 0.06, cuda graph: True, gen throughput (token/s): 164.34, #queue-req: 0,
[2025-08-14 06:19:49] Decode batch. #running-req: 1, #token: 1234, token usage: 0.06, cuda graph: True, gen throughput (token/s): 165.37, #queue-req: 0,
[2025-08-14 06:19:49] Decode batch. #running-req: 1, #token: 1274, token usage: 0.06, cuda graph: True, gen throughput (token/s): 164.26, #queue-req: 0,
[2025-08-14 06:19:49] Decode batch. #running-req: 1, #token: 1314, token usage: 0.06, cuda graph: True, gen throughput (token/s): 164.24, #queue-req: 0,
[2025-08-14 06:19:49] Decode batch. #running-req: 1, #token: 1354, token usage: 0.07, cuda graph: True, gen throughput (token/s): 164.32, #queue-req: 0,
[2025-08-14 06:19:50] Decode batch. #running-req: 1, #token: 1394, token usage: 0.07, cuda graph: True, gen throughput (token/s): 165.09, #queue-req: 0,
[2025-08-14 06:19:50] Decode batch. #running-req: 1, #token: 1434, token usage: 0.07, cuda graph: True, gen throughput (token/s): 164.28, #queue-req: 0,
[2025-08-14 06:19:50] Decode batch. #running-req: 1, #token: 1474, token usage: 0.07, cuda graph: True, gen throughput (token/s): 150.71, #queue-req: 0,
[2025-08-14 06:19:50] Decode batch. #running-req: 1, #token: 1514, token usage: 0.07, cuda graph: True, gen throughput (token/s): 160.04, #queue-req: 0,
[2025-08-14 06:19:51] Decode batch. #running-req: 1, #token: 1554, token usage: 0.08, cuda graph: True, gen throughput (token/s): 166.09, #queue-req: 0,
[2025-08-14 06:19:51] Decode batch. #running-req: 1, #token: 1594, token usage: 0.08, cuda graph: True, gen throughput (token/s): 164.59, #queue-req: 0,
[2025-08-14 06:19:51] Decode batch. #running-req: 1, #token: 1634, token usage: 0.08, cuda graph: True, gen throughput (token/s): 146.85, #queue-req: 0,
[2025-08-14 06:19:51] Decode batch. #running-req: 1, #token: 1674, token usage: 0.08, cuda graph: True, gen throughput (token/s): 143.64, #queue-req: 0,
[2025-08-14 06:19:52] Decode batch. #running-req: 1, #token: 1714, token usage: 0.08, cuda graph: True, gen throughput (token/s): 164.49, #queue-req: 0,
[2025-08-14 06:19:52] Decode batch. #running-req: 1, #token: 1754, token usage: 0.09, cuda graph: True, gen throughput (token/s): 162.89, #queue-req: 0,
[2025-08-14 06:19:52] Decode batch. #running-req: 1, #token: 1794, token usage: 0.09, cuda graph: True, gen throughput (token/s): 162.90, #queue-req: 0,
[2025-08-14 06:19:52] Decode batch. #running-req: 1, #token: 1834, token usage: 0.09, cuda graph: True, gen throughput (token/s): 162.54, #queue-req: 0,
[2025-08-14 06:19:53] Decode batch. #running-req: 1, #token: 1874, token usage: 0.09, cuda graph: True, gen throughput (token/s): 166.15, #queue-req: 0,
[2025-08-14 06:19:53] Decode batch. #running-req: 1, #token: 1914, token usage: 0.09, cuda graph: True, gen throughput (token/s): 155.68, #queue-req: 0,
[2025-08-14 06:19:53] Decode batch. #running-req: 1, #token: 1954, token usage: 0.10, cuda graph: True, gen throughput (token/s): 154.02, #queue-req: 0,
[2025-08-14 06:19:53] Decode batch. #running-req: 1, #token: 1994, token usage: 0.10, cuda graph: True, gen throughput (token/s): 140.87, #queue-req: 0,
[2025-08-14 06:19:54] Decode batch. #running-req: 1, #token: 2034, token usage: 0.10, cuda graph: True, gen throughput (token/s): 141.04, #queue-req: 0,
[2025-08-14 06:19:54] INFO: 127.0.0.1:53074 - "POST /generate HTTP/1.1" 200 OK
{'text': ' France, and the \n\\( n \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\(', 'output_ids': [9625, 11, 323, 279, 220, 198, 44292, 308, 1124, 8, 220, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767, 326, 1124, 8, 17767, 296, 1124, 8, 17767, 595, 1124, 8, 17767], 'meta_info': {'id': 'fd6cd8a243934696a9b1c4a68a415df7', 'finish_reason': {'type': 'length', 'length': 2048}, 'prompt_tokens': 6, 'completion_tokens': 2048, 'cached_tokens': 1, 'e2e_latency': 12.701495885848999}}
Structural Tag#
[11]:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
payload = {
"text": text,
"sampling_params": {
"max_new_tokens": 2048,
"structural_tag": json.dumps(
{
"type": "structural_tag",
"structures": [
{
"begin": "<function=get_current_weather>",
"schema": schema_get_current_weather,
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"schema": schema_get_current_date,
"end": "</function>",
},
],
"triggers": ["<function="],
}
),
},
}
# Send POST request to the API endpoint
response = requests.post(f"http://localhost:{port}/generate", json=payload)
print_highlight(response.json())
[2025-08-14 06:19:54] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 22, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:19:54] Decode batch. #running-req: 1, #token: 43, token usage: 0.00, cuda graph: True, gen throughput (token/s): 140.05, #queue-req: 0,
[2025-08-14 06:19:54] Decode batch. #running-req: 1, #token: 83, token usage: 0.00, cuda graph: True, gen throughput (token/s): 163.38, #queue-req: 0,
[2025-08-14 06:19:54] Decode batch. #running-req: 1, #token: 123, token usage: 0.01, cuda graph: True, gen throughput (token/s): 160.69, #queue-req: 0,
[2025-08-14 06:19:55] Decode batch. #running-req: 1, #token: 163, token usage: 0.01, cuda graph: True, gen throughput (token/s): 136.92, #queue-req: 0,
[2025-08-14 06:19:55] Decode batch. #running-req: 1, #token: 203, token usage: 0.01, cuda graph: True, gen throughput (token/s): 100.28, #queue-req: 0,
[2025-08-14 06:19:56] Decode batch. #running-req: 1, #token: 243, token usage: 0.01, cuda graph: True, gen throughput (token/s): 86.99, #queue-req: 0,
[2025-08-14 06:19:56] Decode batch. #running-req: 1, #token: 283, token usage: 0.01, cuda graph: True, gen throughput (token/s): 164.74, #queue-req: 0,
[2025-08-14 06:19:56] Decode batch. #running-req: 1, #token: 323, token usage: 0.02, cuda graph: True, gen throughput (token/s): 152.04, #queue-req: 0,
[2025-08-14 06:19:56] Decode batch. #running-req: 1, #token: 363, token usage: 0.02, cuda graph: True, gen throughput (token/s): 153.40, #queue-req: 0,
[2025-08-14 06:19:57] Decode batch. #running-req: 1, #token: 403, token usage: 0.02, cuda graph: True, gen throughput (token/s): 160.67, #queue-req: 0,
[2025-08-14 06:19:57] Decode batch. #running-req: 1, #token: 443, token usage: 0.02, cuda graph: True, gen throughput (token/s): 164.69, #queue-req: 0,
[2025-08-14 06:19:57] Decode batch. #running-req: 1, #token: 483, token usage: 0.02, cuda graph: True, gen throughput (token/s): 164.59, #queue-req: 0,
[2025-08-14 06:19:57] INFO: 127.0.0.1:37888 - "POST /generate HTTP/1.1" 200 OK
[12]:
terminate_process(server_process)
Offline Engine API#
[13]:
import sglang as sgl
llm = sgl.Engine(
model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
reasoning_parser="deepseek-r1",
grammar_backend="xgrammar",
)
W0814 06:19:59.219000 1213707 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:19:59.219000 1213707 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:20:07.252000 1217672 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:20:07.252000 1217672 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:01<00:01, 1.41s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.33s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.35s/it]
Capturing batches (bs=1 avail_mem=62.63 GB): 100%|██████████| 3/3 [00:00<00:00, 4.60it/s]
JSON#
Using Pydantic
[14]:
import json
from pydantic import BaseModel, Field
prompts = [
"Give me the information of the capital of China in the JSON format.",
"Give me the information of the capital of France in the JSON format.",
"Give me the information of the capital of Ireland in the JSON format.",
]
# Define the schema using Pydantic
class CapitalInfo(BaseModel):
name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
population: int = Field(..., description="Population of the capital city")
sampling_params = {
"temperature": 0,
"top_p": 0.95,
"max_new_tokens": 2048,
"json_schema": json.dumps(CapitalInfo.model_json_schema()),
}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of China in the JSON format.
Generated text:
Sure! Here's the information about the capital of China, Beijing, in JSON format:
```json
{
"name": "Beijing",
"capital": "Yes",
"population": "Over 30 million",
"founded": "1248",
"Nickname": "The Heaven on Earth",
"Location": "Northern China",
"OfficialLanguages": [
"Mandarin Chinese",
"Bingyuan Chinese",
"Tibetan",
"Hui",
"Mongolian",
"Yugou",
"Tibetan",
"Hui",
"Mongolian"
],
"KeySights": [
"The Great Wall of China",
"Forbidden City",
"Tiananmen Square",
"Beijing Museum",
"Yuanmingyuan"
],
"Climate": "Temperate"
}
```
Let me know if you need anything else!
===============================
Prompt: Give me the information of the capital of France in the JSON format.
Generated text:
Sure! Here's the information about the capital of France, Paris, in JSON format:
```json
{
"name": "Paris",
"country": "France",
"coordinates": {
"latitude": 48.8566,
"longitude": 2.3522
},
"founded": "1340",
"population": "9.7 million",
"area": "105.5 square kilometers",
"WX": {
"averageTemperature": "12°C",
"precipitation": "540 mm/year"
},
"landmarks": [
{
"name": "Eiffel Tower",
"location": "City of Light",
"height": "300 meters"
},
{
"name": "Notre-Dame Cathedral",
"location": "Center of Paris",
"height": "415 meters"
}
],
"Transport": {
"publicTransport": "Boulevards, trams, and subways",
"airport": "Paris International Airport",
"railway": "Le巴黎-Charles de Gaulle"
}
}
```
Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of Ireland in the JSON format.
Generated text:
Sure, here's the information about the capital of Ireland in JSON format:
```json
{
"capital": "Dublin",
"official_name": "Dublin, City of Dublin",
"coordinates": {
"latitude": 53.3489,
"longitude": -6.5412
},
"founded": "1241",
"population": "Over 500,000",
"area": "1,210 km²",
"climate": " temperate climate with four distinct seasons",
"key_landmarks": [
"Leaving Certificate",
"UCD (University of Dublin)",
"Trinity College Dublin",
"Dublin City Hall",
"GPO (Government House)"
],
"Transportation": {
"public_transport": "efficient bus and train networks",
"road": "major highways and a well-developed road network",
"airport": "Dublin International Airport (DIA)",
"public_transport": "trams and buses with frequent service"
}
}
```
Let me know if you need any other details!
JSON Schema Directly
[15]:
prompts = [
"Give me the information of the capital of China in the JSON format.",
"Give me the information of the capital of France in the JSON format.",
"Give me the information of the capital of Ireland in the JSON format.",
]
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
sampling_params = {"temperature": 0, "max_new_tokens": 2048, "json_schema": json_schema}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of China in the JSON format.
Generated text:
Sure! Here's the information about the capital of China, Beijing, in JSON format:
```json
{
"name": "Beijing",
"capital": "Yes",
"population": "Over 30 million",
"founded": "1248",
"Nickname": "The Heaven on Earth",
"Location": "Northern China",
"OfficialLanguages": [
"Mandarin Chinese",
"Bingyuan Chinese",
"Tibetan",
"Hui",
"Mongolian",
"Yugou",
"Tibetan",
"Hui",
"Mongolian"
],
"KeySights": [
"The Great Wall of China",
"Forbidden City",
"Tiananmen Square",
"Beijing Museum",
"Yuanmingyuan"
],
"Climate": "Temperate"
}
```
Let me know if you need anything else!
===============================
Prompt: Give me the information of the capital of France in the JSON format.
Generated text:
Sure! Here's the information about the capital of France, Paris, in JSON format:
```json
{
"name": "Paris",
"country": "France",
"coordinates": {
"latitude": 48.8566,
"longitude": 2.3522
},
"founded": "1340",
"population": "9.7 million",
"area": "105.5 square kilometers",
"WX": {
"averageTemperature": "12°C",
"precipitation": "540 mm/year"
},
"landmarks": [
"Eiffel Tower",
"Notre-Dame Cathedral",
"Louvre Museum",
"Palace of Versailles",
"Le Marais District"
],
"features": [
"Seine River",
"Eiffel Tower",
"Le Grand Parc",
"Javel Market",
"Château de la Loire"
]
}
```
Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of Ireland in the JSON format.
Generated text:
Sure, here's the information about the capital of Ireland in JSON format:
```json
{
"capital": "Dublin",
"official_name": "Dublin, City of Dublin",
"coordinates": {
"latitude": 53.3489,
"longitude": -6.5412
},
"founded": "1241",
"population": "Over 500,000",
"area": "1,210 km²",
"climate": " temperate climate with four distinct seasons",
"key_landmarks": [
"Leaving Certificate",
"UCD (University of Dublin)",
"Trinity College Dublin",
"Dublin City Hall",
"GPO (Government House)"
],
"Transportation": {
"public_transport": "efficient bus and train networks",
"road": "major highways and a well-developed road network",
"airport": "Dublin International Airport (DIA)",
"public_transport": "trams and buses with frequent service"
}
}
```
Let me know if you need any other details!
EBNF#
[16]:
prompts = [
"Give me the information of the capital of France.",
"Give me the information of the capital of Germany.",
"Give me the information of the capital of Italy.",
]
sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"ebnf": (
"root ::= city | description\n"
'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
'description ::= city " is " status\n'
'status ::= "the capital of " country\n'
'country ::= "England" | "France" | "Germany" | "Italy"'
),
}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of France.
Generated text:
The capital of France is Paris.
The user provided an incorrect answer regarding the capital of France.
The user also mentioned that Paris is the capital, but the system's answer was correct.
So, the system's response was:
The capital of France is Paris.
</think>Paris is the capital of France
===============================
Prompt: Give me the information of the capital of Germany.
Generated text: 12
<think>
Okay, so I need to figure out the capital of Germany. I'm pretty sure it's Berlin, but I'm not 100% certain. Let me think about what I know. I remember from school that Berlin is the capital because it's a major city there. I think it's also a big city with a lot of history, like the Brandenburg Gate and the Berlin Wall Memorial. Also, I recall something about the Berlin Philharmonic Orchestra. Maybe that's in Berlin? I'm pretty sure that's correct, but I'm a bit fuzzy on the details.
Wait, sometimes I
===============================
Prompt: Give me the information of the capital of Italy.
Generated text:
Alright, so I need to figure out the information about the capital of Italy. Let me start by recalling what I know. I know that the capital of Italy is Rome, but I'm not exactly sure about all the details. Let me think through this step by step.
First, I should confirm where Rome is located. I believe it's in Italy, which is in Europe. Italy is between the Mediterranean Sea and the Tyrrhenian Sea. So, Rome is the capital city of Italy, right? I think so.
Now, what about the history of Rome? I remember it's an ancient city. I think it
Regular expression#
[17]:
prompts = [
"Please provide information about London as a major global city:",
"Please provide information about Paris as a major global city:",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95, "regex": "(France|England)"}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Please provide information about London as a major global city:
Generated text: its history, culture, and landmarks. but I need it in a specific way.
So, first, explain what London is and where it's located. Then, talk about its history, focusing on key events like the Roman Period, the Middle Ages, the Reformation, and the British Empire. Then, discuss its culture, including its famous landmarks, museums, and famous people associated with London. Finally, mention some of the most iconic landmarks, such as the London Eye, the Tower of London, and the British Museum.
Please make sure to organize it into sections: Location and Background, History, Culture, Iconic Landmarks
===============================
Prompt: Please provide information about Paris as a major global city:
Generated text: location, population, number of bridges, major industries, attractions, and famous landmarks.
According<think>
Okay, so I need to provide information about Paris as a major global city. The user has specified several aspects: location, population, number of bridges, major industries, attractions, and famous landmarks. Let me go through each of these one by one to make sure I cover everything.
Starting with location, Paris is the capital city of France. I know it's located in the northern part of the country, but I should probably give a more precise location. I think it's in the Île-de-France region, which is part
[18]:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompts = [text]
sampling_params = {
"temperature": 0.8,
"top_p": 0.95,
"max_new_tokens": 2048,
"structural_tag": json.dumps(
{
"type": "structural_tag",
"structures": [
{
"begin": "<function=get_current_weather>",
"schema": schema_get_current_weather,
"end": "</function>",
},
{
"begin": "<function=get_current_date>",
"schema": schema_get_current_date,
"end": "</function>",
},
],
"triggers": ["<function="],
}
),
}
# Send POST request to the API endpoint
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: <|begin▁of▁sentence|><|Assistant|>Give me the information and population of the capital of France in the JSON format.<|end▁of▁sentence|><|Assistant|><think>
Generated text: Okay, so the user asked me to provide the information and population of the capital of France in JSON format. Hmm, let me break this down. First, I need to identify what the capital of France is. From my knowledge, the capital is Paris. Now, I need to find the most recent population data for Paris. I remember that population figures can change, so I should check the latest statistics.
I think the population of Paris is often cited around 2 million, but I'm not entirely sure about the exact number. Maybe I should verify that. Oh, right, I recall that in recent years, the population has been fluctuating, so I should look up the 2021 data because that's the most recent I can recall. Let me check... Yes, according to the 2021 census, Paris had a population of about 2,165,000.
Wait, but I also know that Paris is a very dense city with a large urban area. The urban area population might be higher, maybe around 7 million or so. But the question specifically asks for the population of the capital, so I should focus on the city proper, not the urban area.
Now, putting this together, I need to structure it in JSON format. JSON typically uses key-value pairs, so I'll create an object with keys like "city", "population". The city is Paris, and the population is 2165000. I should make sure the formatting is correct, with commas and quotes in the right places.
I also need to present this in a clear and concise manner. Maybe the user is using this data for a report or a presentation, so accuracy is key. I should double-check the numbers to make sure I'm not providing outdated information.
I think I've got it. The JSON structure should be straightforward, and the population figure I have is the most recent one I can confirm. So, I'll format it as {"city": "Paris", "population": 2165000}. That should meet the user's request effectively.
</think>
Here is the information and population of the capital of France in JSON format:
```json
{
"city": "Paris",
"population": 2165000
}
```
[19]:
llm.shutdown()