How to Support New Models#
This document explains how to add support for new language models and multimodal large language models (MLLMs) in SGLang. It also covers how to test new models and register external implementations.
How to Support a New Language Model#
To support a new model in SGLang, you only need to add a single file under the SGLang Models Directory. You can learn from existing model implementations and create a new file for your model. For most models, you should be able to find a similar model to start with (e.g., starting from Llama). Also refer how to port a Model from vLLM to SGLang
How to Support a New Multimodal Large Language Model#
To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the standard LLM support:
Register your new model as multimodal: Extend
is_multimodal_model
in model_config.py to returnTrue
for your model.Register a new chat-template: Only when your default chat-template is unable to accept images as input: Register a new chat template in conversation.py and the corresponding matching function.
Multimodal Data Processor: Define a new
Processor
class that inherits fromBaseMultimodalProcessor
and register this processor as your model’s dedicated processor. See multimodal_processor.py for more details.Handle Multimodal Tokens: Implement a
pad_input_ids
function for your new model. In this function, multimodal tokens in the prompt should be expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data withRadixAttention
.Handle Image Feature Extraction: Implement a
get_image_feature
function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.Adapt to Vision Attention: Adapt the multi-headed
Attention
of ViT with SGLang’sVisionAttention
.
You can refer to Qwen2VL or other mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs.
Testing and Debugging#
Please note all your testing and benchmarking results in PR description.
Interactive Debugging#
For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands should give the same text output and very similar prefill logits:
Get the reference output:
python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,mllm}
Get the SGLang output:
python3 -m sglang.bench_one_batch --correct --model [new model]
Add the Model to the Test Suite#
To ensure the new model is well maintained, add it to the test suite by including it in the ALL_OTHER_MODELS
list in
the test_generation_models.py
file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,
MMMU-Pro, etc.) in your PR. \
For VLMs, also include a test in test_vision_openai_server_{x}.py
(e.g. test_vision_openai_server_a.py, test_vision_openai_server_b.py).
This is an example command to run to test a new model on your local machine:
ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
Benchmark#
(Required) MMMU: follow MMMU benchmark README.md to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.ai/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
(Optional) Other evals: If you ran other evals, please note the results in PR description.
Port a Model from vLLM to SGLang#
The vLLM Models Directory is a valuable resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models from vLLM to SGLang.
To port a model from vLLM to SGLang:
Compare these two files for guidance:
The major differences include:
Replace vLLM’s
Attention
withRadixAttention
(ensure you passlayer_id
toRadixAttention
).Replace vLLM’s
LogitsProcessor
with SGLang’sLogitsProcessor
.Replace the multi-headed
Attention
of ViT with SGLang’sVisionAttention
.Replace other vLLM layers (such as
RMSNorm
,SiluAndMul
) with SGLang layers.Remove
Sample
.Change the
forward()
functions and add aforward_batch()
method.Add
EntryClass
at the end.Ensure that the new implementation uses only SGLang components and does not rely on any vLLM components.
Note: make sure you add your new model to the supported models list in the supported models documentation.
Registering an External Model Implementation#
In addition to the methods above, you can register your new model with the ModelRegistry
before launching the server.
This allows you to integrate your model without modifying the source code.
For example:
from sglang.srt.models.registry import ModelRegistry
from sglang.srt.entrypoints.http_server import launch_server
# For a single model, add it to the registry:
ModelRegistry.models[model_name] = model_class
# For multiple models, you can imitate the import_model_classes() function:
from functools import lru_cache
@lru_cache()
def import_new_model_classes():
model_arch_name_to_cls = {}
# Populate model_arch_name_to_cls with your new model classes.
...
return model_arch_name_to_cls
ModelRegistry.models.update(import_new_model_classes())
# Launch the server with your server arguments:
launch_server(server_args)
Example: Implementing and Serving a Llama Wrapper Model#
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the Offline Engine.
Implementing Our Model#
To keep things simple, this new model will be a simple wrapper around Llama 3.1-8B-Instruct, and our goal will be just to bias the output logits for each forward
call by taking the square root of each individual logit.
Let’s start by defining our model in a file called llama_wrapper.py
.
The first step is to import the necessary libraries from SRT, which is SGLang’s internal backend.
# In the file `llama_wrapper.py`
import torch
from transformers import LlamaConfig
from typing import Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaForCausalLM
Next, we declare a new class
for our model and have it inherit from LlamaForCausalLM
, which allows our model to access LlamaForCausalLM
’s predefined modules and layers, such as LlamaAttention
and LlamaMLP
.
Note that almost all model implementations take in config
and quant_config
as arguments for their __init__
method; config
and quant_config
are passed in via model_loader/loader.py
.
Because we have inherited from LlamaForCausalLM
, we can pass our parameters directly to its constructor, which will set the member variables for us.
class LlamaWrapper(LlamaForCausalLM):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
Now, we want to define the forward
method, which is what will be called at inference time.
Note that the signature for forward
is essentially the same for any model; you can take a look at the other models defined in the models
directory for references.
To see where exactly forward
is called in the SGLang runtime’s internals, take a look at forward_decode
and forward_extend
in the ModelRunner
class.
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
input_embeds: Optional[torch.Tensor] = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
We now call the __call__
method for self.model
(which is a member variable that LlamaForCausalLM
defines in its __init__
method), which eventually calls LlamaForCausalLM
’s forward
method.
After that, we feed the hidden_states
into our model’s LogitsProcessor
(again defined in LlamaForCausalLM
).
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
res: LogitsProcessorOutput = self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
)
After receiving the logits for the next token, we can finally perform our biasing step.
orig_logits = res.next_token_logits
res.next_token_logits = torch.where(
orig_logits > 0,
orig_logits.sqrt(),
orig_logits
)
return res
Now, our LlamaWrapper
model is created and ready to be served!
Serving Our Model Via SGLang’s Offline Engine#
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
First, create a new file called run.py
.
Now, we must ensure that SGLang’s ModelRegistry
can find our model.
To do this, we first download the model’s configuration and weights from Huggingface.
# In the file `run.py`
import asyncio
from functools import lru_cache
from huggingface_hub import snapshot_download
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
import sglang as sgl
from sglang.srt.models.registry import ModelRegistry
# Make sure to request access to this model on Huggingface, then export your
# `HF_TOKEN` to download the model snapshot
llama_dir = snapshot_download(
repo_id="meta-llama/Llama-3.1-8B-Instruct",
local_dir="./llama_ckpt",
)
Now that we have our model on disk, we want to point it to LlamaWrapper
by changing the architectures
field in ./llama_ckpt/config.json
to be LlamaWrapper
.
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use “LlamaWrapper” instead of “LlamaForCausalLM” as our model.
{
"architectures": [
# "LlamaForCausalLM"
"LlamaWrapper"
],
...
}
However, if we don’t link our LlamaWrapper
class to the “LlamaWrapper” registry keyword, then SGLang won’t be able to find our model.
Thus, to register our LlamaWrapper
, we want to follow the steps in the above section titled “Registering an External Model Implementation”.
@lru_cache()
def import_new_model_classes():
model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
return model_arch_name_to_cls
ModelRegistry.models.update(import_new_model_classes())
Lastly, when we create our Engine
, we just pass in the path to the local model directory.
Then, our LlamaWrapper
is ready to be served; for this walkthrough, we will use SGLang Engine
’s non-streaming asynchronous generation endpoint.
def main():
llm = sgl.Engine(model_path="./llama_ckpt")
sampling_params = {"temperature": 0.2, "top_k": 5}
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
asyncio.run(run_llm(llm, sampling_params, prompts))
llm.shutdown()
async def run_llm(
llm,
sampling_params,
prompts,
) -> None:
outputs = await llm.async_generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print(f"\nPrompt: {prompt}")
print(f"Generated text: {output['text']}")
if __name__ == "__main__":
main()
Now, when we call python run.py
, we will get the outputs of our newly created model!
Documentation#
Add to table of supported models in generative_models.md or multimodal_language_models.md
By following these guidelines, you can add support for new language models and multimodal large language models in SGLang and ensure they are thoroughly tested and easily integrated into the system.