Tool and Function Calling#

This guide demonstrates how to use SGLang’s Function calling functionality.

OpenAI Compatible API#

Launching the Server#

[1]:
from openai import OpenAI
import json
from sglang.utils import wait_for_server, print_highlight, terminate_process
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd
    import nest_asyncio

    nest_asyncio.apply()

server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0"  # qwen25
)
wait_for_server(f"http://localhost:{port}")
[2025-05-30 02:25:31] server_args=ServerArgs(model_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='0.0.0.0', port=37324, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=426589854, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='qwen25', enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-30 02:25:47] Attention backend not set. Use fa3 backend by default.
[2025-05-30 02:25:47] Init torch distributed begin.
[2025-05-30 02:25:49] Init torch distributed ends. mem usage=0.00 GB
[2025-05-30 02:25:49] init_expert_location from trivial
[2025-05-30 02:25:50] Load weight begin. avail mem=77.48 GB
[2025-05-30 02:25:51] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.55it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.43it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.39it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.42it/s]

[2025-05-30 02:25:54] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=46.51 GB, mem usage=30.97 GB.
[2025-05-30 02:25:54] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-05-30 02:25:54] Memory pool end. avail mem=45.21 GB
[2025-05-30 02:25:55] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=32768
[2025-05-30 02:25:55] INFO:     Started server process [2272902]
[2025-05-30 02:25:55] INFO:     Waiting for application startup.
[2025-05-30 02:25:55] INFO:     Application startup complete.
[2025-05-30 02:25:55] INFO:     Uvicorn running on http://0.0.0.0:37324 (Press CTRL+C to quit)
[2025-05-30 02:25:56] INFO:     127.0.0.1:55914 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:25:56] INFO:     127.0.0.1:55924 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-30 02:25:56] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:25:58] INFO:     127.0.0.1:55938 - "POST /generate HTTP/1.1" 200 OK
[2025-05-30 02:25:58] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.

Note that --tool-call-parser defines the parser used to interpret responses. Currently supported parsers include:

  • llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).

  • llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).

  • mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/ Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).

  • qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in reasoning parser.

  • deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).

Define Tools for Function Call#

Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters.

[2]:
# Define tools
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'",
                    },
                    "state": {
                        "type": "string",
                        "description": "the two-letter abbreviation for the state that the city is"
                        " in, e.g. 'CA' which would mean 'California'",
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
                },
                "required": ["city", "state", "unit"],
            },
        },
    }
]

Define Messages#

[3]:
def get_messages():
    return [
        {
            "role": "user",
            "content": "What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.",
        }
    ]


messages = get_messages()

Initialize the Client#

[4]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id
[2025-05-30 02:26:01] INFO:     127.0.0.1:55954 - "GET /v1/models HTTP/1.1" 200 OK

Non-Streaming Request#

