Reasoning Parser#

SGLang supports parsing reasoning content our from “normal” content for reasoning models such as DeepSeek R1.

Supported Models#

Currently, SGLang supports the following reasoning models:

  • DeepSeek R1 series: The reasoning content is wrapped with <think> and </think> tags.

  • QwQ: The reasoning content is wrapped with <think> and </think> tags.

Usage#

Launching the Server#

Specify the --reasoning-parser option.

[1]:
import requests
from openai import OpenAI
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process


server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1"
)

wait_for_server(f"http://localhost:{port}")
[2025-05-07 07:08:31] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', skip_tokenizer_init=False, enable_tokenizer_batch_encode=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=37323, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=611511635, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='deepseek-r1', dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
[2025-05-07 07:08:41] Attention backend not set. Use fa3 backend by default.
[2025-05-07 07:08:41] Init torch distributed begin.
[2025-05-07 07:08:42] Init torch distributed ends. mem usage=0.00 GB
[2025-05-07 07:08:42] Load weight begin. avail mem=78.58 GB
[2025-05-07 07:08:42] Ignore import error when loading sglang.srt.models.llama4.
[2025-05-07 07:08:43] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.44s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.42s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.42s/it]

[2025-05-07 07:08:46] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=50.98 GB, mem usage=27.61 GB.
[2025-05-07 07:08:47] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-05-07 07:08:47] Memory pool end. avail mem=49.61 GB
[2025-05-07 07:08:47] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-05-07 07:08:48] INFO:     Started server process [1421886]
[2025-05-07 07:08:48] INFO:     Waiting for application startup.
[2025-05-07 07:08:48] INFO:     Application startup complete.
[2025-05-07 07:08:48] INFO:     Uvicorn running on http://0.0.0.0:37323 (Press CTRL+C to quit)
[2025-05-07 07:08:49] INFO:     127.0.0.1:57068 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-07 07:08:49] INFO:     127.0.0.1:57082 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-07 07:08:49] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[2025-05-07 07:08:51] INFO:     127.0.0.1:57086 - "POST /generate HTTP/1.1" 200 OK
[2025-05-07 07:08:51] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.

Note that --reasoning-parser defines the parser used to interpret responses. Currently supported parsers include:

  • deepseek-r1: DeepSeek R1 series and QwQ (e.g. deepseek-ai/DeepSeek-R1, Qwen/QwQ-32B).

OpenAI Compatible API#

Using the OpenAI compatible API, the contract follows the DeepSeek API design established with the release of DeepSeek-R1:

  • reasoning_content: The content of the CoT.

  • content: The content of the final answer.

[2]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id

messages = [
    {
        "role": "user",
        "content": "What is 1+3?",
    }
]
[2025-05-07 07:08:54] INFO:     127.0.0.1:47862 - "GET /v1/models HTTP/1.1" 200 OK

Non-Streaming Request#

[3]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": True},
)
print_highlight("==== Reasoning ====")
print_highlight(response_non_stream.choices[0].message.reasoning_content)

print_highlight("==== Text ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-05-07 07:08:54] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-07 07:08:54] Decode batch. #running-req: 1, #token: 45, token usage: 0.00, gen throughput (token/s): 5.82, #queue-req: 0
[2025-05-07 07:08:54] Decode batch. #running-req: 1, #token: 85, token usage: 0.00, gen throughput (token/s): 109.74, #queue-req: 0
[2025-05-07 07:08:55] INFO:     127.0.0.1:47862 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Reasoning ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

Next, I perform the addition by adding these two numbers together.

Finally, I arrive at the conclusion that 1 plus 3 equals 4.
==== Text ====
**Solution:**

We are asked to find the sum of 1 and 3.

1+3=4

Therefore, the final answer is 4.

Streaming Request#

[4]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content += chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-05-07 07:08:55] INFO:     127.0.0.1:47862 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-05-07 07:08:55] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-07 07:08:55] Decode batch. #running-req: 1, #token: 28, token usage: 0.00, gen throughput (token/s): 97.58, #queue-req: 0
[2025-05-07 07:08:55] Decode batch. #running-req: 1, #token: 68, token usage: 0.00, gen throughput (token/s): 105.91, #queue-req: 0
[2025-05-07 07:08:56] Decode batch. #running-req: 1, #token: 108, token usage: 0.01, gen throughput (token/s): 108.08, #queue-req: 0
[2025-05-07 07:08:56] Decode batch. #running-req: 1, #token: 148, token usage: 0.01, gen throughput (token/s): 108.55, #queue-req: 0
==== Reasoning ====
To solve the problem of adding 1 and 3, I start by identifying the two numbers involved.

