Structured Outputs For Reasoning Models#

When working with reasoning models that use special tokens like <think>...</think> to denote reasoning sections, you might want to allow free-form text within these sections while still enforcing grammar constraints on the rest of the output.

SGLang provides a feature to disable grammar restrictions within reasoning sections. This is particularly useful for models that need to perform complex reasoning steps before providing a structured output.

To enable this feature, use the --reasoning-parser flag which decide the think_end_token, such as </think>, when launching the server. You can also specify the reasoning parser using the --reasoning-parser flag.

Supported Models#

Currently, SGLang supports the following reasoning models:

  • DeepSeek R1 series: The reasoning content is wrapped with <think> and </think> tags.

  • QwQ: The reasoning content is wrapped with <think> and </think> tags.

Usage#

OpenAI Compatible API#

Specify the --grammar-backend, --reasoning-parser option.

[1]:
import openai
import os
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process

os.environ["TOKENIZERS_PARALLELISM"] = "false"


server_process, port = launch_server_cmd(
    "python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1"
)

wait_for_server(f"http://localhost:{port}")
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")
[2025-05-30 02:30:22] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='0.0.0.0', port=31492, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=846527266, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='deepseek-r1', dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-30 02:30:33] Attention backend not set. Use fa3 backend by default.
[2025-05-30 02:30:33] Init torch distributed begin.
[2025-05-30 02:30:33] Init torch distributed ends. mem usage=0.00 GB
[2025-05-30 02:30:33] init_expert_location from trivial
[2025-05-30 02:30:33] Load weight begin. avail mem=53.65 GB
[2025-05-30 02:30:34] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.49s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.40s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.42s/it]

[2025-05-30 02:30:37] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=39.30 GB, mem usage=14.35 GB.
[2025-05-30 02:30:37] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-05-30 02:30:37] Memory pool end. avail mem=37.93 GB
[2025-05-30 02:30:38] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-05-30 02:30:38] INFO:     Started server process [2294434]
[2025-05-30 02:30:38] INFO:     Waiting for application startup.
[2025-05-30 02:30:38] INFO:     Application startup complete.
[2025-05-30 02:30:38] INFO:     Uvicorn running on http://0.0.0.0:31492 (Press CTRL+C to quit)
[2025-05-30 02:30:39] INFO:     127.0.0.1:34160 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:30:39] INFO:     127.0.0.1:34176 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-30 02:30:39] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:41] INFO:     127.0.0.1:34186 - "POST /generate HTTP/1.1" 200 OK
[2025-05-30 02:30:41] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.

JSON#

you can directly define a JSON schema or use Pydantic to define and validate the response.

Using Pydantic

[2]:
from pydantic import BaseModel, Field


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    messages=[
        {
            "role": "assistant",
            "content": "Give me the information and population of the capital of France in the JSON format.",
        },
    ],
    temperature=0,
    max_tokens=2048,
    response_format={
        "type": "json_schema",
        "json_schema": {
            "name": "foo",
            # convert the pydantic model to json schema
            "schema": CapitalInfo.model_json_schema(),
        },
    },
)

print_highlight(
    f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-05-30 02:30:44] Prefill batch. #new-seq: 1, #new-token: 21, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:45] Decode batch. #running-req: 1, #token: 55, token usage: 0.00, cuda graph: False, gen throughput (token/s): 5.09, #queue-req: 0
[2025-05-30 02:30:46] Decode batch. #running-req: 1, #token: 95, token usage: 0.00, cuda graph: False, gen throughput (token/s): 105.06, #queue-req: 0
[2025-05-30 02:30:46] Decode batch. #running-req: 1, #token: 135, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.84, #queue-req: 0
[2025-05-30 02:30:47] Decode batch. #running-req: 1, #token: 175, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.22, #queue-req: 0
[2025-05-30 02:30:47] Decode batch. #running-req: 1, #token: 215, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.62, #queue-req: 0
[2025-05-30 02:30:47] Decode batch. #running-req: 1, #token: 255, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.79, #queue-req: 0
[2025-05-30 02:30:48] Decode batch. #running-req: 1, #token: 295, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.65, #queue-req: 0
[2025-05-30 02:30:48] Decode batch. #running-req: 1, #token: 335, token usage: 0.02, cuda graph: False, gen throughput (token/s): 104.10, #queue-req: 0
[2025-05-30 02:30:49] Decode batch. #running-req: 1, #token: 375, token usage: 0.02, cuda graph: False, gen throughput (token/s): 103.86, #queue-req: 0
[2025-05-30 02:30:49] Decode batch. #running-req: 1, #token: 415, token usage: 0.02, cuda graph: False, gen throughput (token/s): 103.85, #queue-req: 0
[2025-05-30 02:30:49] Decode batch. #running-req: 1, #token: 455, token usage: 0.02, cuda graph: False, gen throughput (token/s): 102.81, #queue-req: 0
[2025-05-30 02:30:49] INFO:     127.0.0.1:49856 - "POST /v1/chat/completions HTTP/1.1" 200 OK
reasoing_content: Alright, so the user asked for the information and population of the capital of France in JSON format. I immediately thought about what the capital is—Paris. Then, I considered the population. I know it's a big city, but I'm not exactly sure of the exact number. I remember it's over 3 million, but I'm not certain if it's 3.5 or 3.6. I should double-check that.

Wait, maybe I should think about recent data. I recall that Paris has been growing a bit, so perhaps it's around 3.6 million. But I'm not 100% sure. I should look up the latest statistics to confirm. Okay, checking a reliable source, and it says the population is approximately 3,600,000. That seems right.

Now, the user wants this in JSON format. I need to structure it properly. The key should be "capital" with the name "Paris," and another key for "population." I'll make sure the population number is correctly formatted as an integer in the JSON.

I also need to present this clearly. Maybe I should write it out in the response, showing the JSON structure so the user can see exactly what they'll get. That way, they can easily copy it or use it in their application.

I wonder if the user is a developer working on a project that requires population data. They might need it in JSON for integration or analysis. So providing the exact format they asked for is important. I should also make sure the response is concise and to the point, without any extra fluff.

Is there anything else I should consider? Maybe the date of the population count, but the user didn't specify that. I'll stick to the most recent data available. Also, I should ensure that the JSON is valid, so using proper syntax with quotes and commas.

In summary, I'll provide the JSON with the correct capital name and population number, making sure it's accurate and in the format the user requested. That should meet their needs effectively.


content: {"name": "Paris", "population": 3600000}

JSON Schema Directly

[3]:
import json

json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    messages=[
        {
            "role": "assistant",
            "content": "Give me the information and population of the capital of France in the JSON format.",
        },
    ],
    temperature=0,
    max_tokens=2048,
    response_format={
        "type": "json_schema",
        "json_schema": {"name": "foo", "schema": json.loads(json_schema)},
    },
)

print_highlight(
    f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-05-30 02:30:49] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 21, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:50] Decode batch. #running-req: 1, #token: 52, token usage: 0.00, cuda graph: False, gen throughput (token/s): 86.37, #queue-req: 0
