# Introduction
As open-source models continue to grow larger, the need for robust infrastructure to handle large-scale AI training has never been more critical. At Felafax, we recently fine-tuned the **LLaMA 3.1 405B** model on AMD GPUs, demonstrating their ability to efficiently manage large-scale AI workloads. The experience was very positive, and we're excited to share that we've open-sourced all our work on ⭐ [GitHub](https://github.com/felafax/felafax).
AMD GPUs, especially the MI300X series, offer a strong alternative to NVIDIA AI hardware, providing higher performance per dollar. Our setup included a single node with **8 x AMD MI300x GPUs**, and we used **JAX** for fine-tuning. In this article, we'll share the full story of fine-tuning LLaMA 405B, including all the details of parameter sharding and LoRA implementation.
*A huge shout-out to **[TensorWave](https://tensorwave.com/)** for providing us with an AMD node for this project. Their support was invaluable and made this endeavor possible!*
If you'd like to try AMD GPUs or fine-tune the 405B model yourself, sign up for the [waitlist](https://tally.so/r/wbrRDE)!
Now, let's delve into how we leveraged JAX and AMD GPUs to orchestrate this large-scale fine-tuning run.
## What is JAX and why we picked it
JAX is a powerful machine learning library that combines NumPy-like APIs, automatic differentiation, and Google's XLA compiler. It offers superior APIs for model parallelism, making it ideal for training massive models like LLaMA 3.1 405B.
##### **Why I'm a big fan of JAX**:
1. **Pure Functions**: JAX encourages writing pure functions (if you want to compile your code using JIT), making code easier to compose, debug, and read.
2. **Advanced Parallelism**: JAX's flexible JIT APIs support advanced data and model parallelism out of the box, crucial for large-scale training.
3. **Cleaner Codebases**: JAX's design philosophy promotes writing code that is inherently portable across hardware platforms (CPU, GPUs, TPUs), resulting in cleaner, more maintainable codebases.
For a deeper dive into JAX's advantages over PyTorch, I recommend reading the blog post [PyTorch is dead. Long live JAX](https://neel04.github.io/my-website/blog/pytorch_rant/).
##### **JAX is especially great for working with non-NVIDIA hardware:**
JAX offers several advantages when working with AMD:
1. **Hardware-agnostic Approach**: JAX leverages the XLA (Accelerated Linear Algebra) compiler, which compiles computations into a hardware-independent intermediate representation (HLO graph). This allows the same JAX code to be optimized and executed efficiently on different hardware backends, including AMD GPUs, without modification.
2. **Platform-independent Optimizations**: The XLA compiler performs optimizations independently of the underlying hardware, benefiting all supported platforms.
3. **Simplified Portability**: Transitioning from NVIDIA to AMD (or other supported hardware) with JAX requires minimal code changes. This contrasts with PyTorch, which is more tightly coupled with NVIDIA's CUDA ecosystem.
- PyTorch often relies on CUDA-specific implementations (e.g., `torch.cuda` calls, `scaled_dot_product_attention`).
- While PyTorch supports other backends like ROCm for AMD GPUs, porting code can be cumbersome due to NVIDIA-specific code paths.
- The process of "de-NVIDIAfying" PyTorch code can introduce complexity and hinder portability.
## JAX on AMD was a breezy setup!
Setting up JAX on AMD GPUs is straightforward:
```bash
# Pull the Docker Image:
docker pull rocm/jax:latest
# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Verify the Installation:
python3 -c 'import jax; print(jax.devices())'
```
I had access to an AMD node with 8 x AMD MI300x GPUs. Each MI300x has 192GB of HBM3 memory. It fares very well against the latest NVIDIA H100 GPUs. (Comparison below, source: [TensorWave](https://tensorwave.com/))
![[Pasted image 20240922223903.png]]
## Training LLaMA 405B: Performance and Scalability
Using JAX, I managed to train the **LLaMA 405B** model on AMD GPUs with impressive results.
We did LoRA fine-tuning with all model weights and lora parameters in `bfloat16` precision, and with LoRA rank of 8 and LoRA alpha of 16:
- **Model Size**: The LLaMA model weights occupy around **800GB** of VRAM.
- **LoRA Weights + Optimizer State**: Approximately **400GB** of VRAM.
- **Total VRAM Usage**: 77% of the total VRAM, around **1200GB**.
- **Constraints**: Due to the large size of the 405B model, there was limited space for batch size and sequence length. The batch size I used was 16 and the sequence length was 64.
- **JIT Compilation**: Also, I couldn't run the JIT-compiled version due to space constraints; it likely requires slightly more space than the eager mode graph.
- **Training Speed**: ~35 tokens/second running in JAX eager mode (EDITED: 1 train step took 30s)
- **Memory Efficiency**: Consistently around 70%
- **Scaling**: With JAX, scaling was near-linear across 8 GPUs.
Note: We couldn't run the JIT-compiled version of the 405B model due to infra and VRAM constraints (we need to investigate this further). The entire training run was executed in JAX eager mode, so there is significant potential for performance improvements.
Below are the GPU, memory utilization and `rocm-smi` results across 8 GPUs during one train step of the fine-tuning run:
- **gpu utilization**:
- ![[Pasted image 20240923093326.png]]
- **VRAM utilization**:
- ![[Pasted image 20240923093333.png]]
`rocm-smi` output:
| Device | Temp | Power | Partitions | Fan | Perf | PwrCap | VRAM% | GPU% |
|--------|--------|-------|------------------|-----|------|--------|-------|------|
| 0 | 58.0°C | 232.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 27% |
| 1 | 58.0°C | 233.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 25% |
| 2 | 56.0°C | 236.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 24% |
| 3 | 52.0°C | 228.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 23% |
| 4 | 59.0°C | 232.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 22% |
| 5 | 51.0°C | 230.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 21% |
| 6 | 61.0°C | 235.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 18% |
| 7 | 56.0°C | 227.0W| NPS1, SPX, 0 | 0% | auto | 750.0W | 77% | 18% |
Edit: Complete GPU utilization, VRAM utilization and rocm-smi data can be found on our [github repo](https://github.com/felafax/felafax?tab=readme-ov-file#amd-405b-fine-tuning-run).
## Our Training Setup
### We translated LLaMA 3.1 from PyTorch to JAX
![[Pasted image 20240922222526.png]]
We translated the LLaMA 3.1 architecture from PyTorch to JAX. You can check out our implementation in [this GitHub repository](https://github.com/felafax/felafax). Additionally, we shared our motivations and experiences in [this Hacker News article](https://news.ycombinator.com/item?id=41512142).
This translation opened up new possibilities for performance and scalability for us.
## Loading the Model and Sharding Parameters
Handling a massive model like LLaMA 405B requires efficient parameter sharding across multiple devices. Here's how I achieved it using JAX.
### Sharding Parameters in JAX
To efficiently distribute the massive LLaMA 405B model across the 8 AMD GPUs, we utilize JAX's device mesh feature ([codepointer](https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69)). The device mesh organizes the available devices into a multidimensional grid, allowing us to specify how computations and data are partitioned. In our setup, we create a mesh with the shape (1, 8, 1), naming the axes as data parallelism (dp), fully sharded data parallelism (fsdp), and model parallelism (mp) respectively. We then apply specific sharding rules to the model parameters, specifying for each tensor in the model how its dimensions are sharded along these mesh axes
```
DEVICES = jax.devices()
DEVICE_COUNT = len(DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh((1, 8, 1))
MESH = Mesh(devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
```
#### Visualizing Sharding
You can visualize the sharding of arrays using `jax.debug.visualize_array_sharding`. This is incredibly helpful for verifying that your sharding specifications are being applied as intended.
#### Partitioning Rules
We defined partitioning rules for different components of the model [here](https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/llama_config.py#L44):
#### How Parameters Are Sharded
- **Normal Parameters**: Sharded across 8 GPUs.
- For example, the LM head (`lm_head/kernel`) tensor has two axes which are partitioned with `PS("fsdp", "mp")`, which is 8, 1 in our case, so you see that tensor is split across 8 GPUs along the first axis.
- ![[Pasted image 20240921172402.png]]
- **Non-Replicated Parameters**:
- Parameters without any sharding specification are replicated across all devices.
- For instance, layer norms (`attention_norm/kernel` and `ffn_norm/kernel`) use `PS(None)`.
- ![[Pasted image 20240921173131 1.png]]
#### Applying Sharding Constraints
As we load the model, we shard the model weights incrementally using custom sharding functions:
```python
def make_shard_and_gather_fns(partition_specs):
def make_shard_fn(partition_spec):
out_sharding = NamedSharding(mesh, partition_spec)
def shard_fn(tensor):
return jax.device_put(tensor, out_sharding).block_until_ready()
return shard_fn
shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
return shard_fns
# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns(partitioning_rules)
```
This allows us to place each parameter on the appropriate devices with the specified sharding.
#### Sharding the Training Batch
Initially, the training batch is created normally. Before feeding it into the model, we shard it across GPUs as shown in the code below:
```python
train_batch = jax.device_put(
train_batch, NamedSharding(self.mesh, PS("dp", "fsdp"))
)
```
Here, we're specifying that the training batch should be sharded across the data parallel (`"dp"`) and fully sharded data parallel (`"fsdp"`) axes, which would map to 1, 8 in our case, leading to below visualization:
- before sharding
- ![[Pasted image 20240921173337.png]]
- after calling `jax.device_put`
- ![[Pasted image 20240921173419.png]]
## Implementing LoRA Training
**LoRA** (Low-Rank Adaptation) reduces the number of trainable parameters by factorizing weight updates into low-rank matrices. This is particularly useful for fine-tuning large models.
Key aspects of our LoRA implementation:
1. Separate parameterization: We keep LoRA parameters (lora_a and lora_b) separate from the main model parameters.
2. Gradient stopping: We use jax.lax.stop_gradient(kernel) to prevent updates to the main model weights.
3. Efficient matrix multiplication: We use lax.dot_general for fast, precision-controlled matrix operations.
4. Scaling factor: The LoRA output is scaled by (self.lora_alpha / self.lora_rank) before being added to the main output.
### The `LoRADense` Layer
We implemented a custom `LoRADense` layer that incorporates LoRA parameters:
```python
class LoRADense(nn.Module):
features: int
lora_rank: int = 8
lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
kernel = self.param('kernel', ...)
y = lax.dot_general(inputs, jax.lax.stop_gradient(kernel), ...)
# LoRA parameters (trainable)
lora_a = self.variable('lora_params', 'lora_a', ..., ...)
lora_b = self.variable('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
lora_output = lax.dot_general(inputs, lora_a.value, ...)
lora_output = lax.dot_general(lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
y += (self.lora_alpha / self.lora_rank) * lora_output
return y.astype(self.dtype)
```
### Sharding LoRA Parameters
To efficiently distribute the **LoRA** parameters across devices, we applied specific sharding rules using JAX. This ensures that the LoRA parameters align with the sharding of the main model parameters, optimizing both memory usage and computational efficiency.
#### **LoRA A matrices (`lora_a`)**
- **Partition spec** we use: `PS("fsdp", "mp")`.
- **Visualization**:
- **Axes Sharding**: lora_a params across the layers will be sharded as (8, 1), meaning the first axis is sharded across 8 devices (`fsdp` axis), and the second axis is unsharded.
![[Pasted image 20240921172554.png]]
Figure shows that the first axis is sharded across 8 devices (`fsdp` axis), and the second axis is unsharded.
#### **LoRA B matrices (`lora_b`)**
- **Partition spec** we use: `PS("mp", "fsdp")`.
- **Visualization**:
- **Axes Sharding**: lora_b params across the layers will be sharded as (1, 8), where the second axis is sharded across 8 devices (`fsdp` axis), and the first axis is unsharded.
![[Pasted image 20240921172619.png]]
Figure shows that the second axis is divided among the devices in the `fsdp` axis, partitioning the columns of the matrix.
This sharding strategy optimizes the distribution of parameters, reduces communication overhead, and enhances parallelism during training. It ensures that each device holds only a portion of the LoRA parameters, enabling efficient scaling to large models like LLaMA 405B.
### Updating Only LoRA Parameters
To optimize training while fine-tuning the LLaMA 405B model, we compute gradients only for the LoRA parameters, keeping the main model parameters frozen. This approach reduces memory usage and speeds up training because we're updating a smaller number of parameters. You can check out the implementation details in our [GitHub repository](https://github.com/felafax/felafax).
In our training loop, each step involves passing a batch of input data through the model. Since only the LoRA parameters are trainable, the model's predictions and the computed loss depend only on these parameters. We then backpropagate gradients wrt the LoRA parameters. By focusing the updates on just these parameters, we streamline the training process, making it feasible to fine-tune an extremely large model like LLaMA 405B efficiently across multiple GPUs.
## Conclusion
Fine-tuning a massive model like **LLaMA 3.1 405B** on AMD GPUs using **JAX** has been a very postivie experience. By leveraging JAX's powerful parallelism capabilities and its hardware-agnostic approach, I was able to effectively distribute the model across 8 AMD MI300x GPUs. The use of **parameter sharding** allowed the enormous model parameters to be efficiently managed across devices, enabling near-linear scaling and high memory efficiency.
This journey underscores the viability of AMD GPUs as a strong alternative to NVIDIA hardware for large-scale AI training. The seamless integration of JAX with ROCm support simplifies the transition and opens up new possibilities for the AI community. By sharing this experience and the accompanying code, I hope to encourage others to explore and leverage these tools for their own large-scale machine learning projects.
Check out our [GitHub repository](https://github.com/felafax/felafax) for the full code and to run this yourself!
This project was made possible by generous support from [TensorWave](https://tensorwave.com/)!
![[Pasted image 20240922234922.png]]