> In this blog, I’ll walk you through the Fast Hadamard Transform, why it matters, and how to implement it with some neat code examples.
## Introduction
The **Fast Hadamard Transform**, also known as the [Fast Walsh-Hadamard Transform](https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform), is an efficient algorithm used to compute the Walsh-Hadamard Transform (WHT) in computational mathematics.
The [Hadamard Transform](https://en.wikipedia.org/wiki/Hadamard_transform) belongs to a generalized class of Fourier transforms. I came across this concept while reading a paper titled [SpinQuant: LLM Quantization with Learned Rotations](https://arxiv.org/abs/2405.16406). The paper introduces rotation quantization as a method for post-training quantization in large language models (LLMs). It claims that using a Hadamard matrix provides better precision compared to using a standard rotation matrix.
The Hadamard transform $H_m$ is a $2^m \times 2^m$ square matrix whose entries are either $+1$ or $-1$. Its rows are mutually orthogonal which means $H_m^T H_m=nI$. The orthogonality makes sure that hadamard transform can keep the computational invariance as long as no re-scaling happens in the RMSNorm according to [SliceGPT](https://arxiv.org/abs/2401.15024).
$\text{RMSNorm}(X) = \text{RMSNorm}(XQ^T)Q$
A straightforward Hadamard transform requires $O(n^2)$ as it performs a rotated matrix multiplication. However, due to the property of $\pm 1$, it can perform the fast version which reduces the computational complexity to $O(n\log n)$.
## Fast Hadamard Transform
Before diving into the Fast Hadamard Transform, let’s first take a moment to understand the conventional Hadamard Transform.
### Conventional Hadamard Transform
Before performing Hadamard transform, we need to construct the Hadamard matrix. There are many construction methods, such as [Paley construction](https://en.wikipedia.org/wiki/Paley_construction). And you can get many Hadamard matrix instances from [here](http://neilsloane.com/hadamard/).
Here is the conventional Hadamard transform python implementation.
```python
from scipy.linalg import hadamard
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(x, torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device))
out = out * scale
return out[..., :dim].reshape(*x_shape)
```
We use `scipy` to construct a Hadamard matrix with the desired dimensions and utilize PyTorch’s `F.linear` function to perform the matrix multiplication. To meet the requirements of the Fast Hadamard Transform, we pad the dimensions to the next power of 2.
### Fast Implmentation
For the fast implmentation, we can see the following examples.
![[fast_WHT.svg.png]]
There are a total of eight elements. In the first stage, each element performs either an addition or subtraction operation with a stride of 4. In the second stage, the stride is reduced to 2, and in the final stage, it becomes 1.
Here is a simple Python implementation of this fast algorithm. The stride, denoted by h, starts at 1 and increases by a power of 2 at each iteration. In each iteration, the step size i is `2*h`, which covers the elements `[i, i + 2*h]`. For each i, the inner loop, indexed by j, iterates from i to i + h, performing addition and subtraction on `a[j]` and `a[j + h]`, thus utilizing all `2*h` elements.
Let’s take stride = 2 as an example, as shown in the second column of the illustration. When i = 0, the loop for j iterates from 0 to 2, covering the values 0 and 1. When j = 0, let `x = a[0]` and `y = a[2]`. After applying the Hadamard transform, the results are: `a[0] = x + y` and `a[2] = x - y`.
```python
def fwht(a) -> None:
"""In-place Fast Walsh–Hadamard Transform of array a."""
h = 1
while h < len(a):
# perform FWHT
for i in range(0, len(a), h * 2):
for j in range(i, i + h):
x = a[j]
y = a[j + h]
a[j] = x + y
a[j + h] = x - y
# normalize and increment
a /= math.sqrt(2)
h *= 2
```
Here is a [CUDA version](https://github.com/Dao-AILab/fast-hadamard-transform) of the Fast Hadamard Transform for your reference. You may review it independently.
## Fast Hadamard on any dimension
The Fast Hadamard Transform (FHT) can only be applied to matrices whose dimensions are powers of two. However, how can we handle matrices of other dimensions? One approach is to decompose the matrix into two components: one conforming to a power-of-two dimension and the other as a standard Hadamard matrix.
For example, the Llama `intermediate_size` is 14,336, which is not a power of two, making it unsuitable for a direct Fast Hadamard Transform (FHT). However, it can be decomposed into two matrices: one of size 2^9 and the other of size 28. The FHT can then be applied to the 2^9 matrix, followed by a conventional Hadamard transform on the remaining 28.
Here is the code sample.
```python
def matmul_hadU_cuda(X: torch.Tensor, hadK: torch.Tensor, K: int):
n = X.shape[-1]
inp = X.view(-1, K, n // K)
inp = fast_hadamard_transform(inp, 1.0 / np.sqrt(n))
inp = hadK.to(inp.device).to(inp.dtype) @ inp
return inp.view(X.shape)
```
The `hadK` parameter represents a Hadamard matrix with dimensions `20 × 20`. The input tensor `X` is first reshaped into dimensions `[20, 2^9]`. A Fast Hadamard Transform (FHT) is then applied to the second dimension (2^9) of the reshaped tensor. Subsequently, the remaining dimension (20) undergoes a Hadamard transformation through matrix multiplication with `hadK`.
## Conclusion
This blog documents my learning journey with the Fast Hadamard Transform, a concept I encountered while studying an LLM quantization paper. It has become a key technique for mitigating channel-wise outliers in large language models and is widely used in the LLM community, including in implementations like FlashAttention3.
The CUDA implementation of the Fast Hadamard Transform is somewhat complex, so it is not discussed in this post. I may write another blog to share my insights on the CUDA code—stay tuned!
## References
- [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406)
- [SliceGPT: Compress Large Language Models by Deleting Rows and Columns](https://arxiv.org/abs/2401.15024)
- https://en.wikipedia.org/wiki/Paley_construction
- https://github.com/Dao-AILab/fast-hadamard-transform
- http://www.neilsloane.com/hadamard/index.html