[2025-05-30 02:30:50] Decode batch. #running-req: 1, #token: 92, token usage: 0.00, cuda graph: False, gen throughput (token/s): 104.64, #queue-req: 0
[2025-05-30 02:30:51] Decode batch. #running-req: 1, #token: 132, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.08, #queue-req: 0
[2025-05-30 02:30:51] Decode batch. #running-req: 1, #token: 172, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.17, #queue-req: 0
[2025-05-30 02:30:51] Decode batch. #running-req: 1, #token: 212, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.88, #queue-req: 0
[2025-05-30 02:30:52] Decode batch. #running-req: 1, #token: 252, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.61, #queue-req: 0
[2025-05-30 02:30:52] Decode batch. #running-req: 1, #token: 292, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.45, #queue-req: 0
[2025-05-30 02:30:52] Decode batch. #running-req: 1, #token: 332, token usage: 0.02, cuda graph: False, gen throughput (token/s): 104.80, #queue-req: 0
[2025-05-30 02:30:53] INFO:     127.0.0.1:49856 - "POST /v1/chat/completions HTTP/1.1" 200 OK
reasoing_content: Alright, so the user asked for the information and population of the capital of France in JSON format. I immediately thought about what the capital is. Paris is the capital, no doubt there. Now, I need to recall or look up the population. I remember it's a big city, so the population is in the millions. I think it's around 2 million, but I'm not 100% sure. Maybe I should double-check that.

Wait, I should consider the exact figure. I believe the population is approximately 2,175,000 as of 2023. That seems right, but I'm not certain. I should make sure to present it accurately. Also, the user wants it in JSON format, so I need to structure it properly with the key "capital" and the value as a string.

I wonder if the user is a student working on a project or maybe a developer needing data for an app. Either way, providing the correct and up-to-date information is crucial. Maybe they also need this for a presentation or a report, so accuracy is key.

I should also consider if there are any other details they might need, like the area or the time zone, but the user specifically asked for population. So I'll stick to that. I'll format the JSON correctly, making sure the syntax is right to avoid any errors.

In summary, I'll provide the JSON with the correct population number for Paris, the capital of France, ensuring it's accurate and formatted as requested.


content: {"name": "Paris", "population": 2175000}

EBNF#

[4]:
ebnf_grammar = """
root ::= city | description
city ::= "London" | "Paris" | "Berlin" | "Rome"
description ::= city " is " status
status ::= "the capital of " country
country ::= "England" | "France" | "Germany" | "Italy"
"""

response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    messages=[
        {"role": "system", "content": "You are a helpful geography bot."},
        {
            "role": "assistant",
            "content": "Give me the information and population of the capital of France in the JSON format.",
        },
    ],
    temperature=0,
    max_tokens=2048,
    extra_body={"ebnf": ebnf_grammar},
)

print_highlight(
    f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-05-30 02:30:53] Prefill batch. #new-seq: 1, #new-token: 28, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:53] Decode batch. #running-req: 1, #token: 42, token usage: 0.00, cuda graph: False, gen throughput (token/s): 87.00, #queue-req: 0
[2025-05-30 02:30:53] Decode batch. #running-req: 1, #token: 82, token usage: 0.00, cuda graph: False, gen throughput (token/s): 95.16, #queue-req: 0
[2025-05-30 02:30:54] Decode batch. #running-req: 1, #token: 122, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.13, #queue-req: 0
[2025-05-30 02:30:54] Decode batch. #running-req: 1, #token: 162, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.63, #queue-req: 0
[2025-05-30 02:30:54] Decode batch. #running-req: 1, #token: 202, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.92, #queue-req: 0
[2025-05-30 02:30:55] Decode batch. #running-req: 1, #token: 242, token usage: 0.01, cuda graph: False, gen throughput (token/s): 102.77, #queue-req: 0
[2025-05-30 02:30:55] Decode batch. #running-req: 1, #token: 282, token usage: 0.01, cuda graph: False, gen throughput (token/s): 102.96, #queue-req: 0
[2025-05-30 02:30:56] Decode batch. #running-req: 1, #token: 322, token usage: 0.02, cuda graph: False, gen throughput (token/s): 104.38, #queue-req: 0
[2025-05-30 02:30:56] Decode batch. #running-req: 1, #token: 362, token usage: 0.02, cuda graph: False, gen throughput (token/s): 103.84, #queue-req: 0
[2025-05-30 02:30:56] Decode batch. #running-req: 1, #token: 402, token usage: 0.02, cuda graph: False, gen throughput (token/s): 100.74, #queue-req: 0
[2025-05-30 02:30:57] INFO:     127.0.0.1:49856 - "POST /v1/chat/completions HTTP/1.1" 200 OK
reasoing_content: Okay, so the user asked for the information and population of the capital of France in JSON format. I responded with Paris, its population of around 2.1 million, and included some key facts. Now, the user is following up with another query about the population of the United States. They want the data in JSON again, but this time for the entire country.

Hmm, I need to make sure I provide accurate and up-to-date information. The population of the US is a big number, so I should check the latest estimates. I remember it's over 300 million, maybe around 332 million as of 2023. I should include that. Also, adding some key facts about the US population would be helpful, like the diversity of states, the fastest-growing state, and the slowest-growing state. That gives a bit more context and makes the information more useful.

I should structure the JSON properly, making sure the keys are clear and the data is easy to read. Maybe include the country name, population, and the key facts as an array. I need to ensure the JSON syntax is correct, with proper commas and brackets. Also, I should avoid any markdown formatting as per the instructions, keeping it plain text.

Wait, should I mention the growth rate? The user didn't ask for it, but it's a related fact. Maybe just stick to what they asked for unless they specify more details. I'll focus on the population and the key facts they provided.

Double-checking the population number is crucial. I think it's 332.8 million as of the latest data, so I'll use that. For the key facts, I'll list the most populous and least populous states. That should cover the user's needs and provide a comprehensive answer in the JSON format they requested.


content: Rome is the capital of Italy

Regular expression#

[5]:
response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    messages=[
        {"role": "assistant", "content": "What is the capital of France?"},
    ],
    temperature=0,
    max_tokens=2048,
    extra_body={"regex": "(Paris|London)"},
)

print_highlight(
    f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-05-30 02:30:57] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 2, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:57] Decode batch. #running-req: 1, #token: 41, token usage: 0.00, cuda graph: False, gen throughput (token/s): 94.19, #queue-req: 0
[2025-05-30 02:30:57] Decode batch. #running-req: 1, #token: 81, token usage: 0.00, cuda graph: False, gen throughput (token/s): 104.68, #queue-req: 0
[2025-05-30 02:30:58] INFO:     127.0.0.1:49856 - "POST /v1/chat/completions HTTP/1.1" 200 OK
reasoing_content: Alright, so the user just asked, "What is the capital of France?" Hmm, that's a pretty straightforward question. I should make sure I provide a clear and accurate answer. Let me think, Paris is definitely the capital. But wait, is there any chance I might have mixed it up with another country? No, I'm pretty sure France's capital is Paris. Maybe I should double-check that to be safe. Yeah, Paris is correct. I'll go with that.


content: London

Structural Tag#

[6]:
tool_get_current_weather = {
    "type": "function",
    "function": {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "city": {
                    "type": "string",
                    "description": "The city to find the weather for, e.g. 'San Francisco'",
                },
                "state": {
                    "type": "string",
                    "description": "the two-letter abbreviation for the state that the city is"
                    " in, e.g. 'CA' which would mean 'California'",
                },
                "unit": {
                    "type": "string",
                    "description": "The unit to fetch the temperature in",
                    "enum": ["celsius", "fahrenheit"],
                },
            },
            "required": ["city", "state", "unit"],
        },
    },
}

tool_get_current_date = {
    "type": "function",
    "function": {
        "name": "get_current_date",
        "description": "Get the current date and time for a given timezone",
        "parameters": {
            "type": "object",
            "properties": {
                "timezone": {
                    "type": "string",
                    "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'",
                }
            },
            "required": ["timezone"],
        },
    },
}

schema_get_current_weather = tool_get_current_weather["function"]["parameters"]
schema_get_current_date = tool_get_current_date["function"]["parameters"]


