Reasoning Parser#

SGLang supports parsing reasoning content our from “normal” content for reasoning models such as DeepSeek R1.

Supported Models#

Currently, SGLang supports the following reasoning models:

Usage#

Launching the Server#

Specify the --reasoning-parser option.

[1]:
import requests
from openai import OpenAI
from sglang.test.test_utils import is_in_ci

if is_in_ci():
    from patch import launch_server_cmd
else:
    from sglang.utils import launch_server_cmd

from sglang.utils import wait_for_server, print_highlight, terminate_process


server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1"
)

wait_for_server(f"http://localhost:{port}")
[2025-03-04 05:24:15] server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30924, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, prefill_only_one_req=False, tp_size=1, stream_interval=1, stream_output=False, random_seed=794947657, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser='deepseek-r1', dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=5, speculative_eagle_topk=4, speculative_num_draft_tokens=8, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, enable_flashinfer_mla=False, flashinfer_mla_disable_ragged=False, warmups=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False)
[2025-03-04 05:24:34 TP0] Init torch distributed begin.
[2025-03-04 05:24:34 TP0] Load weight begin. avail mem=59.64 GB
[2025-03-04 05:24:34 TP0] The following error message 'operation scheduled before its operands' can be ignored.
[2025-03-04 05:24:35 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.64s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.59s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.59s/it]

[2025-03-04 05:24:39 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=29.59 GB
[2025-03-04 05:24:39 TP0] KV Cache is allocated. K size: 0.55 GB, V size: 0.55 GB.
[2025-03-04 05:24:39 TP0] Memory pool end. avail mem=28.28 GB
[2025-03-04 05:24:40 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-03-04 05:24:40] INFO:     Started server process [1867107]
[2025-03-04 05:24:40] INFO:     Waiting for application startup.
[2025-03-04 05:24:40] INFO:     Application startup complete.
[2025-03-04 05:24:40] INFO:     Uvicorn running on http://0.0.0.0:30924 (Press CTRL+C to quit)
[2025-03-04 05:24:40] INFO:     127.0.0.1:47050 - "GET /v1/models HTTP/1.1" 200 OK
[2025-03-04 05:24:41] INFO:     127.0.0.1:56182 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-03-04 05:24:41 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:44] INFO:     127.0.0.1:56196 - "POST /generate HTTP/1.1" 200 OK
[2025-03-04 05:24:44] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.

OpenAI Compatible API#

Using the OpenAI compatible API, the contract follows the DeepSeek API design established with the release of DeepSeek-R1:

  • reasoning_content: The content of the CoT.

  • content: The content of the final answer.

[2]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id

messages = [
    {
        "role": "user",
        "content": "What is 1+3?",
    }
]
[2025-03-04 05:24:45] INFO:     127.0.0.1:56208 - "GET /v1/models HTTP/1.1" 200 OK

Non-Streaming Request#

[3]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": True},
)
print_highlight("==== Reasoning ====")
print_highlight(response_non_stream.choices[0].message.reasoning_content)

print_highlight("==== Text ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-03-04 05:24:46 TP0] Prefill batch. #new-seq: 1, #new-token: 11, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:46 TP0] Decode batch. #running-req: 1, #token: 45, token usage: 0.00, gen throughput (token/s): 6.46, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:46 TP0] Decode batch. #running-req: 1, #token: 85, token usage: 0.00, gen throughput (token/s): 92.39, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:47] INFO:     127.0.0.1:56208 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Reasoning ====
First, I recognize that the problem is asking for the sum of two numbers: 1 and 3.

Next, I add the numbers together: 1 plus 3 equals 4.

Finally, I conclude that the result of 1 plus 3 is 4.
==== Text ====
**Solution:**

We are asked to find the sum of \(1\) and \(3\).

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).

Streaming Request#

[4]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content += chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-03-04 05:24:47] INFO:     127.0.0.1:56208 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-03-04 05:24:47 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:47 TP0] Decode batch. #running-req: 1, #token: 21, token usage: 0.00, gen throughput (token/s): 79.18, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:47 TP0] Decode batch. #running-req: 1, #token: 61, token usage: 0.00, gen throughput (token/s): 89.28, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:48 TP0] Decode batch. #running-req: 1, #token: 101, token usage: 0.00, gen throughput (token/s): 91.24, largest-len: 0, #queue-req: 0,
==== Reasoning ====
First, I identify the two numbers to be added: 1 and 3.

