PositionalEmbeddings
Documentation for PositionalEmbeddings.
The PositionalEmbeddings package provides implementations of positional embeddings for encoding sequential position information into feature vectors. This encoding is essential for models where the order of sequence elements must be preserved during processing.
The package implements two foundational approaches to positional encoding:
Rotary Position Embeddings (RoPE) encode positions by rotating vectors in 2D subspaces, enabling explicit relative position modeling through geometric transformations.
Absolute Positional Embeddings (AbsolutePE) create unique position markers using sinusoidal functions, following the original approach from "Attention Is All You Need."
API Reference
PositionalEmbeddings.AbsolutePE
— TypeAbsolutePE{T<:AbstractArray}
AbsolutePE(embedding_size::Int, max_length::Int; base::Number=10_000)
Absolute Position Embeddings using sinusoidal frequencies from "Attention Is All You Need" paper. Formula: PE(pos,2i) = sin(pos/10000^(2i/dmodel)) PE(pos,2i+1) = cos(pos/10000^(2i/dmodel))
Fields
embedding_size::Int
: Size of the embedding dimension (d_model)max_length::Int
: Maximum sequence length supportedembeddings::T
: Positional embeddings
PositionalEmbeddings.RoPE
— TypeRoPE(head_size::Int, seq_len::Int;
base::Number=10_000,
scale::Number=1.0)
Rotary Position Embeddings (RoPE) implementation as described in the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding".
Construct a RoPE object with the following arguments:
head_size::Int
: Head size to apply rotation to (must be multiple of 2)seq_len::Int
: Maximum sequence length to supportbase::Number=10_000
: Base for geometric progression of frequenciesscale::Number=1.0
: Scaling factor for the frequencies
Examples
# Create RoPE for a model with 512 head size and max sequence length of 1024
rope = RoPE(512, 1024)
# Apply RoPE to input tensor of shape (head_size, seq_len, nheads*batch_size)
Q = randn(Float32, 512, 100, 32)
Q_positioned = rope(x)
PositionalEmbeddings.RoPE
— Method(rope::RoPE)(x) -> AbstractArray
Apply Rotary Position Embeddings to the input array x
of shape (head_size, seq_len, batch * num_heads)
.
Arguments
x
: Input array where first dimension must matchrope.head_size
and second dimension must not exceed the maximum cached sequence length.
See also: RoPE
PositionalEmbeddings.compute_frequencies
— Functioncompute_frequencies(dim::Int, seq_len::Int, base::Number=10_000)
Compute frequency bands for rotary position embeddings.
Arguments
dim::Int
: Number of dimensions for the frequency bandsseq_len::Int
: Maximum sequence length to compute frequencies forbase::Number=10_000
: Base for geometric progression of frequencies
Returns
- Matrix of shape (dim, seq_len) containing frequency values
PositionalEmbeddings.create_causal_mask
— Methodcreate_causal_mask(seq_len::Int)
Create a causal (autoregressive) attention mask that prevents positions from attending to future positions. This is commonly used in language models to ensure predictions only depend on previous tokens.
The mask ensures that position i can only attend to positions j ≤ i, creating a triangular pattern where the upper triangle including diagonal is masked (True) and the lower triangle is unmasked (False).
Arguments
seq_len::Int
: Length of the sequence to create mask for
Returns
- 3D boolean array of shape (seqlen, seqlen, 1) where True indicates positions to mask
Examples
julia> mask = create_causal_mask(3)[:,:,1]
3×3 Matrix{Bool}:
1 1 1 # First position can't attend anything
0 1 1 # Second position can attend to first only
0 0 1 # Third position can attend to first and second
PositionalEmbeddings.create_padding_mask
— Methodcreate_padding_mask(lengths::Vector{Int}, max_len::Int)
Create padding masks for batched sequences of varying lengths. This ensures that padded positions (positions beyond each sequence's actual length) are masked out and don't participate in attention.
Arguments
lengths::Vector{Int}
: Actual length of each sequence in the batchmax_len::Int
: Maximum sequence length (padded length)
Returns
- 3D boolean array of shape (batchsize, maxlen, 1) where True indicates padded positions
Examples
# For 2 sequences of lengths 2 and 3, padded to length 4:
julia> mask = create_padding_mask([2, 3], 4)[:,:,1]
2×4 Matrix{Bool}:
0 0 1 1 # First sequence: length 2, positions 3-4 are padding
0 0 0 1 # Second sequence: length 3, position 4 is padding
Usage with Causal Mask
Padding and causal masks are often combined for batched autoregressive tasks:
seq_len = 5
batch_lengths = [3, 4]
# Create both masks
causal = create_causal_mask(seq_len) # Shape: (5, 5, 1)
padding = create_padding_mask(batch_lengths, seq_len) # Shape: (2, 5, 1)
# Combine masks which will either prevent attending to future tokens or padding tokens
combined = causal .| padding
# final_mask will prevent:
# 1. Attending to future tokens (from causal mask)
# 2. Attending to padding tokens (from padding mask)
PositionalEmbeddings.neg_half
— Functionneg_half(x::AbstractArray, dim::Int=1)
Helper function that negates the second half of the array along dimension dim
. This implementatio uses half negative array instead of interleaving pairs, as in LlaMA https://github.com/huggingface/transformers/issues/25199
Arguments
x::AbstractArray
: Input arraydim::Integer=1
: Dimension along which to perform the operation
Returns
- Array with second half negated along specified dimension
Usage Examples
# Create RoPE for head dimension of 64 and maximum sequence length of 1024
rope = RoPE(64, 1024)
# Apply to input tensor of shape (head_size, seq_len, nheads*batch)
# For example, with 64-dim heads, sequence length 100, 8 heads × 32 batch size:
x = randn(Float32, 64, 100, 256) # 256 = 8 heads × 32 batch
x_with_positions = rope(x)
Input tensors for RoPE must follow the shape (headsize, seqlen, nheads*batch). The headsize parameter must be even, seqlen represents your sequence length, and the final dimension combines the number of attention heads and batch size.
The RoPE constructor accepts several parameters:
function RoPE(head_size::Int, seq_len::Int;
base::Number=10_000,
scale::Number=1.0,
T::Type=Float32)
The base parameter controls frequency bands for position encoding, with higher values creating slower-changing position representations. The scale parameter allows adjusting the positional encoding's influence.
Absolute Positional Embeddings
AbsolutePE implements fixed positional patterns through sinusoidal encoding:
# Create embeddings for 512-dimensional features up to length 1000
pe = AbsolutePE(512, 1000)
# Apply to input tensor of shape (seq_len, features, batch)
x = randn(Float32, 100, 512, 32)
x_with_positions = pe(x)
For AbsolutePE, tensors require the shape (seqlen, features, batch), where features matches your model's dimension and seqlen represents the sequence length.
The AbsolutePE constructor allows customization through:
function AbsolutePE(embedding_size::Int, max_length::Int; base::Number=10_000)
The base parameter influences the wavelength pattern of sinusoidal embeddings, with each dimension using a different frequency derived from this base value.
Flux Integration Example
This example that adds RoPERoPEMultiHeadAttention
that. Here's the complete implementation:
using PositionalEmbeddings
using LinearAlgebra
using WeightInitializers
using Functors
using NNlib
struct RoPEMultiHeadAttention{T<:AbstractFloat, A<:AbstractArray{T, 2}}
Wq::A
Wk::A
Wv::A
Wo::A
num_heads::Int
head_dim::Int
scale::T
rope::RoPE
end
function RoPEMultiHeadAttention(d_model::Int, num_heads::Int; maxlen=1000)
head_dim = d_model ÷ num_heads
@assert head_dim * num_heads == d_model "d_model ($d_model) must be divisible by num_heads ($num_heads)"
Wq = kaiming_normal(d_model, d_model)
Wk = kaiming_normal(d_model, d_model)
Wv = kaiming_normal(d_model, d_model)
Wo = kaiming_normal(d_model, d_model)
scale = Float32(sqrt(head_dim))
rope = RoPE(head_dim, maxlen)
RoPEMultiHeadAttention(Wq, Wk, Wv, Wo, num_heads, head_dim, scale, rope)
end
# Split: (d_model, seqlen, batch) -> (head_dim, seqlen, num_heads * batch)
function split_heads(x::AbstractArray, head_dim::Int, num_heads::Int)
d_model, seqlen, batch = size(x)
return reshape(permutedims(reshape(x, head_dim, num_heads, seqlen, batch), (1, 3, 2, 4)),
head_dim, seqlen, num_heads * batch)
end
# Join: (head_dim, seqlen, num_heads * batch) -> (d_model, seqlen, batch)
function join_heads(x::AbstractArray, head_dim::Int, num_heads::Int, batch_size::Int)
return reshape(permutedims(reshape(x, head_dim, size(x, 2), num_heads, batch_size), (1, 3, 2, 4)),
head_dim * num_heads, size(x, 2), batch_size)
end
function apply_mask(logits, mask)
neginf = typemin(eltype(logits))
ifelse.(mask, logits, neginf)
end
function (mha::RoPEMultiHeadAttention)(x::AbstractArray, mask=nothing)
d_model, seqlen, batch_size = size(x)
# Project and split heads in one go
q = split_heads(reshape(mha.Wq * reshape(x, d_model, :), d_model, seqlen, batch_size),
mha.head_dim, mha.num_heads)
k = split_heads(reshape(mha.Wk * reshape(x, d_model, :), d_model, seqlen, batch_size),
mha.head_dim, mha.num_heads)
v = split_heads(reshape(mha.Wv * reshape(x, d_model, :), d_model, seqlen, batch_size),
mha.head_dim, mha.num_heads)
# Apply RoPE
q = mha.rope(q)
k = mha.rope(k)
# All operations now work with (head_dim, seqlen, num_heads * batch)
attention_scores = NNlib.batched_mul(NNlib.batched_transpose(k), (q ./ mha.scale))
if !isnothing(mask)
neginf = typemin(eltype(attention_scores))
attention_scores = ifelse.(mask, attention_scores, neginf)
end
attention_probs = softmax(attention_scores; dims=1)
attention_output = NNlib.batched_mul(v, attention_probs)
# Join heads only at the very end
output = join_heads(attention_output, mha.head_dim, mha.num_heads, batch_size)
return reshape(mha.Wo * reshape(output, d_model, :), d_model, seqlen, batch_size)
end
Functors.@functor RoPEMultiHeadAttention
x = rand(512, 20, 32);
mha = RoPEMultiHeadAttention(512, 8)
mha(x)