def get_messages():
    return [
        {
            "role": "system",
            "content": f"""
# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
You have access to the following functions:
Use the function 'get_current_weather' to: Get the current weather in a given location
{tool_get_current_weather["function"]}
Use the function 'get_current_date' to: Get the current date and time for a given timezone
{tool_get_current_date["function"]}
If a you choose to call a function ONLY reply in the following format:
<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{{"example_name": "example_value"}}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.""",
        },
        {
            "role": "assistant",
            "content": "You are in New York. Please get the current date and time, and the weather.",
        },
    ]


messages = get_messages()

response = client.chat.completions.create(
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    messages=messages,
    response_format={
        "type": "structural_tag",
        "max_new_tokens": 2048,
        "structures": [
            {
                "begin": "<function=get_current_weather>",
                "schema": schema_get_current_weather,
                "end": "</function>",
            },
            {
                "begin": "<function=get_current_date>",
                "schema": schema_get_current_date,
                "end": "</function>",
            },
        ],
        "triggers": ["<function="],
    },
)

print_highlight(
    f"reasoing_content: {response.choices[0].message.reasoning_content}\n\ncontent: {response.choices[0].message.content}"
)
[2025-05-30 02:30:58] Prefill batch. #new-seq: 1, #new-token: 472, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:30:58] Decode batch. #running-req: 1, #token: 480, token usage: 0.02, cuda graph: False, gen throughput (token/s): 57.67, #queue-req: 0
[2025-05-30 02:30:58] Decode batch. #running-req: 1, #token: 520, token usage: 0.03, cuda graph: False, gen throughput (token/s): 100.60, #queue-req: 0
[2025-05-30 02:30:59] Decode batch. #running-req: 1, #token: 560, token usage: 0.03, cuda graph: False, gen throughput (token/s): 101.36, #queue-req: 0
[2025-05-30 02:30:59] Decode batch. #running-req: 1, #token: 600, token usage: 0.03, cuda graph: False, gen throughput (token/s): 101.41, #queue-req: 0
[2025-05-30 02:30:59] Decode batch. #running-req: 1, #token: 640, token usage: 0.03, cuda graph: False, gen throughput (token/s): 100.78, #queue-req: 0
[2025-05-30 02:31:00] Decode batch. #running-req: 1, #token: 680, token usage: 0.03, cuda graph: False, gen throughput (token/s): 98.60, #queue-req: 0
[2025-05-30 02:31:00] Decode batch. #running-req: 1, #token: 720, token usage: 0.04, cuda graph: False, gen throughput (token/s): 100.10, #queue-req: 0
[2025-05-30 02:31:01] INFO:     127.0.0.1:49856 - "POST /v1/chat/completions HTTP/1.1" 200 OK
reasoing_content: Okay, so the user is asking for the current date and time in New York and the weather there. I need to figure out how to get this information using the functions provided.

First, I think I should use the get_current_date function. It requires a timezone parameter. The user mentioned New York, so I should specify 'America/New_York' as the timezone. That should give me the current date and time in that location.

Next, for the weather, I should use get_current_weather. I need the city, state, and unit. The city is New York, the state is NY, and since the user didn't specify, I'll default to Celsius for the temperature unit.

I should structure the function calls step by step. First, call get_current_date with the timezone. Then, call get_current_weather with the city, state, and unit. I'll need to format each function call correctly, making sure to include the parameters in JSON format within the function call syntax.

Finally, I'll present both results to the user, showing the date/time and weather information clearly.


content: {"timezone": "America/New_York"}
{"city": "New York", "state": "NY", "unit": "celsius"}

Native API and SGLang Runtime (SRT)#

JSON#

Using Pydantic

[7]:
import requests
from pydantic import BaseModel, Field
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


messages = [
    {
        "role": "assistant",
        "content": "Give me the information and population of the capital of France in the JSON format.",
    },
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
# Make API request
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 2048,
            "json_schema": json.dumps(CapitalInfo.model_json_schema()),
        },
    },
)
print(response.json())


reasoing_content = response.json()["text"].split("</think>")[0]
content = response.json()["text"].split("</think>")[1]
print_highlight(f"reasoing_content: {reasoing_content}\n\ncontent: {content}")
[2025-05-30 02:31:05] Prefill batch. #new-seq: 1, #new-token: 22, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:31:05] Decode batch. #running-req: 1, #token: 41, token usage: 0.00, cuda graph: False, gen throughput (token/s): 8.64, #queue-req: 0
[2025-05-30 02:31:05] Decode batch. #running-req: 1, #token: 81, token usage: 0.00, cuda graph: False, gen throughput (token/s): 102.33, #queue-req: 0
[2025-05-30 02:31:06] Decode batch. #running-req: 1, #token: 121, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.04, #queue-req: 0
[2025-05-30 02:31:06] Decode batch. #running-req: 1, #token: 161, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.86, #queue-req: 0
[2025-05-30 02:31:07] Decode batch. #running-req: 1, #token: 201, token usage: 0.01, cuda graph: False, gen throughput (token/s): 87.02, #queue-req: 0
[2025-05-30 02:31:07] Decode batch. #running-req: 1, #token: 241, token usage: 0.01, cuda graph: False, gen throughput (token/s): 85.38, #queue-req: 0
[2025-05-30 02:31:07] Decode batch. #running-req: 1, #token: 281, token usage: 0.01, cuda graph: False, gen throughput (token/s): 84.64, #queue-req: 0
[2025-05-30 02:31:08] Decode batch. #running-req: 1, #token: 321, token usage: 0.02, cuda graph: False, gen throughput (token/s): 85.29, #queue-req: 0
[2025-05-30 02:31:08] Decode batch. #running-req: 1, #token: 361, token usage: 0.02, cuda graph: False, gen throughput (token/s): 89.60, #queue-req: 0
[2025-05-30 02:31:09] INFO:     127.0.0.1:51912 - "POST /generate HTTP/1.1" 200 OK
{'text': 'Okay, so the user is asking for the information and population of the capital of France in JSON format. Let me break this down.\n\nFirst, I need to identify the capital of France. I know that Paris is the capital, so that\'s straightforward. Now, I should find the most recent population data. I remember that the population of Paris has been growing, but I\'m not sure of the exact number. I think it\'s around 2 million, but I should verify that.\n\nWait, I should check the latest statistics to be accurate. Maybe I can recall that as of 2023, the population was approximately 2,150,000. That seems about right. I should make sure to include this number in the JSON.\n\nNext, I need to structure this information into a JSON format. JSON typically uses key-value pairs, so I\'ll create an object with keys like "city", "population", and "country". The city is Paris, the population is 2,150,000, and the country is France.\n\nI should also consider the format. The user wants it in JSON, so I\'ll make sure to use proper syntax with quotes and commas. I\'ll avoid any markdown since they specified that, so just plain JSON.\n\nPutting it all together, the JSON object will have the city, population, and country. I\'ll double-check the numbers to ensure accuracy. I think 2,150,000 is correct, but if I\'m unsure, I might mention that the data is approximate.\n\nFinally, I\'ll present the JSON without any additional text, just the code, as per the user\'s request. That should fulfill their query effectively.\n</think>{\n  "name": "Paris",\n  "population": 2150000\n}', 'meta_info': {'id': 'b628cbb1c5bf4f50ade3cf21712d7f58', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 23, 'completion_tokens': 374, 'cached_tokens': 1, 'e2e_latency': 4.016446590423584}}
reasoing_content: Okay, so the user is asking for the information and population of the capital of France in JSON format. Let me break this down.

First, I need to identify the capital of France. I know that Paris is the capital, so that's straightforward. Now, I should find the most recent population data. I remember that the population of Paris has been growing, but I'm not sure of the exact number. I think it's around 2 million, but I should verify that.

