SGLang Native APIs#

Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce the following APIs:

  • /generate (text generation model)

  • /get_model_info

  • /get_server_info

  • /health

  • /health_generate

  • /flush_cache

  • /update_weights

  • /encode(embedding model)

  • /v1/rerank(cross encoder rerank model)

  • /classify(reward model)

  • /start_expert_distribution_record

  • /stop_expert_distribution_record

  • /dump_expert_distribution_record

  • /tokenize

  • /detokenize

  • A full list of these APIs can be found at http_server.py

We mainly use requests to test these APIs in the following examples. You can also use curl.

Launch A Server#

[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0 --log-level warning"
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:33:25] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:33:25] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:33:25] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:33:30] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:33:30] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:33:30] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:33:32] WARNING server_args.py:1197: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-12 15:33:32] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:33:32] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:33:38] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:33:38] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:33:38] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:33:40] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:33:40] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:33:44] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:33:44] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:33:44] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:33:46] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:33:46] WARNING memory_pool_host.py:36: Current platform not support pin_memory
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.94it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.93it/s]

Capturing batches (bs=1 avail_mem=8.32 GB): 100%|██████████| 3/3 [00:00<00:00, 11.12it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.

Generate (text generation model)#

Generate completions. This is similar to the /v1/completions in OpenAI API. Detailed parameters can be found in the sampling parameters.

[2]:
import requests

url = f"http://localhost:{port}/generate"
data = {"text": "What is the capital of France?"}

response = requests.post(url, json=data)
print_highlight(response.json())
{'text': ' Paris (Paris) is the capital of France (France). \n\nP.S. Paris, French (French) is the primary language in France. Thus, "Paris" and "French" can be included in this question. \n\nWe can also remove the noun "capital" if the answer lies in "French", that answer will be automatically selected. Actually, it didn\'t cross my mind at first because there\'s a lot of nouns in this question. Therefore, it helps make a right choice for small group like this. \n\nSince there\'s only 2 answers, Majority wins round in question, so I just chose the most remote answer', 'output_ids': [12095, 320, 59604, 8, 374, 279, 6722, 315, 9625, 320, 49000, 568, 4710, 47, 808, 13, 12095, 11, 8585, 320, 43197, 8, 374, 279, 6028, 4128, 304, 9625, 13, 14301, 11, 330, 59604, 1, 323, 330, 43197, 1, 646, 387, 5230, 304, 419, 3405, 13, 4710, 1654, 646, 1083, 4057, 279, 36921, 330, 65063, 1, 421, 279, 4226, 15448, 304, 330, 43197, 497, 429, 4226, 686, 387, 9463, 4091, 13, 33763, 11, 432, 3207, 944, 5312, 847, 3971, 518, 1156, 1576, 1052, 594, 264, 2696, 315, 89838, 304, 419, 3405, 13, 15277, 11, 432, 8609, 1281, 264, 1290, 5754, 369, 2613, 1874, 1075, 419, 13, 4710, 12549, 1052, 594, 1172, 220, 17, 11253, 11, 54035, 14816, 4778, 304, 3405, 11, 773, 358, 1101, 14554, 279, 1429, 8699, 4226], 'meta_info': {'id': 'c04f84dcb6094d8e8c22f205e70bea91', 'finish_reason': {'type': 'length', 'length': 128}, 'prompt_tokens': 7, 'weight_version': 'default', 'total_retractions': 0, 'completion_tokens': 128, 'cached_tokens': 0, 'e2e_latency': 0.29218316078186035, 'response_sent_to_client_ts': 1762961650.9033}}

Get Model Info#

Get the information of the model.

  • model_path: The path/name of the model.

  • is_generation: Whether the model is used as generation model or embedding model.

  • tokenizer_path: The path/name of the tokenizer.

  • preferred_sampling_params: The default sampling params specified via --preferred-sampling-params. None is returned in this example as we did not explicitly configure it in server args.

  • weight_version: This field contains the version of the model weights. This is often used to track changes or updates to the model’s trained parameters.

  • has_image_understanding: Whether the model has image-understanding capability.

  • has_audio_understanding: Whether the model has audio-understanding capability.

[3]:
url = f"http://localhost:{port}/get_model_info"

response = requests.get(url)
response_json = response.json()
print_highlight(response_json)
assert response_json["model_path"] == "qwen/qwen2.5-0.5b-instruct"
assert response_json["is_generation"] is True
assert response_json["tokenizer_path"] == "qwen/qwen2.5-0.5b-instruct"
assert response_json["preferred_sampling_params"] is None
assert response_json.keys() == {
    "model_path",
    "is_generation",
    "tokenizer_path",
    "preferred_sampling_params",
    "weight_version",
    "has_image_understanding",
    "has_audio_understanding",
}
{'model_path': 'qwen/qwen2.5-0.5b-instruct', 'tokenizer_path': 'qwen/qwen2.5-0.5b-instruct', 'is_generation': True, 'preferred_sampling_params': None, 'weight_version': 'default', 'has_image_understanding': False, 'has_audio_understanding': False}

Get Server Info#

Gets the server information including CLI arguments, token limits, and memory pool sizes.

  • Note: get_server_info merges the following deprecated endpoints:

    • get_server_args

    • get_memory_pool_size

    • get_max_total_num_tokens

[4]:
url = f"http://localhost:{port}/get_server_info"

response = requests.get(url)
print_highlight(response.text)
{"model_path":"qwen/qwen2.5-0.5b-instruct","tokenizer_path":"qwen/qwen2.5-0.5b-instruct","tokenizer_mode":"auto","tokenizer_worker_num":1,"skip_tokenizer_init":false,"load_format":"auto","model_loader_extra_config":"{}","trust_remote_code":false,"context_length":null,"is_embedding":false,"enable_multimodal":null,"revision":null,"model_impl":"auto","host":"0.0.0.0","port":39540,"grpc_mode":false,"skip_server_warmup":false,"warmups":null,"nccl_port":null,"checkpoint_engine_wait_weights_before_ready":false,"dtype":"auto","quantization":null,"quantization_param_path":null,"kv_cache_dtype":"auto","enable_fp32_lm_head":false,"modelopt_quant":null,"modelopt_checkpoint_restore_path":null,"modelopt_checkpoint_save_path":null,"modelopt_export_path":null,"quantize_and_serve":false,"mem_fraction_static":0.841,"max_running_requests":128,"max_queued_requests":null,"max_total_tokens":20480,"chunked_prefill_size":8192,"max_prefill_tokens":16384,"schedule_policy":"fcfs","enable_priority_scheduling":false,"abort_on_priority_when_disabled":false,"schedule_low_priority_values_first":false,"priority_scheduling_preemption_threshold":10,"schedule_conservativeness":1.0,"page_size":1,"hybrid_kvcache_ratio":null,"swa_full_tokens_ratio":0.8,"disable_hybrid_swa_memory":false,"radix_eviction_policy":"lru","device":"cuda","tp_size":1,"pp_size":1,"pp_max_micro_batch_size":null,"stream_interval":1,"stream_output":false,"random_seed":697426651,"constrained_json_whitespace_pattern":null,"constrained_json_disable_any_whitespace":false,"watchdog_timeout":300,"dist_timeout":null,"download_dir":null,"base_gpu_id":0,"gpu_id_step":1,"sleep_on_idle":false,"log_level":"warning","log_level_http":null,"log_requests":false,"log_requests_level":2,"crash_dump_folder":null,"show_time_cost":false,"enable_metrics":false,"enable_metrics_for_all_schedulers":false,"tokenizer_metrics_custom_labels_header":"x-custom-labels","tokenizer_metrics_allowed_custom_labels":null,"bucket_time_to_first_token":null,"bucket_inter_token_latency":null,"bucket_e2e_request_latency":null,"collect_tokens_histogram":false,"prompt_tokens_buckets":null,"generation_tokens_buckets":null,"gc_warning_threshold_secs":0.0,"decode_log_interval":40,"enable_request_time_stats_logging":false,"kv_events_config":null,"enable_trace":false,"otlp_traces_endpoint":"localhost:4317","api_key":null,"served_model_name":"qwen/qwen2.5-0.5b-instruct","weight_version":"default","chat_template":null,"completion_template":null,"file_storage_path":"sglang_storage","enable_cache_report":false,"reasoning_parser":null,"tool_call_parser":null,"tool_server":null,"sampling_defaults":"model","dp_size":1,"load_balance_method":"round_robin","load_watch_interval":0.1,"prefill_round_robin_balance":false,"dist_init_addr":null,"nnodes":1,"node_rank":0,"json_model_override_args":"{}","preferred_sampling_params":null,"enable_lora":null,"max_lora_rank":null,"lora_target_modules":null,"lora_paths":null,"max_loaded_loras":null,"max_loras_per_batch":8,"lora_eviction_policy":"lru","lora_backend":"csgmv","max_lora_chunk_size":16,"attention_backend":"fa3","decode_attention_backend":null,"prefill_attention_backend":null,"sampling_backend":"flashinfer","grammar_backend":"xgrammar","mm_attention_backend":null,"nsa_prefill_backend":"flashmla_sparse","nsa_decode_backend":"fa3","speculative_algorithm":null,"speculative_draft_model_path":null,"speculative_draft_model_revision":null,"speculative_draft_load_format":null,"speculative_num_steps":null,"speculative_eagle_topk":null,"speculative_num_draft_tokens":null,"speculative_accept_threshold_single":1.0,"speculative_accept_threshold_acc":1.0,"speculative_token_map":null,"speculative_attention_mode":"prefill","speculative_moe_runner_backend":null,"speculative_ngram_min_match_window_size":1,"speculative_ngram_max_match_window_size":12,"speculative_ngram_min_bfs_breadth":1,"speculative_ngram_max_bfs_breadth":10,"speculative_ngram_match_type":"BFS","speculative_ngram_branch_length":18,"speculative_ngram_capacity":10000000,"ep_size":1,"moe_a2a_backend":"none","moe_runner_backend":"auto","flashinfer_mxfp4_moe_precision":"default","enable_flashinfer_allreduce_fusion":false,"deepep_mode":"auto","ep_num_redundant_experts":0,"ep_dispatch_algorithm":"static","init_expert_location":"trivial","enable_eplb":false,"eplb_algorithm":"auto","eplb_rebalance_num_iterations":1000,"eplb_rebalance_layers_per_chunk":null,"eplb_min_rebalancing_utilization_threshold":1.0,"expert_distribution_recorder_mode":null,"expert_distribution_recorder_buffer_size":1000,"enable_expert_distribution_metrics":false,"deepep_config":null,"moe_dense_tp_size":null,"elastic_ep_backend":null,"mooncake_ib_device":null,"max_mamba_cache_size":null,"mamba_ssm_dtype":"float32","mamba_full_memory_ratio":0.9,"enable_hierarchical_cache":false,"hicache_ratio":2.0,"hicache_size":0,"hicache_write_policy":"write_through","hicache_io_backend":"kernel","hicache_mem_layout":"layer_first","hicache_storage_backend":null,"hicache_storage_prefetch_policy":"best_effort","hicache_storage_backend_extra_config":null,"enable_lmcache":false,"kt_weight_path":null,"kt_method":"AMXINT4","kt_cpuinfer":null,"kt_threadpool_count":2,"kt_num_gpu_experts":null,"kt_max_deferred_experts_per_token":null,"enable_double_sparsity":false,"ds_channel_config_path":null,"ds_heavy_channel_num":32,"ds_heavy_token_num":256,"ds_heavy_channel_type":"qk","ds_sparse_decode_threshold":4096,"cpu_offload_gb":0,"offload_group_size":-1,"offload_num_in_group":1,"offload_prefetch_step":1,"offload_mode":"cpu","multi_item_scoring_delimiter":null,"disable_radix_cache":false,"cuda_graph_max_bs":4,"cuda_graph_bs":[1,2,4],"disable_cuda_graph":false,"disable_cuda_graph_padding":false,"enable_profile_cuda_graph":false,"enable_cudagraph_gc":false,"enable_nccl_nvls":false,"enable_symm_mem":false,"disable_flashinfer_cutlass_moe_fp4_allgather":false,"enable_tokenizer_batch_encode":false,"disable_tokenizer_batch_decode":false,"disable_outlines_disk_cache":false,"disable_custom_all_reduce":false,"enable_mscclpp":false,"enable_torch_symm_mem":false,"disable_overlap_schedule":false,"enable_mixed_chunk":false,"enable_dp_attention":false,"enable_dp_lm_head":false,"enable_two_batch_overlap":false,"enable_single_batch_overlap":false,"tbo_token_distribution_threshold":0.48,"enable_torch_compile":false,"enable_piecewise_cuda_graph":false,"torch_compile_max_bs":32,"piecewise_cuda_graph_max_tokens":4096,"piecewise_cuda_graph_tokens":[4,8,12,16,20,24,28,32,48,64,80,96,112,128,144,160,176,192,208,224,240,256,288,320,352,384,416,448,480,512,640,768,896,1024,1152,1280,1408,1536,1664,1792,1920,2048,2176,2304,2432,2560,2688,2816,2944,3072,3200,3328,3456,3584,3712,3840,3968,4096],"piecewise_cuda_graph_compiler":"eager","torchao_config":"","enable_nan_detection":false,"enable_p2p_check":false,"triton_attention_reduce_in_fp32":false,"triton_attention_num_kv_splits":8,"triton_attention_split_tile_size":null,"num_continuous_decode_steps":1,"delete_ckpt_after_loading":false,"enable_memory_saver":false,"enable_weights_cpu_backup":false,"allow_auto_truncate":false,"enable_custom_logit_processor":false,"flashinfer_mla_disable_ragged":false,"disable_shared_experts_fusion":false,"disable_chunked_prefix_cache":false,"disable_fast_image_processor":false,"keep_mm_feature_on_device":false,"enable_return_hidden_states":false,"scheduler_recv_interval":1,"numa_node":null,"enable_deterministic_inference":false,"rl_on_policy_target":null,"enable_dynamic_batch_tokenizer":false,"dynamic_batch_tokenizer_batch_size":32,"dynamic_batch_tokenizer_batch_timeout":0.002,"debug_tensor_dump_output_folder":null,"debug_tensor_dump_layers":null,"debug_tensor_dump_input_file":null,"debug_tensor_dump_inject":false,"disaggregation_mode":"null","disaggregation_transfer_backend":"mooncake","disaggregation_bootstrap_port":8998,"disaggregation_decode_tp":null,"disaggregation_decode_dp":null,"disaggregation_prefill_pp":1,"disaggregation_ib_device":null,"disaggregation_decode_enable_offload_kvcache":false,"num_reserved_decode_tokens":512,"disaggregation_decode_polling_interval":1,"custom_weight_loader":[],"weight_loader_disable_mmap":false,"remote_instance_weight_loader_seed_instance_ip":null,"remote_instance_weight_loader_seed_instance_service_port":null,"remote_instance_weight_loader_send_weights_group_ports":null,"enable_pdmux":false,"pdmux_config_path":null,"sm_group_num":8,"mm_max_concurrent_calls":32,"mm_per_request_timeout":10.0,"decrypted_config_file":null,"decrypted_draft_config_file":null,"status":"ready","max_total_num_tokens":20480,"max_req_input_len":20474,"internal_states":[{"model_path":"qwen/qwen2.5-0.5b-instruct","tokenizer_path":"qwen/qwen2.5-0.5b-instruct","tokenizer_mode":"auto","tokenizer_worker_num":1,"skip_tokenizer_init":false,"load_format":"auto","model_loader_extra_config":"{}","trust_remote_code":false,"context_length":null,"is_embedding":false,"enable_multimodal":null,"revision":null,"model_impl":"auto","host":"0.0.0.0","port":39540,"grpc_mode":false,"skip_server_warmup":false,"warmups":null,"nccl_port":null,"checkpoint_engine_wait_weights_before_ready":false,"dtype":"auto","quantization":null,"quantization_param_path":null,"kv_cache_dtype":"auto","enable_fp32_lm_head":false,"modelopt_quant":null,"modelopt_checkpoint_restore_path":null,"modelopt_checkpoint_save_path":null,"modelopt_export_path":null,"quantize_and_serve":false,"mem_fraction_static":0.841,"max_running_requests":128,"max_queued_requests":null,"max_total_tokens":20480,"chunked_prefill_size":8192,"max_prefill_tokens":16384,"schedule_policy":"fcfs","enable_priority_scheduling":false,"abort_on_priority_when_disabled":false,"schedule_low_priority_values_first":false,"priority_scheduling_preemption_threshold":10,"schedule_conservativeness":1.0,"page_size":1,"hybrid_kvcache_ratio":null,"swa_full_tokens_ratio":0.8,"disable_hybrid_swa_memory":false,"radix_eviction_policy":"lru","device":"cuda","tp_size":1,"pp_size":1,"pp_max_micro_batch_size":128,"stream_interval":1,"stream_output":false,"random_seed":697426651,"constrained_json_whitespace_pattern":null,"constrained_json_disable_any_whitespace":false,"watchdog_timeout":300,"dist_timeout":null,"download_dir":null,"base_gpu_id":0,"gpu_id_step":1,"sleep_on_idle":false,"log_level":"warning","log_level_http":null,"log_requests":false,"log_requests_level":2,"crash_dump_folder":null,"show_time_cost":false,"enable_metrics":false,"enable_metrics_for_all_schedulers":false,"tokenizer_metrics_custom_labels_header":"x-custom-labels","tokenizer_metrics_allowed_custom_labels":null,"bucket_time_to_first_token":null,"bucket_inter_token_latency":null,"bucket_e2e_request_latency":null,"collect_tokens_histogram":false,"prompt_tokens_buckets":null,"generation_tokens_buckets":null,"gc_warning_threshold_secs":0.0,"decode_log_interval":40,"enable_request_time_stats_logging":false,"kv_events_config":null,"enable_trace":false,"otlp_traces_endpoint":"localhost:4317","api_key":null,"served_model_name":"qwen/qwen2.5-0.5b-instruct","weight_version":"default","chat_template":null,"completion_template":null,"file_storage_path":"sglang_storage","enable_cache_report":false,"reasoning_parser":null,"tool_call_parser":null,"tool_server":null,"sampling_defaults":"model","dp_size":1,"load_balance_method":"round_robin","load_watch_interval":0.1,"prefill_round_robin_balance":false,"dist_init_addr":null,"nnodes":1,"node_rank":0,"json_model_override_args":"{}","preferred_sampling_params":null,"enable_lora":null,"max_lora_rank":null,"lora_target_modules":null,"lora_paths":null,"max_loaded_loras":null,"max_loras_per_batch":8,"lora_eviction_policy":"lru","lora_backend":"csgmv","max_lora_chunk_size":16,"attention_backend":"fa3","decode_attention_backend":"fa3","prefill_attention_backend":"fa3","sampling_backend":"flashinfer","grammar_backend":"xgrammar","mm_attention_backend":null,"nsa_prefill_backend":"flashmla_sparse","nsa_decode_backend":"fa3","speculative_algorithm":null,"speculative_draft_model_path":null,"speculative_draft_model_revision":null,"speculative_draft_load_format":null,"speculative_num_steps":null,"speculative_eagle_topk":null,"speculative_num_draft_tokens":null,"speculative_accept_threshold_single":1.0,"speculative_accept_threshold_acc":1.0,"speculative_token_map":null,"speculative_attention_mode":"prefill","speculative_moe_runner_backend":null,"speculative_ngram_min_match_window_size":1,"speculative_ngram_max_match_window_size":12,"speculative_ngram_min_bfs_breadth":1,"speculative_ngram_max_bfs_breadth":10,"speculative_ngram_match_type":"BFS","speculative_ngram_branch_length":18,"speculative_ngram_capacity":10000000,"ep_size":1,"moe_a2a_backend":"none","moe_runner_backend":"auto","flashinfer_mxfp4_moe_precision":"default","enable_flashinfer_allreduce_fusion":false,"deepep_mode":"auto","ep_num_redundant_experts":0,"ep_dispatch_algorithm":"static","init_expert_location":"trivial","enable_eplb":false,"eplb_algorithm":"auto","eplb_rebalance_num_iterations":1000,"eplb_rebalance_layers_per_chunk":null,"eplb_min_rebalancing_utilization_threshold":1.0,"expert_distribution_recorder_mode":null,"expert_distribution_recorder_buffer_size":1000,"enable_expert_distribution_metrics":false,"deepep_config":null,"moe_dense_tp_size":null,"elastic_ep_backend":null,"mooncake_ib_device":null,"max_mamba_cache_size":null,"mamba_ssm_dtype":"float32","mamba_full_memory_ratio":0.9,"enable_hierarchical_cache":false,"hicache_ratio":2.0,"hicache_size":0,"hicache_write_policy":"write_through","hicache_io_backend":"kernel","hicache_mem_layout":"layer_first","hicache_storage_backend":null,"hicache_storage_prefetch_policy":"best_effort","hicache_storage_backend_extra_config":null,"enable_lmcache":false,"kt_weight_path":null,"kt_method":"AMXINT4","kt_cpuinfer":null,"kt_threadpool_count":2,"kt_num_gpu_experts":null,"kt_max_deferred_experts_per_token":null,"enable_double_sparsity":false,"ds_channel_config_path":null,"ds_heavy_channel_num":32,"ds_heavy_token_num":256,"ds_heavy_channel_type":"qk","ds_sparse_decode_threshold":4096,"cpu_offload_gb":0,"offload_group_size":-1,"offload_num_in_group":1,"offload_prefetch_step":1,"offload_mode":"cpu","multi_item_scoring_delimiter":null,"disable_radix_cache":false,"cuda_graph_max_bs":4,"cuda_graph_bs":[1,2,4],"disable_cuda_graph":false,"disable_cuda_graph_padding":false,"enable_profile_cuda_graph":false,"enable_cudagraph_gc":false,"enable_nccl_nvls":false,"enable_symm_mem":false,"disable_flashinfer_cutlass_moe_fp4_allgather":false,"enable_tokenizer_batch_encode":false,"disable_tokenizer_batch_decode":false,"disable_outlines_disk_cache":false,"disable_custom_all_reduce":false,"enable_mscclpp":false,"enable_torch_symm_mem":false,"disable_overlap_schedule":false,"enable_mixed_chunk":false,"enable_dp_attention":false,"enable_dp_lm_head":false,"enable_two_batch_overlap":false,"enable_single_batch_overlap":false,"tbo_token_distribution_threshold":0.48,"enable_torch_compile":false,"enable_piecewise_cuda_graph":false,"torch_compile_max_bs":32,"piecewise_cuda_graph_max_tokens":4096,"piecewise_cuda_graph_tokens":[4,8,12,16,20,24,28,32,48,64,80,96,112,128,144,160,176,192,208,224,240,256,288,320,352,384,416,448,480,512,640,768,896,1024,1152,1280,1408,1536,1664,1792,1920,2048,2176,2304,2432,2560,2688,2816,2944,3072,3200,3328,3456,3584,3712,3840,3968,4096],"piecewise_cuda_graph_compiler":"eager","torchao_config":"","enable_nan_detection":false,"enable_p2p_check":false,"triton_attention_reduce_in_fp32":false,"triton_attention_num_kv_splits":8,"triton_attention_split_tile_size":null,"num_continuous_decode_steps":1,"delete_ckpt_after_loading":false,"enable_memory_saver":false,"enable_weights_cpu_backup":false,"allow_auto_truncate":false,"enable_custom_logit_processor":false,"flashinfer_mla_disable_ragged":false,"disable_shared_experts_fusion":false,"disable_chunked_prefix_cache":true,"disable_fast_image_processor":false,"keep_mm_feature_on_device":false,"enable_return_hidden_states":false,"scheduler_recv_interval":1,"numa_node":null,"enable_deterministic_inference":false,"rl_on_policy_target":null,"enable_dynamic_batch_tokenizer":false,"dynamic_batch_tokenizer_batch_size":32,"dynamic_batch_tokenizer_batch_timeout":0.002,"debug_tensor_dump_output_folder":null,"debug_tensor_dump_layers":null,"debug_tensor_dump_input_file":null,"debug_tensor_dump_inject":false,"disaggregation_mode":"null","disaggregation_transfer_backend":"mooncake","disaggregation_bootstrap_port":8998,"disaggregation_decode_tp":null,"disaggregation_decode_dp":null,"disaggregation_prefill_pp":1,"disaggregation_ib_device":null,"disaggregation_decode_enable_offload_kvcache":false,"num_reserved_decode_tokens":512,"disaggregation_decode_polling_interval":1,"custom_weight_loader":[],"weight_loader_disable_mmap":false,"remote_instance_weight_loader_seed_instance_ip":null,"remote_instance_weight_loader_seed_instance_service_port":null,"remote_instance_weight_loader_send_weights_group_ports":null,"enable_pdmux":false,"pdmux_config_path":null,"sm_group_num":8,"mm_max_concurrent_calls":32,"mm_per_request_timeout":10.0,"decrypted_config_file":null,"decrypted_draft_config_file":null,"use_mla_backend":false,"last_gen_throughput":661.3036578435012,"memory_usage":{"weight":1.06,"kvcache":0.23,"token_capacity":20480,"graph":0.07}}],"version":"0.5.5.post2"}

Health Check#

  • /health: Check the health of the server.

  • /health_generate: Check the health of the server by generating one token.

[5]:
url = f"http://localhost:{port}/health_generate"

response = requests.get(url)
print_highlight(response.text)
[6]:
url = f"http://localhost:{port}/health"

response = requests.get(url)
print_highlight(response.text)

Flush Cache#

Flush the radix cache. It will be automatically triggered when the model weights are updated by the /update_weights API.

[7]:
url = f"http://localhost:{port}/flush_cache"

response = requests.post(url)
print_highlight(response.text)
Cache flushed.
Please check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)

