OpenAI APIs - Vision#
SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models. A complete reference for the API is available in the OpenAI API Reference. This tutorial covers the vision APIs for vision language models.
SGLang supports various vision language models such as Llama 3.2, LLaVA-OneVision, Qwen2.5-VL, Gemma3 and more.
As an alternative to the OpenAI API, you can also use the SGLang offline engine.
Launch A Server#
Launch the server in your terminal and wait for it to initialize.
[1]:
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
vision_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-05-30 02:28:28] server_args=ServerArgs(model_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen2.5-VL-7B-Instruct', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, host='127.0.0.1', port=30547, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=886239110, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, n_share_experts_fusion=0, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-05-30 02:28:30] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-05-30 02:28:31] Infer the chat template name from the model path and obtain the result: qwen2-vl.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-05-30 02:28:40] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-05-30 02:28:41] Attention backend not set. Use flashinfer backend by default.
[2025-05-30 02:28:41] Automatically reduce --mem-fraction-static to 0.792 because this is a multimodal model.
[2025-05-30 02:28:41] Init torch distributed begin.
[2025-05-30 02:28:41] Init torch distributed ends. mem usage=0.00 GB
[2025-05-30 02:28:41] init_expert_location from trivial
[2025-05-30 02:28:41] Load weight begin. avail mem=46.58 GB
[2025-05-30 02:28:41] Multimodal attention backend not set. Use sdpa.
[2025-05-30 02:28:42] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 20% Completed | 1/5 [00:00<00:02, 1.61it/s]
Loading safetensors checkpoint shards: 40% Completed | 2/5 [00:01<00:01, 1.54it/s]
Loading safetensors checkpoint shards: 60% Completed | 3/5 [00:01<00:01, 1.49it/s]
Loading safetensors checkpoint shards: 80% Completed | 4/5 [00:02<00:00, 1.49it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00, 1.92it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00, 1.72it/s]
[2025-05-30 02:28:45] Load weight end. type=Qwen2_5_VLForConditionalGeneration, dtype=torch.bfloat16, avail mem=30.88 GB, mem usage=15.70 GB.
[2025-05-30 02:28:45] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-05-30 02:28:45] Memory pool end. avail mem=29.51 GB
2025-05-30 02:28:45,303 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
[2025-05-30 02:28:47] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=128000
[2025-05-30 02:28:47] INFO: Started server process [2286298]
[2025-05-30 02:28:47] INFO: Waiting for application startup.
[2025-05-30 02:28:47] INFO: Application startup complete.
[2025-05-30 02:28:47] INFO: Uvicorn running on http://127.0.0.1:30547 (Press CTRL+C to quit)
[2025-05-30 02:28:47] INFO: 127.0.0.1:41618 - "GET /v1/models HTTP/1.1" 200 OK
[2025-05-30 02:28:48] INFO: 127.0.0.1:41634 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-05-30 02:28:48] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
2025-05-30 02:28:50,004 - INFO - flashinfer.jit: Loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
2025-05-30 02:28:50,017 - INFO - flashinfer.jit: Finished loading JIT ops: batch_prefill_with_kv_cache_dtype_q_bf16_dtype_kv_bf16_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False_f16qk_False
[2025-05-30 02:28:51] INFO: 127.0.0.1:41650 - "POST /generate HTTP/1.1" 200 OK
[2025-05-30 02:28:51] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Using cURL#
Once the server is up, you can send test requests using curl or requests.
[2]:
import subprocess
curl_command = f"""
curl -s http://localhost:{port}/v1/chat/completions \\
-d '{{
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"messages": [
{{
"role": "user",
"content": [
{{
"type": "text",
"text": "What’s in this image?"
}},
{{
"type": "image_url",
"image_url": {{
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
}}
}}
]
}}
],
"max_tokens": 300
}}'
"""
response = subprocess.check_output(curl_command, shell=True).decode()
print_highlight(response)
response = subprocess.check_output(curl_command, shell=True).decode()
print_highlight(response)
[2025-05-30 02:28:53] Prefill batch. #new-seq: 1, #new-token: 307, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:28:54] Decode batch. #running-req: 1, #token: 340, token usage: 0.02, cuda graph: False, gen throughput (token/s): 5.38, #queue-req: 0
[2025-05-30 02:28:55] Decode batch. #running-req: 1, #token: 380, token usage: 0.02, cuda graph: False, gen throughput (token/s): 58.04, #queue-req: 0
[2025-05-30 02:28:55] INFO: 127.0.0.1:41654 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"19e24b47f7484cd193854c2cbe9ebe12","object":"chat.completion","created":1748572132,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image shows a man doing laundry outdoors, ironing clothes using what appears to be a makeshift ironing board setup attached to the back of a taxi. He is standing behind the taxi and holding an iron while wearing a bright yellow shirt. There are other taxis parked nearby on a busy street with buildings, storefronts, and some American flags visible in the background.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":381,"completion_tokens":74,"prompt_tokens_details":null}}
[2025-05-30 02:28:55] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 306, token usage: 0.01, #running-req: 0, #queue-req: 0
[2025-05-30 02:28:56] Decode batch. #running-req: 1, #token: 346, token usage: 0.02, cuda graph: False, gen throughput (token/s): 37.22, #queue-req: 0
[2025-05-30 02:28:57] INFO: 127.0.0.1:34864 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"8d20d409b918446bbefea911399812c4","object":"chat.completion","created":1748572135,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image shows a man ironing clothes using makeshift equipment attached to the rear of a New York City taxicab. The man is wearing a yellow shirt, and he appears to be improvising or performing in a humorous or quirky way, as this setup is not typically seen in everyday life. The background shows a busy street with typical New York City taxis and office buildings.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":383,"completion_tokens":76,"prompt_tokens_details":null}}
Using Python Requests#
[3]:
import requests
url = f"http://localhost:{port}/v1/chat/completions"
data = {
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(url, json=data)
print_highlight(response.text)
[2025-05-30 02:28:57] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 306, token usage: 0.01, #running-req: 0, #queue-req: 0
[2025-05-30 02:28:57] Decode batch. #running-req: 1, #token: 310, token usage: 0.02, cuda graph: False, gen throughput (token/s): 34.20, #queue-req: 0
[2025-05-30 02:28:58] Decode batch. #running-req: 1, #token: 350, token usage: 0.02, cuda graph: False, gen throughput (token/s): 58.67, #queue-req: 0
[2025-05-30 02:28:59] Decode batch. #running-req: 1, #token: 390, token usage: 0.02, cuda graph: False, gen throughput (token/s): 58.30, #queue-req: 0
[2025-05-30 02:28:59] INFO: 127.0.0.1:34880 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"38054ccf140d4ab5bfba65f0ecf42856","object":"chat.completion","created":1748572137,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image depicts a person wearing a yellow shirt who is seemingly an employee at a car wash, performing징 in the back of a yellow taxi. He is using an ironing board-like device to wash the interior of the car, possibly cloth curtains, perhaps saved in the taxi to maintain cleanliness for taxi drivers or passengers. The setting appears to be an urban city street with visible cab signage and another taxi in the background.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":392,"completion_tokens":85,"prompt_tokens_details":null}}
Using OpenAI Python Client#
[4]:
from openai import OpenAI
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="Qwen/Qwen2.5-VL-7B-Instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is in this image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
],
}
],
max_tokens=300,
)
print_highlight(response.choices[0].message.content)
[2025-05-30 02:28:59] Prefill batch. #new-seq: 1, #new-token: 292, #cached-token: 15, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:29:00] Decode batch. #running-req: 1, #token: 345, token usage: 0.02, cuda graph: False, gen throughput (token/s): 33.76, #queue-req: 0
[2025-05-30 02:29:00] Decode batch. #running-req: 1, #token: 385, token usage: 0.02, cuda graph: False, gen throughput (token/s): 58.63, #queue-req: 0
[2025-05-30 02:29:00] INFO: 127.0.0.1:34890 - "POST /v1/chat/completions HTTP/1.1" 200 OK
The image depicts a man standing on the back of a yellow taxi, actively engaged in ironing a shirt. The shirt is laid flat on an improvised ironing board, with the man handling the iron and pressing it onto the fabric. The taxi is parked near a street curb amidst what seems to be an urban street scene. The background includes signage, buildings, and another taxi, highlighting a busy city environment.
Multiple-Image Inputs#
The server also supports multiple images and interleaved text and images if the model supports it.
[5]:
from openai import OpenAI
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="None")
response = client.chat.completions.create(
model="Qwen/Qwen2.5-VL-7B-Instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true",
},
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
},
},
{
"type": "text",
"text": "I have two very different images. They are not related at all. "
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
},
],
}
],
temperature=0,
)
print_highlight(response.choices[0].message.content)
[2025-05-30 02:29:01] Prefill batch. #new-seq: 1, #new-token: 2532, #cached-token: 14, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-05-30 02:29:03] Decode batch. #running-req: 1, #token: 2581, token usage: 0.13, cuda graph: False, gen throughput (token/s): 16.81, #queue-req: 0
[2025-05-30 02:29:03] INFO: 127.0.0.1:34898 - "POST /v1/chat/completions HTTP/1.1" 200 OK
The first image shows a man ironing clothes on the back of a taxi in a busy urban street. The second image is a stylized logo featuring the letters "SGL" with a book and a computer icon incorporated into the design.
[6]:
terminate_process(vision_process)
[2025-05-30 02:29:03] Child process unexpectedly failed with an exit code 9. pid=2286812