Next, I perform the addition of these two numbers.

Finally, I calculate the sum to determine that 1 plus 3 equals 4.
==== Text ====


**Solution:**

We need to calculate the sum of the numbers 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).

Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content).

[5]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True, "stream_reasoning": False},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content = chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
[2025-03-04 05:24:48] INFO:     127.0.0.1:56208 - "POST /v1/chat/completions HTTP/1.1" 200 OK
[2025-03-04 05:24:48 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:48 TP0] Decode batch. #running-req: 1, #token: 48, token usage: 0.00, gen throughput (token/s): 85.39, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:49 TP0] Decode batch. #running-req: 1, #token: 88, token usage: 0.00, gen throughput (token/s): 91.07, largest-len: 0, #queue-req: 0,
==== Reasoning ====
First, I need to calculate the sum of 1 and 3.

Adding these two numbers together gives me 4.

Therefore, the final answer is 4.
==== Text ====


**Solution:**

We need to calculate the sum of 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).

The reasoning separation is enable by default when specify . To disable it, set the ``separate_reasoning`` option to ``False`` in request.

[6]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": False},
)

print_highlight("==== Original Output ====")
print_highlight(response_non_stream.choices[0].message.content)
[2025-03-04 05:24:49 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 11, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:49 TP0] Decode batch. #running-req: 1, #token: 49, token usage: 0.00, gen throughput (token/s): 84.70, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:50 TP0] Decode batch. #running-req: 1, #token: 89, token usage: 0.00, gen throughput (token/s): 90.76, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:50 TP0] Decode batch. #running-req: 1, #token: 129, token usage: 0.01, gen throughput (token/s): 77.10, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:50] INFO:     127.0.0.1:56208 - "POST /v1/chat/completions HTTP/1.1" 200 OK
==== Original Output ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

Next, I perform the addition by combining the two numbers: 1 plus 3 equals 4.

Therefore, the final answer is 4.


**Solution:**

To find the sum of 1 and 3, follow these simple steps:

1. **Start with the number 1.**
2. **Add the number 3 to it.**
3. **Calculate the total.**

\[
1 + 3 = 4
\]

**Final Answer:** \(\boxed{4}\)

SGLang Native API#

[7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)

gen_url = f"http://localhost:{port}/generate"
gen_data = {
    "text": input,
    "sampling_params": {
        "skip_special_tokens": False,
        "max_new_tokens": 1024,
        "temperature": 0.6,
        "top_p": 0.95,
    },
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]

print_highlight("==== Original Output ====")
print_highlight(gen_response)

parse_url = f"http://localhost:{port}/separate_reasoning"
separate_reasoning_data = {
    "text": gen_response,
    "reasoning_parser": "deepseek-r1",
}
separate_reasoning_response_json = requests.post(
    parse_url, json=separate_reasoning_data
).json()
print_highlight("==== Reasoning ====")
print_highlight(separate_reasoning_response_json["reasoning_text"])
print_highlight("==== Text ====")
print_highlight(separate_reasoning_response_json["text"])
[2025-03-04 05:24:56 TP0] Prefill batch. #new-seq: 1, #new-token: 12, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-03-04 05:24:56 TP0] Decode batch. #running-req: 1, #token: 42, token usage: 0.00, gen throughput (token/s): 6.40, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:57 TP0] Decode batch. #running-req: 1, #token: 82, token usage: 0.00, gen throughput (token/s): 89.78, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:57 TP0] Decode batch. #running-req: 1, #token: 122, token usage: 0.01, gen throughput (token/s): 91.18, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:58 TP0] Decode batch. #running-req: 1, #token: 162, token usage: 0.01, gen throughput (token/s): 81.47, largest-len: 0, #queue-req: 0,
[2025-03-04 05:24:58] INFO:     127.0.0.1:56276 - "POST /generate HTTP/1.1" 200 OK
==== Original Output ====
First, I need to identify the two numbers in the problem, which are 1 and 3.

Next, I'll add these two numbers together.

After performing the addition, I find that the sum of 1 and 3 is 4.


Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Perform the addition:**
\[
1 + 3 = 4
\]

3. **Final Answer:**
\[
\boxed{4}
\]
[2025-03-04 05:24:58] INFO:     127.0.0.1:56292 - "POST /separate_reasoning HTTP/1.1" 200 OK
==== Reasoning ====
First, I need to identify the two numbers in the problem, which are 1 and 3.