Update Weights From Disk#

Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.

SGLang support update_weights_from_disk API for continuous evaluation during training (save checkpoint to disk and update weights from disk).

[8]:
# successful update with same architecture and size

url = f"http://localhost:{port}/update_weights_from_disk"
data = {"model_path": "qwen/qwen2.5-0.5b-instruct"}

response = requests.post(url, json=data)
print_highlight(response.text)
assert response.json()["success"] is True
assert response.json()["message"] == "Succeeded to update model weights."
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.21it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.20it/s]

{"success":true,"message":"Succeeded to update model weights.","num_paused_requests":0}
[9]:
# failed update with different parameter size or wrong name

url = f"http://localhost:{port}/update_weights_from_disk"
data = {"model_path": "qwen/qwen2.5-0.5b-instruct-wrong"}

response = requests.post(url, json=data)
response_json = response.json()
print_highlight(response_json)
assert response_json["success"] is False
assert response_json["message"] == (
    "Failed to get weights iterator: "
    "qwen/qwen2.5-0.5b-instruct-wrong"
    " (repository not found)."
)
[2025-11-12 15:34:13] Failed to get weights iterator: qwen/qwen2.5-0.5b-instruct-wrong (repository not found).
{'success': False, 'message': 'Failed to get weights iterator: qwen/qwen2.5-0.5b-instruct-wrong (repository not found).', 'num_paused_requests': 0}
[10]:
terminate_process(server_process)