[5]:
# Non-streaming mode test
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0,
    top_p=0.95,
    max_tokens=1024,
    stream=False,  # Non-streaming
    tools=tools,
)
print_highlight("Non-stream response:")
print(response_non_stream)
print_highlight("==== content ====")
print(response_non_stream.choices[0].message.content)
print_highlight("==== tool_calls ====")
print(response_non_stream.choices[0].message.tool_calls)
[2025-05-30 02:26:01] Prefill batch. #new-seq: 1, #new-token: 281, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:01] Decode batch. #running-req: 1, #token: 314, token usage: 0.02, cuda graph: False, gen throughput (token/s): 6.28, #queue-req: 0
[2025-05-30 02:26:02] Decode batch. #running-req: 1, #token: 354, token usage: 0.02, cuda graph: False, gen throughput (token/s): 108.42, #queue-req: 0
[2025-05-30 02:26:02] Decode batch. #running-req: 1, #token: 394, token usage: 0.02, cuda graph: False, gen throughput (token/s): 106.80, #queue-req: 0
[2025-05-30 02:26:02] Decode batch. #running-req: 1, #token: 434, token usage: 0.02, cuda graph: False, gen throughput (token/s): 109.15, #queue-req: 0
[2025-05-30 02:26:02] INFO:     127.0.0.1:55954 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Non-stream response:
ChatCompletion(id='e843d2648b4f4f3cb987e0840fa83874', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content="To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name, state, and unit for temperature. Boston is located in Massachusetts, so the state abbreviation is 'MA'. For the temperature unit, since it's not specified, I will provide both Celsius and Fahrenheit options to give you a comprehensive view.\n\nReasoning: The `get_current_weather` function is the most appropriate tool to use for this query as it directly provides the current weather conditions for a specified location.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_P84gPecySAaY6SkXiKQluA', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "celsius"}', name='get_current_weather'), type='function', index=None), ChatCompletionMessageToolCall(id='call_RazFara7QIWx91Oet1JPcw', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "fahrenheit"}', name='get_current_weather'), type='function', index=None)], reasoning_content=None), matched_stop=None)], created=1748571961, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=172, prompt_tokens=281, total_tokens=453, completion_tokens_details=None, prompt_tokens_details=None))
==== content ====
To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name, state, and unit for temperature. Boston is located in Massachusetts, so the state abbreviation is 'MA'. For the temperature unit, since it's not specified, I will provide both Celsius and Fahrenheit options to give you a comprehensive view.

Reasoning: The `get_current_weather` function is the most appropriate tool to use for this query as it directly provides the current weather conditions for a specified location.
==== tool_calls ====
[ChatCompletionMessageToolCall(id='call_P84gPecySAaY6SkXiKQluA', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "celsius"}', name='get_current_weather'), type='function', index=None), ChatCompletionMessageToolCall(id='call_RazFara7QIWx91Oet1JPcw', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "fahrenheit"}', name='get_current_weather'), type='function', index=None)]

Handle Tools#

When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.

[6]:
name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name
arguments_non_stream = (
    response_non_stream.choices[0].message.tool_calls[0].function.arguments
)

print_highlight(f"Final streamed function call name: {name_non_stream}")
print_highlight(f"Final streamed function call arguments: {arguments_non_stream}")
Final streamed function call name: get_current_weather
Final streamed function call arguments: {"city": "Boston", "state": "MA", "unit": "celsius"}

Streaming Request#

[7]:
# Streaming mode test
print_highlight("Streaming response:")
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0,
    top_p=0.95,
    max_tokens=1024,
    stream=True,  # Enable streaming
    tools=tools,
)

texts = ""
tool_calls = []
name = ""
arguments = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        texts += chunk.choices[0].delta.content
    if chunk.choices[0].delta.tool_calls:
        tool_calls.append(chunk.choices[0].delta.tool_calls[0])
print_highlight("==== Text ====")
print(texts)

print_highlight("==== Tool Call ====")
for tool_call in tool_calls:
    print(tool_call)
Streaming response:
[2025-05-30 02:26:02] INFO:     127.0.0.1:55954 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-05-30 02:26:02] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 280, token usage: 0.01, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:03] Decode batch. #running-req: 1, #token: 302, token usage: 0.01, cuda graph: False, gen throughput (token/s): 99.37, #queue-req: 0
[2025-05-30 02:26:03] Decode batch. #running-req: 1, #token: 342, token usage: 0.02, cuda graph: False, gen throughput (token/s): 110.16, #queue-req: 0
[2025-05-30 02:26:03] Decode batch. #running-req: 1, #token: 382, token usage: 0.02, cuda graph: False, gen throughput (token/s): 111.06, #queue-req: 0
[2025-05-30 02:26:04] Decode batch. #running-req: 1, #token: 422, token usage: 0.02, cuda graph: False, gen throughput (token/s): 110.28, #queue-req: 0
==== Text ====
To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name, state, and unit for temperature. Boston is located in Massachusetts, so the state abbreviation is 'MA'. For the temperature unit, since it's not specified, I will provide both Celsius and Fahrenheit options to give you a comprehensive view.