Wait, I should check the latest statistics to be accurate. Maybe I can recall that as of 2023, the population was approximately 2,150,000. That seems about right. I should make sure to include this number in the JSON.

Next, I need to structure this information into a JSON format. JSON typically uses key-value pairs, so I'll create an object with keys like "city", "population", and "country". The city is Paris, the population is 2,150,000, and the country is France.

I should also consider the format. The user wants it in JSON, so I'll make sure to use proper syntax with quotes and commas. I'll avoid any markdown since they specified that, so just plain JSON.

Putting it all together, the JSON object will have the city, population, and country. I'll double-check the numbers to ensure accuracy. I think 2,150,000 is correct, but if I'm unsure, I might mention that the data is approximate.

Finally, I'll present the JSON without any additional text, just the code, as per the user's request. That should fulfill their query effectively.


content: {
"name": "Paris",
"population": 2150000
}

JSON Schema Directly

[8]:
json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

# JSON
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": text,
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 2048,
            "json_schema": json_schema,
        },
    },
)

print_highlight(response.json())
[2025-05-30 02:31:09] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 22, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:31:09] Decode batch. #running-req: 1, #token: 27, token usage: 0.00, cuda graph: False, gen throughput (token/s): 97.44, #queue-req: 0
[2025-05-30 02:31:09] Decode batch. #running-req: 1, #token: 67, token usage: 0.00, cuda graph: False, gen throughput (token/s): 104.06, #queue-req: 0
[2025-05-30 02:31:10] Decode batch. #running-req: 1, #token: 107, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.50, #queue-req: 0
[2025-05-30 02:31:10] Decode batch. #running-req: 1, #token: 147, token usage: 0.01, cuda graph: False, gen throughput (token/s): 100.65, #queue-req: 0
[2025-05-30 02:31:10] Decode batch. #running-req: 1, #token: 187, token usage: 0.01, cuda graph: False, gen throughput (token/s): 95.92, #queue-req: 0
[2025-05-30 02:31:11] Decode batch. #running-req: 1, #token: 227, token usage: 0.01, cuda graph: False, gen throughput (token/s): 92.89, #queue-req: 0
[2025-05-30 02:31:11] Decode batch. #running-req: 1, #token: 267, token usage: 0.01, cuda graph: False, gen throughput (token/s): 94.75, #queue-req: 0
[2025-05-30 02:31:12] Decode batch. #running-req: 1, #token: 307, token usage: 0.01, cuda graph: False, gen throughput (token/s): 100.33, #queue-req: 0
[2025-05-30 02:31:12] Decode batch. #running-req: 1, #token: 347, token usage: 0.02, cuda graph: False, gen throughput (token/s): 95.36, #queue-req: 0
[2025-05-30 02:31:12] Decode batch. #running-req: 1, #token: 387, token usage: 0.02, cuda graph: False, gen throughput (token/s): 98.96, #queue-req: 0
[2025-05-30 02:31:13] Decode batch. #running-req: 1, #token: 427, token usage: 0.02, cuda graph: False, gen throughput (token/s): 98.64, #queue-req: 0
[2025-05-30 02:31:13] Decode batch. #running-req: 1, #token: 467, token usage: 0.02, cuda graph: False, gen throughput (token/s): 94.21, #queue-req: 0
[2025-05-30 02:31:14] Decode batch. #running-req: 1, #token: 507, token usage: 0.02, cuda graph: False, gen throughput (token/s): 94.03, #queue-req: 0
[2025-05-30 02:31:14] INFO:     127.0.0.1:51916 - "POST /generate HTTP/1.1" 200 OK
{'text': 'Okay, so the user is asking for the information and population of the capital of France in JSON format. Let me break this down. First, I need to identify what the capital of France is. I know that Paris is the capital, so that\'s the starting point.\n\nNext, I need to find the population of Paris. I remember that Paris is a major city with a large population, but I\'m not exactly sure of the current number. I think it\'s around 2 million, but I should double-check that. Maybe I can recall that it\'s approximately 2,150,000 as of recent estimates.\n\nNow, the user wants this information in JSON format. JSON stands for JavaScript Object Notation, which is a way to structure data. I need to create a JSON object that includes the key "capital" with the value "Paris" and another key "population" with the number I just thought of.\n\nI should make sure the JSON syntax is correct. That means using double quotes for keys and string values, and commas appropriately between key-value pairs. Also, the numbers should be in quotes if they\'re strings, but population is a number, so it should be without quotes.\n\nPutting it all together, the JSON object should look like this: {"capital": "Paris", "population": 2150000}. I should present this clearly so the user can easily understand and use the information.\n\nI wonder if the user needs more details, like the population figure\'s source or the exact year it was recorded. But since they didn\'t ask for that, I\'ll stick to the information requested. Maybe they just need a straightforward data structure for a program or a report.\n\nAlso, considering the user\'s request, they might be a student working on a project, or perhaps someone developing an app that requires capital cities and their populations. Either way, providing accurate and concise data is key.\n\nI should ensure that the JSON is correctly formatted to avoid any errors when the user tries to use it. No markdown, just plain JSON. That way, it\'s easy to copy and paste into their code or document.\n\nIn summary, I identified the capital, found the population number, structured it into a JSON object, and made sure the syntax is correct. I think this should meet the user\'s needs effectively.\n{\n "name": "Paris",\n "population": 2150000\n}', 'meta_info': {'id': '4c53a5df72e640bcbe000542857ff2df', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 23, 'completion_tokens': 497, 'cached_tokens': 22, 'e2e_latency': 5.104497194290161}}

EBNF#

[9]:
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": "Give me the information of the capital of France.",
        "sampling_params": {
            "max_new_tokens": 2048,
            "temperature": 0,
            "n": 3,
            "ebnf": (
                "root ::= city | description\n"
                'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
                'description ::= city " is " status\n'
                'status ::= "the capital of " country\n'
                'country ::= "England" | "France" | "Germany" | "Italy"'
            ),
        },
        "stream": False,
        "return_logprob": False,
    },
)

