Reasoning Parser#

SGLang supports parsing reasoning content out from “normal” content for reasoning models such as DeepSeek R1.

Supported Models & Parsers#

Model

Reasoning tags

Parser

Notes

DeepSeek‑R1 series

<think></think>

deepseek-r1

Supports all variants (R1, R1-0528, R1-Distill)

Standard Qwen3 models

<think></think>

qwen3

Supports enable_thinking parameter

Qwen3-Thinking m odels

<think></think>

qwen3 or qwen3-thinking

Always generates thinking content

Kimi models

◁think▷◁/think▷

kimi

Uses special thinking delimiters

Model-Specific Behaviors#

DeepSeek-R1 Family:

  • DeepSeek-R1: No <think> start tag, jumps directly to thinking content

  • DeepSeek-R1-0528: Generates both <think> start and </think> end tags

  • Both are handled by the same deepseek-r1 parser

Qwen3 Family:

  • Standard Qwen3 (e.g., Qwen3-2507): Use qwen3 parser, supports enable_thinking in chat templates

  • Qwen3-Thinking (e.g., Qwen3-235B-A22B-Thinking-2507): Use qwen3 or qwen3-thinking parser, always thinks

Kimi:

  • Kimi: Uses special ◁think▷ and ◁/think▷ tags

Usage#

Launching the Server#

Specify the --reasoning-parser option.

[1]:
import requests
from openai import OpenAI
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1"
)

wait_for_server(f"http://localhost:{port}")
W0814 06:16:29.465000 1203487 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:16:29.465000 1203487 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:16:32] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='0.0.0.0', port=38928, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.874, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=1070558583, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='deepseek-r1', tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 06:16:33] Using default HuggingFace chat template with detected content format: string
W0814 06:16:40.639000 1204102 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:16:40.639000 1204102 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:16:41.997000 1204101 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:16:41.997000 1204101 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 06:16:43] Attention backend not explicitly specified. Use fa3 backend by default.
[2025-08-14 06:16:43] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 06:16:46] Init torch distributed ends. mem usage=30.02 GB
[2025-08-14 06:16:47] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 06:16:47] Load weight begin. avail mem=48.03 GB
[2025-08-14 06:16:47] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.27s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.18s/it]

[2025-08-14 06:16:50] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=32.16 GB, mem usage=15.88 GB.
[2025-08-14 06:16:50] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-08-14 06:16:50] Memory pool end. avail mem=30.58 GB
[2025-08-14 06:16:50] Capture cuda graph begin. This can take up to several minutes. avail mem=27.61 GB
[2025-08-14 06:16:51] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=27.49 GB): 100%|██████████| 3/3 [00:00<00:00,  6.56it/s]
[2025-08-14 06:16:52] Capture cuda graph end. Time elapsed: 1.91 s. mem usage=0.13 GB. avail mem=27.48 GB.
[2025-08-14 06:16:53] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=131072, available_gpu_mem=27.46 GB
[2025-08-14 06:16:53] INFO:     Started server process [1203487]
[2025-08-14 06:16:53] INFO:     Waiting for application startup.
[2025-08-14 06:16:53] INFO:     Application startup complete.
[2025-08-14 06:16:53] INFO:     Uvicorn running on http://0.0.0.0:38928 (Press CTRL+C to quit)
[2025-08-14 06:16:54] INFO:     127.0.0.1:43814 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 06:16:54] INFO:     127.0.0.1:43822 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 06:16:54] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:16:55] INFO:     127.0.0.1:43830 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 06:16:55] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.

Note that --reasoning-parser defines the parser used to interpret responses.

OpenAI Compatible API#

Using the OpenAI compatible API, the contract follows the DeepSeek API design established with the release of DeepSeek-R1:

  • reasoning_content: The content of the CoT.

  • content: The content of the final answer.

[2]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id

messages = [
    {
        "role": "user",
        "content": "What is 1+3?",
    }
]
[2025-08-14 06:16:59] INFO:     127.0.0.1:43846 - "GET /v1/models HTTP/1.1" 200 OK

Non-Streaming Request#

[3]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": True},
)
print_highlight("==== Reasoning ====")
print_highlight(response_non_stream.choices[0].message.reasoning_content)

print_highlight("==== Text ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-08-14 06:16:59] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:16:59] Decode batch. #running-req: 1, #token: 45, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.06, #queue-req: 0,
[2025-08-14 06:16:59] Decode batch. #running-req: 1, #token: 85, token usage: 0.00, cuda graph: True, gen throughput (token/s): 163.72, #queue-req: 0,
[2025-08-14 06:17:00] INFO:     127.0.0.1:43846 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Reasoning ====
I need to calculate the sum of 1 and 3.

Adding these two numbers together gives me 4.

So, the final answer is 4.
==== Text ====
Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Start with the number 1.**
2. **Add 3 to it.**

\[ 1 + 3 = 4 \]

**Answer:** \(\boxed{4}\)

Streaming Request#

[4]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content += chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-08-14 06:17:00] INFO:     127.0.0.1:43846 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-08-14 06:17:00] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:17:00] Decode batch. #running-req: 1, #token: 21, token usage: 0.00, cuda graph: True, gen throughput (token/s): 135.51, #queue-req: 0,
[2025-08-14 06:17:00] Decode batch. #running-req: 1, #token: 61, token usage: 0.00, cuda graph: True, gen throughput (token/s): 83.51, #queue-req: 0,
[2025-08-14 06:17:01] Decode batch. #running-req: 1, #token: 101, token usage: 0.00, cuda graph: True, gen throughput (token/s): 88.59, #queue-req: 0,
==== Reasoning ====
First, I identify the two numbers in the addition problem: 1 and 3.

