💡 Customize Your Own Training#
🔧 Customize Training Args#
torchrun \
--standalone \
--nproc_per_node 8 \
./scripts/train_eagle3_online.py \
--target-model-path meta-llama/Llama-3.1-8B-Instruct \
--draft-model-config ./configs/llama3-8B-eagle3.json \
--train-data-path ./cache/dataset/sharegpt.jsonl \
--output-dir ./outputs/llama3-8b-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama3 \
--cache-dir ./cache
If you wish to understand what each argument does, you can run python scripts/train_eagle3_online.py --help
to see the full list of arguments. Particularly, we will discuss some important arguments below.
--chat-template
: This should be the chat template to use for the model, so please make sure you set it to the correct value.--cache-dir
: This directory contains the dataset cache including theinput_ids
,loss_mask
,attention_mask
andvocab_mapping
. These caches can make your data loading much faster once a cache is generated. The cache file has a name which is obtained by hashing the dataset path to avoid cache collision.
💬 Customize Chat Template#
You can register a new chat template for your model by adding a new entry to the TEMPLATE_REGISTRY
in the specforge.data.template.py
file.
TEMPLATE_REGISTRY.register(
name="your-template-name",
template=ChatTemplate(
assistant_header="xxx",
user_header="xxx",
system_prompt="xxx",
end_of_turn_token="xxx",
),
)
🪅 Customize Model#
Customize Target Model#
If you wish to train Eagle3 for other models, you need to modify the --target-model-path
value. We support loading these models directly from HuggingFace.
However, if your model is too large and requires tensor parallelism, you can implement its tensor parallel version on your own in the specforge.modeling.target
directory. The CausalLM model should inherit the DistributedTargetModel
class in the specforge.modeling.target.base.py
file and apply ColumnParallelLinear
and RowParallelLinear
to its submodules.
from .base import DistributedTargetModel
from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear
class MyModelForCausalLM(MyModelPreTrainedModel, GenerationMixin, DistributedTargetModel):
...
def load_weights(self, state_dict: Dict[str, torch.Tensor]):
...
Afterwards, you need to register this model to the AutoEagle3TargetModel
class in the specforge.modeling.auto.py
file.
class AutoDistributedTargetModel(AutoModelForCausalLMBase):
_model_mapping = {
Llama4TextConfig: [Llama4ForCausalLM],
+ MyModelConfig: [MyModelForCausalLM],
}
When tp_size
is greater than 1, the script will automatically load the distributed version of the model for tensor parallelism.
Customize Draft Model#
If you want to change the draft model configuration, you can write your own configuration file and pass its path to the --draft-model-config
argument. Or, if you do not provide the --draft-model-config
argument, the script will automatically generate the draft model configuration based on the target model configuration. If you wish to serve your customized draft model with SGLang, make sure you implement the draft model in SGLang as well and the architecture name must match. To implement your own draft model, you can create a new class and inherit it from the Eagle3DraftModel
class in the specforge.modeling.draft.base.py
file.
from .base import Eagle3DraftModel
from transformers import PretrainedConfig
class MyModelConfig(PretrainedConfig):
model_type = "mymodel"
def __init__(self, **kwargs):
...
class MyModelEagle3(Eagle3DraftModel):
config_class = MyModelConfig
def __init__(self, config, quant_config=None) -> None:
...
You can then register these models to the AutoEagle3TargetModel
and AutoDraftModelConfig
classes in the specforge.modeling.auto.py
file for the automatic model loading.
class AutoEagle3DraftModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
LlamaConfig: [LlamaForCausalLMEagle3],
+ MyModelConfig: MyModelEagle3,
}
class AutoDraftModelConfig:
_config_mapping = {
"LlamaForCausalLMEagle3": LlamaConfig,
+ "MyModelEagle3": MyModelConfig,
}
In this way, as long as your config.json
specifies the correct architecture name, the script will automatically load the correct draft model for you.