Encode (embedding model)#

Encode text into embeddings. Note that this API is only available for embedding models and will raise an error for generation models. Therefore, we launch a new server to server an embedding model.

[11]:
embedding_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct \
    --host 0.0.0.0 --is-embedding --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:34:18] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:34:18] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:34:18] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:34:21] INFO model_config.py:884: Downcasting torch.float32 to torch.float16.
[2025-11-12 15:34:21] WARNING server_args.py:1197: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-12 15:34:21] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:34:21] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:34:27] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:34:27] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:34:27] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:34:29] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:34:29] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:34:33] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:34:33] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:34:33] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:34:35] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:34:35] WARNING memory_pool_host.py:36: Current platform not support pin_memory
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.95s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.30s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.40s/it]



NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[12]:
# successful encode for embedding model

url = f"http://localhost:{port}/encode"
data = {"model": "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "text": "Once upon a time"}

response = requests.post(url, json=data)
response_json = response.json()
print_highlight(f"Text embedding (first 10): {response_json['embedding'][:10]}")
Text embedding (first 10): [-0.00023102760314941406, -0.04986572265625, -0.0032711029052734375, 0.011077880859375, -0.0140533447265625, 0.0159912109375, -0.01441192626953125, 0.0059051513671875, -0.0228424072265625, 0.0272979736328125]
[13]:
terminate_process(embedding_process)

v1/rerank (cross encoder rerank model)#

Rerank a list of documents given a query using a cross-encoder model. Note that this API is only available for cross encoder model like BAAI/bge-reranker-v2-m3 with attention-backend triton and torch_native.

[14]:
reranker_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model-path BAAI/bge-reranker-v2-m3 \
    --host 0.0.0.0 --disable-radix-cache --chunked-prefill-size -1 --attention-backend triton --is-embedding --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:35:05] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:35:05] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:35:05] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:35:08] INFO model_config.py:884: Downcasting torch.float32 to torch.float16.