print(response.json())
[2025-05-30 02:31:14] Prefill batch. #new-seq: 1, #new-token: 10, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:31:14] Prefill batch. #new-seq: 3, #new-token: 3, #cached-token: 30, token usage: 0.00, #running-req: 1, #queue-req: 0
[2025-05-30 02:31:14] Decode batch. #running-req: 3, #token: 89, token usage: 0.00, cuda graph: False, gen throughput (token/s): 147.71, #queue-req: 0
[2025-05-30 02:31:15] Decode batch. #running-req: 3, #token: 209, token usage: 0.01, cuda graph: False, gen throughput (token/s): 257.04, #queue-req: 0
[2025-05-30 02:31:15] Decode batch. #running-req: 3, #token: 329, token usage: 0.02, cuda graph: False, gen throughput (token/s): 282.86, #queue-req: 0
[2025-05-30 02:31:16] Decode batch. #running-req: 3, #token: 449, token usage: 0.02, cuda graph: False, gen throughput (token/s): 289.19, #queue-req: 0
[2025-05-30 02:31:16] Decode batch. #running-req: 3, #token: 569, token usage: 0.03, cuda graph: False, gen throughput (token/s): 286.94, #queue-req: 0
[2025-05-30 02:31:16] INFO:     127.0.0.1:39830 - "POST /generate HTTP/1.1" 200 OK
[{'text': "\nThe capital of France is Paris.\n\nThat's all the information I have.\n\nOkay, so I need to figure out the capital of France. I know that Paris is the capital, but I'm not entirely sure. Let me think about why I think that. I've heard it mentioned a lot, especially in movies and TV shows. People often go there for business or tourism. Also, I remember learning in school that Paris is a major city in France, known for landmarks like the Eiffel Tower and the Louvre Museum. Those places are famous worldwide, which makes me think that Paris is indeed the capital. Maybe I can cross-check this with some other sources or my notes. Wait, I don't have any other information right now, but based on what I know, Paris is the capital of France. I don't recall any other major city in France being referred to as the capital. So, I'm pretty confident that Paris is correct.\n</think>Paris is the capital of France", 'meta_info': {'id': 'c8c1ed6c51324f04aa1c657b17c94eca', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 201, 'cached_tokens': 10, 'e2e_latency': 2.344856023788452}}, {'text': "\nThe capital of France is Paris.\n\nThat's all the information I have.\n\nOkay, so I need to figure out the capital of France. I know that Paris is the capital, but I'm not entirely sure. Let me think about why I think that. I've heard it mentioned a lot, especially in movies and TV shows. People often go there for business or tourism. Also, I remember learning in school that Paris is a major city in France, known for landmarks like the Eiffel Tower and the Louvre Museum. Those places are famous worldwide, which makes me think that Paris is indeed the capital. Maybe I can cross-check this with some other sources or my notes. Wait, I don't have any other information right now, but based on what I know, Paris is the capital of France. I don't recall any other major city in France being referred to as the capital. So, I'm pretty confident that Paris is correct.\n</think>Paris is the capital of France", 'meta_info': {'id': 'c14e20d5cc08424d950077d9c6011578', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 201, 'cached_tokens': 10, 'e2e_latency': 2.344860792160034}}, {'text': "\nThe capital of France is Paris.\n\nThat's all the information I have.\n\nOkay, so I need to figure out the capital of France. I know that Paris is the capital, but I'm not entirely sure. Let me think about why I think that. I've heard it mentioned a lot, especially in movies and TV shows. People often go there for business or tourism. Also, I remember learning in school that Paris is a major city in France, known for landmarks like the Eiffel Tower and the Louvre Museum. Those places are famous worldwide, which makes me think that Paris is indeed the capital. Maybe I can cross-check this with some other sources or my notes. Wait, I don't have any other information right now, but based on what I know, Paris is the capital of France. I don't recall any other major city in France being referred to as the capital. So, I'm pretty confident that Paris is correct.\n</think>Paris is the capital of France", 'meta_info': {'id': '33bb698f7c4d418eb2fc90ae6ddd944a', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 11, 'completion_tokens': 201, 'cached_tokens': 10, 'e2e_latency': 2.3448638916015625}}]

Regular expression#

