> The increasing prevalence of Large Language Models (LLMs) has elevated the significance of techniques such as flash attention. Comprehension of this mechanism is crucial for those engaged in LLM development. > This blog will explicate the mathematical underpinnings of flash attention, thereby facilitating the implementation and understanding of its kernel. > The complex nature of the flash attention kernel necessitates a thorough grasp of its theoretical framework and mathematical derivation as a preliminary step towards effective coding. ## Softmax Revisit Before diving deeper into attention, it's important to understand the softmax function, a crucial part of this process. **Softmax**, already used in cross-entropy loss calculations before the introduction of the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper, serves to convert a set of numbers into a probability distribution. This distribution has all values between 0 and 1, and the sum of all values equals 1. The mathematical representation of softmax is: $\text{softmax}(\{x_1, \ldots, x_N\}) = \left\{ \frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}} \right\}_{i=1}^N$ While straightforward in principle, softmax can run into numerical issues when the input values ($x_i$) are very large. This can cause an overflow when calculating $e^{x_i}$. For example, when using the `float16` data type (with a maximum value of about 65,536), if x equals 12, then $e^x \approx 162,754$ exceeds this limit. Therefore, it's necessary to use improved methods to tackle these potential numerical instability problems. ## Safe Softmax To address the numerical instability issues with softmax, a technique called "safe softmax" was developed. The core idea is quite simple: by subtracting a large number (m) from each input value ($x_i$), we effectively divide each term in both the numerator and denominator by $e^m$. This results in smaller numbers in $x_i$ that are less likely to cause overflow issues. The question then becomes, what should this "large number" be? A common approach is to use the maximum value among all input values, which we can denote as $m = \max_{j=1}^N(x_j)$. This leads to the "safe" softmax formula: $\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}} = \frac{e^{x_i - m}}{\sum_{j=1}^N e^{x_j - m}}$ By subtracting the maximum value, we ensure that at least one of the values in the exponent becomes zero, and all others become negative. This keeps the exponential results within a reasonable range, thus improving numerical stability. Let's analyze the computational steps involved in the "safe" softmax. Here's a breakdown of the process we discussed: 1. **Finding the Maximum:** We loop through all the elements to find the maximum value, which we'll call *m*. This process can be represented as: $m_i \leftarrow \max(m_{i-1}, x_i)$ 2. **Calculating the Denominator:** Then, we iterate through the elements again to compute the denominator, which is the sum of exponentiated values, each adjusted by the maximum: $d_i \leftarrow d_{i-1} + e^{x_i - m_N}$ 3. **Computing the Final Result:** Finally, we compute each element of the softmax output: $a_i \leftarrow \frac{e^{x_i-m_N}}{d_N}$ ![[safe_softmax.png]] As you pointed out, this process requires us to loop through all the elements three separate times. The question is: can we reduce the number of loops required, thereby enhancing the efficiency of the "safe" softmax computation? ## Online Softmax The iterative computation of $d_i = d_{i-1} + e^{x_i - m_N}$, exhibits a dependency on $m_N$. The question arises as to whether this dependency can be eliminated. A methodology for addressing this challenge is presented in the publication [Online normalizer calculation for softmax](https://arxiv.org/abs/1805.02867) wherein a **surrogate** approach is introduced. The summation can be expressed as $d_i = \sum_{j=1}^i e^{x_j - m_N}$. Through the substitution of $m_N$ with $m_i$, we obtain a modified quantity, $d'_i$, effectively severing the dependency on $N$. While $d_i$ and $d'_i$ are disparate, except in the terminal case where $i=N$, the nature of the problem permits the substitution of $d'_i$ for $d_i$ in intermediate calculations, given that $d_N$ represents the ultimate result. By reformulating the recurrence relation in terms of $d'_i$ and $d'_{i-1}$, the following derivation emerges: $\begin{aligned} d'_i &= \sum_{j=1}^i e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d'_{i-1} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned}$ ![[online_softmax.png|500]] This transformation allows for the amalgamation of the initial two loops into a singular loop, yielding both $m_N$ and $d_N$ concurrently. ## Attention Revisit Before delving into the specifics of flash attention, let's review the basics of standard self-attention, a foundational concept. In self-attention, we utilize three components: Query (Q), Key (K), and Value (V). The computation unfolds as follows: 1. **Correlation Calculation:** We begin by computing the dot product of Q and the transpose of K (denoted as K<sup>T</sup>). This operation yields a matrix, X, which reflects the correlation between each query and the various keys. 2. **Normalization:** Next, a softmax function is applied to X. This process normalizes the correlation scores, transforming them into weights that sum to one. These weights, represented by A, indicate the relative importance of each value. 3. **Weighted Sum:** Finally, the normalized weights (A) are applied to the Value matrix (V). The result, O, is a weighted sum of the values, representing the output of the self-attention mechanism. Here is the complete mathematical formulation: $\begin{aligned} X &= QK^T \\ A &= \text{softmax}(X) \\ O &= AV \end{aligned}$ These steps capture the essence of self-attention, despite omitting certain minor details that do not detract from the core logic. In the context of online softmax within attention mechanisms, the computation typically involves two iterative processes. The first loop calculates intermediate values. For a given query vector $Q[k,:]$, it computes $x_i = Q[k, :]K^T[:, i]$, where $K$ is the key matrix. Subsequently, the online softmax algorithm yields $m_i$ and $d'_i$ iteratively. Upon completion of this loop, the values $d'_N$ and $m_N$ are obtained. The second loop determines the attention weights and the output vector. The attention weight $a_i$ is computed as $\frac{e^{x_i - m_N}}{d'_N}$, followed by the computation of the output vector $o_i = o_{i-1} + a_i V[i,:]$, where $V$ is the value matrix. The final attention output, $o_N$, is the result of this loop. The central question thus becomes: is it feasible to reduce the number of iterative loops required, thereby enhancing computational efficiency? This question motivates the exploration of potential techniques for streamlining the calculation process. That's what flash-attention do. ![[self_attn.png]] ## Flash-Attention The calculation of the attention output, $\mathbf{o}_i$, can be reformulated as: $\mathbf{o}_i = \sum_{j=1}^i \left( \frac{e^{x_j - m_N}}{d'_N} \mathbf{V}[j, :] \right)$ Similar to the strategy employed in online softmax, a **surrogate** technique can be applied to eliminate the dependency on $N$. By substituting $d'_N$ with $d'_i$ and $m_N$ with $m_i$, the formulation becomes: $\mathbf{o}'_i = \left( \sum_{j=1}^i \frac{e^{x_j - m_i}}{d'_i} \mathbf{V}[j, :] \right)$ Crucially, the final output, $\mathbf{o}'_N$, remains equivalent to $\mathbf{o}_N$, thus allowing us to disregard intermediate computations of $\mathbf{o}_i$. Analogous to the online softmax derivation, this formulation can be expressed in a recurrence relation involving $\mathbf{o}'_i$ and $\mathbf{o}'_{i-1}$: $\begin{aligned} \mathbf{o}'_i &= \sum_{j=1}^i \frac{e^{x_j - m_i}}{d'_i} \mathbf{V}[j, :] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d'_i} \mathbf{V}[j, :] \right) + \frac{e^{x_i - m_i}}{d'_i} \mathbf{V}[i, :] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d'_{i-1}} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d'_{i-1}}{d'_i} \mathbf{V}[j, :] \right) + \frac{e^{x_i - m_i}}{d'_i} \mathbf{V}[i, :] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d'_{i-1}} \mathbf{V}[j, :] \right) \frac{d'_{i-1}}{d'_i} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d'_i} \mathbf{V}[i, :] \\ &= \mathbf{o}'_{i-1} \frac{d'_{i-1}}{d'_i} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d'_i} \mathbf{V}[i, :] \end{aligned}$ This transformation enables the computation of $\mathbf{o}'_i$ within a single loop, yielding the final desired output, $\mathbf{o}'_N$, at the end of the iteration. This constitutes the core principle behind the flash-attention algorithm. ![[flash_attn.png|500]] The accompanying diagrams visually represent the essence of flash-attention, as elucidated in the lecture notes, "[From Online Softmax to FlashAttention](https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf)" by Zihao Ye. ![[flash_attn_illustrate.png|400]] ## Flash-Decoding Flash-Attention demonstrates high efficiency in most scenarios, with the notable exception of long-context decoding. While it effectively parallelizes across query blocks and batch sizes, it fails to fully utilize GPU streaming multiprocessors during the decoding process. Below is a simplified illustration of Flash-Attention: ![[Queries.gif]] A key limitation becomes apparent: keys and values cannot be parallelized. Flash-Decoding addresses this constraint by implementing parallelization across keys and values, albeit with the trade-off of requiring an additional final reduction step. The following illustration demonstrates Flash-Decoding, where keys/values are partitioned into five chunks, enabling parallel memory-efficient attention computation: ![[Flash Decoding Inference.gif]] This optimization comes with certain trade-offs. Flash-Decoding introduces minimal overhead through its final reduction step and requires an additional log-sum-exp computation for each chunk's attention values. The use of log-sum-exp, rather than sum-exp, is motivated by numerical stability considerations, as exponential operations can lead to overflow or underflow issues with extreme values. To understand the final reduction step, let us first examine the $o'_i$ computation: $\mathbf{o'}_i = \sum_{j=1}^i \left( \frac{e^{x_j - m_i}}{d'_i} \mathbf{V}[j, :] \right)$ Upon applying Flash-Attention to each partition, we obtain $o'_i$, $d'_i$, and $m_i$ for each split. For the final result, the denominator $d'_N$ is computed as $d'_N = d'_1 + d'_2 + d'_3 + \dots$, where each $d'_i$ represents the i-th chunk. To maintain numerical stability, we employ log-sum-exp ($\log d'_i$) when collecting $d'_i$ values across chunks. By multiplying both sides of $\mathbf{o'}_i$ by $d'_i$, we eliminate the per-chunk $d'_i$ dependency: $\mathbf{o'}_i d'_i = \sum_{j=1}^i \left(e^{x_j - m_i} \mathbf{V}[j, :] \right)$ Similarly to $d'_N$, $m_i$ requires processing to obtain $m_N = \max(m_1, m_2, m_3, \dots)$. Multiplying both sides by $e^{m_i}$ yields: $\mathbf{o'}_i d'_i e^{m_i} = \sum_{j=1}^i \left(e^{x_j} \mathbf{V}[j, :] \right)$ The final reduction step synthesizes these components. Given $\sum_{j=1}^i \left(e^{x_j} \mathbf{V}[j, :] \right)$, $m_N = \max(m_1, m_2, m_3, \dots)$, and $d'_N = \sum_{i} d'_i$ for each chunk, the final output is computed as: $\mathbf{o'}_N = \sum_i^c \mathbf{o'}_i d'_i \frac{e^{m_i-m_N}}{d'_N}$ where $c$ denotes the number of chunks. This final reduction across all chunks yields the complete Flash-Decoding result. ## Summary This article presents a comprehensive mathematical derivation of flash-attention, beginning with the foundational softmax operation and progressing through to flash-attention and flash-decoding implementations. The mathematical principles established here serve as essential groundwork for understanding the more complex triton and CUDA kernel implementations, which will be explored in a subsequent article. I am particularly indebted to Zihao Ye's [excellent lecture notes](https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf), which provide an exceptionally clear and intuitive explanation of flash-attention's complexity. This article draws heavily from that invaluable resource. ## Reference - https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf - https://pytorch.org/blog/flash-decoding/