[2025-11-12 15:35:08] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:35:08] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:35:14] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:35:14] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:35:14] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:35:16] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:35:16] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:35:20] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:35:20] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:35:20] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:35:22] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:35:22] WARNING memory_pool_host.py:36: Current platform not support pin_memory
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.55it/s]



NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[15]:
# compute rerank scores for query and documents

url = f"http://localhost:{port}/v1/rerank"
data = {
    "model": "BAAI/bge-reranker-v2-m3",
    "query": "what is panda?",
    "documents": [
        "hi",
        "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.",
    ],
}

response = requests.post(url, json=data)
response_json = response.json()
for item in response_json:
    print_highlight(f"Score: {item['score']:.2f} - Document: '{item['document']}'")
Score: 5.26 - Document: 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.'
Score: -8.19 - Document: 'hi'
[16]:
terminate_process(reranker_process)

Classify (reward model)#

SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations.

[17]:
# Note that SGLang now treats embedding models and reward models as the same type of models.
# This will be updated in the future.

reward_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:35:52] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:35:52] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:35:52] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:35:54] WARNING server_args.py:1197: Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-11-12 15:35:54] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:35:54] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:36:00] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:36:00] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:36:00] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:36:02] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:36:02] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:36:05] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:36:05] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:36:05] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:36:07] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:36:07] WARNING memory_pool_host.py:36: Current platform not support pin_memory
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:02,  1.28it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.21it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.21it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.81it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.54it/s]



NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[18]:
from transformers import AutoTokenizer

PROMPT = (
    "What is the range of the numeric output of a sigmoid node in a neural network?"
)

RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."

CONVS = [
    [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
    [{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]

tokenizer = AutoTokenizer.from_pretrained("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2")
prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)

url = f"http://localhost:{port}/classify"
data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": prompts}

responses = requests.post(url, json=data).json()
for response in responses:
    print_highlight(f"reward: {response['embedding'][0]}")
reward: -24.125
reward: 1.171875
[19]:
terminate_process(reward_process)

Capture expert selection distribution in MoE models#

SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization.

Note: We only print out the first 10 lines of the csv below for better readability. Please adjust accordingly if you want to analyze the results more deeply.

[20]:
expert_record_server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat --log-level warning"
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:37:35] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:37:35] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:37:35] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:37:37] WARNING server_args.py:1197: Attention backend not explicitly specified. Use flashinfer backend by default.
[2025-11-12 15:37:37] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:37:37] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:37:42] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:37:42] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:37:42] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:37:44] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:37:44] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:37:48] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:37:48] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:37:48] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:37:50] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:37:50] WARNING memory_pool_host.py:36: Current platform not support pin_memory
Loading safetensors checkpoint shards:   0% Completed | 0/8 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  12% Completed | 1/8 [00:00<00:05,  1.22it/s]
Loading safetensors checkpoint shards:  25% Completed | 2/8 [00:01<00:05,  1.14it/s]
Loading safetensors checkpoint shards:  38% Completed | 3/8 [00:02<00:04,  1.14it/s]
Loading safetensors checkpoint shards:  50% Completed | 4/8 [00:03<00:03,  1.16it/s]
Loading safetensors checkpoint shards:  62% Completed | 5/8 [00:04<00:02,  1.17it/s]
Loading safetensors checkpoint shards:  75% Completed | 6/8 [00:05<00:01,  1.15it/s]
Loading safetensors checkpoint shards:  88% Completed | 7/8 [00:06<00:00,  1.18it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:06<00:00,  1.56it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:06<00:00,  1.29it/s]

