Tool and Function Calling#
This guide demonstrates how to use SGLang’s Funcion calling functionality.
OpenAI Compatible API#
Launching the Server#
[1]:
from openai import OpenAI
import json
from sglang.utils import wait_for_server, print_highlight, terminate_process
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
server_process, port = launch_server_cmd(
"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0" # qwen25
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:25:00] server_args=ServerArgs(model_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=35895, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=206007858, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser='qwen25', enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:25:10 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:25:10 TP0] Init torch distributed begin.
[2025-04-13 23:25:11 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:25:11 TP0] Load weight begin. avail mem=59.09 GB
[2025-04-13 23:25:11 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:25:11 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.69it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.61it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:01<00:00, 1.57it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.60it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.60it/s]
[2025-04-13 23:25:15 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=40.72 GB, mem usage=18.37 GB.
[2025-04-13 23:25:15 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-04-13 23:25:15 TP0] Memory pool end. avail mem=39.42 GB
[2025-04-13 23:25:15 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:25:15 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=32768
[2025-04-13 23:25:15] INFO: Started server process [2784710]
[2025-04-13 23:25:15] INFO: Waiting for application startup.
[2025-04-13 23:25:15] INFO: Application startup complete.
[2025-04-13 23:25:15] INFO: Uvicorn running on http://0.0.0.0:35895 (Press CTRL+C to quit)
[2025-04-13 23:25:16] INFO: 127.0.0.1:43808 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:25:16] INFO: 127.0.0.1:43810 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:25:16 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:25:19] INFO: 127.0.0.1:43824 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:25:19] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Note that --tool-call-parser
defines the parser used to interpret responses. Currently supported parsers include:
llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).
mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/ Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).
qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in reasoning parser.
Define Tools for Function Call#
Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters.
[2]:
# Define tools
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
]
Define Messages#
[3]:
def get_messages():
return [
{
"role": "user",
"content": "What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.",
}
]
messages = get_messages()
Initialize the Client#
[4]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id
[2025-04-13 23:25:21] INFO: 127.0.0.1:52520 - "GET /v1/models HTTP/1.1" 200 OK
Non-Streaming Request#
[5]:
# Non-streaming mode test
response_non_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.1,
top_p=0.95,
max_tokens=1024,
stream=False, # Non-streaming
tools=tools,
)
print_highlight("Non-stream response:")
print(response_non_stream)
print_highlight("==== content ====")
print(response_non_stream.choices[0].message.content)
print_highlight("==== tool_calls ====")
print(response_non_stream.choices[0].message.tool_calls)
[2025-04-13 23:25:21 TP0] Prefill batch. #new-seq: 1, #new-token: 281, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:08 TP0] Decode batch. #running-req: 1, #token: 314, token usage: 0.02, gen throughput (token/s): 0.75, #queue-req: 0,
[2025-04-13 23:26:09 TP0] Decode batch. #running-req: 1, #token: 354, token usage: 0.02, gen throughput (token/s): 64.36, #queue-req: 0,
[2025-04-13 23:26:09] INFO: 127.0.0.1:52520 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ChatCompletion(id='8a20e4188d4b4569a0cca449c07f7df7', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content='To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name "Boston", the state "MA" (which is the two-letter abbreviation for Massachusetts), and specifying the unit of temperature in "fahrenheit" since it\'s commonly used in the United States.', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "fahrenheit"}', name='get_current_weather'), type='function')], reasoning_content=None), matched_stop=None)], created=1744586721, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=95, prompt_tokens=281, total_tokens=376, completion_tokens_details=None, prompt_tokens_details=None))
To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name "Boston", the state "MA" (which is the two-letter abbreviation for Massachusetts), and specifying the unit of temperature in "fahrenheit" since it's commonly used in the United States.
[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{"city": "Boston", "state": "MA", "unit": "fahrenheit"}', name='get_current_weather'), type='function')]
Handle Tools#
When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.
[6]:
name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name
arguments_non_stream = (
response_non_stream.choices[0].message.tool_calls[0].function.arguments
)
print_highlight(f"Final streamed function call name: {name_non_stream}")
print_highlight(f"Final streamed function call arguments: {arguments_non_stream}")
Streaming Request#
[7]:
# Streaming mode test
print_highlight("Streaming response:")
response_stream = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.1,
top_p=0.95,
max_tokens=1024,
stream=True, # Enable streaming
tools=tools,
)
texts = ""
tool_calls = []
name = ""
arguments = ""
for chunk in response_stream:
if chunk.choices[0].delta.content:
texts += chunk.choices[0].delta.content
if chunk.choices[0].delta.tool_calls:
tool_calls.append(chunk.choices[0].delta.tool_calls[0])
print_highlight("==== Text ====")
print(texts)
print_highlight("==== Tool Call ====")
for tool_call in tool_calls:
print(tool_call)
[2025-04-13 23:26:09] INFO: 127.0.0.1:52520 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-04-13 23:26:09 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 280, token usage: 0.01, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:10 TP0] Decode batch. #running-req: 1, #token: 299, token usage: 0.01, gen throughput (token/s): 56.17, #queue-req: 0,
[2025-04-13 23:26:10 TP0] Decode batch. #running-req: 1, #token: 339, token usage: 0.02, gen throughput (token/s): 60.07, #queue-req: 0,
[2025-04-13 23:26:11 TP0] Decode batch. #running-req: 1, #token: 379, token usage: 0.02, gen throughput (token/s): 61.62, #queue-req: 0,
To determine the current weather in Boston, I will use the `get_current_weather` function by providing the city name, state, and the temperature unit in which you prefer the information. Since the state of Boston is Massachusetts, the state abbreviation is 'MA'. For the temperature unit, I will use Fahrenheit as it is commonly used in the United States.
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='', name='get_current_weather'), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='{"city": "', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='Boston"', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments=', "state": "', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='MA"', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments=', "unit": "', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='f', name=None), type='function')
ChoiceDeltaToolCall(index=None, id='0', function=ChoiceDeltaToolCallFunction(arguments='ahrenheit"}', name=None), type='function')
Handle Tools#
When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly.
[8]:
# Parse and combine function call arguments
arguments = []
for tool_call in tool_calls:
if tool_call.function.name:
print_highlight(f"Streamed function call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments.append(tool_call.function.arguments)
# Combine all fragments into a single JSON string
full_arguments = "".join(arguments)
print_highlight(f"streamed function call arguments: {full_arguments}")
Define a Tool Function#
[9]:
# This is a demonstration, define real function according to your usage.
def get_current_weather(city: str, state: str, unit: "str"):
return (
f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather}
Execute the Tool#
[10]:
call_data = json.loads(full_arguments)
messages.append(
{
"role": "user",
"content": "",
"tool_calls": {"name": "get_current_weather", "arguments": full_arguments},
}
)
# Call the corresponding tool function
tool_name = messages[-1]["tool_calls"]["name"]
tool_to_call = available_tools[tool_name]
result = tool_to_call(**call_data)
print_highlight(f"Function call result: {result}")
messages.append({"role": "tool", "content": result, "name": tool_name})
print_highlight(f"Updated message history: {messages}")
Send Results Back to Model#
[11]:
final_response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=0.1,
top_p=0.95,
stream=False,
tools=tools,
)
print_highlight("Non-stream response:")
print(final_response)
print_highlight("==== Text ====")
print(final_response.choices[0].message.content)
[2025-04-13 23:26:11 TP0] Prefill batch. #new-seq: 1, #new-token: 49, #cached-token: 279, token usage: 0.01, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:12] INFO: 127.0.0.1:52520 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ChatCompletion(id='859f9586e3ba456eab5bd2a64c3f8ea6', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The current weather in Boston, MA is 85 degrees Fahrenheit. It is partly cloudy, with expected highs in the 90's today.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=151645)], created=1744586771, model='Qwen/Qwen2.5-7B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=31, prompt_tokens=328, total_tokens=359, completion_tokens_details=None, prompt_tokens_details=None))
The current weather in Boston, MA is 85 degrees Fahrenheit. It is partly cloudy, with expected highs in the 90's today.
Native API and SGLang Runtime (SRT)#
[12]:
from transformers import AutoTokenizer
import requests
# generate an answer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
messages = get_messages()
input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
tools=tools,
)
gen_url = f"http://localhost:{port}/generate"
gen_data = {
"text": input,
"sampling_params": {
"skip_special_tokens": False,
"max_new_tokens": 1024,
"temperature": 0.1,
"top_p": 0.95,
},
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]
print_highlight("==== Reponse ====")
print(gen_response)
# parse the response
parse_url = f"http://localhost:{port}/parse_function_call"
function_call_input = {
"text": gen_response,
"tool_call_parser": "qwen25",
"tools": tools,
}
function_call_response = requests.post(parse_url, json=function_call_input)
function_call_response_json = function_call_response.json()
print_highlight("==== Text ====")
print(function_call_response_json["normal_text"])
print_highlight("==== Calls ====")
print("function name: ", function_call_response_json["calls"][0]["name"])
print("function arguments: ", function_call_response_json["calls"][0]["parameters"])
[2025-04-13 23:26:12 TP0] Prefill batch. #new-seq: 1, #new-token: 231, #cached-token: 55, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:12 TP0] Decode batch. #running-req: 1, #token: 289, token usage: 0.01, gen throughput (token/s): 37.15, #queue-req: 0,
[2025-04-13 23:26:12 TP0] Decode batch. #running-req: 1, #token: 329, token usage: 0.02, gen throughput (token/s): 72.30, #queue-req: 0,
[2025-04-13 23:26:13 TP0] Decode batch. #running-req: 1, #token: 369, token usage: 0.02, gen throughput (token/s): 68.97, #queue-req: 0,
[2025-04-13 23:26:13] INFO: 127.0.0.1:58140 - "POST /generate HTTP/1.1" 200 OK
To provide you with the current weather in Boston, I will first need to use the `get_current_weather` function. Since Boston is in the United States, I will assume the default unit for temperature is Fahrenheit, but I will check this to ensure accuracy.
Let's proceed with fetching the current weather in Boston, Massachusetts, in Fahrenheit.
<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA", "unit": "fahrenheit"}}
</tool_call>
[2025-04-13 23:26:13] INFO: 127.0.0.1:58148 - "POST /parse_function_call HTTP/1.1" 200 OK
To provide you with the current weather in Boston, I will first need to use the `get_current_weather` function. Since Boston is in the United States, I will assume the default unit for temperature is Fahrenheit, but I will check this to ensure accuracy.
Let's proceed with fetching the current weather in Boston, Massachusetts, in Fahrenheit.
function name: get_current_weather
function arguments: {"city": "Boston", "state": "MA", "unit": "fahrenheit"}
[13]:
terminate_process(server_process)
[2025-04-13 23:26:13] Child process unexpectedly failed with an exit code 9. pid=2785578
[2025-04-13 23:26:13] Child process unexpectedly failed with an exit code 9. pid=2785370
Offline Engine API#
[14]:
import sglang as sgl
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import Tool, Function
llm = sgl.Engine(model_path="Qwen/Qwen2.5-7B-Instruct")
tokenizer = llm.tokenizer_manager.tokenizer
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, tools=tools
)
sampling_params = {
"max_new_tokens": 1024,
"temperature": 0.1,
"top_p": 0.95,
"skip_special_tokens": False,
}
# 1) Offline generation
result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)
generated_text = result["text"] # Assume there is only one prompt
print("=== Offline Engine Output Text ===")
print(generated_text)
# 2) Parse using FunctionCallParser
def convert_dict_to_tool(tool_dict: dict) -> Tool:
function_dict = tool_dict.get("function", {})
return Tool(
type=tool_dict.get("type", "function"),
function=Function(
name=function_dict.get("name"),
description=function_dict.get("description"),
parameters=function_dict.get("parameters"),
),
)
tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]
parser = FunctionCallParser(tools=tools, tool_call_parser="qwen25")
normal_text, calls = parser.parse_non_stream(generated_text)
print("=== Parsing Result ===")
print("Normal text portion:", normal_text)
print("Function call portion:")
for call in calls:
# call: ToolCallItem
print(f" - tool name: {call.name}")
print(f" parameters: {call.parameters}")
# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc.
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.65it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.59it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:01<00:00, 1.58it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.62it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.61it/s]
=== Offline Engine Output Text ===
To provide you with the current weather in Boston, I will first need to use the `get_current_weather` function with the appropriate city, state, and unit for temperature. Since Boston is in Massachusetts, the state abbreviation is 'MA'. For the temperature unit, I will use Celsius as it is commonly used in many countries and can provide a different perspective compared to Fahrenheit.
<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA", "unit": "celsius"}}
</tool_call>
=== Parsing Result ===
Normal text portion: To provide you with the current weather in Boston, I will first need to use the `get_current_weather` function with the appropriate city, state, and unit for temperature. Since Boston is in Massachusetts, the state abbreviation is 'MA'. For the temperature unit, I will use Celsius as it is commonly used in many countries and can provide a different perspective compared to Fahrenheit.
Function call portion:
- tool name: get_current_weather
parameters: {"city": "Boston", "state": "MA", "unit": "celsius"}
[15]:
llm.shutdown()
How to support a new model?#
Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:
TOOLS_TAG_LIST = [
“<|plugin|>“,
“<function=“,
“<tool_call>“,
“<|python_tag|>“,
“[TOOL_CALLS]”
]
Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:
class NewModelDetector(BaseFormatDetector):
Add the new detector to the MultiFormatParser class that manages all the format detectors.