In this post, the gradient of the attention op will be derived from a single rule used to implement reverse mode automatic differentiation.

Attention mechanism is the foundational building block of the transformer architecture that is the foundation of Today’s most successful language models. It was shown that it can replace recurrent blocks in neural networks and greatly increases parallelism thus enables the training on datasets of unprecedented sizes resulting in large language models (LLMs).

We start with the definition of the attention block, then the basics of automatic differentiation (AD) is explained, then the gradient of the attention block is derived from a basic AD rule. We do not need any special knowledge to follow the steps, only elementary math techniques such as computing the derivatives of multivariate functions and the very basics of linear algebra (for matrix computations).

Basics

In this section, we define the attention mechanism and the basics of reverse mode automatic differentiation.

Attention description

For simplicity, assume that we have two neural network layers and we would like to introduce an attention block between them. The first layer $L_a$ has output shape $N \times d_v$ excluding the batch dimension ($N$ $d_v$-dimensional features such as $d_v$ dimensional embedding of $N$ tokens), and $L_b$ of $M \times d_k$. We have $L_b(a(L_a))$, where $a$ implements the attention ($N=M$ when processing tokens).

The self attention, when applied between two layers, is a special form of attention that weights each element of $y$ (output of $L_a$) using a learned lookup table to feed them into $L_b$ after re-weighting using the attention mechanism $a$. The input of $L_b$ is now called as query (as it will be submitted into the attention mechanism to query the output of $L_a$ to re-weight its features) i.e. each input of $L_b$ will be the weighted sum of all outputs of $L_a$ conditioned on itself. This mechanism enables each particular token in layer $L_b$ to attend to particular tokens in layer $L_a$ by assigning larger weights. As the whole stuff operates on the same $y$, this form of attention is called self attention, however, there exist other types of attention too. For simplicity, we set $d_k=d_v=d$ and consider $N$ tokens. For a detailed explanation of self attention in transformers, consult the attention source code in the nanoGPT (https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L29).

By summarizing this, the elements of the attention block are the $\mathbf{Q}$ $\in \mathbb{R}^{N \times d}$ (query), $\mathbf{K} \in \mathbb{R}^{N \times d}$ (key) and $\mathbf{V} \in \mathbb{R}^{N \times d} $ (value):