Next, I'll add these two numbers together.

After performing the addition, I find that the sum of 1 and 3 is 4.
==== Text ====
Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Perform the addition:**
\[
1 + 3 = 4
\]

3. **Final Answer:**
\[
\boxed{4}
\]
[8]:
terminate_process(server_process)

Offline Engine API#

[9]:
import sglang as sgl
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.utils import print_highlight

llm = sgl.Engine(model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
sampling_params = {
    "max_new_tokens": 1024,
    "skip_special_tokens": False,
    "temperature": 0.6,
    "top_p": 0.95,
}
result = llm.generate(prompt=input, sampling_params=sampling_params)

generated_text = result["text"]  # Assume there is only one prompt

print_highlight("==== Original Output ====")
print_highlight(generated_text)

parser = ReasoningParser("deepseek-r1")
reasoning_text, text = parser.parse_non_stream(generated_text)
print_highlight("==== Reasoning ====")
print_highlight(reasoning_text)
print_highlight("==== Text ====")
print_highlight(text)
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.68s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.60s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.61s/it]

==== Original Output ====
To solve the problem of adding 1 and 3, I start by identifying the two numbers involved.

Next, I add these numbers together to find the total.

Finally, I conclude that the sum of 1 and 3 is 4.


**Solution:**

We are asked to find the sum of 1 and 3.

1. **Identify the numbers to add:**
- First number: 1
- Second number: 3

2. **Perform the addition:**
\[
1 + 3 = 4
\]

3. **State the final answer:**
\[
\boxed{4}
\]

**Answer:** \(\boxed{4}\)
==== Reasoning ====
To solve the problem of adding 1 and 3, I start by identifying the two numbers involved.

Next, I add these numbers together to find the total.

Finally, I conclude that the sum of 1 and 3 is 4.
==== Text ====
**Solution:**

We are asked to find the sum of 1 and 3.

1. **Identify the numbers to add:**
- First number: 1
- Second number: 3

2. **Perform the addition:**
\[
1 + 3 = 4
\]

3. **State the final answer:**
\[
\boxed{4}
\]

**Answer:** \(\boxed{4}\)
[10]:
llm.shutdown()

Supporting New Reasoning Model Schemas#

For future reasoning models, you can implement the reasoning parser as a subclass of BaseReasoningFormatDetector in python/sglang/srt/reasoning_parser.py and specify the reasoning parser for new reasoning model schemas accordingly.

class DeepSeekR1Detector(BaseReasoningFormatDetector):
    """
    Detector for DeepSeek-R1 model.
    Assumes reasoning format:
      (<think>)*(.*)</think>
    Returns all the text before the </think> tag as `reasoning_text`
    and the rest of the text as `normal_text`.

    Args:
        stream_reasoning (bool): If False, accumulates reasoning content until the end tag.
            If True, streams reasoning content as it arrives.
    """

    def __init__(self, stream_reasoning: bool = False):
        # DeepSeek-R1 is assumed to be reasoning until `</think>` token
        super().__init__("<think>", "</think>", True, stream_reasoning=stream_reasoning)
        # https://github.com/sgl-project/sglang/pull/3202#discussion_r1950153599


class ReasoningParser:
    """
    Parser that handles both streaming and non-streaming scenarios for extracting
    reasoning content from model outputs.

    Args:
        model_type (str): Type of model to parse reasoning from
        stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
            If True, streams reasoning content as it arrives.
    """

    DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
        "deepseek-r1": DeepSeekR1Detector
    }

    def __init__(self, model_type: str = None, stream_reasoning: bool = True):
        if not model_type:
            raise ValueError("Model type must be specified")

        detector_class = self.DetectorMap.get(model_type.lower())
        if not detector_class:
            raise ValueError(f"Unsupported model type: {model_type}")

        self.detector = detector_class(stream_reasoning=stream_reasoning)

    def parse_non_stream(self, full_text: str) -> StreamingParseResult:
        """Non-streaming call: one-time parsing"""
        ret = self.detector.detect_and_parse(full_text)
        return ret.reasoning_text, ret.normal_text

    def parse_stream_chunk(self, chunk_text: str) -> StreamingParseResult:
        """Streaming call: incremental parsing"""
        ret = self.detector.parse_streaming_increment(chunk_text)
        return ret.reasoning_text, ret.normal_text