Reasoning: The `get_current_weather` function is the most appropriate tool to use for this query as it directly provides the current weather conditions for a specified location.



==== Tool Call ====
ChoiceDeltaToolCall(index=0, id='call_6CoiLXFNRO-LvicLhzbnTw', function=ChoiceDeltaToolCallFunction(arguments='', name='get_current_weather'), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"city": "', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='Boston"', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=', "state": "', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='MA"', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments=', "unit": "', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='c', name=None), type='function')
ChoiceDeltaToolCall(index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='elsius"}', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='', name='get_current_weather'), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"city": "', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='Boston"', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments=', "state": "', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='MA"', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments=', "unit": "', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='f', name=None), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='ahrenheit"}', name=None), type='function')

Handle Tools#

When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.

[8]:
# Parse and combine function call arguments
arguments = []
for tool_call in tool_calls:
    if tool_call.function.name:
        print_highlight(f"Streamed function call name: {tool_call.function.name}")

    if tool_call.function.arguments:
        arguments.append(tool_call.function.arguments)

# Combine all fragments into a single JSON string
full_arguments = "".join(arguments)
print_highlight(f"streamed function call arguments: {full_arguments}")
Streamed function call name: get_current_weather
Streamed function call name: get_current_weather
streamed function call arguments: {"city": "Boston", "state": "MA", "unit": "celsius"}{"city": "Boston", "state": "MA", "unit": "fahrenheit"}

Define a Tool Function#

[9]:
# This is a demonstration, define real function according to your usage.
def get_current_weather(city: str, state: str, unit: "str"):
    return (
        f"The weather in {city}, {state} is 85 degrees {unit}. It is "
        "partly cloudly, with highs in the 90's."
    )


available_tools = {"get_current_weather": get_current_weather}

Execute the Tool#

[10]:
messages.append(response_non_stream.choices[0].message)

# Call the corresponding tool function
tool_call = messages[-1].tool_calls[0]
tool_name = tool_call.function.name
tool_to_call = available_tools[tool_name]
result = tool_to_call(**(json.loads(tool_call.function.arguments)))
print_highlight(f"Function call result: {result}")
# messages.append({"role": "tool", "content": result, "name": tool_name})
messages.append(
    {
        "role": "tool",
        "tool_call_id": tool_call.id,
        "content": str(result),
        "name": tool_name,
    }
)