Capturing batches (bs=4 avail_mem=28.21 GB):   0%|          | 0/3 [00:00<?, ?it/s][2025-11-12 15:38:19] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /public_sglang_ci/runner-l1c-gpu-45/_work/sglang/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=60,N=1408,device_name=NVIDIA_H100_80GB_HBM3.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
[2025-11-12 15:38:19] Using default MoE kernel config. Performance might be sub-optimal! Config file not found at /public_sglang_ci/runner-l1c-gpu-45/_work/sglang/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=60,N=1408,device_name=NVIDIA_H100_80GB_HBM3_down.json, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
Capturing batches (bs=1 avail_mem=47.15 GB): 100%|██████████| 3/3 [00:10<00:00,  3.36s/it]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[21]:
response = requests.post(f"http://localhost:{port}/start_expert_distribution_record")
print_highlight(response)

url = f"http://localhost:{port}/generate"
data = {"text": "What is the capital of France?"}

response = requests.post(url, json=data)
print_highlight(response.json())

response = requests.post(f"http://localhost:{port}/stop_expert_distribution_record")
print_highlight(response)

response = requests.post(f"http://localhost:{port}/dump_expert_distribution_record")
print_highlight(response)
{'text': ' Paris.', 'output_ids': [12095, 13, 151643], 'meta_info': {'id': '8fe735065b6a456a8f0e1231aabd250d', 'finish_reason': {'type': 'stop', 'matched': 151643}, 'prompt_tokens': 7, 'weight_version': 'default', 'total_retractions': 0, 'completion_tokens': 3, 'cached_tokens': 0, 'e2e_latency': 0.38155221939086914, 'response_sent_to_client_ts': 1762961956.2824721}}
[22]:
terminate_process(expert_record_server_process)

