Attention, Transformers, and Decoding
Attention stack
Grilly provides both module-level and backend-level attention:
Module API: nn.MultiheadAttention, nn.FlashAttention2
Backend API: backend.attention.* and backend.attention.flash_attention2(…)
Use module attention when building end-to-end model graphs. Use backend attention when running direct kernel workflows and benchmarks.
Why Flash Attention 2 matters
Flash-style attention reduces memory pressure for long sequences by computing attention in tiled blocks rather than materializing full intermediate matrices.
In Grilly this can improve throughput and reduce OOM risk on practical sequence lengths, especially on consumer GPUs.
Transformer-facing modules
grilly.nn also includes transformer-oriented components:
TransformerEncoderLayer
TransformerDecoderLayer
RoPE
ProsodyModulatedAttention
decoding modules (GreedyDecoder, SampleDecoder)
Basic attention example
import numpy as np
import grilly.nn as nn
attn = nn.MultiheadAttention(embed_dim=256, num_heads=8)
x = np.random.randn(4, 32, 256).astype(np.float32)
out, weights = attn(query=x, key=x, value=x)
print(out.shape, weights.shape)
Backend flash attention example
import numpy as np
import grilly
backend = grilly.Compute()
q = np.random.randn(2, 8, 64, 64).astype(np.float32)
k = np.random.randn(2, 8, 64, 64).astype(np.float32)
v = np.random.randn(2, 8, 64, 64).astype(np.float32)
y = backend.attention.flash_attention2(q, k, v)
print(y.shape)
Decoding usage
Decoder modules are used to convert logits to token decisions:
greedy decoding for deterministic paths
sampled decoding for stochastic generation
They fit naturally after transformer output projection heads.