print_highlight(f"Updated message history: {messages}")
Function call result: The weather in Boston, MA is 85 degrees celsius. It is partly cloudly, with highs in the 90's.
Updated message history: [{'role': 'user', 'content': "What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you."}, ChatCompletionMessage(content="To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name, state, and unit for temperature. Boston is located in Massachusetts, so the state abbreviation is 'MA'. For the temperature unit, since it's not specified, I will provide both Celsius and Fahrenheit options to give you a comprehensive view.\n\nReasoning: The `get_current_weather` function is the most appropriate tool to use for this query as it directly provides the current weather conditions for a specified location.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_P84gPecySAaY6SkXiKQluA', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "celsius"}', name='get_current_weather'), type='function', index=None), ChatCompletionMessageToolCall(id='call_RazFara7QIWx91Oet1JPcw', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "fahrenheit"}', name='get_current_weather'), type='function', index=None)], reasoning_content=None), {'role': 'tool', 'tool_call_id': 'call_P84gPecySAaY6SkXiKQluA', 'content': "The weather in Boston, MA is 85 degrees celsius. It is partly cloudly, with highs in the 90's.", 'name': 'get_current_weather'}]

Send Results Back to Model#

[11]:
final_response = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0,
    top_p=0.95,
    stream=False,
    tools=tools,
)
print_highlight("Non-stream response:")
print(final_response)

print_highlight("==== Text ====")
print(final_response.choices[0].message.content)
[2025-05-30 02:26:04] Prefill batch. #new-seq: 1, #new-token: 119, #cached-token: 384, token usage: 0.02, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:04] Decode batch. #running-req: 1, #token: 512, token usage: 0.03, cuda graph: False, gen throughput (token/s): 88.38, #queue-req: 0
[2025-05-30 02:26:05] Decode batch. #running-req: 1, #token: 552, token usage: 0.03, cuda graph: False, gen throughput (token/s): 105.47, #queue-req: 0
[2025-05-30 02:26:05] INFO:     127.0.0.1:55954 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Non-stream response:
ChatCompletion(id='9e28cf541d4a4625933df67fe8e11167', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="There seems to be an error in the weather data provided. The temperature of 85 degrees Celsius is not typical for Boston, as it would be extremely hot. Let's correct this by using the Fahrenheit unit instead.\n\nThe correct weather in Boston, MA, according to the data, is 85 degrees Fahrenheit. It is partly cloudy, with highs in the 90's.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=151645)], created=1748571964, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=79, prompt_tokens=503, total_tokens=582, completion_tokens_details=None, prompt_tokens_details=None))
==== Text ====
There seems to be an error in the weather data provided. The temperature of 85 degrees Celsius is not typical for Boston, as it would be extremely hot. Let's correct this by using the Fahrenheit unit instead.

The correct weather in Boston, MA, according to the data, is 85 degrees Fahrenheit. It is partly cloudy, with highs in the 90's.

Tool Choice Mode#

SGLang supports OpenAI’s tool_choice parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.

Supported Tool Choice Options#

  • ``tool_choice=”required”``: Forces the model to call at least one tool

  • ``tool_choice={“type”: “function”, “function”: {“name”: “specific_function”}}``: Forces the model to call a specific function

Backend Compatibility#

Tool choice is fully supported with the Xgrammar backend, which is the default grammar backend (--grammar-backend xgrammar). However, it may not be fully supported with other backends such as outlines.

Example: Required Tool Choice#

[12]:
from openai import OpenAI
import json
from sglang.utils import wait_for_server, print_highlight, terminate_process
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd
    import nest_asyncio

    nest_asyncio.apply()

# Start a new server session for tool choice examples
server_process_tool_choice, port_tool_choice = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port_tool_choice}")

# Initialize client for tool choice examples
client_tool_choice = OpenAI(
    api_key="None", base_url=f"http://0.0.0.0:{port_tool_choice}/v1"
)
model_name_tool_choice = client_tool_choice.models.list().data[0].id

# Example with tool_choice="required" - forces the model to call a tool
messages_required = [
    {"role": "user", "content": "Hello, what is the capital of France?"}
]

# Define tools
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'",
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
                },
                "required": ["city", "unit"],
            },
        },
    }
]

response_required = client_tool_choice.chat.completions.create(
    model=model_name_tool_choice,
    messages=messages_required,
    temperature=0,
    max_tokens=1024,
    tools=tools,
    tool_choice="required",  # Force the model to call a tool
)

print_highlight("Response with tool_choice='required':")
print("Content:", response_required.choices[0].message.content)
print("Tool calls:", response_required.choices[0].message.tool_calls)
[2025-05-30 02:26:11] server_args=ServerArgs(model_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='0.0.0.0', port=30460, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=54154885, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='qwen25', enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-30 02:26:23] Attention backend not set. Use fa3 backend by default.
[2025-05-30 02:26:23] Init torch distributed begin.
[2025-05-30 02:26:23] Init torch distributed ends. mem usage=0.00 GB
[2025-05-30 02:26:23] init_expert_location from trivial
[2025-05-30 02:26:23] Load weight begin. avail mem=57.96 GB
[2025-05-30 02:26:23] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.37it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.37it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.37it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.40it/s]

