Offline Engine API#
SGLang provides a direct inference engine without the need for an HTTP server, especially for use cases where additional HTTP server adds unnecessary complexity or overhead. Here are two general use cases:
Offline Batch Inference
Custom Server on Top of the Engine
This document focuses on the offline batch inference, demonstrating four different inference modes:
Non-streaming synchronous generation
Streaming synchronous generation
Non-streaming asynchronous generation
Streaming asynchronous generation
Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in custom_server.
Nest Asyncio#
Note that if you want to use Offline Engine in ipython or some other nested loop code, you need to add the following code:
import nest_asyncio
nest_asyncio.apply()
Advanced Usage#
The engine supports vlm inference as well as extracting hidden states.
Please see the examples for further use cases.
Offline Batch Inference#
SGLang offline engine supports batch inference with efficient scheduling.
[1]:
# launch the offline engine
import asyncio
import sglang as sgl
import sglang.test.doc_patch
from sglang.utils import async_stream_and_merge, stream_and_merge
llm = sgl.Engine(model_path="qwen/qwen2.5-0.5b-instruct")
[2025-11-12 15:34:21] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:34:21] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:34:21] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:34:23] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:34:23] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:34:24] WARNING server_args.py:1197: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-12 15:34:24] INFO engine.py:123: server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.835, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=1071819355, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='error', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method=None, kt_cpuinfer=None, kt_threadpool_count=None, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
[2025-11-12 15:34:30] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:34:30] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:34:30] INFO utils.py:164: NumExpr defaulting to 16 threads.
WARNING:sglang.srt.mem_cache.memory_pool_host:Current platform not support pin_memory
[2025-11-12 15:34:32] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:34:32] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.40it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.39it/s]
Capturing batches (bs=1 avail_mem=71.74 GB): 100%|██████████| 20/20 [00:01<00:00, 10.37it/s]
Non-streaming Synchronous Generation#
[2]:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Hello, my name is
Generated text: Mikaela and I am the owner of The Royal Eiffel Tower. I am passionate about living a sustainable lifestyle and want to share my personal stories and experiences with others who share the same values.
I decided to start my own business because I feel that it is important to be a part of something bigger than myself, to help others, and to create something that has a positive impact. I am always looking for new ideas and opportunities to make a difference and I am excited to start The Royal Eiffel Tower! I hope that my customers can find value in their experience, whether they are looking for a unique gift, a cozy
===============================
Prompt: The president of the United States is
Generated text: running for a second term. He will be replaced by a new president immediately after the inauguration. What is the probability that the president is re-elected given that he is defeated by his opponents in the election? To determine the probability that the president is re-elected given that he is defeated, we need to consider the following:
1. The president has a 50% chance of winning the election.
2. The president is defeated by his opponents.
3. Given that the president is defeated, the probability of him winning the election is the same as the probability of winning in an election with all opponents eliminated.
Since the president is defeated,
===============================
Prompt: The capital of France is
Generated text: :
A. Paris
B. London
C. Rome
D. Madrid
The capital of France is:
A. Paris
Paris is the capital of France. The other options are not capitals of France:
- London is the capital of England.
- Rome is the capital of Italy.
- Madrid is the capital of Spain.
None of these are the capitals of France. The correct answer is Paris.
To verify: Paris is indeed the capital of France, located in the southeast of the country. It's known for its historical landmarks such as the Eiffel Tower and the Louvre Museum. Paris is a popular tourist
===============================
Prompt: The future of AI is
Generated text: here and it is already taking over our lives. Predictive Analytics is a critical area of interest because we are already seeing the emergence of AI with a potentially life changing impact on society. With the recent release of the iPhone 12, Apple is the first major company to release a device that is built around AI. This is not the first time that Apple has embraced AI, but their iPhone 12 is a game changer in the market. If you are a follower of Apple and want to learn about the future of AI and how Apple is making it happen, then this book is a must read.
This book discusses the future of
Streaming Synchronous Generation#
[3]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {
"temperature": 0.2,
"top_p": 0.9,
}
print("\n=== Testing synchronous streaming generation with overlap removal ===\n")
for prompt in prompts:
print(f"Prompt: {prompt}")
merged_output = stream_and_merge(llm, prompt, sampling_params)
print("Generated text:", merged_output)
print()
=== Testing synchronous streaming generation with overlap removal ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: [Name], and I'm a [job title] at [company name]. I'm excited to meet you and learn more about you. What can you tell me about yourself? I'm a [job title] at [company name], and I'm passionate about [job title] and [job title]. I enjoy [job title] because [reason for interest]. What's your favorite hobby or activity? I love [hobby or activity]. What's your favorite book or movie? I love [book or movie]. What's your favorite food? I love [food]. What's your favorite color? I love [color]. What
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris, the city known for its iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum. It is also home to the French Parliament, the French National Museum, and the French Academy of Sciences. Paris is a bustling metropolis with a rich cultural heritage and is a major economic and political center in Europe. The city is known for its fashion, art, and cuisine, and is a popular tourist destination. Paris is also home to the Eiffel Tower, the Louvre Museum, and the Notre-Dame Cathedral. The city is known for its iconic landmarks and is a major economic and
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: likely to be characterized by rapid advancements in areas such as machine learning, natural language processing, and computer vision. Some possible future trends in AI include:
1. Increased use of AI in healthcare: AI is already being used in healthcare to diagnose and treat diseases, and it has the potential to revolutionize the field. AI-powered diagnostic tools, such as AI-powered X-ray machines, could significantly improve patient outcomes.
2. Increased use of AI in finance: AI is already being used in finance to automate trading, fraud detection, and risk management. As AI technology continues to improve, we can expect to see even more sophisticated applications in finance.
Non-streaming Asynchronous Generation#
[4]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
print("\n=== Testing asynchronous batch generation ===")
async def main():
outputs = await llm.async_generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print(f"\nPrompt: {prompt}")
print(f"Generated text: {output['text']}")
asyncio.run(main())
=== Testing asynchronous batch generation ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: [Character's Name]. I am a [Age] year old [Occupation or Profession] who has always been passionate about [Why is it that you are passionate about [Occupation or Profession]].
I am always learning and growing, and I am always up for new challenges. I am a team player, always looking to contribute to the team and get the best out of everyone. I am an excellent communicator, always able to convey my ideas clearly and effectively. I am a hard worker, always putting in the extra effort to get things done.
And most importantly, I am a friend. I am always there for you, whether
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris.
Paris is the largest city in France and the second-largest city in the European Union after Rome. It is known for its beautiful architecture, rich cultural heritage, and annual celebration of the Etoile de Paris (the Star of Paris). Paris is also a world-renowned music capital, home to iconic venues like the Opéra Garnier and the Théâtre du Châtelet. The city is known for its bustling street food, festivals, and world-class museums, including the Louvre and the Centre Pompidou. Paris is a popular tourist destination, known for its vibrant nightlife and extensive museums and art galleries. As
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: expected to be characterized by rapid development and innovation, as well as ongoing changes to the way that AI is used and deployed. Here are some possible trends that could influence the future of AI:
1. Increased focus on ethical considerations: As AI systems become more advanced, they will increasingly be used for decision-making processes that affect people's lives. Therefore, there is an increasing emphasis on ethical considerations and the potential for unintended consequences. AI developers will need to be mindful of the impact that their algorithms and models may have on society and take steps to mitigate any potential harms.
2. Continued development of deep learning: Deep learning is a key area of
Streaming Asynchronous Generation#
[5]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
print("\n=== Testing asynchronous streaming generation (no repeats) ===")
async def main():
for prompt in prompts:
print(f"\nPrompt: {prompt}")
print("Generated text: ", end="", flush=True)
# Replace direct calls to async_generate with our custom overlap-aware version
async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):
print(cleaned_chunk, end="", flush=True)
print() # New line after each prompt
asyncio.run(main())
=== Testing asynchronous streaming generation (no repeats) ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: Emily, and I'm a friendly, laid-back barista at a local coffee shop. I'm here to serve you all the time and make sure your drink is perfect for you. I love brewing coffee and helping people find their way around the bustling coffee shop scene. I'm a go-to for those who want to start their day with a caffeine fix or a smoothie. I'm here to assist you in finding the perfect cup of coffee and bring you the best experience possible. How can I help you today? I'll take care of you and make sure that you're getting the best experience possible. What do you need help with
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris, an ancient city nestled in the Saintes-Maries-de-la-Soleil mountains on the Mediterranean coast.
Paris is a vibrant metropolis known for its rich history, cultural importance, and stunning architecture. The city's streets are lined with historic monuments, including the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral. It is also home to iconic landmarks such as the Seine River and the Arc de Triomphe. Despite its size, Paris boasts a diverse population and is a major cultural and financial center in Europe. Its status as both a political and economic capital has made it a popular destination for tourists from
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: likely to be characterized by a number of different trends that will shape how the technology is used and developed. Here are some potential areas of development that could be expected in the coming years:
1. Increased efficiency and accuracy: One of the biggest challenges facing AI is its ability to process and analyze large amounts of data quickly and accurately. As we become more data-driven, we may see a growing trend toward more efficient and accurate AI systems, with the goal of making data-driven decisions with greater speed and precision.
2. Deep learning: Deep learning is a type of machine learning that involves building complex neural networks with many layers. As the technology continues
[6]:
llm.shutdown()