Public API
This page lists the public API of Attention.jl
.
Modules
Attention
— ModuleAttention
A Julia package providing modular and extensible attention mechanisms for deep learning models.
It offers:
- A flexible
AbstractAttention
interface. - Implementations like
DotProductAttention
,NNlibAttention
. - A full
MultiHeadAttention
layer compatible with Flux. - Utilities such as
make_causal_mask
. - Support for custom Q/K transformations (e.g., for RoPE in
MultiHeadAttention
).
Attention Mechanisms
Attention.AbstractAttention
— TypeAbstractAttention
Abstract type for attention mechanisms. Custom implementations should implement the compute_attention
method.
Attention.compute_attention
— Functioncompute_attention(mechanism::AbstractAttention, q, k, v, bias=nothing;
mask=nothing, nheads=1, fdrop=identity)
Compute attention based on the specified mechanism.
Arguments
mechanism
: The attention mechanism to useq
: Query tensor of shape (dmodel, seqlen_q, batch)k
: Key tensor of shape (dmodel, seqlen_k, batch)v
: Value tensor of shape (dmodel, seqlen_v, batch)bias
: Optional bias tensormask
: Optional mask tensornheads
: Number of attention headsfdrop
: Dropout function to apply
Returns
output
: Output tensor of shape (dmodel, seqlen_q, batch)attention_weights
: Attention weights
Attention.DotProductAttention
— TypeDotProductAttention <: AbstractAttention
Standard scaled dot-product attention as described in "Attention is All You Need" paper.
Attention.NNlibAttention
— TypeNNlibAttention <: AbstractAttention
Attention implementation that uses NNlib's dotproductattention when available. This provides a more optimized implementation that may be faster in some cases.
Attention.LinearAttention
— TypeLinearAttention <: AbstractAttention
Linear attention implementation that computes attention scores using a feature map φ(x) = elu(x) + 1, following the method described in the paper "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention".
This implementation has linear complexity O(N) with respect to sequence length, compared to quadratic complexity O(N²) of standard dot-product attention.
Attention.MultiHeadAttention
— TypeMultiHeadAttention(d_model, nheads=8; bias=false, dropout_prob=0.0, attention_impl=DotProductAttention(), q_transform=identity, k_transform=identity)
The multi-head dot-product attention layer used in Transformer architectures.
Returns the transformed input sequence and the attention scores.
Arguments
d_model
: The embedding dimensionnheads
: number of heads. Default8
.bias
: whether pointwise QKVO dense transforms use bias. Defaultfalse
.dropout_prob
: dropout probability for the attention scores. Default0.0
.attention_impl
: the attention implementation to use. DefaultDotProductAttention()
.q_transform
: a function to apply to the query tensor after projection. Defaultidentity
.k_transform
: a function to apply to the key tensor after projection. Defaultidentity
.
Utilities
Attention.make_causal_mask
— Functionmake_causal_mask(x::AbstractArray, dims::Int=2)
Create a causal mask for a sequence of length derived from x
. The mask ensures that position i
can only attend to positions j ≤ i
.
Arguments
x
: Input array from which sequence length is deriveddims
: Dimension along which to derive the sequence length (default: 2)
Returns
- A boolean mask matrix of shape (seqlen, seqlen) where
true
indicates allowed attention andfalse
indicates masked (disallowed) attention.