OpenAI APIs - Vision#

SGLang provides OpenAI-compatible APIs to enable a smooth transition from OpenAI services to self-hosted local models. A complete reference for the API is available in the OpenAI API Reference. This tutorial covers the vision APIs for vision language models.

SGLang supports various vision language models such as Llama 3.2, LLaVA-OneVision, Qwen2.5-VL, Gemma3 and more.

As an alternative to the OpenAI API, you can also use the SGLang offline engine.

Launch A Server#

Launch the server in your terminal and wait for it to initialize.

[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

vision_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct
"""
)

wait_for_server(f"http://localhost:{port}")
W0814 09:48:26.120000 1089794 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:48:26.120000 1089794 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
[2025-08-14 09:48:27] server_args=ServerArgs(model_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_path='Qwen/Qwen2.5-VL-7B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=36201, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', mem_fraction_static=0.7835956249999999, max_running_requests=128, max_queued_requests=9223372036854775807, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, device='cuda', tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=811650675, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, served_model_name='Qwen/Qwen2.5-VL-7B-Instruct', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, ep_size=1, moe_a2a_backend=None, enable_flashinfer_cutlass_moe=False, enable_flashinfer_trtllm_moe=False, enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=None, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, torch_compile_max_bs=32, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, enable_return_hidden_states=False, enable_triton_kernel_moe=False, enable_flashinfer_mxfp4_moe=False, scheduler_recv_interval=1, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, debug_tensor_dump_prefill_only=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, num_reserved_decode_tokens=512, pdlb_url=None, custom_weight_loader=[], weight_loader_disable_mmap=False, enable_pdmux=False, sm_group_num=3, enable_ep_moe=False, enable_deepep_moe=False)
[2025-08-14 09:48:28] Ignore import error when loading sglang.srt.multimodal.processors.glm4v: No module named 'transformers.models.glm4v_moe'
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 09:48:29] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 09:48:30] Using default HuggingFace chat template with detected content format: openai
W0814 09:48:34.842000 1090023 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:48:34.842000 1090023 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
W0814 09:48:34.857000 1090024 torch/utils/cpp_extension.py:2425] TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
W0814 09:48:34.857000 1090024 torch/utils/cpp_extension.py:2425] If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'] to specific architectures.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 09:48:36] You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
[2025-08-14 09:48:37] Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-08-14 09:48:37] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-08-14 09:48:38] Init torch distributed ends. mem usage=0.00 GB
[2025-08-14 09:48:39] Ignore import error when loading sglang.srt.models.glm4v_moe: No module named 'transformers.models.glm4v_moe'
[2025-08-14 09:48:39] Load weight begin. avail mem=78.58 GB
[2025-08-14 09:48:39] Multimodal attention backend not set. Use triton_attn.
[2025-08-14 09:48:39] Using triton_attn as multimodal attention backend.
[2025-08-14 09:48:39] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/5 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  20% Completed | 1/5 [00:00<00:02,  1.58it/s]
Loading safetensors checkpoint shards:  40% Completed | 2/5 [00:01<00:01,  1.51it/s]
Loading safetensors checkpoint shards:  60% Completed | 3/5 [00:01<00:01,  1.49it/s]
Loading safetensors checkpoint shards:  80% Completed | 4/5 [00:02<00:00,  1.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  1.90it/s]
Loading safetensors checkpoint shards: 100% Completed | 5/5 [00:02<00:00,  1.70it/s]

[2025-08-14 09:48:42] Load weight end. type=Qwen2_5_VLForConditionalGeneration, dtype=torch.bfloat16, avail mem=62.79 GB, mem usage=15.79 GB.
[2025-08-14 09:48:42] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-08-14 09:48:42] Memory pool end. avail mem=61.46 GB
[2025-08-14 09:48:43] Capture cuda graph begin. This can take up to several minutes. avail mem=60.89 GB
[2025-08-14 09:48:43] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=60.79 GB): 100%|██████████| 3/3 [00:01<00:00,  2.37it/s]
[2025-08-14 09:48:45] Capture cuda graph end. Time elapsed: 2.04 s. mem usage=0.13 GB. avail mem=60.76 GB.
[2025-08-14 09:48:46] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=128000, available_gpu_mem=60.76 GB
[2025-08-14 09:48:47] INFO:     Started server process [1089794]
[2025-08-14 09:48:47] INFO:     Waiting for application startup.
[2025-08-14 09:48:47] INFO:     Application startup complete.
[2025-08-14 09:48:47] INFO:     Uvicorn running on http://127.0.0.1:36201 (Press CTRL+C to quit)
[2025-08-14 09:48:48] INFO:     127.0.0.1:46058 - "GET /v1/models HTTP/1.1" 200 OK
[2025-08-14 09:48:48] INFO:     127.0.0.1:46072 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-08-14 09:48:48] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:48] INFO:     127.0.0.1:46088 - "POST /generate HTTP/1.1" 200 OK
[2025-08-14 09:48:48] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.

Using cURL#

Once the server is up, you can send test requests using curl or requests.

[2]:
import subprocess

curl_command = f"""
curl -s http://localhost:{port}/v1/chat/completions \\
  -H "Content-Type: application/json" \\
  -d '{{
    "model": "Qwen/Qwen2.5-VL-7B-Instruct",
    "messages": [
      {{
        "role": "user",
        "content": [
          {{
            "type": "text",
            "text": "What’s in this image?"
          }},
          {{
            "type": "image_url",
            "image_url": {{
              "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
            }}
          }}
        ]
      }}
    ],
    "max_tokens": 300
  }}'
