Sampling Parameters#
This doc describes the sampling parameters of the SGLang Runtime. It is the low-level endpoint of the runtime.
If you want a high-level endpoint that can automatically handle chat templates, consider using the OpenAI Compatible API.
/generate
Endpoint#
The /generate
endpoint accepts the following parameters in JSON format. For detailed usage, see the native API doc.
Argument |
Type/Default |
Description |
---|---|---|
text |
|
The input prompt. Can be a single prompt or a batch of prompts. |
input_ids |
|
Alternative to |
sampling_params |
|
The sampling parameters as described in the sections below. |
return_logprob |
|
Whether to return log probabilities for tokens. |
logprob_start_len |
|
If returning log probabilities, specifies the start position in the prompt. Default is “-1”, which returns logprobs only for output tokens. |
top_logprobs_num |
|
If returning log probabilities, specifies the number of top logprobs to return at each position. |
stream |
|
Whether to stream the output. |
lora_path |
|
Path to LoRA weights. |
custom_logit_processor |
|
Custom logit processor for advanced sampling control. For usage see below. |
return_hidden_states |
|
Whether to return hidden states of the model. Note that each time it changes, the CUDA graph will be recaptured, which might lead to a performance hit. See the examples for more information. |
Sampling parameters#
Core parameters#
Argument |
Type/Default |
Description |
---|---|---|
max_new_tokens |
|
The maximum output length measured in tokens. |
stop |
|
One or multiple stop words. Generation will stop if one of these words is sampled. |
stop_token_ids |
|
Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
temperature |
|
Temperature when sampling the next token. |
top_p |
|
Top-p selects tokens from the smallest sorted set whose cumulative probability exceeds |
top_k |
|
Top-k randomly selects from the |
min_p |
|
Min-p samples from tokens with probability larger than |
Penalizers#
Argument |
Type/Default |
Description |
---|---|---|
frequency_penalty |
|
Penalizes tokens based on their frequency in generation so far. Must be between |
presence_penalty |
|
Penalizes tokens if they appeared in the generation so far. Must be between |
min_new_tokens |
|
Forces the model to generate at least |
Constrained decoding#
Please refer to our dedicated guide on constrained decoding for the following parameters.
Argument |
Type/Default |
Description |
---|---|---|
json_schema |
|
JSON schema for structured outputs. |
regex |
|
Regex for structured outputs. |
ebnf |
|
EBNF for structured outputs. |
Other options#
Argument |
Type/Default |
Description |
---|---|---|
n |
|
Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; repeating the same prompts several times offers better control and efficiency.) |
spaces_between_special_tokens |
|
Whether or not to add spaces between special tokens during detokenization. |
no_stop_trim |
|
Don’t trim stop words or EOS token from the generated text. |
continue_final_message |
|
When enabled, the final assistant message is removed and its content is used as a prefill so that the model continues that message instead of starting a new turn. See openai_chat_with_response_prefill.py for examples. |
ignore_eos |
|
Don’t stop generation when EOS token is sampled. |
skip_special_tokens |
|
Remove special tokens during decoding. |
custom_params |
|
Used when employing |
Examples#
Normal#
Launch a server:
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
Send a request:
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
Detailed example in send request.
Streaming#
Send a request and stream the output:
import requests, json
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
Detailed example in openai compatible api.
Multimodal#
Launch a server:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov
Download an image:
curl -o example_image.png -L https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true
Send a request:
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n"
"<|im_start|>assistant\n",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
The image_data
can be a file name, a URL, or a base64 encoded string. See also python/sglang/srt/utils.py:load_image
.
Streaming is supported in a similar manner as above.
Detailed example in openai api vision.
Structured Outputs (JSON, Regex, EBNF)#
You can specify a JSON schema, regular expression or EBNF to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (json_schema
, regex
, or ebnf
) can be specified for a request.
SGLang supports two grammar backends:
Outlines: Supports JSON schema and regular expression constraints.
XGrammar (default): Supports JSON schema, regular expression, and EBNF constraints.
XGrammar currently uses the GGML BNF format.
If instead you want to initialize the Outlines backend, you can use --grammar-backend outlines
flag:
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: xgrammar)
import json
import requests
json_schema = json.dumps({
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
})
# JSON (works with both Outlines and XGrammar)
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print(response.json())
# Regular expression (Outlines backend only)
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "Paris is the capital of",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"regex": "(France|England)",
},
},
)
print(response.json())
# EBNF (XGrammar backend only)
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "Write a greeting.",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"ebnf": 'root ::= "Hello" | "Hi" | "Hey"',
},
},
)
print(response.json())
Detailed example in structured outputs.
Custom logit processor#
Launch a server with --enable-custom-logit-processor
flag on.
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor
Define a custom logit processor that will always sample a specific token id.
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
# Check that the number of logits matches the number of custom parameters
assert logits.shape[0] == len(custom_param_list)
key = "token_id"
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[key]] = 0.0
return logits
Send a request:
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "The capital of France is",
"custom_logit_processor": DeterministicLogitProcessor().to_str(),
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 32,
"custom_params": {"token_id": 5},
},
},
)
print(response.json())