Public API
This page lists the public API of Attention.jl.
Modules
Attention — ModuleAttentionA Julia package providing modular and extensible attention mechanisms for deep learning models.
It offers:
- A flexible
AbstractAttentioninterface. - Implementations like
DotProductAttention,NNlibAttention. - A full
MultiHeadAttentionlayer compatible with Flux. - Utilities such as
make_causal_mask. - Support for custom Q/K transformations (e.g., for RoPE in
MultiHeadAttention).
Attention Mechanisms
Attention.AbstractAttention — TypeAbstractAttentionAbstract type for attention mechanisms. Custom implementations should implement the compute_attention method.
Attention.compute_attention — Functioncompute_attention(mechanism::AbstractAttention, q, k, v, bias=nothing;
mask=nothing, nheads=1, fdrop=identity)Compute attention based on the specified mechanism.
Arguments
mechanism: The attention mechanism to useq: Query tensor of shape (dmodel, seqlen_q, batch)k: Key tensor of shape (dmodel, seqlen_k, batch)v: Value tensor of shape (dmodel, seqlen_v, batch)bias: Optional bias tensormask: Optional mask tensornheads: Number of attention headsfdrop: Dropout function to apply
Returns
output: Output tensor of shape (dmodel, seqlen_q, batch)attention_weights: Attention weights
Attention.DotProductAttention — TypeDotProductAttention <: AbstractAttentionStandard scaled dot-product attention as described in "Attention is All You Need" paper.
Attention.NNlibAttention — TypeNNlibAttention <: AbstractAttentionAttention implementation that uses NNlib's dotproductattention when available. This provides a more optimized implementation that may be faster in some cases.
Attention.LinearAttention — TypeLinearAttention <: AbstractAttentionLinear attention implementation that computes attention scores using a feature map φ(x) = elu(x) + 1, following the method described in the paper "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention".
This implementation has linear complexity O(N) with respect to sequence length, compared to quadratic complexity O(N²) of standard dot-product attention.
Attention.MultiHeadAttention — TypeMultiHeadAttention(d_model, nheads=8; bias=false, dropout_prob=0.0, attention_impl=DotProductAttention(), q_transform=identity, k_transform=identity)The multi-head dot-product attention layer used in Transformer architectures.
Returns the transformed input sequence and the attention scores.
Arguments
d_model: The embedding dimensionnheads: number of heads. Default8.bias: whether pointwise QKVO dense transforms use bias. Defaultfalse.dropout_prob: dropout probability for the attention scores. Default0.0.attention_impl: the attention implementation to use. DefaultDotProductAttention().q_transform: a function to apply to the query tensor after projection. Defaultidentity.k_transform: a function to apply to the key tensor after projection. Defaultidentity.
Utilities
Attention.make_causal_mask — Functionmake_causal_mask(x::AbstractArray, dims::Int=2)Create a causal mask for a sequence of length derived from x. The mask ensures that position i can only attend to positions j ≤ i.
Arguments
x: Input array from which sequence length is deriveddims: Dimension along which to derive the sequence length (default: 2)
Returns
- A boolean mask matrix of shape (seqlen, seqlen) where
trueindicates allowed attention andfalseindicates masked (disallowed) attention.