[2025-05-30 02:26:27] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=27.12 GB, mem usage=30.84 GB.
[2025-05-30 02:26:27] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-05-30 02:26:27] Memory pool end. avail mem=25.83 GB
[2025-05-30 02:26:27] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=32768
[2025-05-30 02:26:28] INFO:     Started server process [2275975]
[2025-05-30 02:26:28] INFO:     Waiting for application startup.
[2025-05-30 02:26:28] INFO:     Application startup complete.
[2025-05-30 02:26:28] INFO:     Uvicorn running on http://0.0.0.0:30460 (Press CTRL+C to quit)
[2025-05-30 02:26:28] INFO:     127.0.0.1:57598 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:26:29] INFO:     127.0.0.1:57600 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-30 02:26:29] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:31] INFO:     127.0.0.1:57612 - "POST /generate HTTP/1.1" 200 OK
[2025-05-30 02:26:31] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[2025-05-30 02:26:33] INFO:     127.0.0.1:38614 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:26:33] Prefill batch. #new-seq: 1, #new-token: 225, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:34] INFO:     127.0.0.1:38614 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response with tool_choice='required':
Content: None
Tool calls: [ChatCompletionMessageToolCall(id='call_WJHbEKYFRnew0RLA_mBN4g', function=Function(arguments='{"city": "Straßburg", "unit": "celsius"}', name='get_current_weather'), type='function', index=None)]

Example: Specific Function Choice#

[13]:
# Example with specific function choice - forces the model to call a specific function
messages_specific = [
    {"role": "user", "content": "What are the most attactive places in France?"}
]

response_specific = client_tool_choice.chat.completions.create(
    model=model_name_tool_choice,
    messages=messages_specific,
    temperature=0,
    max_tokens=1024,
    tools=tools,
    tool_choice={
        "type": "function",
        "function": {"name": "get_current_weather"},
    },  # Force the model to call the specific get_current_weather function
)

print_highlight("Response with specific function choice:")
print("Content:", response_specific.choices[0].message.content)
print("Tool calls:", response_specific.choices[0].message.tool_calls)

if response_specific.choices[0].message.tool_calls:
    tool_call = response_specific.choices[0].message.tool_calls[0]
    print(f"Called function: {tool_call.function.name}")
    print(f"Arguments: {tool_call.function.arguments}")
[2025-05-30 02:26:34] Prefill batch. #new-seq: 1, #new-token: 15, #cached-token: 211, token usage: 0.01, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:34] Decode batch. #running-req: 1, #token: 230, token usage: 0.01, cuda graph: False, gen throughput (token/s): 5.54, #queue-req: 0
[2025-05-30 02:26:35] INFO:     127.0.0.1:38614 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Response with specific function choice:
Content: None
Tool calls: [ChatCompletionMessageToolCall(id='call_okIeoNjtRyCWXqnWCyxsBw', function=Function(arguments='{"city": "Dijon", "unit": "celsius"}', name='get_current_weather'), type='function', index=None)]
Called function: get_current_weather
Arguments: {"city": "Dijon", "unit": "celsius"}

Native API and SGLang Runtime (SRT)#

[14]:
from transformers import AutoTokenizer
import requests

# generate an answer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

messages = get_messages()

input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    tools=tools,
)

gen_url = f"http://localhost:{port}/generate"
gen_data = {
    "text": input,
    "sampling_params": {
        "skip_special_tokens": False,
        "max_new_tokens": 1024,
        "temperature": 0,
        "top_p": 0.95,
    },
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]
print_highlight("==== Response ====")
print(gen_response)

# parse the response
parse_url = f"http://localhost:{port}/parse_function_call"

function_call_input = {
    "text": gen_response,
    "tool_call_parser": "qwen25",
    "tools": tools,
}

function_call_response = requests.post(parse_url, json=function_call_input)
function_call_response_json = function_call_response.json()

