SGLang Native APIs#
Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:
/generate
(text generation model)/get_model_info
/get_server_info
/health
/health_generate
/flush_cache
/update_weights
/encode
(embedding model)/classify
(reward model)/start_expert_distribution_record
/stop_expert_distribution_record
/dump_expert_distribution_record
We mainly use requests
to test these APIs in the following examples. You can also use curl
.
Launch A Server#
[1]:
import requests
from sglang.test.test_utils import is_in_ci
if is_in_ci():
from patch import launch_server_cmd
else:
from sglang.utils import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process
server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:25:00] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=34881, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=332652088, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:25:10 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:25:10 TP0] Init torch distributed begin.
[2025-04-13 23:25:11 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:25:11 TP0] Load weight begin. avail mem=59.58 GB
[2025-04-13 23:25:11 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:25:12 TP0] Using model weights format ['*.safetensors']
[2025-04-13 23:25:12 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.40it/s]
[2025-04-13 23:25:13 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=42.11 GB, mem usage=17.48 GB.
[2025-04-13 23:25:13 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.31 GB, V size: 0.31 GB
[2025-04-13 23:25:13 TP0] Memory pool end. avail mem=41.25 GB
[2025-04-13 23:25:14 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:25:14 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-04-13 23:25:15] INFO: Started server process [2784775]
[2025-04-13 23:25:15] INFO: Waiting for application startup.
[2025-04-13 23:25:15] INFO: Application startup complete.
[2025-04-13 23:25:15] INFO: Uvicorn running on http://0.0.0.0:34881 (Press CTRL+C to quit)
[2025-04-13 23:25:15] INFO: 127.0.0.1:49774 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:25:16] INFO: 127.0.0.1:49782 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:25:16 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:25:18] INFO: 127.0.0.1:49792 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:25:18] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
Generate (text generation model)#
Generate completions. This is similar to the /v1/completions
in OpenAI API. Detailed parameters can be found in the sampling parameters.
[2]:
url = f"http://localhost:{port}/generate"
data = {"text": "What is the capital of France?"}
response = requests.post(url, json=data)
print_highlight(response.json())
[2025-04-13 23:25:20 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 1, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:25:20 TP0] Decode batch. #running-req: 1, #token: 41, token usage: 0.00, gen throughput (token/s): 6.66, #queue-req: 0,
[2025-04-13 23:25:20] INFO: 127.0.0.1:55534 - "POST /generate HTTP/1.1" 200 OK
Get Model Info#
Get the information of the model.
model_path
: The path/name of the model.is_generation
: Whether the model is used as generation model or embedding model.tokenizer_path
: The path/name of the tokenizer.
[3]:
url = f"http://localhost:{port}/get_model_info"
response = requests.get(url)
response_json = response.json()
print_highlight(response_json)
assert response_json["model_path"] == "meta-llama/Llama-3.2-1B-Instruct"
assert response_json["is_generation"] is True
assert response_json["tokenizer_path"] == "meta-llama/Llama-3.2-1B-Instruct"
assert response_json.keys() == {"model_path", "is_generation", "tokenizer_path"}
[2025-04-13 23:25:20] INFO: 127.0.0.1:55536 - "GET /get_model_info HTTP/1.1" 200 OK
Get Server Info#
Gets the server information including CLI arguments, token limits, and memory pool sizes.
Note:
get_server_info
merges the following deprecated endpoints:get_server_args
get_memory_pool_size
get_max_total_num_tokens
[4]:
# get_server_info
url = f"http://localhost:{port}/get_server_info"
response = requests.get(url)
print_highlight(response.text)
[2025-04-13 23:25:20] INFO: 127.0.0.1:55550 - "GET /get_server_info HTTP/1.1" 200 OK
Health Check#
/health
: Check the health of the server./health_generate
: Check the health of the server by generating one token.
[5]:
url = f"http://localhost:{port}/health_generate"
response = requests.get(url)
print_highlight(response.text)
[2025-04-13 23:25:20 TP0] Prefill batch. #new-seq: 1, #new-token: 1, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:25:21] INFO: 127.0.0.1:55556 - "GET /health_generate HTTP/1.1" 200 OK
[6]:
url = f"http://localhost:{port}/health"
response = requests.get(url)
print_highlight(response.text)
[2025-04-13 23:25:21] INFO: 127.0.0.1:55560 - "GET /health HTTP/1.1" 200 OK
Flush Cache#
Flush the radix cache. It will be automatically triggered when the model weights are updated by the /update_weights
API.
[7]:
# flush cache
url = f"http://localhost:{port}/flush_cache"
response = requests.post(url)
print_highlight(response.text)
[2025-04-13 23:25:21] INFO: 127.0.0.1:55568 - "POST /flush_cache HTTP/1.1" 200 OK
[2025-04-13 23:25:21 TP0] Cache flushed successfully!
Please check backend logs for more details. (When there are running or waiting requests, the operation will not be performed.)
Update Weights From Disk#
Update model weights from disk without restarting the server. Only applicable for models with the same architecture and parameter size.
SGLang support update_weights_from_disk
API for continuous evaluation during training (save checkpoint to disk and update weights from disk).
[8]:
# successful update with same architecture and size
url = f"http://localhost:{port}/update_weights_from_disk"
data = {"model_path": "meta-llama/Llama-3.2-1B"}
response = requests.post(url, json=data)
print_highlight(response.text)
assert response.json()["success"] is True
assert response.json()["message"] == "Succeeded to update model weights."
[2025-04-13 23:25:21] Start update_weights. Load format=auto
[2025-04-13 23:25:21 TP0] Update engine weights online from disk begin. avail mem=38.47 GB
[2025-04-13 23:25:21 TP0] Using model weights format ['*.safetensors']
[2025-04-13 23:25:21 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.33it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.33it/s]
[2025-04-13 23:25:22 TP0] Update weights end.
[2025-04-13 23:25:22 TP0] Cache flushed successfully!
[2025-04-13 23:25:22] INFO: 127.0.0.1:55576 - "POST /update_weights_from_disk HTTP/1.1" 200 OK
[9]:
# failed update with different parameter size or wrong name
url = f"http://localhost:{port}/update_weights_from_disk"
data = {"model_path": "meta-llama/Llama-3.2-1B-wrong"}
response = requests.post(url, json=data)
response_json = response.json()
print_highlight(response_json)
assert response_json["success"] is False
assert response_json["message"] == (
"Failed to get weights iterator: "
"meta-llama/Llama-3.2-1B-wrong"
" (repository not found)."
)
[2025-04-13 23:25:22] Start update_weights. Load format=auto
[2025-04-13 23:25:22 TP0] Update engine weights online from disk begin. avail mem=38.47 GB
[2025-04-13 23:25:22 TP0] Failed to get weights iterator: meta-llama/Llama-3.2-1B-wrong (repository not found).
[2025-04-13 23:25:22] INFO: 127.0.0.1:55578 - "POST /update_weights_from_disk HTTP/1.1" 400 Bad Request
Encode (embedding model)#
Encode text into embeddings. Note that this API is only available for embedding models and will raise an error for generation models. Therefore, we launch a new server to server an embedding model.
[10]:
terminate_process(server_process)
embedding_process, port = launch_server_cmd(
"""
python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \
--host 0.0.0.0 --is-embedding
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:25:22] Child process unexpectedly failed with an exit code 9. pid=2785572
[2025-04-13 23:25:32] server_args=ServerArgs(model_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_path='Alibaba-NLP/gte-Qwen2-7B-instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Alibaba-NLP/gte-Qwen2-7B-instruct', chat_template=None, completion_template=None, is_embedding=True, revision=None, host='0.0.0.0', port=33354, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=147376542, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:25:32] Downcasting torch.float32 to torch.float16.
[2025-04-13 23:25:41 TP0] Downcasting torch.float32 to torch.float16.
[2025-04-13 23:25:42 TP0] Overlap scheduler is disabled for embedding models.
[2025-04-13 23:25:42 TP0] Downcasting torch.float32 to torch.float16.
[2025-04-13 23:25:42 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:25:42 TP0] Init torch distributed begin.
[2025-04-13 23:25:42 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:25:42 TP0] Load weight begin. avail mem=61.54 GB
[2025-04-13 23:25:42 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:25:43 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 14% Completed | 1/7 [00:01<00:11, 1.84s/it]
Loading safetensors checkpoint shards: 29% Completed | 2/7 [00:03<00:08, 1.80s/it]
Loading safetensors checkpoint shards: 43% Completed | 3/7 [00:05<00:07, 1.83s/it]
Loading safetensors checkpoint shards: 57% Completed | 4/7 [00:07<00:05, 1.83s/it]
Loading safetensors checkpoint shards: 71% Completed | 5/7 [00:09<00:03, 1.82s/it]
Loading safetensors checkpoint shards: 86% Completed | 6/7 [00:10<00:01, 1.64s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:11<00:00, 1.40s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:11<00:00, 1.61s/it]
[2025-04-13 23:25:54 TP0] Load weight end. type=Qwen2ForCausalLM, dtype=torch.float16, avail mem=47.11 GB, mem usage=14.43 GB.
[2025-04-13 23:25:54 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.55 GB, V size: 0.55 GB
[2025-04-13 23:25:54 TP0] Memory pool end. avail mem=45.74 GB
[2025-04-13 23:25:55 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-04-13 23:25:55] INFO: Started server process [2787472]
[2025-04-13 23:25:55] INFO: Waiting for application startup.
[2025-04-13 23:25:55] INFO: Application startup complete.
[2025-04-13 23:25:55] INFO: Uvicorn running on http://0.0.0.0:33354 (Press CTRL+C to quit)
[2025-04-13 23:25:55] INFO: 127.0.0.1:56960 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:25:56] INFO: 127.0.0.1:56976 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:25:56 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:25:57] INFO: 127.0.0.1:56992 - "POST /encode HTTP/1.1" 200 OK
[2025-04-13 23:25:57] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[11]:
# successful encode for embedding model
url = f"http://localhost:{port}/encode"
data = {"model": "Alibaba-NLP/gte-Qwen2-7B-instruct", "text": "Once upon a time"}
response = requests.post(url, json=data)
response_json = response.json()
print_highlight(f"Text embedding (first 10): {response_json['embedding'][:10]}")
[2025-04-13 23:26:00 TP0] Prefill batch. #new-seq: 1, #new-token: 4, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:00] INFO: 127.0.0.1:35998 - "POST /encode HTTP/1.1" 200 OK
[12]:
terminate_process(embedding_process)
[2025-04-13 23:26:00] Child process unexpectedly failed with an exit code 9. pid=2787821
[2025-04-13 23:26:00] Child process unexpectedly failed with an exit code 9. pid=2787691
Classify (reward model)#
SGLang Runtime also supports reward models. Here we use a reward model to classify the quality of pairwise generations.
[13]:
terminate_process(embedding_process)
# Note that SGLang now treats embedding models and reward models as the same type of models.
# This will be updated in the future.
reward_process, port = launch_server_cmd(
"""
python -m sglang.launch_server --model-path Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 --host 0.0.0.0 --is-embedding
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:26:10] server_args=ServerArgs(model_path='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', tokenizer_path='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Skywork/Skywork-Reward-Llama-3.1-8B-v0.2', chat_template=None, completion_template=None, is_embedding=True, revision=None, host='0.0.0.0', port=30766, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=689574379, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:26:21 TP0] Overlap scheduler is disabled for embedding models.
[2025-04-13 23:26:21 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:26:21 TP0] Init torch distributed begin.
[2025-04-13 23:26:22 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:26:22 TP0] Load weight begin. avail mem=59.44 GB
[2025-04-13 23:26:22 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:26:24 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.18it/s]
Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.10it/s]
Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.08it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.63it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.39it/s]
[2025-04-13 23:26:27 TP0] Load weight end. type=LlamaForSequenceClassification, dtype=torch.bfloat16, avail mem=30.38 GB, mem usage=29.06 GB.
[2025-04-13 23:26:27 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.25 GB, V size: 1.25 GB
[2025-04-13 23:26:27 TP0] Memory pool end. avail mem=27.58 GB
[2025-04-13 23:26:28 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-04-13 23:26:28] INFO: Started server process [2789411]
[2025-04-13 23:26:28] INFO: Waiting for application startup.
[2025-04-13 23:26:28] INFO: Application startup complete.
[2025-04-13 23:26:28] INFO: Uvicorn running on http://0.0.0.0:30766 (Press CTRL+C to quit)
[2025-04-13 23:26:28] INFO: 127.0.0.1:58360 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:26:29] INFO: 127.0.0.1:58364 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:26:29 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:30] INFO: 127.0.0.1:58380 - "POST /encode HTTP/1.1" 200 OK
[2025-04-13 23:26:30] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[14]:
from transformers import AutoTokenizer
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
CONVS = [
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]
tokenizer = AutoTokenizer.from_pretrained("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2")
prompts = tokenizer.apply_chat_template(CONVS, tokenize=False)
url = f"http://localhost:{port}/classify"
data = {"model": "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", "text": prompts}
responses = requests.post(url, json=data).json()
for response in responses:
print_highlight(f"reward: {response['embedding'][0]}")
[2025-04-13 23:26:34 TP0] Prefill batch. #new-seq: 2, #new-token: 136, #cached-token: 2, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:26:34] INFO: 127.0.0.1:58384 - "POST /classify HTTP/1.1" 200 OK
[15]:
terminate_process(reward_process)
Capture expert selection distribution in MoE models#
SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization.
Note: We only print out the first 10 lines of the csv below for better readability. Please adjust accordingly if you want to analyze the results more deeply.
[16]:
expert_record_server_process, port = launch_server_cmd(
"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0"
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:26:43] server_args=ServerArgs(model_path='Qwen/Qwen1.5-MoE-A2.7B', tokenizer_path='Qwen/Qwen1.5-MoE-A2.7B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='Qwen/Qwen1.5-MoE-A2.7B', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=31146, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=58589652, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:26:53 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:26:53 TP0] Init torch distributed begin.
[2025-04-13 23:26:53 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:26:53 TP0] Load weight begin. avail mem=58.82 GB
[2025-04-13 23:26:53 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:26:54 TP0] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards: 0% Completed | 0/8 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 12% Completed | 1/8 [00:01<00:12, 1.74s/it]
Loading safetensors checkpoint shards: 25% Completed | 2/8 [00:03<00:11, 1.85s/it]
Loading safetensors checkpoint shards: 38% Completed | 3/8 [00:05<00:08, 1.66s/it]
Loading safetensors checkpoint shards: 50% Completed | 4/8 [00:06<00:06, 1.66s/it]
Loading safetensors checkpoint shards: 62% Completed | 5/8 [00:08<00:04, 1.59s/it]
Loading safetensors checkpoint shards: 75% Completed | 6/8 [00:09<00:03, 1.58s/it]
Loading safetensors checkpoint shards: 88% Completed | 7/8 [00:10<00:01, 1.32s/it]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:10<00:00, 1.04it/s]
Loading safetensors checkpoint shards: 100% Completed | 8/8 [00:10<00:00, 1.35s/it]
[2025-04-13 23:27:05 TP0] Load weight end. type=Qwen2MoeForCausalLM, dtype=torch.bfloat16, avail mem=51.75 GB, mem usage=7.07 GB.
[2025-04-13 23:27:05 TP0] KV Cache is allocated. #tokens: 20480, K size: 1.88 GB, V size: 1.88 GB
[2025-04-13 23:27:05 TP0] Memory pool end. avail mem=47.34 GB
[2025-04-13 23:27:05 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:27:06 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=8192
[2025-04-13 23:27:06] INFO: Started server process [2791321]
[2025-04-13 23:27:06] INFO: Waiting for application startup.
[2025-04-13 23:27:06] INFO: Application startup complete.
[2025-04-13 23:27:06] INFO: Uvicorn running on http://0.0.0.0:31146 (Press CTRL+C to quit)
[2025-04-13 23:27:06] INFO: 127.0.0.1:40446 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:27:07] INFO: 127.0.0.1:40454 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:27:07 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:27:09 TP0] Using default MoE config. Performance might be sub-optimal! Config file not found at /public_sglang_ci/runner-l1b-gpu-23/_work/sglang/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=60,N=1408,device_name=NVIDIA_H100_80GB_HBM3.json, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py.
[2025-04-13 23:27:10] INFO: 127.0.0.1:40458 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:27:10] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[17]:
response = requests.post(f"http://localhost:{port}/start_expert_distribution_record")
print_highlight(response)
url = f"http://localhost:{port}/generate"
data = {"text": "What is the capital of France?"}
response = requests.post(url, json=data)
print_highlight(response.json())
response = requests.post(f"http://localhost:{port}/stop_expert_distribution_record")
print_highlight(response)
response = requests.post(f"http://localhost:{port}/dump_expert_distribution_record")
print_highlight(response)
import glob
output_file = glob.glob("expert_distribution_*.csv")[0]
with open(output_file, "r") as f:
print_highlight("\n| Layer ID | Expert ID | Count |")
print_highlight("|----------|-----------|--------|")
next(f)
for i, line in enumerate(f):
if i < 9:
layer_id, expert_id, count = line.strip().split(",")
print_highlight(f"| {layer_id:8} | {expert_id:9} | {count:6} |")
[2025-04-13 23:27:11 TP0] Resetting expert distribution record...
[2025-04-13 23:27:11] INFO: 127.0.0.1:52644 - "POST /start_expert_distribution_record HTTP/1.1" 200 OK
[2025-04-13 23:27:11 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:27:12 TP0] Decode batch. #running-req: 1, #token: 40, token usage: 0.00, gen throughput (token/s): 6.30, #queue-req: 0,
[2025-04-13 23:27:13 TP0] Decode batch. #running-req: 1, #token: 80, token usage: 0.00, gen throughput (token/s): 44.84, #queue-req: 0,
[2025-04-13 23:27:14 TP0] Decode batch. #running-req: 1, #token: 120, token usage: 0.01, gen throughput (token/s): 44.97, #queue-req: 0,
[2025-04-13 23:27:14] INFO: 127.0.0.1:52650 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:27:14] INFO: 127.0.0.1:52654 - "POST /stop_expert_distribution_record HTTP/1.1" 200 OK
[2025-04-13 23:27:14 TP0] Resetting expert distribution record...
[2025-04-13 23:27:14] INFO: 127.0.0.1:52668 - "POST /dump_expert_distribution_record HTTP/1.1" 200 OK
| Layer ID | Expert ID | Count |
[18]:
terminate_process(expert_record_server_process)
[2025-04-13 23:27:14] Child process unexpectedly failed with an exit code 9. pid=2791672
[2025-04-13 23:27:14] Child process unexpectedly failed with an exit code 9. pid=2791606
Skip Tokenizer and Detokenizer#
SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow.
[19]:
tokenizer_free_server_process, port = launch_server_cmd(
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --skip-tokenizer-init
"""
)
wait_for_server(f"http://localhost:{port}")
[2025-04-13 23:27:23] server_args=ServerArgs(model_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_path='meta-llama/Llama-3.2-1B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=True, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='meta-llama/Llama-3.2-1B-Instruct', chat_template=None, completion_template=None, is_embedding=False, revision=None, host='127.0.0.1', port=36300, mem_fraction_static=0.88, max_running_requests=200, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, stream_interval=1, stream_output=False, random_seed=407596092, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, enable_llama4_multimodal=None, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_deepep_moe=False, deepep_mode='auto', enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, enable_flashinfer_mla=False, enable_flashmla=False, flashinfer_mla_disable_ragged=False, warmups=None, n_share_experts_fusion=0, disable_shared_experts_fusion=False, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disable_fast_image_processor=False)
[2025-04-13 23:27:33 TP0] Attention backend not set. Use flashinfer backend by default.
[2025-04-13 23:27:33 TP0] Init torch distributed begin.
[2025-04-13 23:27:34 TP0] Init torch distributed ends. mem usage=0.00 GB
[2025-04-13 23:27:34 TP0] Load weight begin. avail mem=57.90 GB
[2025-04-13 23:27:34 TP0] Ignore import error when loading sglang.srt.models.llama4.
[2025-04-13 23:27:35 TP0] Using model weights format ['*.safetensors']
[2025-04-13 23:27:35 TP0] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.46it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.46it/s]
[2025-04-13 23:27:35 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=55.45 GB, mem usage=2.44 GB.
[2025-04-13 23:27:35 TP0] KV Cache is allocated. #tokens: 20480, K size: 0.31 GB, V size: 0.31 GB
[2025-04-13 23:27:35 TP0] Memory pool end. avail mem=51.59 GB
[2025-04-13 23:27:36 TP0]
CUDA Graph is DISABLED.
This will cause significant performance degradation.
CUDA Graph should almost never be disabled in most usage scenarios.
If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.
[2025-04-13 23:27:36 TP0] max_total_num_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=200, context_len=131072
[2025-04-13 23:27:36] INFO: Started server process [2793543]
[2025-04-13 23:27:36] INFO: Waiting for application startup.
[2025-04-13 23:27:36] INFO: Application startup complete.
[2025-04-13 23:27:36] INFO: Uvicorn running on http://127.0.0.1:36300 (Press CTRL+C to quit)
[2025-04-13 23:27:36] INFO: 127.0.0.1:37614 - "GET /v1/models HTTP/1.1" 200 OK
[2025-04-13 23:27:37] INFO: 127.0.0.1:37622 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-04-13 23:27:37 TP0] Prefill batch. #new-seq: 1, #new-token: 3, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:27:39] INFO: 127.0.0.1:37638 - "POST /generate HTTP/1.1" 200 OK
[2025-04-13 23:27:39] The server is fired up and ready to roll!
NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
We are running those notebooks in a CI parallel environment, so the throughput is not representative of the actual performance.
[20]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
input_text = "What is the capital of France?"
input_tokens = tokenizer.encode(input_text)
print_highlight(f"Input Text: {input_text}")
print_highlight(f"Tokenized Input: {input_tokens}")
response = requests.post(
f"http://localhost:{port}/generate",
json={
"input_ids": input_tokens,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 256,
"stop_token_ids": [tokenizer.eos_token_id],
},
"stream": False,
},
)
output = response.json()
output_tokens = output["output_ids"]
output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)
print_highlight(f"Tokenized Output: {output_tokens}")
print_highlight(f"Decoded Output: {output_text}")
print_highlight(f"Output Text: {output['meta_info']['finish_reason']}")
[2025-04-13 23:27:42 TP0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0,
[2025-04-13 23:27:42 TP0] Decode batch. #running-req: 1, #token: 41, token usage: 0.00, gen throughput (token/s): 6.66, #queue-req: 0,
[2025-04-13 23:27:42 TP0] Decode batch. #running-req: 1, #token: 81, token usage: 0.00, gen throughput (token/s): 198.73, #queue-req: 0,
[2025-04-13 23:27:42] INFO: 127.0.0.1:46324 - "POST /generate HTTP/1.1" 200 OK
The capital of France is Paris. Paris is the most populous city in France and is known for its rich history, art, fashion, and cuisine. It is also home to many famous landmarks such as the Eiffel Tower, Notre Dame Cathedral, and the Louvre Museum. Paris is a popular tourist destination and is often referred to as the "City of Light" due to its association with the Enlightenment and the French Revolution.<|eot_id|>
[21]:
terminate_process(tokenizer_free_server_process)
[2025-04-13 23:27:42] Child process unexpectedly failed with an exit code 9. pid=2794261