Query Vision Language Model#
Querying Qwen-VL#
[1]:
import nest_asyncio
nest_asyncio.apply() # Run this first.
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
chat_template = "qwen2-vl"
[2]:
# Lets create a prompt.
from io import BytesIO
import requests
from PIL import Image
from sglang.srt.conversation import chat_templates
image = Image.open(
BytesIO(
requests.get(
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
).content
)
)
conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]
print(conv.get_prompt())
image
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What's shown here: <|vision_start|><|image_pad|><|vision_end|>?<|im_end|>
<|im_start|>assistant
[2]:

Query via the offline Engine API#
[3]:
from sglang import Engine
llm = Engine(
model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8
)
W0814 06:20:35.539000 1220226 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:20:35.539000 1220226 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 06:20:38] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
W0814 06:20:44.338000 1220697 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:20:44.338000 1220697 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 06:20:45] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:00<00:00, 1.40it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00, 1.44it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00, 1.43it/s]
Capturing batches (bs=1 avail_mem=10.19 GB): 100%|██████████| 23/23 [00:04<00:00, 4.72it/s]
[4]:
out = llm.generate(prompt=conv.get_prompt(), image_data=[image])
print(out["text"])
The image shows a New York City street scene with a person hanging clothes on a clothesline attached to the back of a New York City taxi. The background includes a storefront with various goods on display and another taxi driving by. Theगर्दिओं are likely doing laundry on the street, which is a common occurrence in New York City.
Query via the offline Engine API, but send precomputed embeddings#
[5]:
# Compute the image embeddings using Huggingface.
from transformers import AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
vision = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()
)
[6]:
processed_prompt = processor(
images=[image], text=conv.get_prompt(), return_tensors="pt"
)
input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()
precomputed_embeddings = vision(
processed_prompt["pixel_values"].cuda(), processed_prompt["image_grid_thw"].cuda()
)
mm_item = dict(
modality="IMAGE",
image_grid_thw=processed_prompt["image_grid_thw"],
precomputed_embeddings=precomputed_embeddings,
)
out = llm.generate(input_ids=input_ids, image_data=[mm_item])
print(out["text"])
The image shows a scene in an urban area with a yellow street and a few cars parked along the side. Specifically:
- There are two yellow taxis parked on the street.
- A person is washing clothes using an upright washing appliance mounted on a folding stand.
- The washing appliance is blue and has a cord connected to it, indicating it's an electric washing machine.
- The person is standing alone, seemingly engrossed in the task.
- There areometers or billboards along the street with various advertisements.
- The backdrop includes a multi-story building with windows and国有企业 signage.
This setup appears to be for a protest against car ownership
Querying Llama 4 (Vision)#
[7]:
import nest_asyncio
nest_asyncio.apply() # Run this first.
model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
chat_template = "llama-4"
[8]:
# Lets create a prompt.
from io import BytesIO
import requests
from PIL import Image
from sglang.srt.conversation import chat_templates
image = Image.open(
BytesIO(
requests.get(
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
).content
)
)
conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]
print(conv.get_prompt())
print(f"Image size: {image.size}")
image
<|header_start|>user<|header_end|>
What's shown here: <|image|>?<|eot|><|header_start|>assistant<|header_end|>
Image size: (570, 380)
[8]:

Query via the offline Engine API#
[9]:
from sglang.test.test_utils import is_in_ci
if not is_in_ci():
from sglang import Engine
llm = Engine(
model_path=model_path,
trust_remote_code=True,
enable_multimodal=True,
mem_fraction_static=0.8,
tp_size=4,
attention_backend="fa3",
context_length=65536,
)
[10]:
if not is_in_ci():
out = llm.generate(prompt=conv.get_prompt(), image_data=[image])
print(out["text"])
Query via the offline Engine API, but send precomputed embeddings#
[11]:
if not is_in_ci():
# Compute the image embeddings using Huggingface.
from transformers import AutoProcessor
from transformers import Llama4ForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_path, torch_dtype="auto"
).eval()
vision = model.vision_model.cuda()
multi_modal_projector = model.multi_modal_projector.cuda()
[12]:
if not is_in_ci():
processed_prompt = processor(
images=[image], text=conv.get_prompt(), return_tensors="pt"
)
print(f'{processed_prompt["pixel_values"].shape=}')
input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()
image_outputs = vision(
processed_prompt["pixel_values"].to("cuda"), output_hidden_states=False
)
image_features = image_outputs.last_hidden_state
vision_flat = image_features.view(-1, image_features.size(-1))
precomputed_embeddings = multi_modal_projector(vision_flat)
mm_item = dict(modality="IMAGE", precomputed_embeddings=precomputed_embeddings)
out = llm.generate(input_ids=input_ids, image_data=[mm_item])
print(out["text"])