print_highlight("==== Text ====")
print(function_call_response_json["normal_text"])
print_highlight("==== Calls ====")
print("function name: ", function_call_response_json["calls"][0]["name"])
print("function arguments: ", function_call_response_json["calls"][0]["parameters"])
[2025-05-30 02:26:39] Prefill batch. #new-seq: 1, #new-token: 189, #cached-token: 55, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:26:39] Decode batch. #running-req: 1, #token: 254, token usage: 0.01, cuda graph: False, gen throughput (token/s): 1.18, #queue-req: 0
[2025-05-30 02:26:39] Decode batch. #running-req: 1, #token: 294, token usage: 0.01, cuda graph: False, gen throughput (token/s): 108.43, #queue-req: 0
[2025-05-30 02:26:39] Decode batch. #running-req: 1, #token: 334, token usage: 0.02, cuda graph: False, gen throughput (token/s): 108.76, #queue-req: 0
[2025-05-30 02:26:40] INFO:     127.0.0.1:50254 - "POST /generate HTTP/1.1" 200 OK
==== Response ====
To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name and the unit for temperature. Since you didn't specify the unit, I will assume you prefer the temperature in Celsius for this query.

Reasoning: The user wants to know the current weather conditions in Boston. Using the `get_current_weather` function will allow me to fetch and provide this information accurately.

<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Boston", "unit": "celsius"}}
</tool_call>
[2025-05-30 02:26:40] INFO:     127.0.0.1:50262 - "POST /parse_function_call HTTP/1.1" 200 OK
==== Text ====
To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name and the unit for temperature. Since you didn't specify the unit, I will assume you prefer the temperature in Celsius for this query.

Reasoning: The user wants to know the current weather conditions in Boston. Using the `get_current_weather` function will allow me to fetch and provide this information accurately.
==== Calls ====
function name:  get_current_weather
function arguments:  {"city": "Boston", "unit": "celsius"}
[15]:
terminate_process(server_process)
[2025-05-30 02:26:40] Child process unexpectedly failed with an exit code 9. pid=2273539
[2025-05-30 02:26:40] Child process unexpectedly failed with an exit code 9. pid=2273342

Offline Engine API#

[16]:
import sglang as sgl
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import Tool, Function

llm = sgl.Engine(model_path="Qwen/Qwen2.5-7B-Instruct")
tokenizer = llm.tokenizer_manager.tokenizer
input_ids = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True, tools=tools
)

sampling_params = {
    "max_new_tokens": 1024,
    "temperature": 0,
    "top_p": 0.95,
    "skip_special_tokens": False,
}

# 1) Offline generation
result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)
generated_text = result["text"]  # Assume there is only one prompt

print("=== Offline Engine Output Text ===")
print(generated_text)


# 2) Parse using FunctionCallParser
def convert_dict_to_tool(tool_dict: dict) -> Tool:
    function_dict = tool_dict.get("function", {})
    return Tool(
        type=tool_dict.get("type", "function"),
        function=Function(
            name=function_dict.get("name"),
            description=function_dict.get("description"),
            parameters=function_dict.get("parameters"),
        ),
    )


tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]

parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25")
normal_text, calls = parser.parse_non_stream(generated_text)

print("=== Parsing Result ===")
print("Normal text portion:", normal_text)
print("Function call portion:")
for call in calls:
    # call: ToolCallItem
    print(f"  - tool name: {call.name}")
    print(f"    parameters: {call.parameters}")

# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.36it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.24it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.19it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:03<00:00,  1.20it/s]

=== Offline Engine Output Text ===
To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name and the unit for temperature. Since you didn't specify the unit, I will assume you prefer the temperature in Celsius for this query.

Reasoning: The user wants to know the current weather conditions in Boston. Using the `get_current_weather` function will allow me to fetch and provide this information accurately.

<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Boston", "unit": "celsius"}}
</tool_call>
=== Parsing Result ===
Normal text portion: To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name and the unit for temperature. Since you didn't specify the unit, I will assume you prefer the temperature in Celsius for this query.