"""

response = subprocess.check_output(curl_command, shell=True).decode()
print_highlight(response)


response = subprocess.check_output(curl_command, shell=True).decode()
print_highlight(response)
[2025-08-14 09:48:54] Prefill batch. #new-seq: 1, #new-token: 307, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:55] Decode batch. #running-req: 1, #token: 340, token usage: 0.02, cuda graph: True, gen throughput (token/s): 4.79, #queue-req: 0,
[2025-08-14 09:48:55] Decode batch. #running-req: 1, #token: 380, token usage: 0.02, cuda graph: True, gen throughput (token/s): 128.29, #queue-req: 0,
[2025-08-14 09:48:55] INFO:     127.0.0.1:48142 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"7f48d41b4c2e43e4bdcfce90ee7451a4","object":"chat.completion","created":1755164935,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image depicts a man leaning outside the back of a yellow taxi, engaged in pressing a shirt on what appears to be an ad hoc setup involving an ironing board and an iron. The taxi is parked on a busy street with another taxi parked in front of it and buildings that suggest an urban area in the background. The man is carefully pressing a pair of blue jeans, perhaps demonstrating or promoting laundry or ironing services.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":393,"completion_tokens":86,"prompt_tokens_details":null,"reasoning_tokens":0}}
[2025-08-14 09:48:55] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 306, token usage: 0.01, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:56] Decode batch. #running-req: 1, #token: 334, token usage: 0.02, cuda graph: True, gen throughput (token/s): 56.76, #queue-req: 0,
[2025-08-14 09:48:56] Decode batch. #running-req: 1, #token: 374, token usage: 0.02, cuda graph: True, gen throughput (token/s): 131.69, #queue-req: 0,
[2025-08-14 09:48:56] INFO:     127.0.0.1:48152 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"b9e53b761750432b829a3f0ed9e33322","object":"chat.completion","created":1755164936,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image shows a man in a yellow shirt standing on the back of a yellow taxi, balancing a makeshift ironing board (constructed from sturdy bars) on the tailgate. He is holding an iron and appears to be ironing a pair of jeans. The taxi seems to be stationary in an urban setting near storefronts with pink promotional banners, and there are other taxis visible in the background. The scene depicts a humorous or unconventional method of ironing while parked in a busy city area.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":406,"completion_tokens":99,"prompt_tokens_details":null,"reasoning_tokens":0}}

Using Python Requests#

[3]:
import requests

url = f"http://localhost:{port}/v1/chat/completions"

data = {
    "model": "Qwen/Qwen2.5-VL-7B-Instruct",
    "messages": [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "What’s in this image?"},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
                    },
                },
            ],
        }
    ],
    "max_tokens": 300,
}

response = requests.post(url, json=data)
print_highlight(response.text)
[2025-08-14 09:48:56] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 306, token usage: 0.01, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:57] Decode batch. #running-req: 1, #token: 315, token usage: 0.02, cuda graph: True, gen throughput (token/s): 56.19, #queue-req: 0,
[2025-08-14 09:48:57] Decode batch. #running-req: 1, #token: 355, token usage: 0.02, cuda graph: True, gen throughput (token/s): 130.31, #queue-req: 0,
[2025-08-14 09:48:57] Decode batch. #running-req: 1, #token: 395, token usage: 0.02, cuda graph: True, gen throughput (token/s): 132.31, #queue-req: 0,
[2025-08-14 09:48:57] INFO:     127.0.0.1:48154 - "POST /v1/chat/completions HTTP/1.1" 200 OK
{"id":"346b113d77274f399f5691a2c7a1ac6f","object":"chat.completion","created":1755164937,"model":"Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The image depicts a man standing on or leaning on the rear of a yellow SUV parked on a busy street, which seems to be a taxi or rental vehicle both branded and named \"Luxury SUV Rental.\" He is dressed casually in a long-sleeve shirt and appears to be in the act of ironing garments draped over what may resemble an improvised ironing board or rack supported by stands set on the car's rear bumper. The street is lined with urban buildings, parked cars, and yellow taxis, suggesting that this is taking place in a city setting.","reasoning_content":null,"tool_calls":null},"logprobs":null,"finish_reason":"stop","matched_stop":151645}],"usage":{"prompt_tokens":307,"total_tokens":420,"completion_tokens":113,"prompt_tokens_details":null,"reasoning_tokens":0}}

Using OpenAI Python Client#

[4]:
from openai import OpenAI

client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="Qwen/Qwen2.5-VL-7B-Instruct",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What is in this image?",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
                    },
                },
            ],
        }
    ],
    max_tokens=300,
)

print_highlight(response.choices[0].message.content)
[2025-08-14 09:48:58] Prefill batch. #new-seq: 1, #new-token: 292, #cached-token: 15, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:58] Decode batch. #running-req: 1, #token: 322, token usage: 0.02, cuda graph: True, gen throughput (token/s): 56.57, #queue-req: 0,
[2025-08-14 09:48:58] Decode batch. #running-req: 1, #token: 362, token usage: 0.02, cuda graph: True, gen throughput (token/s): 132.42, #queue-req: 0,
[2025-08-14 09:48:58] INFO:     127.0.0.1:48166 - "POST /v1/chat/completions HTTP/1.1" 200 OK
This image depicts a man standing behind the open trunk of a yellow taxi on a city street, using an iron to press articles of clothing. The setting seems to be a busy urban area, with other taxis parked nearby and pedestrians visible in the background. It appears to be a humorous take on the idea of laundry services for city commuters.

Multiple-Image Inputs#

The server also supports multiple images and interleaved text and images if the model supports it.

[5]:
from openai import OpenAI

client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="Qwen/Qwen2.5-VL-7B-Instruct",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true",
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
                    },
                },
                {
                    "type": "text",
                    "text": "I have two very different images. They are not related at all. "
                    "Please describe the first image in one sentence, and then describe the second image in another sentence.",
                },
            ],
        }
    ],
    temperature=0,
)

print_highlight(response.choices[0].message.content)
[2025-08-14 09:48:59] Prefill batch. #new-seq: 1, #new-token: 2532, #cached-token: 14, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-08-14 09:48:59] Decode batch. #running-req: 1, #token: 2573, token usage: 0.13, cuda graph: True, gen throughput (token/s): 35.06, #queue-req: 0,
[2025-08-14 09:48:59] INFO:     127.0.0.1:57582 - "POST /v1/chat/completions HTTP/1.1" 200 OK
The first image shows a man ironing clothes on the back of a taxi in a busy urban street. The second image is a stylized logo featuring the letters "SGL" in orange with a book and a computer icon incorporated into the design.
[6]:
terminate_process(vision_process)