[[2024-10-11-Linear_Attention]] [[2024-10-10-Attention_Math]]
Source
Code base for this article. https://github.com/lucidrains/performer-pytorch Code base for the Attention Kernel Code article: https://teddykoker.com/2020/11/performers/ the same thing but simpler for illustration.
Performance = FAVOR+
- Fast-Attention (FA), V is Via
- Positive Random Feature (+/PRF),
- Orthogonal Random features (ORF).
Fast Attention Code
![[Pasted image 20241111233206.png]]
下圖 $L_L$是 typo, 應該是$I_L$. ![[Pasted image 20241111233229.png]]
FastAttention Class
A class implementing fast attention with customizable kernels, causal settings, and optional projections.
ProjectionUpdater Class (New to Fast Attention)
白話文就是上式的 $\phi(\cdot)$ 因為是 random feature, 需要 periodically update 得到更好的 performance (Question: 可以用於 adaption during testing 類似人體骨頭的 piezoelectric effect to optimize the bone structure?)
Tracks when to update projection matrices in FastAttention to improve training efficiency. It only redraws projections at a specified interval.
這個 class (ProjectionUpdater) and function (redraw_projection_matrix) 只有在 training 使用,inferencing 不會 update.
Transformer Layers
- ReZero, PreScaleNorm, PreLayerNorm: Variants of normalization layers that adapt to different normalization strategies.
- ReZero is for training stability.
- Scale Norm is a simpler norm.
- Chunk: Divides tensors into chunks, processes each chunk independently, and recombines them, reducing memory usage.
- FeedForward: A standard feedforward network layer with optional GLU (Gated Linear Unit) activation.
- (New 重點) Attention: Implements a multi-headed attention mechanism that can include local and global heads, rotary or absolute positional embeddings, and dropout.
1 | |
The partial function in Python, from the functools module, is used to create a new function with some of its arguments pre-filled (partially applied). This is useful when you want to fix certain parameters of a function and leave others open to be specified later. Essentially, it allows you to create a simpler function from an existing one by pre-setting some of the parameters.
Why Use partial?
- Simplify Function Calls: If you frequently call a function with certain parameters fixed,
partiallets you create a new function with those parameters pre-set, reducing redundancy. - Function Customization: In complex code,
partialenables more modular and customizable code by allowing functions to be adapted with specific parameters for different contexts. - Improve Readability: It makes code more readable by reducing the need for repeated arguments and creating functions that are easier to understand at a glance.
How partial Works
partial takes a function as its first argument and then any arguments or keyword arguments you want to pre-set. It returns a new function where those pre-set arguments are “locked in,” while the remaining arguments can still be supplied when calling the new function.
Here’s an example:
1 | |
In this case:
partial(multiply, 2)creates a new functiondoublethat is equivalent tomultiply(2, y).- Calling
double(5)is the same as callingmultiply(2, 5).
Example in Context of FastAttention
In the FastAttention class, partial is used to create functions with specific parameters pre-set. For instance:
1 | |
In this example:
partialis used to createself.create_projection, a version ofgaussian_orthogonal_random_matrixwithnb_rows,nb_columns, andscalingpre-set.self.create_projectioncan now be called with just adeviceargument, simplifying function calls and making the code more modular.
Fast Attention
原文 FA 算法如下。Bidirectional = Non-causal linear attention. Unidirectional = Causal linear atttention. ![[Pasted image 20241113193123.png]]
核心 code
- linear_attention: Non-causal linear attention for efficient matrix multiplications.
- causal_linear_attention: Causal linear attention, preventing future tokens from attending to past ones.
- causal_linear_attention_noncuda: Inefficient CPU-based implementation of causal linear attention for reference. 沒有使用,只是用於解釋。
Non-causal Linear Attention (直接單純)
This function computes non-causal linear attention for transformer models. Linear attention mechanisms approximate the standard softmax attention to reduce computational complexity from $O(n^2)$ to $O(n)$, where $n$ is the sequence length.
只有四行 code.
1 | |
Inputs
q: Query tensor of shape(..., n, d)k: Key tensor of shape(..., n, d)v: Value tensor of shape(..., n, e)
Here, ... represents any number of batch dimensions, n is the sequence length, d is the embedding dimension, and e is the value dimension.
Outputs
out: Output tensor of shape(..., n, e)
Mathematical Interpretation
This function approximates the standard softmax attention using linear operations:
\[\text{Attention}(Q, K, V) \approx D^{-1}\phi(Q)(\phi(K)^\top V)\]where $\phi(\cdot)$ is a feature map (e.g., exponential function) that linearizes the softmax kernel.
Causal Linear Attention (有點燒腦)
正常的 attention 先計算 $QK^{\top}$, 再直接加上 causal mask $M$ (strict upper triangular matrix with element = $-\infty$), 或是 $\text{softmax} (QK^{\top}) \otimes M$. $\otimes$ 是 element-wise 乘法,此處 $M$ 是 lower triangular matrix with element = +1. 最後再乘 $V$.
但是 Fast Attention (linear attention) 要先計算 $K^{\top} V$. 如何處理 causal mask 就非常燒腦。原文解釋如下。需要變成 iterative process (index i),其實可以視為另一種 RNN. 但是這個 iterative process 也可以平行計算時間可以 $O(L)\to O(\log L)$,只是要多計算和存儲 (11)。
![[Pasted image 20241113190707.png]]
我們看比較慢 (without cuda code), 使用 iterative method 類似 RNN implementation 如下。
1 | |
causal_linear_attention_noncuda Function
The causal_linear_attention_noncuda function is an implementation of causal linear attention without using CUDA acceleration. Although it’s labeled as “inefficient” and “not being used,” it serves as a reference implementation to understand how causal linear attention can be computed using cumulative sums. This function processes queries (q), keys (k), and values (v) to compute the attention output in a causal (autoregressive) manner, ensuring that each position in the sequence only attends to previous positions.
Below, we will break down the code step by step, explain each component, and relate it to the underlying mathematical operations.
Function Signature
1 | |
- Inputs:
q: Query tensor of shape(..., N, D).k: Key tensor of shape(..., N, D).v: Value tensor of shape(..., N, E).chunk_size: Size of chunks to process at a time along the sequence dimension (default: 128).eps: Small epsilon value to prevent division by zero (default: 1e-6).
- Outputs:
- Returns the attention output tensor of shape
(..., N, E).
- Returns the attention output tensor of shape
Initialization
1 | |
- Purpose:
last_k_cumsum: Stores the cumulative sum of keys up to the previous chunk.last_context_cumsum: Stores the cumulative sum of the context (key-value products) up to the previous chunk.outs: List to collect output chunks.
Processing Chunks
The function processes the input tensors in chunks along the sequence dimension to manage memory usage and simulate causal computation.
Chunking the Inputs
1 | |
- Explanation:
t.chunk(chunk_size, dim=-2): Splits each tensor (q,k,v) into chunks along the sequence dimension (dim=-2).zip(...): Iterates over the corresponding chunks ofq,k, andv.
Iterating Over Chunks
For each chunk, the function performs the following steps:
-
Cumulative Sum of Keys
1
1
k_cumsum = last_k_cumsum + k.cumsum(dim=-2)- Explanation:
k.cumsum(dim=-2): Computes the cumulative sum ofkalong the sequence dimension within the current chunk.last_k_cumsum: Adds the cumulative sum from the previous chunks to account for all previous positions.- Result:
k_cumsumcontains the cumulative sum of keys up to the current position.
- Explanation:
-
Computing the Normalization Factor $D_{\text{inv}}$
1
1
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps)- Explanation:
torch.einsum('...nd,...nd->...n', q, k_cumsum + eps): Computes the dot product betweenqandk_cumsumalong the feature dimension (D), resulting in a tensor of shape(..., N).1. / (...): Takes the reciprocal to obtain $D_{\text{inv}}$, the inverse of the normalization factor.eps: A small value added for numerical stability to prevent division by zero.- Mathematical Representation: \(D_{\text{inv, t}} = \frac{1}{\sum_{i=1}^{t} q_t^\top k_i + \epsilon} = \frac{1}{q_t^\top\sum_{i=1}^{t} k_i + \epsilon}\) where $t$ indexes the positions in the sequence.
- Explanation:
-
Computing the Context Tensor
1
context = torch.einsum('...nd,...ne->...nde', k, v)- Explanation:
torch.einsum('...nd,...ne->...nde', k, v): Computes the outer product betweenkandvfor each position, resulting in a tensor of shape(..., N, D, E).- Result:
contextcontains the key-value products for the current chunk.
- Explanation:
-
Cumulative Sum of Context
1
context_cumsum = last_context_cumsum + context.cumsum(dim=-3)- Explanation:
context.cumsum(dim=-3): Computes the cumulative sum of the context along the sequence dimension (now atdim=-3due to the additional dimension from the outer product).last_context_cumsum: Adds the cumulative sum from the previous chunks.- Result:
context_cumsumcontains the cumulative sum of key-value products up to the current position.
- Explanation:
-
Computing the Output
1
1
out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)- Explanation:
torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv): Performs a tensor contraction that:- Multiplies
context_cumsumwithqalong theDdimension. - Scales the result by
D_inv.
- Multiplies
- Result:
outis the attention output for the current chunk, of shape(..., N, E).
- Mathematical Representation: \(\text{out}_t = D_{\text{inv}, t} \left( \sum_{i=1}^{t} \left( q_t^\top k_i \right) v_i \right) = D_{\text{inv}, t} \left( q_t^{\top} \sum_{i=1}^{t} k_i v_i \right)\)
- Explanation:
-
Updating Cumulative Sums
1
2last_k_cumsum = k_cumsum[..., -1:, :] last_context_cumsum = context_cumsum[..., -1:, :, :]- Explanation:
- Extracts the last cumulative sum from the current chunk to carry over to the next chunk.
- Slices the tensors to keep the last position (
-1:) along the sequence dimension.
- Explanation:
-
Collecting the Output
1
outs.append(out)- Explanation:
- Appends the output
outof the current chunk to the listouts.
- Appends the output
- Explanation:
Concatenating the Outputs
After processing all chunks, the outputs are concatenated along the sequence dimension:
1 | |
- Explanation:
torch.cat(outs, dim=-2): Concatenates the list of output tensors along the sequence dimension to form the final output.
Mathematical Understanding
Causal Linear Attention
The causal linear attention mechanism aims to compute the attention output efficiently by leveraging the associativity of matrix multiplication and avoiding explicit computation of the attention matrix, which is quadratic in sequence length.
The standard attention mechanism is defined as:
\[\text{Attention}(Q, K, V) = \text{softmax}(Q K^\top) V\]In linear attention, we approximate the softmax function with a feature mapping $\phi(x)$such that:
\[\text{Attention}(Q, K, V) \approx \left( \phi(Q) \left( \phi(K)^\top V \right) \right)\]For causal attention, we need to ensure that each position $t$only attends to positions $\leq t $.
Implementing Causality with Cumulative Sums
By using cumulative sums, we can efficiently compute the causal attention outputs:
- Cumulative Sum of Keys:
- $S_t = \sum_{i=1}^{t} k_i$
- Cumulative Sum of Context:
- $Z_t = \sum_{i=1}^{t} k_i v_i^\top$
- Normalization Factor:
- $D_t = q_t^\top S_t + \epsilon$
- Attention Output:
- $\text{out}_t = \frac{1}{D_t} q_t^\top Z_t$
This ensures that at each position $t$, the output depends only on keys and values from positions $\leq t$.
Relating Code to Mathematical Equations
-
Computing $S_t$ (Cumulative Sum of Keys):
1
1
k_cumsum = last_k_cumsum + k.cumsum(dim=-2)- Corresponds to: \(S_t = S_{t-1} + k_t\)
-
Computing $Z_t$ (Cumulative Sum of Context):
1
2context = torch.einsum('...nd,...ne->...nde', k, v) context_cumsum = last_context_cumsum + context.cumsum(dim=-3)- Corresponds to: \(Z_t = Z_{t-1} + k_t v_t^\top\)
-
Computing $D_t$ (Normalization Factor):
1
1
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps)- Corresponds to: \(D_t = q_t^\top S_t + \epsilon\)
-
Computing the Output $\text{out}_t$:
1
1
out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)- Corresponds to: \(\text{out}_t = \frac{1}{D_t} q_t^\top Z_t\)
Recap
Since we don’t have the actual figure, let’s conceptually understand how the cumulative sums implement causal attention.
- Sequence Processing:
- The sequence is processed in chunks to simulate streaming data or to handle long sequences without exceeding memory limits.
- Causality Enforcement:
- By using cumulative sums, each position’s output only depends on current and past positions.
- No information from future positions is included in the computation.
- Chunking Mechanism:
- The cumulative sums (
k_cumsumandcontext_cumsum) are updated incrementally. - The last cumulative sums from the previous chunk are used to initialize the current chunk’s cumulative sums.
- The cumulative sums (
- Attention Computation:
- The attention output at each position is computed using the cumulative sums up to that position.
- The normalization factor ensures proper scaling of the outputs.
Practical Implications
- Inefficiency:
- The code is labeled as “inefficient” because it does not leverage optimized GPU operations or memory-efficient algorithms.
- It serves as a reference implementation to illustrate the concept.
- Not Using CUDA:
- The absence of CUDA-specific code means that it may not run efficiently on GPUs.
- In practice, optimized implementations would use custom CUDA kernels or libraries designed for performance.
Summary
The causal_linear_attention_noncuda function:
- Implements causal linear attention by computing cumulative sums of keys and context.
- Ensures that each output position only depends on current and past positions (causality).
- Processes the sequence in chunks to manage memory usage.
- Uses Einstein summation (
torch.einsum) for tensor contractions. - Is an illustrative implementation for understanding causal linear attention, albeit not optimized for performance.
CUDA 有一個優化的版本如下。不知道 causal_dot_product_fn 是如何實現的。
1 | |
Kernel Functions: ORF (Orthogonal Random Feature +)
These are optimized implementations of softmax and generalized kernels:
- softmax_kernel: Computes a softmax-based kernel using projections. Used to replace standard softmax attention.
- generalized_kernel: Implements a general kernel function, defaulting to
ReLU, with optional projection. - orthogonal_matrix_chunk & gaussian_orthogonal_random_matrix: Create random matrices with orthogonal constraints, useful for efficient kernel approximations.
QR decomposition 產生正交的 basis.
1 | |
Key Points
- Linear Attention: Replaces quadratic complexity attention with linear complexity attention for faster and memory-efficient computations.
- Projection Updater: Manages periodic redraws of projection matrices to improve performance without recalculating projections every step.
- Flexible Configuration: Various options for normalization, token shifting, cross-attention, and dropout make it highly customizable.
Performer Class
The main Performer class is an efficient variant of the transformer. Key features:
- Supports multiple heads and local attention windows.
- Reversible layers allow for memory-efficient training. 這是 Feedforward layer, 最早從 reformer model 引入。
- Feature redraw interval: Redraws projection matrices periodically.
- ProjectionUpdater is used to manage feature projection updates automatically.
PerformerLM Class
This is the full transformer language model using the Performer attention mechanism. Key components:
- Embedding layers: For tokens and positional encodings.
- Performer: Uses the efficient Performer attention mechanism for self-attention.
- Layer normalization and linear output layer: Applied to the final output.
1 | |
The main differences between PerformerLM and Performer in the code are related to their specific roles within the transformer architecture. Here’s a breakdown of each:
1. Performer
The Performer class implements the core attention mechanism and feedforward layers of the Performer architecture. It’s the main building block that defines how tokens interact with each other through attention and nonlinear transformations. Key points about Performer:
- Attention Layers: uses
FastAttentioncan be set to either causal or non-causal mode - Local Attention Support:
Performerincludes support for both local and global attention heads. This allows for a mixture of local attention (attention to a limited neighborhood) and global attention (full context). - Feedforward Layers: chunked for memory efficiency. Support Reversible Sequence Support allowing for memory savings during training by not storing intermediate activations.
2. PerformerLM
The PerformerLM class is a language model that builds on top of Performer. It wraps the Performer class and adds elements specific to language modeling tasks. Key points about PerformerLM:
- Token Embeddings
- Positional Embeddings: Adds positional information to the tokens to maintain the sequence order. It supports multiple types of positional embeddings, such as fixed sinusoidal, rotary, or absolute positional embeddings.
- Output Layer: After processing through the
Performerlayers,PerformerLMincludes a final layer to map the transformed embeddings back to vocabulary logits for language modeling tasks. Iftie_embedis set toTrue, it ties the input and output embeddings for parameter efficiency. - Language Model-Specific Configurations:
PerformerLMincludes specific configurations for language modeling, like the ability to handle maximum sequence lengths and efficient generation through causal attention.