Reasoning: The user wants to know the current weather conditions in Boston. Using the `get_current_weather` function will allow me to fetch and provide this information accurately.
Function call portion:
  - tool name: get_current_weather
    parameters: {"city": "Boston", "unit": "celsius"}
[17]:
llm.shutdown()

Pythonic Tool Call Format (Llama-3.2 / Llama-3.3 / Llama-4)#

Some Llama models (such as Llama-3.2-1B, Llama-3.2-3B, Llama-3.3-70B, and Llama-4) support a “pythonic” tool call format, where the model outputs function calls as Python code, e.g.:

[get_current_weather(city="San Francisco", state="CA", unit="celsius")]
  • The output is a Python list of function calls, with arguments as Python literals (not JSON).

  • Multiple tool calls can be returned in the same list:

[get_current_weather(city="San Francisco", state="CA", unit="celsius"),
 get_current_weather(city="New York", state="NY", unit="fahrenheit")]

For more information, refer to Meta’s documentation on Zero shot function calling.

How to enable#

  • Launch the server with --tool-call-parser pythonic

  • You may also specify –chat-template with the improved template for the model (e.g., --chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja). This is recommended because the model expects a special prompt format to reliably produce valid pythonic tool call outputs. The template ensures that the prompt structure (e.g., special tokens, message boundaries like <|eom|>, and function call delimiters) matches what the model was trained or fine-tuned on. If you do not use the correct chat template, tool calling may fail or produce inconsistent results.

Forcing Pythonic Tool Call Output Without a Chat Template#

If you don’t want to specify a chat template, you must give the model extremely explicit instructions in your messages to enforce pythonic output. For example, for Llama-3.2-1B-Instruct, you need:

[18]:
import openai

server_process, port = launch_server_cmd(
    " python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1"  # llama-3.2-1b-instruct
)
wait_for_server(f"http://localhost:{port}")

tools = [
    {
        "type": "function",
        "function": {
            "name": "get_weather",
            "description": "Get the current weather for a given location.",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The name of the city or location.",
                    }
                },
                "required": ["location"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_tourist_attractions",
            "description": "Get a list of top tourist attractions for a given city.",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The name of the city to find attractions for.",
                    }
                },
                "required": ["city"],
            },
        },
    },
]


def get_messages():
    return [
        {
            "role": "system",
            "content": (
                "You are a travel assistant. "
                "When asked to call functions, ALWAYS respond ONLY with a python list of function calls, "
                "using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. "
                "Do NOT use JSON, do NOT use variables, do NOT use any other format. "
                "Here is an example:\n"
                '[get_weather(location="Paris"), get_tourist_attractions(city="Paris")]'
            ),
        },
        {
            "role": "user",
            "content": (
                "I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? "
                "Propose parallel tool calls at once, using the python list of function calls format as shown above."
            ),
        },
    ]


messages = get_messages()

client = openai.Client(base_url=f"http://localhost:{port}/v1", api_key="xxxxxx")
model_name = client.models.list().data[0].id


response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0,
    top_p=0.9,
    stream=False,  # Non-streaming
    tools=tools,
)
print_highlight("Non-stream response:")
print(response_non_stream)

response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0,
    top_p=0.9,
    stream=True,
    tools=tools,
)
texts = ""
tool_calls = []
name = ""
arguments = ""

for chunk in response_stream:
    if chunk.choices[0].delta.content:
        texts += chunk.choices[0].delta.content
    if chunk.choices[0].delta.tool_calls:
        tool_calls.append(chunk.choices[0].delta.tool_calls[0])

print_highlight("Streaming Response:")
print_highlight("==== Text ====")
print(texts)

print_highlight("==== Tool Call ====")
for tool_call in tool_calls:
    print(tool_call)