[10]:
response = requests.post(
    f"http://localhost:{port}/generate",
    json={
        "text": "Paris is the capital of",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 2048,
            "regex": "(France|England)",
        },
    },
)
print(response.json())
[2025-05-30 02:31:16] Prefill batch. #new-seq: 1, #new-token: 5, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:31:17] Decode batch. #running-req: 1, #token: 31, token usage: 0.00, cuda graph: False, gen throughput (token/s): 166.48, #queue-req: 0
[2025-05-30 02:31:17] Decode batch. #running-req: 1, #token: 71, token usage: 0.00, cuda graph: False, gen throughput (token/s): 104.75, #queue-req: 0
[2025-05-30 02:31:17] Decode batch. #running-req: 1, #token: 111, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.74, #queue-req: 0
[2025-05-30 02:31:18] Decode batch. #running-req: 1, #token: 151, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.91, #queue-req: 0
[2025-05-30 02:31:18] Decode batch. #running-req: 1, #token: 191, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.34, #queue-req: 0
[2025-05-30 02:31:18] Decode batch. #running-req: 1, #token: 231, token usage: 0.01, cuda graph: False, gen throughput (token/s): 100.87, #queue-req: 0
[2025-05-30 02:31:19] Decode batch. #running-req: 1, #token: 271, token usage: 0.01, cuda graph: False, gen throughput (token/s): 99.26, #queue-req: 0
[2025-05-30 02:31:19] Decode batch. #running-req: 1, #token: 311, token usage: 0.02, cuda graph: False, gen throughput (token/s): 92.52, #queue-req: 0
[2025-05-30 02:31:20] Decode batch. #running-req: 1, #token: 351, token usage: 0.02, cuda graph: False, gen throughput (token/s): 84.83, #queue-req: 0
[2025-05-30 02:31:20] Decode batch. #running-req: 1, #token: 391, token usage: 0.02, cuda graph: False, gen throughput (token/s): 84.91, #queue-req: 0
[2025-05-30 02:31:21] Decode batch. #running-req: 1, #token: 431, token usage: 0.02, cuda graph: False, gen throughput (token/s): 94.79, #queue-req: 0
[2025-05-30 02:31:21] Decode batch. #running-req: 1, #token: 471, token usage: 0.02, cuda graph: False, gen throughput (token/s): 97.83, #queue-req: 0
[2025-05-30 02:31:21] Decode batch. #running-req: 1, #token: 511, token usage: 0.02, cuda graph: False, gen throughput (token/s): 105.04, #queue-req: 0
[2025-05-30 02:31:22] Decode batch. #running-req: 1, #token: 551, token usage: 0.03, cuda graph: False, gen throughput (token/s): 103.42, #queue-req: 0
[2025-05-30 02:31:22] Decode batch. #running-req: 1, #token: 591, token usage: 0.03, cuda graph: False, gen throughput (token/s): 102.74, #queue-req: 0
[2025-05-30 02:31:23] Decode batch. #running-req: 1, #token: 631, token usage: 0.03, cuda graph: False, gen throughput (token/s): 102.58, #queue-req: 0
[2025-05-30 02:31:23] Decode batch. #running-req: 1, #token: 671, token usage: 0.03, cuda graph: False, gen throughput (token/s): 102.34, #queue-req: 0
[2025-05-30 02:31:23] Decode batch. #running-req: 1, #token: 711, token usage: 0.03, cuda graph: False, gen throughput (token/s): 102.12, #queue-req: 0
[2025-05-30 02:31:24] Decode batch. #running-req: 1, #token: 751, token usage: 0.04, cuda graph: False, gen throughput (token/s): 103.87, #queue-req: 0
[2025-05-30 02:31:24] Decode batch. #running-req: 1, #token: 791, token usage: 0.04, cuda graph: False, gen throughput (token/s): 102.88, #queue-req: 0
[2025-05-30 02:31:25] Decode batch. #running-req: 1, #token: 831, token usage: 0.04, cuda graph: False, gen throughput (token/s): 103.52, #queue-req: 0
[2025-05-30 02:31:25] Decode batch. #running-req: 1, #token: 871, token usage: 0.04, cuda graph: False, gen throughput (token/s): 102.41, #queue-req: 0
[2025-05-30 02:31:25] Decode batch. #running-req: 1, #token: 911, token usage: 0.04, cuda graph: False, gen throughput (token/s): 102.02, #queue-req: 0
[2025-05-30 02:31:26] Decode batch. #running-req: 1, #token: 951, token usage: 0.05, cuda graph: False, gen throughput (token/s): 102.53, #queue-req: 0
[2025-05-30 02:31:26] Decode batch. #running-req: 1, #token: 991, token usage: 0.05, cuda graph: False, gen throughput (token/s): 101.92, #queue-req: 0
[2025-05-30 02:31:26] Decode batch. #running-req: 1, #token: 1031, token usage: 0.05, cuda graph: False, gen throughput (token/s): 102.20, #queue-req: 0
[2025-05-30 02:31:27] Decode batch. #running-req: 1, #token: 1071, token usage: 0.05, cuda graph: False, gen throughput (token/s): 100.49, #queue-req: 0
[2025-05-30 02:31:27] Decode batch. #running-req: 1, #token: 1111, token usage: 0.05, cuda graph: False, gen throughput (token/s): 101.06, #queue-req: 0
[2025-05-30 02:31:28] Decode batch. #running-req: 1, #token: 1151, token usage: 0.06, cuda graph: False, gen throughput (token/s): 101.95, #queue-req: 0
[2025-05-30 02:31:28] Decode batch. #running-req: 1, #token: 1191, token usage: 0.06, cuda graph: False, gen throughput (token/s): 100.83, #queue-req: 0
[2025-05-30 02:31:28] Decode batch. #running-req: 1, #token: 1231, token usage: 0.06, cuda graph: False, gen throughput (token/s): 102.17, #queue-req: 0
[2025-05-30 02:31:29] Decode batch. #running-req: 1, #token: 1271, token usage: 0.06, cuda graph: False, gen throughput (token/s): 101.50, #queue-req: 0
[2025-05-30 02:31:29] Decode batch. #running-req: 1, #token: 1311, token usage: 0.06, cuda graph: False, gen throughput (token/s): 101.43, #queue-req: 0
[2025-05-30 02:31:30] Decode batch. #running-req: 1, #token: 1351, token usage: 0.07, cuda graph: False, gen throughput (token/s): 101.42, #queue-req: 0
[2025-05-30 02:31:30] Decode batch. #running-req: 1, #token: 1391, token usage: 0.07, cuda graph: False, gen throughput (token/s): 101.24, #queue-req: 0
[2025-05-30 02:31:30] Decode batch. #running-req: 1, #token: 1431, token usage: 0.07, cuda graph: False, gen throughput (token/s): 100.98, #queue-req: 0
[2025-05-30 02:31:31] Decode batch. #running-req: 1, #token: 1471, token usage: 0.07, cuda graph: False, gen throughput (token/s): 100.52, #queue-req: 0
[2025-05-30 02:31:31] Decode batch. #running-req: 1, #token: 1511, token usage: 0.07, cuda graph: False, gen throughput (token/s): 100.82, #queue-req: 0
[2025-05-30 02:31:32] Decode batch. #running-req: 1, #token: 1551, token usage: 0.08, cuda graph: False, gen throughput (token/s): 101.39, #queue-req: 0
[2025-05-30 02:31:32] Decode batch. #running-req: 1, #token: 1591, token usage: 0.08, cuda graph: False, gen throughput (token/s): 102.62, #queue-req: 0
[2025-05-30 02:31:32] Decode batch. #running-req: 1, #token: 1631, token usage: 0.08, cuda graph: False, gen throughput (token/s): 102.56, #queue-req: 0
[2025-05-30 02:31:33] Decode batch. #running-req: 1, #token: 1671, token usage: 0.08, cuda graph: False, gen throughput (token/s): 102.03, #queue-req: 0
[2025-05-30 02:31:33] Decode batch. #running-req: 1, #token: 1711, token usage: 0.08, cuda graph: False, gen throughput (token/s): 102.16, #queue-req: 0
[2025-05-30 02:31:34] Decode batch. #running-req: 1, #token: 1751, token usage: 0.09, cuda graph: False, gen throughput (token/s): 101.97, #queue-req: 0
[2025-05-30 02:31:34] Decode batch. #running-req: 1, #token: 1791, token usage: 0.09, cuda graph: False, gen throughput (token/s): 102.13, #queue-req: 0
[2025-05-30 02:31:34] Decode batch. #running-req: 1, #token: 1831, token usage: 0.09, cuda graph: False, gen throughput (token/s): 102.44, #queue-req: 0
[2025-05-30 02:31:35] Decode batch. #running-req: 1, #token: 1871, token usage: 0.09, cuda graph: False, gen throughput (token/s): 101.92, #queue-req: 0
[2025-05-30 02:31:35] Decode batch. #running-req: 1, #token: 1911, token usage: 0.09, cuda graph: False, gen throughput (token/s): 102.02, #queue-req: 0
[2025-05-30 02:31:36] Decode batch. #running-req: 1, #token: 1951, token usage: 0.10, cuda graph: False, gen throughput (token/s): 102.90, #queue-req: 0
[2025-05-30 02:31:36] Decode batch. #running-req: 1, #token: 1991, token usage: 0.10, cuda graph: False, gen throughput (token/s): 102.61, #queue-req: 0
[2025-05-30 02:31:36] Decode batch. #running-req: 1, #token: 2031, token usage: 0.10, cuda graph: False, gen throughput (token/s): 102.74, #queue-req: 0
[2025-05-30 02:31:37] INFO:     127.0.0.1:39834 - "POST /generate HTTP/1.1" 200 OK
{'text': ' France, and the \n\\( n \\)  \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\( l \\) \\( m \\) \\( k \\) \\(', 'meta_info': {'id': '5066e1e740df4fd5a3d912e4b86abd02', 'finish_reason': {'type': 'length', 'length': 2048}, 'prompt_tokens': 6, 'completion_tokens': 2048, 'cached_tokens': 1, 'e2e_latency': 20.29635262489319}}

Structural Tag#

[11]:
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
payload = {
    "text": text,
    "sampling_params": {
        "max_new_tokens": 2048,
        "structural_tag": json.dumps(
            {
                "type": "structural_tag",
                "structures": [
                    {
                        "begin": "<function=get_current_weather>",
                        "schema": schema_get_current_weather,
                        "end": "</function>",
                    },
                    {
                        "begin": "<function=get_current_date>",
                        "schema": schema_get_current_date,
                        "end": "</function>",
                    },
                ],
                "triggers": ["<function="],
            }
        ),
    },
}