Tokenize/Detokenize Example (Round Trip)#

This example demonstrates how to use the /tokenize and /detokenize endpoints together. We first tokenize a string, then detokenize the resulting IDs to reconstruct the original text. This workflow is useful when you need to handle tokenization externally but still leverage the server for detokenization.

[23]:
tokenizer_free_server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct
"""
)

wait_for_server(f"http://localhost:{port}")
[2025-11-12 15:39:22] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:39:22] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:39:22] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:39:25] WARNING server_args.py:1197: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-12 15:39:25] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:39:25] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:39:25] server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=35750, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.841, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=368876061, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method='AMXINT4', kt_cpuinfer=None, kt_threadpool_count=2, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, decrypted_config_file=None, decrypted_draft_config_file=None)
[2025-11-12 15:39:25] Using default HuggingFace chat template with detected content format: string
[2025-11-12 15:39:31] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:39:31] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:39:31] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:39:32] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-12 15:39:32] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-12 15:39:32] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-12 15:39:33] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:39:33] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:39:34] INFO trace.py:69: opentelemetry package is not installed, tracing disabled
[2025-11-12 15:39:34] WARNING memory_pool_host.py:36: Current platform not support pin_memory
[2025-11-12 15:39:34] Init torch distributed begin.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-12 15:39:35] Init torch distributed ends. mem usage=0.00 GB
[2025-11-12 15:39:35] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-11-12 15:39:55] Load weight begin. avail mem=78.07 GB
[2025-11-12 15:39:55] Found local HF snapshot for qwen/qwen2.5-0.5b-instruct at /hf_home/hub/models--qwen--qwen2.5-0.5b-instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775; skipping download.
[2025-11-12 15:39:55] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.56it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  4.56it/s]

[2025-11-12 15:39:56] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=77.01 GB, mem usage=1.06 GB.
[2025-11-12 15:39:56] Using KV cache dtype: torch.bfloat16
[2025-11-12 15:39:56] KV Cache is allocated. #tokens: 20480, K size: 0.12 GB, V size: 0.12 GB
[2025-11-12 15:39:56] Memory pool end. avail mem=76.61 GB
[2025-11-12 15:39:56] Capture cuda graph begin. This can take up to several minutes. avail mem=76.51 GB
[2025-11-12 15:39:56] Capture cuda graph bs [1, 2, 4]
Capturing batches (bs=1 avail_mem=76.45 GB): 100%|██████████| 3/3 [00:00<00:00, 11.66it/s]
[2025-11-12 15:39:57] Capture cuda graph end. Time elapsed: 0.84 s. mem usage=0.07 GB. avail mem=76.44 GB.
[2025-11-12 15:39:57] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=128, context_len=32768, available_gpu_mem=76.44 GB
[2025-11-12 15:39:58] INFO:     Started server process [685732]
[2025-11-12 15:39:58] INFO:     Waiting for application startup.
[2025-11-12 15:39:58] Using default chat sampling params from model generation config: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
[2025-11-12 15:39:58] Using default chat sampling params from model generation config: {'repetition_penalty': 1.1, 'temperature': 0.7, 'top_k': 20, 'top_p': 0.8}
[2025-11-12 15:39:58] INFO:     Application startup complete.
[2025-11-12 15:39:58] INFO:     Uvicorn running on http://127.0.0.1:35750 (Press CTRL+C to quit)
[2025-11-12 15:39:58] INFO:     127.0.0.1:50138 - "GET /v1/models HTTP/1.1" 200 OK
[2025-11-12 15:39:59] INFO:     127.0.0.1:50148 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-11-12 15:39:59] Prefill batch, #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-11-12 15:39:59] INFO:     127.0.0.1:50156 - "POST /generate HTTP/1.1" 200 OK
[2025-11-12 15:39:59] The server is fired up and ready to roll!


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[24]:
import requests
from sglang.utils import print_highlight

base_url = f"http://localhost:{port}"
tokenize_url = f"{base_url}/tokenize"
detokenize_url = f"{base_url}/detokenize"

model_name = "qwen/qwen2.5-0.5b-instruct"
input_text = "SGLang provides efficient tokenization endpoints."
print_highlight(f"Original Input Text:\n'{input_text}'")

# --- tokenize the input text ---
tokenize_payload = {
    "model": model_name,
    "prompt": input_text,
    "add_special_tokens": False,
}
try:
    tokenize_response = requests.post(tokenize_url, json=tokenize_payload)
    tokenize_response.raise_for_status()
    tokenization_result = tokenize_response.json()
    token_ids = tokenization_result.get("tokens")

    if not token_ids:
        raise ValueError("Tokenization returned empty tokens.")

    print_highlight(f"\nTokenized Output (IDs):\n{token_ids}")
    print_highlight(f"Token Count: {tokenization_result.get('count')}")
    print_highlight(f"Max Model Length: {tokenization_result.get('max_model_len')}")

    # --- detokenize the obtained token IDs ---
    detokenize_payload = {
        "model": model_name,
        "tokens": token_ids,
        "skip_special_tokens": True,
    }

    detokenize_response = requests.post(detokenize_url, json=detokenize_payload)
    detokenize_response.raise_for_status()
    detokenization_result = detokenize_response.json()
    reconstructed_text = detokenization_result.get("text")

    print_highlight(f"\nDetokenized Output (Text):\n'{reconstructed_text}'")

    if input_text == reconstructed_text:
        print_highlight(
            "\nRound Trip Successful: Original and reconstructed text match."
        )
    else:
        print_highlight(
            "\nRound Trip Mismatch: Original and reconstructed text differ."
        )

except requests.exceptions.RequestException as e:
    print_highlight(f"\nHTTP Request Error: {e}")
except Exception as e:
    print_highlight(f"\nAn error occurred: {e}")
Original Input Text:
'SGLang provides efficient tokenization endpoints.'
[2025-11-12 15:40:03] INFO:     127.0.0.1:58666 - "POST /tokenize HTTP/1.1" 200 OK

Tokenized Output (IDs):
[50, 3825, 524, 5707, 11050, 3950, 2022, 36342, 13]
Token Count: 9
Max Model Length: 131072
[2025-11-12 15:40:03] INFO:     127.0.0.1:58668 - "POST /detokenize HTTP/1.1" 200 OK

Detokenized Output (Text):
'SGLang provides efficient tokenization endpoints.'

Round Trip Successful: Original and reconstructed text match.
[25]:
terminate_process(tokenizer_free_server_process)