Tool Parser#
This guide demonstrates how to use SGLang’s Function calling functionality.
Currently supported parsers:#
Parser |
Supported Models |
Notes |
---|---|---|
|
Llama 3.1 / 3.2 / 3.3 (e.g. |
|
|
Llama 4 (e.g. |
|
|
Mistral (e.g. |
|
|
Qwen 2.5 (e.g. |
For QwQ, reasoning parser can be enabled together with tool call parser. See reasoning parser. |
|
DeepSeek-v3 (e.g., |
|
|
GPT-OSS (e.g., |
The gpt-oss tool parser filters out analysis channel events and only preserves normal text. This can cause the content to be empty when explanations are in the
analysis channel. To work around this, complete the tool round by returning tool results as |
|
|
|
|
Llama-3.2 / Llama-3.3 / Llama-4 |
Model outputs function calls as Python code. Requires |
OpenAI Compatible API#
Launching the Server#
[1]:
import json
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
from openai import OpenAI
server_process, port = launch_server_cmd(
"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0 --log-level warning" # qwen25
)
wait_for_server(f"http://localhost:{port}")
W0909 21:19:07.125000 399758 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:07.125000 399758 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:transformers.configuration_utils:`torch_dtype` is deprecated! Use `dtype` instead!
All deep_gemm operations loaded successfully!
W0909 21:19:15.173000 400161 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:15.173000 400161 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0909 21:19:15.369000 400162 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:15.369000 400162 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-09 21:19:15] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-09 21:19:16] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.49it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.44it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.42it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.45it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.45it/s]
Capturing batches (bs=1 avail_mem=24.69 GB): 100%|██████████| 3/3 [00:00<00:00, 8.51it/s]
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
Note that --tool-call-parser
defines the parser used to interpret responses.
Define Tools for Function Call#
Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters.
[2]:
# Define tools
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
]
Define Messages#
[3]:
def get_messages():
return [
{
"role": "user",
"content": "What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.",
}
]
messages = get_messages()
Initialize the Client#
[4]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id
Non-Streaming Request#
[5]:
# Non-streaming mode test
response_non_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0,
top_p=0.95,
max_tokens=1024,
stream=False, # Non-streaming
tools=tools,
)
print_highlight("Non-stream response:")
print_highlight(response_non_stream)
print_highlight("==== content ====")
print_highlight(response_non_stream.choices[0].message.content)
print_highlight("==== tool_calls ====")
print_highlight(response_non_stream.choices[0].message.tool_calls)
Handle Tools#
When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.
[6]:
name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name
arguments_non_stream = (
response_non_stream.choices[0].message.tool_calls[0].function.arguments
)
print_highlight(f"Final streamed function call name: {name_non_stream}")
print_highlight(f"Final streamed function call arguments: {arguments_non_stream}")
Streaming Request#
[7]:
# Streaming mode test
print_highlight("Streaming response:")
response_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0,
top_p=0.95,
max_tokens=1024,
stream=True, # Enable streaming
tools=tools,
)
texts = ""
tool_calls = []
name = ""
arguments = ""
for chunk in response_stream:
if chunk.choices[0].delta.content:
texts += chunk.choices[0].delta.content
if chunk.choices[0].delta.tool_calls:
tool_calls.append(chunk.choices[0].delta.tool_calls[0])
print_highlight("==== Text ====")
print_highlight(texts)
print_highlight("==== Tool Call ====")
for tool_call in tool_calls:
print_highlight(tool_call)
Handle Tools#
When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.
[8]:
# Parse and combine function call arguments
arguments = []
for tool_call in tool_calls:
if tool_call.function.name:
print_highlight(f"Streamed function call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments.append(tool_call.function.arguments)
# Combine all fragments into a single JSON string
full_arguments = "".join(arguments)
print_highlight(f"streamed function call arguments: {full_arguments}")
Define a Tool Function#
[9]:
# This is a demonstration, define real function according to your usage.
def get_current_weather(city: str, state: str, unit: "str"):
return (
f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather}
Execute the Tool#
[10]:
messages.append(response_non_stream.choices[0].message)
# Call the corresponding tool function
tool_call = messages[-1].tool_calls[0]
tool_name = tool_call.function.name
tool_to_call = available_tools[tool_name]
result = tool_to_call(**(json.loads(tool_call.function.arguments)))
print_highlight(f"Function call result: {result}")
# messages.append({"role": "tool", "content": result, "name": tool_name})
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(result),
"name": tool_name,
}
)
print_highlight(f"Updated message history: {messages}")
Send Results Back to Model#
[11]:
final_response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0,
top_p=0.95,
stream=False,
tools=tools,
)
print_highlight("Non-stream response:")
print_highlight(final_response)
print_highlight("==== Text ====")
print_highlight(final_response.choices[0].message.content)
Native API and SGLang Runtime (SRT)#
[12]:
from transformers import AutoTokenizer
import requests
# generate an answer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
messages = get_messages()
input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
tools=tools,
)
gen_url = f"http://localhost:{port}/generate"
gen_data = {
"text": input,
"sampling_params": {
"skip_special_tokens": False,
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 0.95,
},
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]
print_highlight("==== Response ====")
print_highlight(gen_response)
# parse the response
parse_url = f"http://localhost:{port}/parse_function_call"
function_call_input = {
"text": gen_response,
"tool_call_parser": "qwen25",
"tools": tools,
}
function_call_response = requests.post(parse_url, json=function_call_input)
function_call_response_json = function_call_response.json()
print_highlight("==== Text ====")
print(function_call_response_json["normal_text"])
print_highlight("==== Calls ====")
print("function name: ", function_call_response_json["calls"][0]["name"])
print("function arguments: ", function_call_response_json["calls"][0]["parameters"])
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA", "unit": "fahrenheit"}}
To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name, state abbreviation, and the unit for temperature. For Boston, the state is Massachusetts, which has the abbreviation 'MA'. I will use the 'fahrenheit' unit for the temperature.
function name: get_current_weather
function arguments: {"city": "Boston", "state": "MA", "unit": "fahrenheit"}
[13]:
terminate_process(server_process)
Offline Engine API#
[14]:
import sglang as sgl
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import Tool, Function
llm = sgl.Engine(model_path="Qwen/Qwen2.5-7B-Instruct")
tokenizer = llm.tokenizer_manager.tokenizer
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, tools=tools
)
# Note that for gpt-oss tool parser, adding "no_stop_trim": True
# to make sure the tool call token <call> is not trimmed.
sampling_params = {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 0.95,
"skip_special_tokens": False,
}
# 1) Offline generation
result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)
generated_text = result["text"] # Assume there is only one prompt
print_highlight("=== Offline Engine Output Text ===")
print_highlight(generated_text)
# 2) Parse using FunctionCallParser
def convert_dict_to_tool(tool_dict: dict) -> Tool:
function_dict = tool_dict.get("function", {})
return Tool(
type=tool_dict.get("type", "function"),
function=Function(
name=function_dict.get("name"),
description=function_dict.get("description"),
parameters=function_dict.get("parameters"),
),
)
tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]
parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25")
normal_text, calls = parser.parse_non_stream(generated_text)
print_highlight("=== Parsing Result ===")
print("Normal text portion:", normal_text)
print_highlight("Function call portion:")
for call in calls:
# call: ToolCallItem
print_highlight(f" - tool name: {call.name}")
print_highlight(f" parameters: {call.parameters}")
# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc.
W0909 21:19:37.438000 399236 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:37.438000 399236 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
All deep_gemm operations loaded successfully!
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:transformers.configuration_utils:`torch_dtype` is deprecated! Use `dtype` instead!
W0909 21:19:48.017000 401782 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:48.017000 401782 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0909 21:19:48.189000 401783 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:19:48.189000 401783 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-09 21:19:48] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.78it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.76it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:01<00:00, 1.73it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.76it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.76it/s]
Capturing batches (bs=1 avail_mem=47.11 GB): 100%|██████████| 3/3 [00:00<00:00, 5.17it/s]
Let's proceed with fetching the weather data.
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA", "unit": "celsius"}}
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA", "unit": "fahrenheit"}}
Normal text portion: To provide you with the current weather in Boston, I will use the `get_current_weather` function. This function requires the city name, state abbreviation, and the temperature unit you prefer. For Boston, the state is Massachusetts, which has the abbreviation 'MA'. You didn't specify a unit, so I'll provide the temperature in both Celsius and Fahrenheit for your convenience.
Let's proceed with fetching the weather data.
[15]:
llm.shutdown()
Tool Choice Mode#
SGLang supports OpenAI’s tool_choice
parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.
Supported Tool Choice Options#
``tool_choice=”required”``: Forces the model to call at least one tool
``tool_choice={“type”: “function”, “function”: {“name”: “specific_function”}}``: Forces the model to call a specific function
Backend Compatibility#
Tool choice is fully supported with the Xgrammar backend, which is the default grammar backend (--grammar-backend xgrammar
). However, it may not be fully supported with other backends such as outlines
.
Example: Required Tool Choice#
[16]:
from openai import OpenAI
from sglang.utils import wait_for_server, print_highlight, terminate_process
from sglang.test.doc_patch import launch_server_cmd
# Start a new server session for tool choice examples
server_process_tool_choice, port_tool_choice = launch_server_cmd(
"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0 --log-level warning"
)
wait_for_server(f"http://localhost:{port_tool_choice}")
# Initialize client for tool choice examples
client_tool_choice = OpenAI(
api_key="None", base_url=f"http://0.0.0.0:{port_tool_choice}/v1"
)
model_name_tool_choice = client_tool_choice.models.list().data[0].id
# Example with tool_choice="required" - forces the model to call a tool
messages_required = [
{"role": "user", "content": "Hello, what is the capital of France?"}
]
# Define tools
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
response_required = client_tool_choice.chat.completions.create(
model=model_name_tool_choice,
messages=messages_required,
temperature=0,
max_tokens=1024,
tools=tools,
tool_choice="required", # Force the model to call a tool
)
print_highlight("Response with tool_choice='required':")
print("Content:", response_required.choices[0].message.content)
print("Tool calls:", response_required.choices[0].message.tool_calls)
W0909 21:20:08.144000 403301 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:08.144000 403301 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:transformers.configuration_utils:`torch_dtype` is deprecated! Use `dtype` instead!
All deep_gemm operations loaded successfully!
W0909 21:20:16.068000 403860 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:16.068000 403860 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0909 21:20:16.096000 403861 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:16.096000 403861 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-09 21:20:16] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-09 21:20:18] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.83it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.72it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:01<00:00, 1.68it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.69it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.70it/s]
Capturing batches (bs=1 avail_mem=45.55 GB): 100%|██████████| 3/3 [00:00<00:00, 6.95it/s]
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
Content: None
Tool calls: [ChatCompletionMessageToolCall(id='call_78a9a793f5514884a177ac11', function=Function(arguments='{"city": "Paris", "unit": "celsius"}', name='get_current_weather'), type='function', index=-1)]
Example: Specific Function Choice#
[17]:
# Example with specific function choice - forces the model to call a specific function
messages_specific = [
{"role": "user", "content": "What are the most attactive places in France?"}
]
response_specific = client_tool_choice.chat.completions.create(
model=model_name_tool_choice,
messages=messages_specific,
temperature=0,
max_tokens=1024,
tools=tools,
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
}, # Force the model to call the specific get_current_weather function
)
print_highlight("Response with specific function choice:")
print("Content:", response_specific.choices[0].message.content)
print("Tool calls:", response_specific.choices[0].message.tool_calls)
if response_specific.choices[0].message.tool_calls:
tool_call = response_specific.choices[0].message.tool_calls[0]
print_highlight(f"Called function: {tool_call.function.name}")
print_highlight(f"Arguments: {tool_call.function.arguments}")
Content: None
Tool calls: [ChatCompletionMessageToolCall(id='call_037805898e654fefb081af6b', function=Function(arguments='{"city": "Paris", "unit": "celsius"}', name='get_current_weather'), type='function', index=-1)]
[18]:
terminate_process(server_process_tool_choice)
Pythonic Tool Call Format (Llama-3.2 / Llama-3.3 / Llama-4)#
Some Llama models (such as Llama-3.2-1B, Llama-3.2-3B, Llama-3.3-70B, and Llama-4) support a “pythonic” tool call format, where the model outputs function calls as Python code, e.g.:
[get_current_weather(city="San Francisco", state="CA", unit="celsius")]
The output is a Python list of function calls, with arguments as Python literals (not JSON).
Multiple tool calls can be returned in the same list:
[get_current_weather(city="San Francisco", state="CA", unit="celsius"),
get_current_weather(city="New York", state="NY", unit="fahrenheit")]
For more information, refer to Meta’s documentation on Zero shot function calling.
Note that this feature is still under development on Blackwell.
How to enable#
Launch the server with
--tool-call-parser pythonic
You may also specify –chat-template with the improved template for the model (e.g.,
--chat-template=examples/chat_template/tool_chat_template_llama4_pythonic.jinja
). This is recommended because the model expects a special prompt format to reliably produce valid pythonic tool call outputs. The template ensures that the prompt structure (e.g., special tokens, message boundaries like<|eom|>
, and function call delimiters) matches what the model was trained or fine-tuned on. If you do not use the correct chat template, tool calling may fail or produce inconsistent results.
Forcing Pythonic Tool Call Output Without a Chat Template#
If you don’t want to specify a chat template, you must give the model extremely explicit instructions in your messages to enforce pythonic output. For example, for Llama-3.2-1B-Instruct
, you need:
[19]:
import openai
server_process, port = launch_server_cmd(
" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1 --log-level warning" # llama-3.2-1b-instruct
)
wait_for_server(f"http://localhost:{port}")
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The name of the city or location.",
}
},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "get_tourist_attractions",
"description": "Get a list of top tourist attractions for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city to find attractions for.",
}
},
"required": ["city"],
},
},
},
]
def get_messages():
return [
{
"role": "system",
"content": (
"You are a travel assistant. "
"When asked to call functions, ALWAYS respond ONLY with a python list of function calls, "
"using this format: [func_name1(param1=value1, param2=value2), func_name2(param=value)]. "
"Do NOT use JSON, do NOT use variables, do NOT use any other format. "
"Here is an example:\n"
'[get_weather(location="Paris"), get_tourist_attractions(city="Paris")]'
),
},
{
"role": "user",
"content": (
"I'm planning a trip to Tokyo next week. What's the weather like and what are some top tourist attractions? "
"Propose parallel tool calls at once, using the python list of function calls format as shown above."
),
},
]
messages = get_messages()
client = openai.Client(base_url=f"http://localhost:{port}/v1", api_key="xxxxxx")
model_name = client.models.list().data[0].id
response_non_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0,
top_p=0.9,
stream=False, # Non-streaming
tools=tools,
)
print_highlight("Non-stream response:")
print_highlight(response_non_stream)
response_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0,
top_p=0.9,
stream=True,
tools=tools,
)
texts = ""
tool_calls = []
name = ""
arguments = ""
for chunk in response_stream:
if chunk.choices[0].delta.content:
texts += chunk.choices[0].delta.content
if chunk.choices[0].delta.tool_calls:
tool_calls.append(chunk.choices[0].delta.tool_calls[0])
print_highlight("Streaming Response:")
print_highlight("==== Text ====")
print_highlight(texts)
print_highlight("==== Tool Call ====")
for tool_call in tool_calls:
print_highlight(tool_call)
terminate_process(server_process)
W0909 21:20:39.876000 405479 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:39.876000 405479 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
WARNING:transformers.configuration_utils:`torch_dtype` is deprecated! Use `dtype` instead!
All deep_gemm operations loaded successfully!
W0909 21:20:50.003000 406508 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:50.003000 406508 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0909 21:20:50.214000 406507 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0909 21:20:50.214000 406507 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-09-09 21:20:50] `torch_dtype` is deprecated! Use `dtype` instead!
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-09-09 21:20:52] MOE_RUNNER_BACKEND is not initialized, using triton backend
All deep_gemm operations loaded successfully!
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.18it/s]
Capturing batches (bs=1 avail_mem=4.95 GB): 100%|██████████| 3/3 [00:00<00:00, 11.63it/s]
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
Note:The model may still default to JSON if it was heavily finetuned on that format. Prompt engineering (including examples) is the only way to increase the chance of pythonic output if you are not using a chat template.
How to support a new model?#
Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:
TOOLS_TAG_LIST = [
“<|plugin|>“,
“<function=“,
“<tool_call>“,
“<|python_tag|>“,
“[TOOL_CALLS]”
]
Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:
class NewModelDetector(BaseFormatDetector):
Add the new detector to the MultiFormatParser class that manages all the format detectors.