Next, I perform the addition operation by combining these numbers together.

Finally, I calculate the sum to determine the result of the addition.
==== Text ====


To solve the addition problem 1+3, follow these simple steps:

1. **Identify the numbers to add:**
1and3

2. **Perform the addition:**
1+3=4

3. **Present the final answer:**
4

So, 1+3 equals 4.

Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content).

[5]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True, "stream_reasoning": False},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content = chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-05-07 07:08:56] INFO:     127.0.0.1:47862 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-05-07 07:08:56] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-07 07:08:56] Decode batch. #running-req: 1, #token: 27, token usage: 0.00, gen throughput (token/s): 91.47, #queue-req: 0
[2025-05-07 07:08:57] Decode batch. #running-req: 1, #token: 67, token usage: 0.00, gen throughput (token/s): 107.42, #queue-req: 0
==== Reasoning ====
I need to add the numbers 1 and 3.

Starting with 1, I add 3 to it.

The sum of 1 and 3 is 4.
==== Text ====


**Solution:**

We need to find the sum of 1 and 3.

1+3=4

Therefore, the final answer is 4.

The reasoning separation is enable by default when specify . To disable it, set the ``separate_reasoning`` option to ``False`` in request.

[6]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": False},
)

print_highlight("==== Original Output ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-05-07 07:08:57] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-07 07:08:57] Decode batch. #running-req: 1, #token: 24, token usage: 0.00, gen throughput (token/s): 88.74, #queue-req: 0
[2025-05-07 07:08:58] Decode batch. #running-req: 1, #token: 64, token usage: 0.00, gen throughput (token/s): 110.42, #queue-req: 0
[2025-05-07 07:08:58] Decode batch. #running-req: 1, #token: 104, token usage: 0.01, gen throughput (token/s): 109.72, #queue-req: 0
[2025-05-07 07:08:58] Decode batch. #running-req: 1, #token: 144, token usage: 0.01, gen throughput (token/s): 108.78, #queue-req: 0
[2025-05-07 07:08:58] INFO:     127.0.0.1:47862 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Original Output ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I add the two numbers together: 1 plus 3 equals 4.

Therefore, the final answer is 4.


Sure! Let's solve the problem step by step.

**Question:** What is 1+3?

**Solution:**

1. **Identify the numbers to add:**
1and3

2. **Add the numbers together:**
1+3=4

**Final Answer:**
4

SGLang Native API#

[7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

gen_url = f"http://localhost:{port}/generate"
gen_data = {
    "text": input,
    "sampling_params": {
        "skip_special_tokens": False,
        "max_new_tokens": 1024,
        "temperature": 0.6,
        "top_p": 0.95,
    },
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]

print_highlight("==== Original Output ====")
print_highlight(gen_response)

parse_url = f"http://localhost:{port}/separate_reasoning"
separate_reasoning_data = {
    "text": gen_response,
    "reasoning_parser": "deepseek-r1",
}
separate_reasoning_response_json = requests.post(
    parse_url, json=separate_reasoning_data
).json()
print_highlight("==== Reasoning ====")
print_highlight(separate_reasoning_response_json["reasoning_text"])
print_highlight("==== Text ====")
print_highlight(separate_reasoning_response_json["text"])
[2025-05-07 07:08:59] Prefill batch. #new-seq: 1, #new-token: 12, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-07 07:08:59] Decode batch. #running-req: 1, #token: 38, token usage: 0.00, gen throughput (token/s): 45.42, #queue-req: 0
[2025-05-07 07:09:00] Decode batch. #running-req: 1, #token: 78, token usage: 0.00, gen throughput (token/s): 109.77, #queue-req: 0
[2025-05-07 07:09:00] Decode batch. #running-req: 1, #token: 118, token usage: 0.01, gen throughput (token/s): 109.69, #queue-req: 0
[2025-05-07 07:09:00] INFO:     127.0.0.1:47872 - "POST /generate HTTP/1.1" 200 OK
==== Original Output ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I perform the addition by combining the two numbers: 1 plus 3 equals 4.

Therefore, the final answer is 4.


**Solution:**

We are asked to find the sum of 1 and 3.

1. Start with the number **1**.
2. Add the number **3** to it.
3. The sum of 1 and 3 is:

1+3=4

**Final Answer:** 4
[2025-05-07 07:09:00] INFO:     127.0.0.1:51698 - "POST /separate_reasoning HTTP/1.1" 200 OK
==== Reasoning ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I perform the addition by combining the two numbers: 1 plus 3 equals 4.

Therefore, the final answer is 4.
==== Text ====
**Solution:**

We are asked to find the sum of 1 and 3.

1. Start with the number **1**.
2. Add the number **3** to it.
3. The sum of 1 and 3 is:

1+3=4

**Final Answer:** 4
[8]:
terminate_process(server_process)
[2025-05-07 07:09:00] Child process unexpectedly failed with an exit code 9. pid=1422240
[2025-05-07 07:09:00] Child process unexpectedly failed with an exit code 9. pid=1422110

Offline Engine API#

[9]:
import sglang as sgl
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import print_highlight

llm = sgl.Engine(model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
sampling_params = {
    "max_new_tokens": 1024,
    "skip_special_tokens": False,
    "temperature": 0.6,
    "top_p": 0.95,
}
result = llm.generate(prompt=input, sampling_params=sampling_params)

generated_text = result["text"]  # Assume there is only one prompt

print_highlight("==== Original Output ====")
print_highlight(generated_text)

parser = ReasoningParser("deepseek-r1")
reasoning_text, text = parser.parse_non_stream(generated_text)
print_highlight("==== Reasoning ====")
print_highlight(reasoning_text)
print_highlight("==== Text ====")
print_highlight(text)
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.49s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.44s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.45s/it]

==== Original Output ====
I need to calculate the sum of 1 and 3.

First, I'll add the two numbers together.

Then, I'll provide the result as the final answer.


Sure! Let's solve the problem step by step.

**Question:** What is 1+3?

**Solution:**

1. **Start with the first number:**

1

2. **Add the second number to it:**

1+3

3. **Calculate the sum:**

1+3=4

**Answer:** 4
==== Reasoning ====
I need to calculate the sum of 1 and 3.

First, I'll add the two numbers together.

Then, I'll provide the result as the final answer.
==== Text ====
Sure! Let's solve the problem step by step.

**Question:** What is 1+3?

**Solution:**

1. **Start with the first number:**

1

2. **Add the second number to it:**

1+3

3. **Calculate the sum:**

1+3=4

**Answer:** 4
[10]:
llm.shutdown()

Supporting New Reasoning Model Schemas#

For future reasoning models, you can implement the reasoning parser as a subclass of BaseReasoningFormatDetector in python/sglang/srt/reasoning_parser.py and specify the reasoning parser for new reasoning model schemas accordingly.

class DeepSeekR1Detector(BaseReasoningFormatDetector):
    """
    Detector for DeepSeek-R1 model.
    Assumes reasoning format:
      (<think>)*(.*)</think>
    Returns all the text before the </think> tag as `reasoning_text`
    and the rest of the text as `normal_text`.

    Args:
        stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
            If True, streams reasoning content as it arrives.
    """

    def __init__(self, stream_reasoning: bool = False):
        # DeepSeek-R1 is assumed to be reasoning until `</think>` token
        super().__init__("<think>", "</think>", True, stream_reasoning=stream_reasoning)
        # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599


class ReasoningParser:
    """
    Parser that handles both streaming and non-streaming scenarios for extracting
    reasoning content from model outputs.

    Args:
        model_type (str): Type of model to parse reasoning from
        stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
            If True, streams reasoning content as it arrives.
    """

    DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
        "deepseek-r1": DeepSeekR1Detector
    }

    def __init__(self, model_type: str = None, stream_reasoning: bool = True):
        if not model_type:
            raise ValueError("Model type must be specified")

        detector_class = self.DetectorMap.get(model_type.lower())
        if not detector_class:
            raise ValueError(f"Unsupported model type: {model_type}")

        self.detector = detector_class(stream_reasoning=stream_reasoning)

    def parse_non_stream(self, full_text: str) -> StreamingParseResult:
        """Non-streaming call: one-time parsing"""
        ret = self.detector.detect_and_parse(full_text)
        return ret.reasoning_text, ret.normal_text

    def parse_stream_chunk(self, chunk_text: str) -> StreamingParseResult:
        """Streaming call: incremental parsing"""
        ret = self.detector.parse_streaming_increment(chunk_text)
        return ret.reasoning_text, ret.normal_text