terminate_process(server_process)
[2025-05-30 02:27:09] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=31184, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=818427339, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='pythonic', enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-30 02:27:20] Attention backend not set. Use fa3 backend by default.
[2025-05-30 02:27:20] Init torch distributed begin.
[2025-05-30 02:27:21] Init torch distributed ends. mem usage=0.00 GB
[2025-05-30 02:27:21] init_expert_location from trivial
[2025-05-30 02:27:22] Load weight begin. avail mem=60.49 GB
[2025-05-30 02:27:22] Using model weights format ['*.safetensors']
[2025-05-30 02:27:22] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.68it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.68it/s]

[2025-05-30 02:27:22] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=58.14 GB, mem usage=2.35 GB.
[2025-05-30 02:27:23] KV Cache is allocated. #tokens: 20480, K size: 0.31 GB, V size: 0.31 GB
[2025-05-30 02:27:23] Memory pool end. avail mem=57.28 GB
[2025-05-30 02:27:23] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-05-30 02:27:24] INFO:     Started server process [2281819]
[2025-05-30 02:27:24] INFO:     Waiting for application startup.
[2025-05-30 02:27:24] INFO:     Application startup complete.
[2025-05-30 02:27:24] INFO:     Uvicorn running on http://127.0.0.1:31184 (Press CTRL+C to quit)
[2025-05-30 02:27:24] INFO:     127.0.0.1:39024 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:27:25] INFO:     127.0.0.1:39028 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-30 02:27:25] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:27:26] INFO:     127.0.0.1:39030 - "POST /generate HTTP/1.1" 200 OK
[2025-05-30 02:27:26] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[2025-05-30 02:27:29] INFO:     127.0.0.1:39044 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:27:29] Prefill batch. #new-seq: 1, #new-token: 406, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:27:29] INFO:     127.0.0.1:39044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
Non-stream response:
ChatCompletion(id='100c1ffee7d04c9ab9997b3be946b617', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_l4n3NYmgTV-8KiWyRR8CJw', function=Function(arguments='{"location": "Tokyo"}', name='get_weather'), type='function', index=None), ChatCompletionMessageToolCall(id='call_wwsrXRiaThm9NR7vHZ-SWw', function=Function(arguments='{"city": "Tokyo"}', name='get_tourist_attractions'), type='function', index=None)], reasoning_content=None), matched_stop=None)], created=1748572049, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=20, prompt_tokens=407, total_tokens=427, completion_tokens_details=None, prompt_tokens_details=None))
[2025-05-30 02:27:29] INFO:     127.0.0.1:39044 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-05-30 02:27:29] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 406, token usage: 0.02, #running-req: 0, #queue-req: 0
[2025-05-30 02:27:29] Decode batch. #running-req: 1, #token: 420, token usage: 0.02, cuda graph: False, gen throughput (token/s): 6.59, #queue-req: 0
Streaming Response:
==== Text ====

==== Tool Call ====
ChoiceDeltaToolCall(index=0, id='call_w8nOVESsQjSm5TIZtzOcsQ', function=ChoiceDeltaToolCallFunction(arguments='{"location": "Tokyo"}', name='get_weather'), type='function')
ChoiceDeltaToolCall(index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"city": "Tokyo"}', name='get_tourist_attractions'), type='function')
[2025-05-30 02:27:29] Child process unexpectedly failed with an exit code 9. pid=2282215
[2025-05-30 02:27:29] Child process unexpectedly failed with an exit code 9. pid=2282145
Note:
The model may still default to JSON if it was heavily finetuned on that format. Prompt engineering (including examples) is the only way to increase the chance of pythonic output if you are not using a chat template.

How to support a new model?#

  1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:

TOOLS_TAG_LIST = [
    “<|plugin|>“,
    “<function=“,
    “<tool_call>“,
    “<|python_tag|>“,
    “[TOOL_CALLS]”
]
  1. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:

class NewModelDetector(BaseFormatDetector):
  1. Add the new detector to the MultiFormatParser class that manages all the format detectors.