Eagle3 for Llama3#
Introduction#
This document provides a step-by-step guide to reproducing the training process described in the EAGLE3 paper, using the script examples/run_llama3_eagle3_sgl_online.sh. We will walk through the script and explain each key step along the way.
Workflow#
Step 1. Prepare environment#
We suggest to use a virtual environment to make sure that all the dependencies can be correctly installed. If you want to use python>=3.12, please set export SETUPTOOLS_USE_DISTUTILS=local.
uv venv --python 3.11
source .venv/bin/activate
cd PATH-TO-SpecForge
uv pip install -r requirements.txt
uv pip install -v .
After completing these steps, you can check if the installation is successful by running the following command. You should not see any error if the installation is successful.
python -c "import specforge"
Step 2. Prepare Model & Dataset#
Next, we can start preparing the model and dataset. First, use these commands to download the model and the dataset.
hf download meta-llama/Llama-3.1-8B-Instruct
hf download Aeala/ShareGPT_Vicuna_unfiltered --repo-type dataset
hf download HuggingFaceH4/ultrachat_200k --repo-type dataset
python scripts/prepare_data.py --dataset ultrachat --output_path /YOUR/PATH/Llama-3.1-8B-Instruct/dataset
python scripts/prepare_data.py --dataset sharegpt --output_path /YOUR/PATH/Llama-3.1-8B-Instruct/dataset
Then, launch the SGLang server and run generate_data_by_target.py to generate responses from the base model across different datasets. Make sure to update the SYSTEM_PROMPT value in generate_data_by_target.py to suit your requirements.
for i in {1..4}; do
CUDA_VISIBLE_DEVICES=$i python3 -m sglang.launch_server \
--model meta-llama/Llama-3.1-8B-Instruct \
--cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 \
--dtype bfloat16 --mem-frac=0.8 --port $((30000 + i)) &
done
python scripts/generate_data_by_target.py \
--model-name meta-llama/Llama-3.1-8B-Instruct \
--raw-data-file /YOUR/PATH/Llama-3.1-8B-Instruct/dataset/sharegpt.jsonl \
--output-dir /YOUR/PATH/Llama-3.1-8B-Instruct/generated-dataset/sharegpt-llama-3.1-8b-instruct \
--max-concurrency 512 \
--num-per-shard 50000 \
--server-address-port 127.0.0.1:30001 127.0.0.1:30002 127.0.0.1:30003 127.0.0.1:30004
python scripts/generate_data_by_target.py \
--model-name meta-llama/Llama-3.1-8B-Instruct \
--raw-data-file /YOUR/PATH/Llama-3.1-8B-Instruct/dataset/ultrachat.jsonl \
--output-dir /YOUR/PATH/Llama-3.1-8B-Instruct/generated-dataset/ultrachat-llama-3.1-8b-instruct \
--max-concurrency 512 \
--num-per-shard 50000 \
--server-address-port 127.0.0.1:30001 127.0.0.1:30002 127.0.0.1:30003 127.0.0.1:30004
After completing these steps, you can review the error entries in error.jsonl. Most of them will likely be request timeout. You can then decide whether you want to regenerate those samples. In my case, I chose not to, so I simply deleted error.jsonl before uploading to Hugging Face. The following command is used:
hf repo create zhuyksir/Ultrachat-Sharegpt-Llama3.1-8B --type dataset
hf upload /YOUR/PATH/Llama-3.1-8B-Instruct/generated-dataset/ultrachat-llama-3.1-8b-instruct --commit-message "generated dataset by Llama3.1-8B"
from datasets import load_dataset
ds = load_dataset("zhuyksir/Ultrachat-Sharegpt-Llama3.1-8B", split="train")
ds.to_json("merged.jsonl", orient="records", lines=True)
ds = ds.train_test_split(test_size=0.05)
train_ds = ds["train"]
test_ds = ds["test"]
Alternatively, For meta-llama/Llama-3.1-8B-Instruct, you can use the dataset we generated: zhuyksir/Ultrachat-Sharegpt-Llama3.1-8B.
Each row should have this structure:
{
"id": XXX,
"conversations":[
{"role": "system", "content": XXX},
{"role": "user", "content": XXX},
{"role": "assistant", "content": XXX},
...
]
}
Second, we need to pre-build the cache for training.
During training, the text must be encoded into input IDs. These encoding steps can be performed before training begins. The resulting cache file will be saved under
$CACHE_DIR.The script also selects the vocabulary with the top-k size.
With the option
--view train-data, you can inspect the dataset by index (e.g., index 1 or index 2 in the example below). This helps verify that the loss mask is generated correctly:Green text indicates tokens where
loss_mask == 1.Red text indicates tokens where
loss_mask == 0 (typically user input and system prompt). Since the goal is to train the draft model only on the target model’s output, user text must be masked out. In other words, only tokens generated by the target model should contribute to the loss.
You might see this warning.
WARNING: No assistant response spans found in the conversation text.This occurs when, during data generation, an error causes a sample to contain only user inputs without any assistant responses. You can safely ignore this warning—the loss mask for such samples is set entirely to zero.
python scripts/build_eagle3_dataset_cache.py \
--target-model-path $MODEL_PATH \
--draft-model-config ./configs/llama3-8B-eagle3.json \
--train-data-path $DATASET_PATH/sharegpt_ultrachat_train.jsonl \
--eval-data-path $DATASET_PATH/sharegpt_ultrachat_test.jsonl \
--cache-dir $CACHE_DIR \
--chat-template $CHAT_TEMPLATE \
--max-length $MAX_LENGTH \
--view-train-data 1 2
Step 3. Start Training#
Use the following script to train.
set
total-steps=800000, learning-rate=5e-5to align with EAGLE official repo config. Feel Free to change this settings to do your own experiments.total-stepsandwarmup-ratiodecide the increasement curve of learning rate.
export NUM_GPUS=4
export OUTPUT_DIR=/YOUR/PATH/Llama-3.1-8B-Instruct/dev_outputs/
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
scripts/train_eagle3_sgl_online.py \
--target-model-path $MODEL_PATH \
--model-path $MODEL_PATH \
--draft-model-config ./configs/llama3-8B-eagle3.json \
--train-data-path $DATASET_PATH/sharegpt_ultrachat_train.jsonl \
--eval-data-path $DATASET_PATH/sharegpt_ultrachat_test.jsonl \
--tp-size $NUM_GPUS \
--output-dir $OUTPUT_DIR \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 5e-5 \
--draft-attention-backend flex_attention \
--max-length $MAX_LENGTH \
--chat-template $CHAT_TEMPLATE \
--cache-dir $CACHE_DIR \
--mem-frac=0.4 \
--total-steps=800000 \
--dist-timeout=10 \
--wandb-project llama3-8b-eagle3 \
--wandb-name sgl-online \
--report-to wandb
Step 4. benchmark#
For Llama3.1-8B, we add a system prompt to all training data, following the approach used in the official repository. Consequently, when benchmarking, we should also include this system prompt to obtain the full accept length. Please uncomment the corresponding line and add the system prompt.
The four numbers in the config represent: batch_size, num_steps, topk, num_verify_tokens. You can adjust the values in the config list to experiment with different test cases.
I have upload my trained eagle model in zhuyksir/EAGLE3-Llama-3.1-8B-Instruct. You are welcome to download and check its accept length.
config_list=(
"4,3,1,4"
"4,7,10,60"
)
CUDA_VISIBLE_DEVICES=4,5,6,7 python3 bench_model_speedup.py \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--speculative-draft-model-path /YOUR/PATH/Llama-3.1-8B-Instruct/dev_outputs/epoch_0 \
--port 20001 \
--trust-remote-code \
--mem-fraction-static 0.8 \
--tp-size 4 \
--config-list "${config_list[@]}" \
--benchmark-list mtbench:80 gsm8k:200 humaneval:200 math500:200 \
--output output.jsonl