# Send POST request to the API endpoint
response = requests.post(f"http://localhost:{port}/generate", json=payload)
print_highlight(response.json())
[2025-05-30 02:31:37] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 22, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:31:37] Decode batch. #running-req: 1, #token: 40, token usage: 0.00, cuda graph: False, gen throughput (token/s): 97.50, #queue-req: 0
[2025-05-30 02:31:37] Decode batch. #running-req: 1, #token: 80, token usage: 0.00, cuda graph: False, gen throughput (token/s): 102.61, #queue-req: 0
[2025-05-30 02:31:37] Decode batch. #running-req: 1, #token: 120, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.74, #queue-req: 0
[2025-05-30 02:31:38] Decode batch. #running-req: 1, #token: 160, token usage: 0.01, cuda graph: False, gen throughput (token/s): 105.38, #queue-req: 0
[2025-05-30 02:31:38] Decode batch. #running-req: 1, #token: 200, token usage: 0.01, cuda graph: False, gen throughput (token/s): 104.75, #queue-req: 0
[2025-05-30 02:31:39] Decode batch. #running-req: 1, #token: 240, token usage: 0.01, cuda graph: False, gen throughput (token/s): 103.49, #queue-req: 0
[2025-05-30 02:31:39] Decode batch. #running-req: 1, #token: 280, token usage: 0.01, cuda graph: False, gen throughput (token/s): 102.99, #queue-req: 0
[2025-05-30 02:31:39] Decode batch. #running-req: 1, #token: 320, token usage: 0.02, cuda graph: False, gen throughput (token/s): 103.10, #queue-req: 0
[2025-05-30 02:31:40] Decode batch. #running-req: 1, #token: 360, token usage: 0.02, cuda graph: False, gen throughput (token/s): 103.06, #queue-req: 0
[2025-05-30 02:31:40] Decode batch. #running-req: 1, #token: 400, token usage: 0.02, cuda graph: False, gen throughput (token/s): 101.94, #queue-req: 0
[2025-05-30 02:31:41] Decode batch. #running-req: 1, #token: 440, token usage: 0.02, cuda graph: False, gen throughput (token/s): 100.12, #queue-req: 0
[2025-05-30 02:31:41] Decode batch. #running-req: 1, #token: 480, token usage: 0.02, cuda graph: False, gen throughput (token/s): 101.93, #queue-req: 0
[2025-05-30 02:31:41] Decode batch. #running-req: 1, #token: 520, token usage: 0.03, cuda graph: False, gen throughput (token/s): 104.17, #queue-req: 0
[2025-05-30 02:31:42] Decode batch. #running-req: 1, #token: 560, token usage: 0.03, cuda graph: False, gen throughput (token/s): 98.93, #queue-req: 0
[2025-05-30 02:31:42] INFO:     127.0.0.1:37886 - "POST /generate HTTP/1.1" 200 OK
{'text': 'Alright, so I need to find the information and population of the capital of France in JSON format. Hmm, okay, let\'s break this down step by step. First, I know that the capital of France is Paris, so that\'s my starting point. I remember that Paris is a major city in France, famous for places like the Eiffel Tower and the Louvre Museum. But I need specific data about its population.\n\nWait, I\'m not entirely sure about the exact population of Paris. I think it\'s a big city, so maybe around 2 million people? But I\'m not certain. Let me think—if I were in France, I\'d probably hear Paris referred to as the largest city, so that makes sense. But I should probably check some sources to be accurate.\n\nI recall that the population figures for cities can vary a lot depending on the year and the source. Maybe I should look for the latest official data. The French government probably releases such information regularly. I think they use censuses or other surveys to get population numbers. \n\nI also remember that when it comes to cities in France, Paris is indeed the most populous. Other major cities like Lyon, Marseille, and Lille are also significant, but Paris leads. So, if I were to create a JSON structure, I should include the basic info about Paris as the capital and then the population number.\n\nWait, should I include the exact figure? I think I saw once that Paris has a population around 2.1 million. But I\'m not 100% sure if that\'s the latest number. Maybe it\'s slightly more or less. I should maybe phrase it as an approximate figure to be cautious because population numbers can fluctuate year to year.\n\nOkay, putting this together, the JSON should have a key like "capital" with Paris as the value, and another key like "population" with the approximate number. But I need to present it in a proper JSON format. So, it would be something like:\n\n{\n "capital": "Paris",\n "population": 2100000\n}\n\nI think that\'s it. I should mention that the population figure is approximate and might change. Also, since the user asked for information and population, maybe they expect it in a more detailed JSON with sub-keys. But since the initial answer was straightforward, perhaps that\'s sufficient for now.\n\n\nHere is the information about the capital of France, Paris, presented in JSON format:\n\n```json\n{\n "capital": "Paris",\n "population": 2100000\n}\n```\n\nPlease note that the population figure is approximate and may vary slightly depending on the source and year of data.', 'meta_info': {'id': '8fc879bf8e8a499b851787aa344ee399', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 23, 'completion_tokens': 556, 'cached_tokens': 22, 'e2e_latency': 5.422600984573364}}
[12]:
terminate_process(server_process)
[2025-05-30 02:31:42] Child process unexpectedly failed with an exit code 9. pid=2294859
[2025-05-30 02:31:42] Child process unexpectedly failed with an exit code 9. pid=2294788

Offline Engine API#

[13]:
import sglang as sgl

llm = sgl.Engine(
    model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    reasoning_parser="deepseek-r1",
    grammar_backend="xgrammar",
)
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.46s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.41s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.42s/it]

JSON#

Using Pydantic

[14]:
import json
from pydantic import BaseModel, Field


prompts = [
    "Give me the information of the capital of China in the JSON format.",
    "Give me the information of the capital of France in the JSON format.",
    "Give me the information of the capital of Ireland in the JSON format.",
]


# Define the schema using Pydantic
class CapitalInfo(BaseModel):
    name: str = Field(..., pattern=r"^\w+$", description="Name of the capital city")
    population: int = Field(..., description="Population of the capital city")


sampling_params = {
    "temperature": 0,
    "top_p": 0.95,
    "max_new_tokens": 2048,
    "json_schema": json.dumps(CapitalInfo.model_json_schema()),
}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of China in the JSON format.
Generated text:
Sure, here's the information about the capital of China, Beijing, in JSON format:

```json
{
  "name": "Beijing",
  "capital": "Yes",
  "population": "Over 30 million",
  "founded": "1248",
  "Nickname": "The Heaven on Earth",
  "Location": "Northern China",
  "OfficialLanguages": [
    "Mandarin Chinese",
    "Bingyuan Chinese",
    "Tibetan",
    "Hui",
    "Mongolian",
    "Yugoslav",
    "Other"
  ],
  "KeySights": [
    "The Great Wall",
    "Tiananmen Square",
    "Forbidden City",
    "Beijing Museum",
    "Yuanmingyuan"
  ],
  "Climate": "Temperate"
}
```

Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of France in the JSON format.
Generated text:
Sure! Here's the information about the capital of France, Paris, in JSON format:

```json
{
  "name": "Paris",
  "country": "France",
  "coordinates": {
    "latitude": 48.8566,
    "longitude": 2.3522
  },
  "founded": "1340",
  "population": "9.7 million",
  "area": "105.5 square kilometers",
  "features": {
    "bridges": "The Eiffel Tower, Notre-Dame, and the Seine River",
    "landmarks": "The Louvre Museum, Montmartre, and the Champs-Élysées"
  },
  "elevation": "2 meters",
  "time_zone": "Central European Time (CET)"
}
```

Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of Ireland in the JSON format.
Generated text:
Sure, here's the information about the capital of Ireland in JSON format:

```json
{
  "capital": "Dublin",
  "official_name": "Dublin City",
  "region": "Dublin",
  "coordinates": {
    "latitude": 53.3489,
    "longitude": -6.2009
  },
  "founded": "1543",
  "population": 1,234,567,
  "area": {
    "total": 123.45,
    "land": 112.34,
    "water": 11.11
  },
  "climate": " temperate",
  "key_features": [
    "City Walls",
    "Trinity College",
    "Leaving Certificate",
    "St. Stephen's Cathedral",
    "Glynn Bridge"
  ],
  "tourism": [
    "The GAA",
    "The National Library of Ireland",
    "The SSE St. Patrick's Cathedral",
    "The Phoenix Park",
    "The Book of Kells"
  ]
}
```

Let me know if you need any adjustments!

JSON Schema Directly

[15]:
prompts = [
    "Give me the information of the capital of China in the JSON format.",
    "Give me the information of the capital of France in the JSON format.",
    "Give me the information of the capital of Ireland in the JSON format.",
]

json_schema = json.dumps(
    {
        "type": "object",
        "properties": {
            "name": {"type": "string", "pattern": "^[\\w]+$"},
            "population": {"type": "integer"},
        },
        "required": ["name", "population"],
    }
)

sampling_params = {"temperature": 0, "max_new_tokens": 2048, "json_schema": json_schema}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of China in the JSON format.
Generated text:
Sure, here's the information about the capital of China, Beijing, in JSON format:

```json
{
  "name": "Beijing",
  "capital": "Yes",
  "population": "Over 30 million",
  "founded": "1248",
  "Nickname": "The Heaven on Earth",
  "Location": "Northern China",
  "OfficialLanguages": [
    "Mandarin Chinese",
    "Bingyuan Chinese",
    "Tibetan",
    "Hui",
    "Mongolian",
    "Yugoslav",
    "Other"
  ],
  "KeySights": [
    "The Great Wall",
    "Tiananmen Square",
    "Forbidden City",
    "Beijing Museum",
    "Yuanmingyuan"
  ],
  "Climate": "Temperate"
}
```

Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of France in the JSON format.
Generated text:
Sure! Here's the information about the capital of France, Paris, in JSON format:

