Public API

This page lists the public API of Attention.jl.

Modules

AttentionModule
Attention

A Julia package providing modular and extensible attention mechanisms for deep learning models.

It offers:

  • A flexible AbstractAttention interface.
  • Implementations like DotProductAttention, NNlibAttention.
  • A full MultiHeadAttention layer compatible with Flux.
  • Utilities such as make_causal_mask.
  • Support for custom Q/K transformations (e.g., for RoPE in MultiHeadAttention).
source

Attention Mechanisms

Attention.AbstractAttentionType
AbstractAttention

Abstract type for attention mechanisms. Custom implementations should implement the compute_attention method.

source
Attention.compute_attentionFunction
compute_attention(mechanism::AbstractAttention, q, k, v, bias=nothing;
                mask=nothing, nheads=1, fdrop=identity)

Compute attention based on the specified mechanism.

Arguments

  • mechanism: The attention mechanism to use
  • q: Query tensor of shape (dmodel, seqlen_q, batch)
  • k: Key tensor of shape (dmodel, seqlen_k, batch)
  • v: Value tensor of shape (dmodel, seqlen_v, batch)
  • bias: Optional bias tensor
  • mask: Optional mask tensor
  • nheads: Number of attention heads
  • fdrop: Dropout function to apply

Returns

  • output: Output tensor of shape (dmodel, seqlen_q, batch)
  • attention_weights: Attention weights
source
Attention.NNlibAttentionType
NNlibAttention <: AbstractAttention

Attention implementation that uses NNlib's dotproductattention when available. This provides a more optimized implementation that may be faster in some cases.

source
Attention.LinearAttentionType
LinearAttention <: AbstractAttention

Linear attention implementation that computes attention scores using a feature map φ(x) = elu(x) + 1, following the method described in the paper "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention".

This implementation has linear complexity O(N) with respect to sequence length, compared to quadratic complexity O(N²) of standard dot-product attention.

source
Attention.MultiHeadAttentionType
MultiHeadAttention(d_model, nheads=8; bias=false, dropout_prob=0.0, attention_impl=DotProductAttention(), q_transform=identity, k_transform=identity)

The multi-head dot-product attention layer used in Transformer architectures.

Returns the transformed input sequence and the attention scores.

Arguments

  • d_model: The embedding dimension
  • nheads: number of heads. Default 8.
  • bias: whether pointwise QKVO dense transforms use bias. Default false.
  • dropout_prob: dropout probability for the attention scores. Default 0.0.
  • attention_impl: the attention implementation to use. Default DotProductAttention().
  • q_transform: a function to apply to the query tensor after projection. Default identity.
  • k_transform: a function to apply to the key tensor after projection. Default identity.
source

Utilities

Attention.make_causal_maskFunction
make_causal_mask(x::AbstractArray, dims::Int=2)

Create a causal mask for a sequence of length derived from x. The mask ensures that position i can only attend to positions j ≤ i.

Arguments

  • x: Input array from which sequence length is derived
  • dims: Dimension along which to derive the sequence length (default: 2)

Returns

  • A boolean mask matrix of shape (seqlen, seqlen) where true indicates allowed attention and false indicates masked (disallowed) attention.
source