Query Vision Language Model#

Querying Qwen-VL#

[1]:
import nest_asyncio

nest_asyncio.apply()  # Run this first.

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
chat_template = "qwen2-vl"
[2]:
# Lets create a prompt.

from io import BytesIO
import requests
from PIL import Image

from sglang.srt.conversation import chat_templates

image = Image.open(
    BytesIO(
        requests.get(
            "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
        ).content
    )
)

conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]

print(conv.get_prompt())
image
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What's shown here: <|vision_start|><|image_pad|><|vision_end|>?<|im_end|>
<|im_start|>assistant

[2]:
../_images/advanced_features_vlm_query_3_1.png

Query via the offline Engine API#

[3]:
from sglang import Engine

llm = Engine(
    model_path=model_path, chat_template=chat_template, mem_fraction_static=0.8
)
W0814 06:20:35.539000 1220226 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:20:35.539000 1220226 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 06:20:38] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
W0814 06:20:44.338000 1220697 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 06:20:44.338000 1220697 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 06:20:45] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.40it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.44it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.43it/s]

Capturing batches (bs=1 avail_mem=10.19 GB): 100%|██████████| 23/23 [00:04<00:00,  4.72it/s]
[4]:
out = llm.generate(prompt=conv.get_prompt(), image_data=[image])
print(out["text"])
The image shows a New York City street scene with a person hanging clothes on a clothesline attached to the back of a New York City taxi. The background includes a storefront with various goods on display and another taxi driving by. Theगर्दिओं are likely doing laundry on the street, which is a common occurrence in New York City.

Query via the offline Engine API, but send precomputed embeddings#

[5]:
# Compute the image embeddings using Huggingface.

from transformers import AutoProcessor
from transformers import Qwen2_5_VLForConditionalGeneration

processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
vision = (
    Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path).eval().visual.cuda()
)
[6]:
processed_prompt = processor(
    images=[image], text=conv.get_prompt(), return_tensors="pt"
)
input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()
precomputed_embeddings = vision(
    processed_prompt["pixel_values"].cuda(), processed_prompt["image_grid_thw"].cuda()
)

mm_item = dict(
    modality="IMAGE",
    image_grid_thw=processed_prompt["image_grid_thw"],
    precomputed_embeddings=precomputed_embeddings,
)
out = llm.generate(input_ids=input_ids, image_data=[mm_item])
print(out["text"])
The image shows a scene in an urban area with a yellow street and a few cars parked along the side. Specifically:

- There are two yellow taxis parked on the street.
- A person is washing clothes using an upright washing appliance mounted on a folding stand.
- The washing appliance is blue and has a cord connected to it, indicating it's an electric washing machine.
- The person is standing alone, seemingly engrossed in the task.
- There areometers or billboards along the street with various advertisements.
- The backdrop includes a multi-story building with windows and国有企业 signage.

This setup appears to be for a protest against car ownership

Querying Llama 4 (Vision)#

[7]:
import nest_asyncio

nest_asyncio.apply()  # Run this first.

model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
chat_template = "llama-4"
[8]:
# Lets create a prompt.

from io import BytesIO
import requests
from PIL import Image

from sglang.srt.conversation import chat_templates

image = Image.open(
    BytesIO(
        requests.get(
            "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
        ).content
    )
)

conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]

print(conv.get_prompt())
print(f"Image size: {image.size}")

image
<|header_start|>user<|header_end|>

What's shown here: <|image|>?<|eot|><|header_start|>assistant<|header_end|>


Image size: (570, 380)
[8]:
../_images/advanced_features_vlm_query_12_1.png

Query via the offline Engine API#

[9]:
from sglang.test.test_utils import is_in_ci

if not is_in_ci():
    from sglang import Engine

    llm = Engine(
        model_path=model_path,
        trust_remote_code=True,
        enable_multimodal=True,
        mem_fraction_static=0.8,
        tp_size=4,
        attention_backend="fa3",
        context_length=65536,
    )
[10]:
if not is_in_ci():
    out = llm.generate(prompt=conv.get_prompt(), image_data=[image])
    print(out["text"])

Query via the offline Engine API, but send precomputed embeddings#

[11]:
if not is_in_ci():
    # Compute the image embeddings using Huggingface.

    from transformers import AutoProcessor
    from transformers import Llama4ForConditionalGeneration

    processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
    model = Llama4ForConditionalGeneration.from_pretrained(
        model_path, torch_dtype="auto"
    ).eval()
    vision = model.vision_model.cuda()
    multi_modal_projector = model.multi_modal_projector.cuda()
[12]:
if not is_in_ci():
    processed_prompt = processor(
        images=[image], text=conv.get_prompt(), return_tensors="pt"
    )
    print(f'{processed_prompt["pixel_values"].shape=}')
    input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()

    image_outputs = vision(
        processed_prompt["pixel_values"].to("cuda"), output_hidden_states=False
    )
    image_features = image_outputs.last_hidden_state
    vision_flat = image_features.view(-1, image_features.size(-1))
    precomputed_embeddings = multi_modal_projector(vision_flat)

    mm_item = dict(modality="IMAGE", precomputed_embeddings=precomputed_embeddings)
    out = llm.generate(input_ids=input_ids, image_data=[mm_item])
    print(out["text"])