Fine Tuning Llama 3.1 405B with Axolotl on a Lambda 1-Click Cluster
Personalizing SOTA Open Source AI
Llama 3.1 405B was recently released by Meta just over a month ago, positioning itself as the premiere open source model capable of competing with proprietary frontier models. We have seen this proven out on both benchmarks and the LMSYS Chatbot Arena Leaderboard.
In this tutorial, you'll learn how to fine-tune the Meta Llama 3.1 model using Axolotl on a multi-node Lambda 1-Click Cluster (1CC). You can also download the below code and examples from the Axolotl Cookbook.
What we'll be covering:
Setting up Axolotl and its required environment.
Fine-tune the Llama 3.1 405B base model using Maxime Labonne’s FineTome dataset: mlabonne/FineTome-100k
Prerequisites
This tutorial assumes the following prerequisites:
A working 1-Click Cluster. We recommend at least 64 H100 GPUs to full parameter fine-tune 405B.
Setting up Axolotl
Once your 1-Click cluster is setup, we can begin installing and setting up Axolotl. Detailed instructions on how to reserve and set up a 1-Click cluster can be found at this link.
Setting up your SSH config
Lambda provisions several head nodes to act as both a jump proxy and coordinator for each of the worker/GPU nodes. Note the inclusion of `ForwardAgent yes` which allows us to easily authenticate via SSH between nodes.
Lambda provides a script for setting up inter-node passwordless SSH. To access it, go to the 1-Click cluster dashboard, click "SETUP" next to "SSH LOGIN," and proceed to the next page:
Setting up the shared filesystem
When requesting a 1-Click cluster, an empty persistent storage is automatically created and attached to each node. Optionally, you can also create a new file system with a specific name or attach an existing one, as shown in the screenshot below.
For the remainder of the tutorial, a shared filesystem mounted at /home/ubuntu/cluster-ml-64 on all nodes will be used. We will store this path as an environment variable that can be referenced by different stages of the workflow.
Cloning the Cookbook
From one of the head nodes, we can clone the cookbook into the shared filesystem so that the configuration scripts are available on each node.
export PATH_STORAGE=/home/ubuntu/cluster-ml-64
git -C $PATH_STORAGE clone https://github.com/axolotl-ai-cloud/axolotl-cookbook.git
Setup mpirun (optional)
OpenMPI allows us to easily run applications in parallel across multiple hosts. While OpenMPI is optional for this walkthrough, it simplifies the workflow by allowing us to run a single command from a head node rather than individually running the same commands on each node.
First, let’s configure the hostfile so that mpirun knows which hosts are the worker nodes.
echo "ml-64-node-001 slots=1
ml-64-node-002 slots=1
ml-64-node-003 slots=1
ml-64-node-004 slots=1
ml-64-node-005 slots=1
ml-64-node-006 slots=1
ml-64-node-007 slots=1
ml-64-node-008 slots=1" > ${PATH_STORAGE}/1cc-hostfile.txt
From a head node, install OpenMPI:
sudo apt-get update
sudo apt-get install -y libucx0
sudo apt-get install -y openmpi-bin openmpi-doc libopenmpi-dev
We also need to install OpenMPI on each of the worker nodes, but we can’t use `mpirun` yet to do this.
cat -n ${PATH_STORAGE}/1cc-hostfile.txt | xargs -n2 bash -c 'ssh $1 sudo apt-get update'
NUM_WORKERS=$(cat ${PATH_STORAGE}/1cc-hostfile.txt | wc -l) && \
cat -n ${PATH_STORAGE}/1cc-hostfile.txt | xargs -n3 -P $NUM_WORKERS bash -c 'ssh $1 sudo apt-get update' && \
cat -n ${PATH_STORAGE}/1cc-hostfile.txt | xargs -n3 -P $NUM_WORKERS bash -c 'ssh $1 sudo apt-get install -y libucx0' && \
cat -n ${PATH_STORAGE}/1cc-hostfile.txt | xargs -n3 -P $NUM_WORKERS bash -c 'ssh $1 sudo apt-get install openmpi-bin openmpi-doc libopenmpi-dev -y'
Let’s now verify our configuration
mpirun --hostfile ${PATH_STORAGE}/1cc-hostfile.txt --prefix /usr hostname
Lambda also has additional information on running distributed commands using xargs available at https://github.com/LambdaLabsML/distributed-training-guide/tree/main/04-job-launchers-bash.
Quickstart
The configuration script to setup the environment and dependencies is available under the setup.sh in the cookbook. From each worker/GPU node, you can simply run:
${PATH_STORAGE}/axolotl-cookbook/lambda/setup.sh $PATH_STORAGE
Or if using mpirun:
mpirun --hostfile ${PATH_STORAGE}/1cc-hostfile.txt --prefix /usr ${PATH_STORAGE}/axolotl-cookbook/lambda/setup.sh $PATH_STORAGE
HuggingFace Auth
Now let’s make sure that the trainer has access to the Llama 3.1 405B model on HuggingFace by adding the auth token to the environment. You can grab your HuggingFace access token from https://huggingface.co/settings/tokens. In the previous setup step, we set the HF_HOME environment variable to the shared filesystem, so you can simply authenticate from only the head node using:
export HF_HOME="${PATH_STORAGE}/.cache/huggingface"
huggingface-cli login
You will also want to make sure you have access to the gated 405B model at https://huggingface.co/meta-llama/Meta-Llama-3.1-405B.
Downloading the Model into cache
Due to possible timeouts and other factors while downloading ~800GB of model weights, we recommend pre-downloading the weights into the HuggingFace Transformers model cache. From a head node, run the following:
ssh $(head -n 1 ${PATH_STORAGE}/1cc-hostfile.txt | awk '{print $1}') && export HF_HUB_ENABLE_HF_TRANSFER="1" && \ huggingface-cli download meta-llama/Meta-Llama-3.1-405B
Let's fine-tune
First, let’s dive into the configuration we’re going to use, most notably the FSDP configuration section. In particular, for 405B, we need to use SHARDED_STATE_DICT
whenever possible, as there is not enough VRAM on a single node to materialize the FULL_STATE_DICT
checkpoint right after training.
fsdp_final_state_dict_type: SHARDED_STATE_DICT
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
Also, this configuration leverages LinkedIn’s Liger Kernels to improve the GPU VRAM efficiency, allowing us to use a larger sequence length.
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
How do we think about the learning rate in context of the batch size?
sequence_len: 4096
gradient_accumulation_steps: 1
micro_batch_size: 1
learning_rate: 1.0e-5
Taking a look at the Llama 3.1 technical report we see that the peak learning rate of 8e-5 at a batch size of 16M tokens/batch. From our configuration, our batch size is 4k tokens * 64 GPUs = 262K tokens per batch. Using square root scaling, sqrt(262K/16M) ~= 0.128 ~= 1/8
; so we scale the peak learning rate by 1/8th to arrive at a learning rate of ~1.0e-5 for finetuning.
Below are the contents of /home/ubuntu/ml-1cc/axolotl-cookbook/lambda/train.sh. You may need to edit the file depending on your specific configuration.
#!/bin/bash
# Check if path storage and the name of the main node are provided
if [ -z "$1" ] || [ -z "$2" ]; then
echo "Usage: $0 <path_storage> <main_node>"
exit 1
fi
PATH_STORAGE=$1
MAIN_NODE=$2
NUM_NODES=8
JOB_ID=axolotl-lambda
MAIN_NODE=$MAIN_NODE:29500
YAML_CFG=${PATH_STORAGE}/axolotl-cookbook/lambda/configs/llama-3_1-405b-fft.yaml
export NODE_IDX=$((10#$(hostname | grep -oE '[0-9]+$') - 1))
/home/ubuntu/miniconda3/envs/pytorch/bin/torchrun --nnodes=$NUM_NODES --nproc-per-node=8 --node-rank=0 --rdzv-backend=c10d --rdzv-id=$JOB_ID --rdzv-endpoint=$MAIN_NODE -m axolotl.cli.train $YAML_CFG
The last step before finetuning is replacing the $PATH_STORAGE by the appropriate path in the training config file:
envsubst < ${PATH_STORAGE}/axolotl-cookbook/lambda/configs/llama-3_1-405b-fft.yaml > temp.yaml && mv temp.yaml ${PATH_STORAGE}/axolotl-cookbook/lambda/configs/llama-3_1-405b-fft.yaml
To kick off the finetuning, simply run the following on each worker node:
${PATH_STORAGE}/axolotl-cookbook/lambda/train.sh $PATH_STORAGE $MAIN_NODE
or if using mpirun, run this from the head node:
MAIN_NODE=$(head -n 1 ${PATH_STORAGE}/1cc-hostfile.txt | awk '{print $1}')
mpirun --hostfile ${PATH_STORAGE}/1cc-hostfile.txt --prefix /usr ${PATH_STORAGE}/axolotl-cookbook/lambda/train.sh $PATH_STORAGE $MAIN_NODE
Once the finetuning has completed, the model weights will be under the configured output_dir path of `${PATH_STORAGE}/axolotl-artifacts/outputs/llama3_1-405b-finetome/pytorch_model_fsdp_0
`. In order to gather the distributed weights and shard them in a format useful for inference, from one of the worker nodes we can merge the distributed state dict and shard the weights in a single step:
python -m axolotl.cli.merge_sharded_fsdp_weights ${PATH_STORAGE}/axolotl-cookbook/lambda/configs/llama-3_1-405b-fft.yaml
Now the final weights are available in the `${PATH_STORAGE}/axolotl-artifacts/outputs/llama3_1-405b-finetome/merged/`
path and can be uploaded to your preferred artifact store.
You can download the scripts and configs from this tutorial in our Axolotl Cookbook in the Lambda section. The model weights from this tutorial are available at https://huggingface.co/axolotl-ai-co/finetome-llama-3.1-405b. Our WandB artifacts are available at https://wandb.ai/oaaic/llama-3.1-405b-fft/runs/8rcjfqeg.
With this follow-up tutorial, you can also learn how to run inference with your fine-tuned model.
Acknowledgements
We'd like to thank the Lambda Team team for all their help and providing the compute on their 1-Click Cluster (1CC) environment. Also, many thanks to Jeffrey Quesnelle and Nous Research for their insights into fine-tuning 405B from their finetuning of Hermes 3 405B.