Recently, the min-P paper was [presented](https://iclr.cc/virtual/2025/oral/31888) at ICLR. It's a sampling method where you ignore the logits of the tokens that aren't within a certain percent of the most likely token. The intuitions and explanations around this can be found in [this reddit post](https://www.reddit.com/r/LocalLLaMA/comments/17vonjo/your_settings_are_probably_hurting_your_model_why/).
Typically it's done by applying the SoftMax, finding the most likely probability, and then recording all the other tokens that pass the threshold. You then set every other logit to $-\infty$ and pass it through the softmax again, illuminating the tokens by how likely they are without the others.
This is just a quick post showing how you can derive a faster algorithm for this with basic math.
We start with our min-p definition, where $l$ are our original logits and $q_i$ is our sampling distribution.
$
\begin{align*}
p_i &= \dfrac{e^{l_i}}{\sum_j e^{l_j}}\\
p_\text{max} &= \max_i p_i \\
m_i &= \begin{Bmatrix}
l_i & \text{if } p_i \geq \delta \times p_\text{max}\\
-\infty & \text{else}
\end{Bmatrix} \\
q_i &= \dfrac{e^{m_i}}{\sum_j e^{m_j}}
\end{align*}
$
We can rewrite the piecewise in the middle pretty easily into the ratio between the probabilities, which caries an important property:
$
\begin{align*}
m_i &= \begin{Bmatrix}
l_i & \text{if } \dfrac{p_i}{p_\text{max}} \geq \delta\\
-\infty & \text{else}
\end{Bmatrix} \\
\end{align*}
$
Calculating the *ratio* between probabilities doesn't require the normalization constant (the summation in the softmax). It's easy to calculate the relative likelihood of two elements in a distribution. But it's hard to know the absolute likelihood of either independently. We can take advantage of this in the following way:
$
\begin{align*}
\dfrac{p_i}{p_\text{max}} &= \dfrac{\dfrac{e^{l_i}}{\cancel{\sum_j e^{l_j}}}}{\dfrac{e^{l_\text{max}}}{\cancel{\sum_j e^{l_j}}}} \\
&= \dfrac{e^{l_i}}{e^{l_\text{max}}} \\
&= e^{l_i-l_\text{max}}
\end{align*}
$
As we can see, we can avoid taking the softmax over the full distribution, we can still do one more optimization by taking the log of both sides in our piece-wise:
$
\begin{align*}
\text{condition} &= \dfrac{p_i}{p_\text{max}} \geq \delta\\
&= \log(e^{l_i-l_\text{max}}) \geq \log(\delta) \\
&= l_i-l_\text{max} \geq \log(\delta) \\
&= l_i \geq l_\text{max} + \log(\delta) \\
\end{align*}
$
Again, the threshold can be computed without having to do the full SoftMax as long as you know the logit of the largest element. Our final full expression then becomes:
$
\begin{align*}
m_i &= \begin{Bmatrix}
l_i & \text{if } l_i \geq l_\text{max} + \log(\delta)\\
-\infty & \text{else}
\end{Bmatrix} \\
q_i &= \dfrac{e^{m_i}}{\sum_j e^{m_j}}
\end{align*}
$
**In JAX**:
```python
def min_p(logits: jax.Array, p=.9):
threshold = logits.max(axis=-1) + jnp.log(p) # use same log as softmax!
logits = jnp.where(logits >= threshold[..., None], logits, -jnp.inf)
# optionally apply softmax, or other filters
# return jax.nn.softmax(logits, axis=-1)
return logits
```
vLLM still uses the ineffecient method, if you want a free OS contribution you can go and fix it. The code is [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/sampler.py#L387-L413).
Another cool benefit of this is that you can now use the logprobs from common inference endpoints to approximate min_p (if they also support prefilling). You don't have access to the original probabilities or the full logits, but usually it'll be close enough.