```json
{
  "name": "Paris",
  "country": "France",
  "coordinates": {
    "latitude": 48.8566,
    "longitude": 2.3522
  },
  "founded": "1340",
  "population": "9.7 million",
  "area": "105.5 square kilometers",
  "features": {
    "bridges": "The Eiffel Tower, Notre-Dame, and the Seine River",
    "landmarks": "The Louvre Museum, Montmartre, and the Champs-Élysées"
  },
  "elevation": "2 meters",
  "time_zone": "Central European Time (CET)"
}
```

Let me know if you need any other information!
===============================
Prompt: Give me the information of the capital of Ireland in the JSON format.
Generated text:
Sure, here's the information about the capital of Ireland in JSON format:

```json
{
  "capital": "Dublin",
  "official_name": "Dublin City",
  "region": "Dublin",
  "coordinates": {
    "latitude": 53.3489,
    "longitude": -6.2009
  },
  "founded": "1241",
  "population": 1,234,567,
  "area": {
    "total": 123.45,
    "land": 112.34,
    "water": 11.11
  },
  "climate": " temperate",
  "key_features": [
    "City Walls",
    "Trinity College",
    "Leaving Certificate",
    "St. Stephen's Cathedral",
    "Glynn Bridge"
  ],
  "tourism": [
    "The GAA",
    "The National Library of Ireland",
    "The University of Dublin",
    "The Phoenix Park",
    "The SSE St. Patrick's Cathedral Quarter"
  ]
}
```

Let me know if you need any adjustments!

EBNF#

[16]:
prompts = [
    "Give me the information of the capital of France.",
    "Give me the information of the capital of Germany.",
    "Give me the information of the capital of Italy.",
]

sampling_params = {
    "temperature": 0.8,
    "top_p": 0.95,
    "ebnf": (
        "root ::= city | description\n"
        'city ::= "London" | "Paris" | "Berlin" | "Rome"\n'
        'description ::= city " is " status\n'
        'status ::= "the capital of " country\n'
        'country ::= "England" | "France" | "Germany" | "Italy"'
    ),
}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Give me the information of the capital of France.
Generated text:
The capital of France is Paris.

Sure, let me elaborate.

The capital of France is Paris, which is one of the most famous cities in the world. It's located in the northern part of the country, situated on the Seine River. Paris has a rich history and is known for its landmarks such as the Eiffel Tower, the Louvre Museum, the Notre-Dame Cathedral, and the Arc de Triomphe. It's also the headquarters of many international organizations and businesses. Paris is celebrated for its cultural heritage, culinary arts, and vibrant arts scene. Its cuisine, known as French cuisine, is renowned for
===============================
Prompt: Give me the information of the capital of Germany.
Generated text: 1. How many letters does the name of the capital city have?2. What is the population of the capital city?3. What is the name of the capital city?

1. The capital city is Berlin. It has 6 letters.
2. As of the latest estimates, the population of Berlin is around 3.7 million.
3. The name of the capital city is Berlin.
</think>Rome is the capital of Italy
===============================
Prompt: Give me the information of the capital of Italy.
Generated text:  the capital of Italy is Milan.

- its geographical location is in Northern Italy, between Lake Como and Lake Iseo, and the Alps.

- it is a major city in Italy, serving as the administrative center.

- the city is known for its rich history and cultural heritage.

- it is also a significant economic center.

- the capital of Italy is Milan.

Formalize these points into a structured summary with proper English punctuation and grammar, ensuring each point is clearly presented.
Okay, so I need to figure out how to structure the information about Milan, the capital of Italy, into a clear and formal summary. Let me go

Regular expression#

[17]:
prompts = [
    "Please provide information about London as a major global city:",
    "Please provide information about Paris as a major global city:",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95, "regex": "(France|England)"}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Please provide information about London as a major global city:
Generated text:  its economy, culture, and history.1. 1.1 Introduction: The introduction should briefly describe London as a major global city. It should be around 3-4 sentences. Include at least one specific example of London's economy, culture, or history in the introduction.2. 1.2 Economy: The economy section should describe London's economy in detail, including at least five key sectors. For each sector, provide a brief description and example of how it contributes to the economy.3. 1.3 Culture: The culture section should describe London's cultural significance, including at least three aspects. For each aspect
===============================
Prompt: Please provide information about Paris as a major global city:
Generated text:  key facts and figures, cultural significance, and major attractions.

1. Key facts and figures about Paris:
   - Population
   - Area
   - Number of inhabitants per square kilometer
   - Number of museums
   - Number of universities
   - Number of bridges
   - Number of streets
   - Average annual temperature
   - Average annual precipitation

2. Cultural significance of Paris:
   - Historical importance
   - Contributions to modern art
   - Important landmarks and museums
   - Role in the French revolution
   - Current political significance

3. Major attractions:
   - Eiffel Tower

[18]:
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
prompts = [text]


sampling_params = {
    "temperature": 0.8,
    "top_p": 0.95,
    "max_new_tokens": 2048,
    "structural_tag": json.dumps(
        {
            "type": "structural_tag",
            "structures": [
                {
                    "begin": "<function=get_current_weather>",
                    "schema": schema_get_current_weather,
                    "end": "</function>",
                },
                {
                    "begin": "<function=get_current_date>",
                    "schema": schema_get_current_date,
                    "end": "</function>",
                },
            ],
            "triggers": ["<function="],
        }
    ),
}


# Send POST request to the API endpoint
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: <|begin▁of▁sentence|><|Assistant|>Give me the information and population of the capital of France in the JSON format.<|end▁of▁sentence|><|Assistant|><think>

Generated text: Okay, so I need to figure out the capital of France and then provide its population in a JSON format. Let me start by recalling what I know about France. I know that Paris is the capital, right? It's the most populous city in France and also the country's capital. Now, I need to find out the current population of Paris.

I remember that population figures can change over time due to births, deaths, and migration. I think the population number I've heard before is around 2 million. But I'm not entirely sure if that's the most recent number. I should probably double-check that. Maybe I can think about recent years. I know that in 2017, Paris had a population of about 2,165,000. I think that number has grown a bit since then.

I also recall that major cities in France tend to have larger populations than smaller ones, but Paris is definitely the biggest. So, if 2017 was around 2.165 million, by 2022, which is the year I think the assistant mentioned, it had grown to about 2.177 million. That seems reasonable, with an increase of roughly 0.012 million, or 12,000 people. That makes sense considering factors like economic growth, tourism, and migration into the city.

So, putting this together, the JSON format should include a key "capital" with the value "Paris" and another key "population" with the value of approximately 2,177,000. I should present this in a JSON structure, making sure it's properly formatted with commas and quotes.

Wait, am I missing anything? The user asked for the population, so I don't need to include other cities or regions, just Paris. Also, the population number should be accurate, so I should confirm the exact figure. Maybe I can think about recent news or official statistics. I believe the latest official estimate for 2022 is indeed around 2.177 million.

I don't think I need to worry about units here, just the numerical value. So, the JSON should be straightforward with the keys as specified. Let me just write that out to make sure it's correct.

Yes, that should do it. I think I've covered all the necessary points and the information is accurate based on what I remember and can verify.
</think>

```json
{
  "capital": "Paris",
  "population": 2177000
}
```
[19]:
llm.shutdown()