Next, I add these two numbers together: 1 plus 3 equals 4.

Therefore, the final answer is 4.
==== Text ====


**Solution:**

We are asked to find the sum of 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).

Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content).

[5]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True, "stream_reasoning": False},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content = chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-08-14 06:17:01] INFO:     127.0.0.1:43846 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-08-14 06:17:01] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:17:01] Decode batch. #running-req: 1, #token: 50, token usage: 0.00, cuda graph: True, gen throughput (token/s): 101.79, #queue-req: 0,
[2025-08-14 06:17:01] Decode batch. #running-req: 1, #token: 90, token usage: 0.00, cuda graph: True, gen throughput (token/s): 167.26, #queue-req: 0,
==== Reasoning ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I add the two numbers together: 1 plus 3 equals 4.

Therefore, the final answer is 4.
==== Text ====


**Solution:**

We need to calculate the sum of 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).

The reasoning separation is enable by default when specify . To disable it, set the ``separate_reasoning`` option to ``False`` in request.

[6]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": False},
)

print_highlight("==== Original Output ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-08-14 06:17:01] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:17:02] Decode batch. #running-req: 1, #token: 39, token usage: 0.00, cuda graph: True, gen throughput (token/s): 129.78, #queue-req: 0,
[2025-08-14 06:17:02] Decode batch. #running-req: 1, #token: 79, token usage: 0.00, cuda graph: True, gen throughput (token/s): 163.75, #queue-req: 0,
[2025-08-14 06:17:02] INFO:     127.0.0.1:43846 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Original Output ====
First, I identify the two numbers in the problem: 1 and 3.

Next, I add the first number to the second number: 1 + 3.

Finally, I calculate the sum to find that 1 plus 3 equals 4.


**Solution:**

We need to find the sum of 1 and 3.

\[
1 + 3 = 4
\]

**Answer:** \boxed{4}

SGLang Native API#

[7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

gen_url = f"http://localhost:{port}/generate"
gen_data = {
    "text": input,
    "sampling_params": {
        "skip_special_tokens": False,
        "max_new_tokens": 1024,
        "temperature": 0.6,
        "top_p": 0.95,
    },
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]

print_highlight("==== Original Output ====")
print_highlight(gen_response)

parse_url = f"http://localhost:{port}/separate_reasoning"
separate_reasoning_data = {
    "text": gen_response,
    "reasoning_parser": "deepseek-r1",
}
separate_reasoning_response_json = requests.post(
    parse_url, json=separate_reasoning_data
).json()
print_highlight("==== Reasoning ====")
print_highlight(separate_reasoning_response_json["reasoning_text"])
print_highlight("==== Text ====")
print_highlight(separate_reasoning_response_json["text"])
[2025-08-14 06:17:03] Prefill batch. #new-seq: 1, #new-token: 12, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 06:17:03] Decode batch. #running-req: 1, #token: 27, token usage: 0.00, cuda graph: True, gen throughput (token/s): 54.37, #queue-req: 0,
[2025-08-14 06:17:03] Decode batch. #running-req: 1, #token: 67, token usage: 0.00, cuda graph: True, gen throughput (token/s): 163.66, #queue-req: 0,
[2025-08-14 06:17:03] Decode batch. #running-req: 1, #token: 107, token usage: 0.01, cuda graph: True, gen throughput (token/s): 163.69, #queue-req: 0,
[2025-08-14 06:17:03] INFO:     127.0.0.1:43856 - "POST /generate HTTP/1.1" 200 OK
==== Original Output ====
I start by identifying the two numbers in the problem: 1 and 3.

Next, I add these two numbers together.

Finally, I calculate the sum to find the answer.


Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:**
We have the numbers 1 and 3.

2. **Add the numbers together:**
\[
1 + 3 = 4
\]

**Answer:**
\[
\boxed{4}
\]
[2025-08-14 06:17:03] INFO:     127.0.0.1:37044 - "POST /separate_reasoning HTTP/1.1" 200 OK
==== Reasoning ====
I start by identifying the two numbers in the problem: 1 and 3.

Next, I add these two numbers together.

Finally, I calculate the sum to find the answer.
==== Text ====
Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:**
We have the numbers 1 and 3.

2. **Add the numbers together:**
\[
1 + 3 = 4
\]

**Answer:**
\[
\boxed{4}
\]
[8]:
terminate_process(server_process)

Offline Engine API#

[9]:
import sglang as sgl
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import print_highlight

llm = sgl.Engine(model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
sampling_params = {
    "max_new_tokens": 1024,
    "skip_special_tokens": False,
    "temperature": 0.6,
    "top_p": 0.95,
}
result = llm.generate(prompt=input, sampling_params=sampling_params)

generated_text = result["text"]  # Assume there is only one prompt

print_highlight("==== Original Output ====")
print_highlight(generated_text)

parser = ReasoningParser("deepseek-r1")
reasoning_text, text = parser.parse_non_stream(generated_text)
print_highlight("==== Reasoning ====")
print_highlight(reasoning_text)
print_highlight("==== Text ====")
print_highlight(text)
W0814 06:17:05.426000 1202875 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:17:05.426000 1202875 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 06:17:13.204000 1206979 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:17:13.204000 1206979 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.36s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.37s/it]

Capturing batches (bs=1 avail_mem=29.99 GB): 100%|██████████| 3/3 [00:00<00:00, 11.11it/s]
==== Original Output ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

Next, I perform the addition of these two numbers to find the total.

Finally, I conclude that the result of adding 1 and 3 is 4.


To solve the problem \(1 + 3\), follow these simple steps:

1. **Start with the first number:**
\(1\)

2. **Add the second number:**
\(1 + 3\)

3. **Calculate the sum:**
\(1 + 3 = 4\)

Therefore, the final answer is:

\[
\boxed{4}
\]
==== Reasoning ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

Next, I perform the addition of these two numbers to find the total.

Finally, I conclude that the result of adding 1 and 3 is 4.
==== Text ====
To solve the problem \(1 + 3\), follow these simple steps:

1. **Start with the first number:**
\(1\)

2. **Add the second number:**
\(1 + 3\)

3. **Calculate the sum:**
\(1 + 3 = 4\)

Therefore, the final answer is:

\[
\boxed{4}
\]
[10]:
llm.shutdown()

Supporting New Reasoning Model Schemas#

For future reasoning models, you can implement the reasoning parser as a subclass of BaseReasoningFormatDetector in python/sglang/srt/reasoning_parser.py and specify the reasoning parser for new reasoning model schemas accordingly.

class DeepSeekR1Detector(BaseReasoningFormatDetector):
    """
    Detector for DeepSeek-R1 family models.

    Supported models:
      - DeepSeek-R1: Always generates thinking content without <think> start tag
      - DeepSeek-R1-0528: Generates thinking content with <think> start tag

    This detector handles both patterns automatically.
    """

    def __init__(self, stream_reasoning: bool = True):
        super().__init__("<think>", "</think>", force_reasoning=True, stream_reasoning=stream_reasoning)


class Qwen3Detector(BaseReasoningFormatDetector):
    """
    Detector for standard Qwen3 models that support enable_thinking parameter.

    These models can switch between thinking and non-thinking modes:
      - enable_thinking=True: Generates <think>...</think> tags
      - enable_thinking=False: No thinking content generated
    """

    def __init__(self, stream_reasoning: bool = True):
        super().__init__("<think>", "</think>", force_reasoning=False, stream_reasoning=stream_reasoning)


class Qwen3ThinkingDetector(BaseReasoningFormatDetector):
    """
    Detector for Qwen3-Thinking models (e.g., Qwen3-235B-A22B-Thinking-2507).

    These models always generate thinking content without <think> start tag.
    They do not support the enable_thinking parameter.
    """

    def __init__(self, stream_reasoning: bool = True):
        super().__init__("<think>", "</think>", force_reasoning=True, stream_reasoning=stream_reasoning)


class ReasoningParser:
    """
    Parser that handles both streaming and non-streaming scenarios.

    Usage:
      # For standard Qwen3 models with enable_thinking support
      parser = ReasoningParser("qwen3")

      # For Qwen3-Thinking models that always think
      parser = ReasoningParser("qwen3-thinking")
    """

    DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
        "deepseek-r1": DeepSeekR1Detector,
        "qwen3": Qwen3Detector,
        "qwen3-thinking": Qwen3ThinkingDetector,
        "kimi": KimiDetector,
    }

    def __init__(self, model_type: str = None, stream_reasoning: bool = True):
        if not model_type:
            raise ValueError("Model type must be specified")

        detector_class = self.DetectorMap.get(model_type.lower())
        if not detector_class:
            raise ValueError(f"Unsupported model type: {model_type}")

        self.detector = detector_class(stream_reasoning=stream_reasoning)

    def parse_non_stream(self, full_text: str) -> Tuple[str, str]:
        """Returns (reasoning_text, normal_text)"""
        ret = self.detector.detect_and_parse(full_text)
        return ret.reasoning_text, ret.normal_text

    def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]:
        """Returns (reasoning_text, normal_text) for the current chunk"""
        ret = self.detector.parse_streaming_increment(chunk_text)
        return ret.reasoning_text, ret.normal_text