Source
Original Skyformer paper (2021 BERT): https://arxiv.org/pdf/2111.00035 Another symmetric attention (2024 BERT): https://arxiv.org/pdf/2406.06366
SKY (Symmetrization Kernelized attention for NYstrom emthod)
Overview
Pytorch Matrix Multiplication Trade-offs
From Perplexity
| Method | Usability | Performance | Flexibility |
|---|---|---|---|
@ Operator |
Very intuitive | Fast | Limited |
torch.matmul() |
Clear and explicit | Fast | Limited |
torch.einsum() |
Complex notation | Potentially slower | Highly flexible |
Conclusion
For most standard matrix multiplications, either the @ operator or torch.matmul() will suffice and perform efficiently. However, when dealing with more complex tensor operations that require specific summation patterns or manipulations, torch.einsum() offers significant advantages despite potential performance trade-offs. Choosing the right method depends on the specific requirements of your computation task in PyTorch.
Another explanation from ChatGPT.
| Method | Flexibility | Performance | Readability | Notes |
|---|---|---|---|---|
torch.mm |
Low | High | High | Best for 2D tensors only. |
torch.matmul |
High | Moderate | Moderate | Handles batch dimensions; prone to silent broadcasting. |
torch.bmm |
Moderate | High | Moderate | Optimized for 3D batched tensors. |
torch.einsum |
Very High | Moderate | Low-Moderate | Great for complex operations; slower than specialized ops. |
torch.tensordot |
High | Moderate | Moderate | Flexible contractions, but requires explicit dims. |
@ operator |
Low | High | High | Clean and concise for matrix multiplication. |
torch.outer |
Low | High | High | Specialized for outer products. |
torch.mul |
Moderate | High | High | Element-wise only. |
For general-purpose use:
- Use
torch.matmulfor its versatility. - Use
torch.einsumfor advanced operations requiring flexibility. - Use
torch.bmmortorch.mmfor performance-critical, batch-specific tasks.
Skyformer High-Level Overview
This code implements a specific type of transformer attention mechanism named Skyformer, which leverages sketching techniques to approximate the kernel-based attention, aiming to reduce computational costs for long sequences. The model includes specialized kernels (kernel_RS_RBF, kernel_SM, etc.) and incorporates iterative inverse approximation (iterative_inv) for normalizing the sketched attention weights.
Code Breakdown
Linearized Attention
linear_attention(q, k, v)- 只用三行 code 就描述 linear attention with normalization!
- Linearized Softmax Attention: Compute the attention as: \(\text{LinearAttention}(Q, K, V) = D^{-1}(Q \cdot (K^T \cdot V))\) Where:
- $D$ is the normalization term defined as: \(D = Q \cdot K^T \cdot \mathbf{1}\) This sums over the keys.
- Steps in Linear Attention:
- Compute the cumulative sum of the keys: \(K_{\text{cumsum}} = \sum K\)
- Normalize using $D^{-1}$: \(D^{-1} = \frac{1}{Q \cdot K_{\text{cumsum}}}\)
- Calculate the context vector: \(\text{Context} = \sum K \cdot V\)
- Combine to produce the output: \(\text{Output} = D^{-1} \cdot (\text{Context} \cdot Q)\)
rbf_attention(q, k, v):- Performs RBF kernel-based attention similar to
linear_attentionbut with an additional normalization step.
- Performs RBF kernel-based attention similar to
Kernel Functions
These compute kernel transformations based on the input query (q) and key (k) matrices.
kernel_SM:- Computes the standard softmax kernel for given inputs.
- Applies matrix exponentiation directly to the inner product of
qandk.
kernel_RS_SM:- Similar to
kernel_SMbut includes random sign multiplication for efficient computation. - Handles cases where sketched keys are accumulated (
X2_accu=True).
- Similar to
kernel_RS_RBF:- Computes the RBF kernel with random sign multiplication for approximation.
- Includes distance-based scaling using the diagonal of
qandk.
Sketching Mechanism
kernel_sketch:- Projects the concatenated
qandkmatrices into a lower-dimensional space using a sketching matrix. - Computes sketched softmax kernels (
AS) for dimensionality reduction.
- Projects the concatenated
Inverse Normalization
iterative_inv(mat, n_iter=6):- Iteratively computes an approximate matrix inverse using a series expansion.
- Starts with an initial approximation (
V) scaled by the largest sum of rows/columns. - Refines the inverse approximation using recursive updates.
Skyformer Class
Initialization:
1 | |
- Initializes hyperparameters such as sequence length (
max_seq_len), number of features (nb_features), and attention head dimensions (dim_heads). - Configures the kernel function (
self.kernel_fn) based on the selected sketched kernel type.
Uniform Sketching:
1 | |
- Randomly selects rows and columns for sketching matrices.
- Applies a random sign for efficient kernel approximation.
Forward Pass:
1 | |
- Preprocessing:
- Normalizes
q,k, andvusing the mask and a scaling factor.
- Normalizes
- Sketching:
- Generates sketching matrices (
self.sketching_matrix) and computes sketched kernel representations (AS).
- Generates sketching matrices (
- Inverse Normalization:
- Constructs a sketch-based self-attention approximation using
STAS(sketched kernel matrix). - Normalizes and computes its inverse (
STAS_inv) viaiterative_inv.
- Constructs a sketch-based self-attention approximation using
- Attention Context:
- Computes the context vector (
context) by combining the sketchedKmatrix and the value vectors (v). - Outputs the final attention (
out).
- Computes the context vector (
Key Features
- Kernel Approximation:
- Reduces computational complexity using sketching and kernel-based approximations.
- Supports various kernels (softmax, RBF).
- Efficient Inverse Computation:
- Utilizes iterative approximations for inverting matrices to avoid direct inversion.
- Memory-Efficient Sketching:
- Randomized sketching reduces memory usage while preserving essential information.
Use Cases
- Designed for long-sequence attention tasks where computational cost and memory usage are bottlenecks.
- Suitable for scenarios where approximations (e.g., kernel methods) are acceptable without significant loss in accuracy.
1 | |