$\mathbf{O}=a(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\text{softmax}(\mathbf{QK}^T)\mathbf{V}$.

The steps to compute the attention output therefore are the following:

  • Compute $\mathbf{R=QK}^T$, $\mathbf{R} \in \mathbb{R}^{N \times N}$,
  • Compute $\mathbf{P}=\text{softmax}(\mathbf{R})$ (row-wise),
  • Compute $\mathbf{O=PV}$, $\mathbf{O} \in \mathbb{R}^{N \times d}$.

A few components left out because their derivation is straightforward, such as: attention mask (as we should not allow the current token to attend to next tokens in the window), the dropout, and the weighting of the softmax output by $d_k$ (for Scaled Dot-product Attention – Vaswani et al. 2017).

Automatic differentiation

To compute the gradient of the attention output $\mathbf{O}$ with respect to the inputs $\mathbf{Q}, \mathbf{K}, \mathbf{V}$, we use automatic differentiation. The AD algorithm defines the way how we apply the chain rule of differentiation to compute the gradient of a function. The AD algorithm operates on (directed) computational graphs. The nodes in the graph are the operations we perform, while the operands flow in through the input edges, while the computation result flows out from the node through the leaving edges. There are two types of AD, the forward and reverse mode AD (and a mix of both), both have their own benefits and drawbacks. The reverse mode computes the chain rule from outside to inside, while the forward mode computes the chain rule from inside to outside. By considering the computational graph, the forward mode AD considers the predecessors (inputs) of an operation, and traverses the graph from the input, while the reverse mode AD computes the derivative for each op using the successors in the computational graph and starts at the output node and progresses towards the inputs of the network. Most neural network libraries use reverse mode because it is more memory efficient when the number of inputs is larger than the number of outputs, that is exactly the case with neural networks. However, to compute a derivative of an operation, we must know the computed value of that output. Thus, we need to compute the function called forward computation, save the result of each node in the computational graph, then we can start the calculation of the derivatives, called backward computation. This requires significantly more memory as we should remember each operation output we computed in the forward pass. To summarize, we compute the following values called adjoints:

$\label{A} \bar{w}=\sum_{{w_*}\in s(w)}\bar{w_*}\frac{w_*}{w}$,

where $s(w)$ are the successor nodes of $w$ in the computational graph. In simple terms, to get the derivative of the function with respect to $w$, we consider each successor $w_*$ of $w$, and for each such $w_*$ we multiply the already computed adjoint $\bar{w_*}$ and also use the already computed forward values (the predecessors of $w_*$ in the computational graph where $w$ is one among them) to determine the derivative of $w_*$ with respect to $w$.

A reverse mode autodiff implementation thus first considers the topological order of the nodes based on the $s(\cdot)$ relation, starts with the first node (that has no successor) and computes the relation above to have the derivatives. As the gradients “flow” from the output to the input in the backward phase, I will use the term incoming gradients to refer to the gradient of the function output with respect to the output node. The adjoint of the output is the seed value that is assumed to be $1$.

Usually, we work with many variables organized to vectors. If we have such concept, the above rule computes the vector-Jacobian product (VJP), where the vector is the incoming gradient for the operation and the Jacobian consists of the derivatives with respect to each output and input variable. In practice, each supported operation is implemented as a component that computes the VJP for the particular op. Note that, it often does not make sense to materialize the whole Jacobian to compute the VJP. (In the other hand, in forward mode autodiff, we compute the JVP. Since the Jacobian has as many rows as the numer of outputs and as many columns as many variables [inputs], we do more work if the number of outputs is larger than the number of inputs therefore computing a VJP [reverse mode autodiff] is computationally cheaper for cases when we have more inputs than outputs).

Computing the gradients

There are two major different types of ops in the attention: matmul and softmax. We start with the gradient of the matmul, then derive the gradient of softmax and merge the results.

Computing $\mathbf{\bar{V}}$ and $\mathbf{\bar{P}}$

In reverese mode autodiff, we progress from the output to the inputs, therefore let us start with $\frac{d\mathbf{O}}{d\mathbf{V}}$ because it is the easiest to compute, since it is the result of a simple matmul op ($\mathbf{O}=\mathbf{PV} \in \mathbb{R}^{N \times d}$, $\mathbf{P} \in \mathbb{R}^{N \times N}$, $\mathbf{V} \in \mathbb{R}^{N \times d}$. In autodiff terms, would like to compute $\bar{v}_{ij}$. We know that the successors of $v_{ij}$ are the values $o_{kj}$, where $k=1..N$. Therefore,

$\bar{v}_{ij}= \sum_{k=1}^N \bar{o}_{kj} \frac{\partial o_{kj}}{\partial v_{ij}} = $

$ \bar{o}_{1j} \frac{\partial o_{1j}}{\partial v_{ij}} + \bar{o}_{2j} \frac{\partial o_{2j}}{\partial v_{ij}} + \dots + \bar{o}_{Mj} \frac{\partial o_{Mj}}{\partial v_{ij}} $.

The dot product that should be differentiated in each term are:

$o_{kj}=\sum_l^{d} p_{kl}v_{lj} = $

$ p_{k1}v_{1j} + p_{k2}v_{2j} + \dots + p_{kd}v_{dj}$,

and differentiating them by $v_{ij}$, each term will be zero, except when $l=i$, therefore

$\frac{\partial o_{kj}}{\partial v_{ij}} = p_{ki}$.

By substituting back we get:

$\bar{v}_{ij}= \sum_{k=1}^N \bar{o}_{kj} \frac{\partial o_{kj}}{\partial v_{ij}} = \sum_{k=1}^N\bar{o}_{kj}p_{ki} $.

Similarly, the gradient with respect to $\mathbf{P}$ is:

$\bar{p}_{ij}= \sum_{k=1}^{d} \bar{o}_{ik} \frac{\partial o_{ik}}{\partial p_{ij}} = $

$ \bar{o}_{i1} \frac{\partial o_{i1}}{\partial p_{i1}} + \bar{o}_{i2} \frac{\partial o_{i2}}{\partial p_{ij}} + \dots + \bar{o}_{id} \frac{\partial o_{id}}{\partial p_{ij}} $,

and because the dot product is:

$o_{ik} = \sum_{l=1}^N p_{il}v_{lk} $,

therefore (because only the term $l=j$ is nonzero):

$\frac{\partial o_{ik}}{\partial p_{ij}} = v_{jk} $,

thus:

$\bar{p}_{ij}=\sum_{k=1}^{d} \bar{o}_{ik} v_{jk} $.

Gradient computation (VJP) for matmul in reverse mode autodiff

As a general rule, if we have the expression $\mathbf{C}=\mathbf{A}\mathbf{B}$ ($\mathbf{A} \in \mathbb{R}^{M \times K}, \mathbf{B} \in \mathbb{R}^{K \times N}$), then the VJP becomes:

$\bar{b}_{ij} = \sum_{k=1}^M \bar{c}_{kj} a_{ki} $,

while

$\bar{a}_{ij} = \sum_{k=1}^N \bar{c}_{ik} b_{jk} $.

In matrix form:

$\mathbf{\bar{B}} = \mathbf{\bar{C}}^T \mathbf{A} = \mathbf{A}^T \mathbf{\bar{C}} $,

and

$\mathbf{\bar{A}} = \mathbf{\bar{C}} \mathbf{B}^T $.

Differentiating through the softmax: computing $\mathbf{\bar{R}} $

We would like to determine the gradient of the softmax given the input. Recall, that we defined $\mathbf{P}=\text{softmax}(\mathbf{R})$ $\in \mathbb{R}^{N \times N}$ previously. We would like to determine $\frac{d\mathbf{P}}{d\mathbf{R}}$. Doing so, we need to determine $\bar{p}_{ij}$. The definition of softmax for $N$ input variables is:

$\text{softmax}((x_1, \dots, x_N) = $

$\Big(\frac{e^{x_1}}{\sum_k e^{x_k}}, \dots, \frac{e^{x_N}}{\sum_k e^{x_k}}\Big) $.

Let $\mathbf{J}_{\text{softmax}}$ be the Jacobian of the softmax operation defined as:

$\mathbf{J}_{\text{softmax}}(x_1, \dots, x_N) = $

$\Big(\nabla \text{softmax}_1(x_1, \dots, x_N) \ldots $

$\nabla \text{softmax}_n(x_1, \dots, x_N)\Big) $

where $\text{softmax}_i$ is the $i$-th element of the result of the softmax operation.

To compute the Jacobian, we need to determine the $\frac{\partial \text{softmax}_i(x_1, \dots, x_N)}{\partial x_j}$ values. There are two cases depending on whether the element is in the diagonal or not:

$\frac{\partial \text{softmax}_i(x_1, \dots, x_N)}{\partial x_i} = \frac{x_i \sum_k x_k – (x_i)^2}{(\sum_k x_k)^2} = $

$\frac{x_i}{\sum_k x_k} – \frac{x_i^2}{(\sum_k x_k)^2} = $

$\text{softmax}_i(x_1, \dots x_n) – \text{softmax}_i(x_1, \dots x_n)^2 $

for the diagonal elements ($i=j$), and

$\frac{\partial \text{softmax}_i(x_1, \dots, x_N)}{\partial x_{j \neq i}} = \frac{- x_ix_j}{(\sum_k x_k)^2} =$

$ \text{softmax}_i(x_1, \dots x_n) \text{softmax}_j(x_1, \dots x_n) $

for the rest.

Merging the two conditions, the Jaobian of $\text{softmax}(\vec{x})$, where $\vec{x}= x_1, \dots x_n$ is:

$\mathbf{J}_\text{softmax}(\vec{x}) = $

$diag(\text{softmax}(\vec{x})) – \text{softmax}(\vec{x}) \text{softmax}(\vec{x})^T $

This is the “work out” the FlashAttentioin-2 paper (Dao, 2023) mentions: “One can work out that if $p = \text{softmax}(s)$ for some vector $s$ and $p$, then with output gradient $dp$, the input gradient $ds = (diag(p)-pp^T)dp$.” (However this does not only computes the Jacobian, but also computes a row of adjoints (the VJP) that we derive below.)

Plugging in into the autodiff algorithm – determining the VJP of softmax

Since we compute the softmax row-wise, the direct successors of $r_{ij}$ are the $p_{ik}$ values (the whole row of the variable is part of), where $k=1 \dots N$ in the computation graph and $p_{ik} = \text{softmax}_k(r_{i1}, \dots, r_{iN})$.

Therefore, the adjoint $\bar{r}_{ij}$ is computed as:

$\bar{r}_{ij} = \sum_k^N \bar{p}_{ik} \frac{\partial p_{ik}}{\partial r_{ij}} = $

$ \sum_k^N \bar{p}_{ik} \frac{\partial \text{softmax}_k(r_{i1}, \dots, r_{iN})} {\partial r_{ij}} = $

$\sum_k^N \bar{p}_{ik} \mathbf{J}_\text{softmax}(r_{i1}, \dots, r_{iN})_{kj} =$

$ {\mathbf{\bar{P}}_i \cdot \mathbf{J}_\text{softmax}(\mathbf{R}_i)^T}_{j} =$

$ \mathbf{J}_\text{softmax}(\mathbf{R}_i)_{j} \cdot \mathbf{\bar{P}}_i $

Note that $\mathbf{J}_\text{softmax}$ is symmetric as it was defined by an outher product of two vectors.

To compute the whole $i$-th row of $\mathbf{\bar{R}} \in \mathbb{R}^{N \times N}$ ($\mathbf{\bar{R}}_i \in \mathbb{R}^N)$, we simply multiply the Jacobian with the incoming gradient:

$\mathbf{\bar{R}}_i = \mathbf{J}_\text{softmax}(\mathbf{R}_i) \mathbf{\bar{P}}_i $

where the computed value $\mathbf{\bar{R}}$ is the gradient of $\mathbf{O}$ with respect to $\mathbf{R}$ ($\frac{\mathbf{dO}}{\mathbf{dR}}$) according to the chain rule, while $\mathbf{\bar{P}} \in \mathbb{R}^{N \times N}$ is the incoming gradient, that is $\frac{\mathbf{dO}}{\mathbf{dP}}$ in our case (already computed) and $\mathbf{\bar{P}}_i$ is the $i$-th row of $\mathbf{\bar{P}} \in \mathbb{R}^N$ (a column vector after extracted from the matrix).

Since $\mathbf{J}_\text{softmax}(\mathbf{R}_i) = diag(\mathbf{P}_i) – \mathbf{P}_i\mathbf{P}_i^T$ according to the definition of $\mathbf{J}_\text{softmax}$, we have:

$\mathbf{\bar{R}}_i = (diag(\mathbf{P}_i) – \mathbf{P}_i{\mathbf{P}_i}^T)\mathbf{\bar{P}}_i $

So far we only computed the gradient for a single row but use matrix computations, that would require repetitive matrix operations to compute each row of the output. Therefore, we need to drop the diagonal matrix and use element-wise product (Hadamard product) (notation: $\circ$) instead and use inner product instead of outer product:

$\mathbf{\bar{R}}_i = \mathbf{P}_i \circ \mathbf{\bar{P}}_i – ({\mathbf{P}_i}^T \mathbf{\bar{P}}_i)\mathbf{P}_i $.

Therefore the computation for one row only requires the element-wise product of two vectors, and a scaling of a vector with a factor resulted by computing dot product of two vectors, and we do not need matrix operations to compute the gradient for a single row.

The whole gradient of $\mathbf{\bar{R}}$ thus becomes:

$\mathbf{\bar{R}} = \mathbf{P} \circ \mathbf{\bar{P}} – rowsum({\mathbf{P}} \circ \mathbf{\bar{P}}) \circ \mathbf{P} $.

Computing $\mathbf{\bar{Q}}$ and $\mathbf{\bar{K}}$

As $\mathbf{R=QK^T}$, and we determined the VJP for $\mathbf{\bar{R}}$, we can use the same rule we used for determining $\mathbf{\bar{P}}$ and $\mathbf{\bar{V}}$:

$\mathbf{\bar{Q}}=\mathbf{\bar{R}} ({\mathbf{K}^T})^T = \mathbf{\bar{R}K}$,

and

$\mathbf{\bar{K}}^T=\mathbf{\bar{Q}}^T\mathbf{\bar{R}}$ thus $\mathbf{\bar{K}}=\mathbf{\bar{R}}^T\mathbf{Q}$.

Putting it all together

The computed values can be simply chained together to compute the gradients with respect to each input variable. See the function naive_attention_backward starting at line 35 in the attached code below.

Summary

We computed the gradient of the attention op using reverse mode automatic differentiation, that resulted in the derivation of the gradient of the matmul and softmax op. In an upcoming post, I show how it can be computed efficiently on GPU by computing the op memory efficiently using the techniques proposed in the FlashAttention paper (Dao et al., 2022, Dao, 2023).

NumPy + PyTorch implementation

We implement the forward and backward step using NumPy (with elementary ops) to test whether the derivation of the VJP is correct. We compare the outputs with the PyTorch implementation. Note that this code does not scale the exp() parameters, therefore it should not be used in production! In the code the result of the softmax op is marked with S following the convention in (Dao et al., 2022, Dao, 2023).

$ cat attention.py
import torch
import numpy as np
def verify_results(A, B, TOL):
    diff = A-B
    if np.mean(diff**2) > TOL:
        print('--> ERROR! Difference: ', np.mean(diff**2))
        print(diff)
    else:
        print('Result matches!')
def row_softmax(A):
    B = np.zeros_like(A)
    for row_id in range(len(A)):
        row = A[row_id]
        row_norm = row-np.max(row)
        row_norm_exp = np.exp(row_norm)
        B[row_id] = row_norm_exp / np.sum(row_norm_exp)
    return B
def naive_attention(Q, K, V):
    S = Q @ np.transpose(K, (1, 0))
    P = row_softmax(S)
    O = P @ V
    return O, S, P
def t_naive_attention(Q, K, V):
    S = Q @ torch.transpose(K, 0, 1)
    P = torch.softmax(S, 1)
    O_naive = P @ V
    return O_naive
def naive_attention_backward(Q, K, V, P, dO):
    # O=PV
    dV = (np.transpose(P, (1, 0)) @ dO)
    dP = dO @ np.transpose(V, (1, 0))
    
    # P = softmax(S)
    dS = P * dP - (np.sum(P * dP, axis=1, keepdims=True)*P)
    
    # S = QK'
    dQ = dS @ K
    dK = np.transpose(dS, (1, 0)) @ Q
    return dQ, dK, dV
def main():
    TOL = 1e-10
    N = 4
    d = 8
    # Note: we should not generate large numbers as the naive 
    # implementation is not numerically stable!
    Q = np.random.rand(N, d)
    K = np.random.rand(N, d)
    V = np.random.rand(N, d)
    print('Forward computation:')
    O_naive, S_naive, P_naive = naive_attention(Q, K, V)
    t_Q = torch.from_numpy(Q)
    t_Q.requires_grad = True
    t_K = torch.from_numpy(K)
    t_K.requires_grad = True
    t_V = torch.from_numpy(V)
    t_V.requires_grad = True
    t_O_naive = t_naive_attention(t_Q, t_K, t_V)
    
    verify_results(t_O_naive.detach().numpy(), O_naive, TOL)
    print('Backward computation:')
    dO = np.ones_like(O_naive)
    dQ_naive, dK_naive, dV_naive = naive_attention_backward(Q, K, V, P_naive, dO)
    grad = torch.autograd.grad(t_O_naive, [t_Q, t_K, t_V], grad_outputs=torch.ones_like(t_O_naive), retain_graph=True)
    t_dQ = grad[0]
    t_dK = grad[1]
    t_dV = grad[2]
    print('Verifying dQ')
    verify_results(t_dQ.detach().numpy(), dQ_naive, TOL)
    print('Verifying dK')
    verify_results(t_dK.detach().numpy(), dK_naive, TOL)
    print('Verifying dV')
    verify_results(t_dV.detach().numpy(), dV_naive, TOL)    
if __name__ == '__main__':
    main()

Execution:

$ python3 attention.py
Forward computation:
Result matches!
Backward computation:
Verifying dQ
Result matches!
Verifying dK
Result matches!
Verifying dV
Result matches!

Leave a Reply

Your email address will not be published. Required fields are marked *