Canon Layers

K = 4the short causal window used in the main Canon implementation
DKdepthwise parameter cost, instead of the full-convolution cost $D^2K$
A/B/C/Dfour insertion points: before attention, inside attention, before MLP, inside MLP

Local causal mixer

A Canon layer adds a small weighted mixture of nearby past token states to the current token.

mixed = 0.20*0.25 + 0.30*0.50 + 0.40*0.75 + 0.10*1.00 = 0.600
output = 0.600 + residual(1.00) = 1.600
conv mixture0.600
Canon output1.600

Depthwise vs full local convolution

Canon uses a separate short causal filter per channel. Full convolution would also mix channels and is much more expensive.

depthwise = D*K = 4,096*4 = 16,384
full = D^2*K = 4,096^2*4 = 67,108,864
ratio = full/depthwise = 4,096x
depthwise params16,384
full-conv params67,108,864
parameter ratio4,096x
Canon cost$O(BTDK)$

Canon Layers are a small architectural primitive from Zeyuan Allen-Zhu’s Physics of Language Models: Part 4.1. The basic idea is simple: give every token a cheap causal path to nearby past token states.

That path is not meant to replace attention. It handles a different job.

Attention should spend capacity on content-addressed routing and retrieval. It should not have to spend layers on routine neighbor-to-neighbor transport.

The mechanism is a residual causal depthwise convolution over the sequence axis. It is local, cheap, and easy to insert into existing Transformer, linear-attention, and SSM-style blocks.

1. The missing path in a standard Transformer

A pre-norm Transformer block usually has:

\[x^{(\ell+\frac12)} = x^{(\ell)} + \operatorname{Attn}\left(\operatorname{Norm}(x^{(\ell)})\right),\] \[x^{(\ell+1)} = x^{(\ell+\frac12)} + \operatorname{MLP}\left(\operatorname{Norm}(x^{(\ell+\frac12)})\right).\]

This gives two strong paths:

  • a vertical residual path, where token position $t$ preserves and refines its own representation across layers;
  • a global attention path, where token position $t$ can retrieve content from previous positions.

But the MLP is pointwise over tokens. It mixes channels, not positions. Attention can move information from $t-1$ to $t$, but attention is a global content-routing mechanism. Using it for routine local relay is expensive and depth-inefficient.

Canon adds a third path:

\[\text{nearby causal context} \quad\rightarrow\quad \text{current token state}.\]

That is why the paper describes Canon as horizontal information flow. The ordinary residual stream is vertical across depth; Canon is local residual flow across positions.

Canon as local horizontal residual flow: the current token receives a small learned mixture of nearby causal states.
Canon as local horizontal residual flow: the current token receives a small learned mixture of nearby causal states.

2. Associative recall shows the problem

Consider the causal sequence:

[A] [B] ... [A] [?]

The desired next token is [B]. A natural mechanism is:

  1. the second [A] attends to the first [A];
  2. the representation at the first [A] carries enough information to identify the following [B];
  3. the model predicts [B].

The catch is causal masking. The first [A] cannot see its future neighbor [B] at the same layer. A model often needs one operation to move information locally from [B] into a neighboring representation, then another operation to retrieve it globally.

Canon makes that first local enrichment cheap.

3. The Canon operator

Let a sequence of hidden states be:

\[H=(h_1,\ldots,h_T), \qquad h_t\in\mathbb{R}^{m}.\]

A width-4 Canon layer computes:

\[\widetilde h_t = w_0\odot h_t +w_1\odot h_{t-1} +w_2\odot h_{t-2} +w_3\odot h_{t-3},\]

where $w_r\in\mathbb{R}^{m}$ are learned channelwise weights, $\odot$ is elementwise multiplication, and missing past states are zero-padded.

The residual form is:

\[h'_t = h_t + \operatorname{Conv1D}_{\mathrm{causal},K=4} \left(h_t,h_{t-1},h_{t-2},h_{t-3}\right).\]

Equivalently, for batch index $b$, position $t$, channel $c$, and kernel size $K$:

\[y_{b,t,c} = x_{b,t,c} + \sum_{r=0}^{K-1} a_{c,r}\,x_{b,t-r,c},\]

with $x_{b,t-r,c}=0$ when $t-r<0$.

The key word is depthwise. Channel $c$ reads only channel $c$ over nearby positions. Canon does not perform hidden-dimension mixing; the projections and MLP still own that job.

4. Why the residual matters

Without the residual path:

\[h'_t=\operatorname{Canon}(H)_t.\]

With the residual path:

\[h'_t=h_t+\operatorname{Canon}(H)_t.\]

The residual version is easier to insert because it starts as a local perturbation around the existing representation. If the local signal is useful, the model can add it. If it is not useful, the model can learn small weights without destroying the vertical residual stream.

The Canon paper’s ablations report that residual Canon is materially more stable and efficient than non-residual variants. The implementation also exposes canon_residual as a configuration flag, with the released LlamaCanon path defaulting to residual behavior.

5. Canon is not local attention

Local attention computes content-dependent weights:

\[y_t = \sum_{j=t-w}^{t}\alpha_{t,j}v_j, \qquad \alpha_{t,j} = \operatorname{softmax}_j \left( \frac{q_t^\top k_j}{\sqrt{d_h}} \right).\]

Canon computes fixed learned local propagation:

\[y_t = x_t + \sum_{r=0}^{K-1}a_r\odot x_{t-r}.\]

The distinction matters:

Mechanism Main job Weights depend on content? Scope
Full attention global retrieval and routing yes all past tokens
Local attention adaptive local retrieval yes local window
Canon cheap causal transport no, in the studied version tiny causal window
MLP channel transformation no token mixing one token


Canon is closer to a short learned transport operator than to a retrieval mechanism.

6. Where Canon goes in a Transformer block

The paper studies four insertion points. For hidden width $d$, Canon-ABCD means:

Canon-A/B/C/D insertion points in a pre-norm Transformer block.
Canon-A/B/C/D insertion points in a pre-norm Transformer block.

Canon-Aafter attention RMSNorm, before Q/K/V projections; width $m=d$
Canon-Bafter Q/K/V projections, on the concatenated projected representation; width $m=n_qd_h+2n_{kv}d_h$ for GQA
Canon-Cafter MLP RMSNorm, before the MLP projections; width $m=d$
Canon-Dinside the MLP, before activation; for gated MLPs it acts on concatenated gate/up branches

Canon-A: before attention

\[u = \operatorname{Canon}_A(\operatorname{Norm}(x)).\]

Then:

\[q=W_qu, \qquad k=W_ku, \qquad v=W_vu.\]

Attention receives token states that already contain a short causal neighborhood.

Canon-B: inside attention

After Q/K/V projection:

\[z_t=[q_t;k_t;v_t].\]

Canon-B applies local mixing to that projected representation:

\[z'_t=\operatorname{Canon}_B(z)_t, \qquad [q'_t;k'_t;v'_t]=z'_t.\]

For ordinary MHA with equal Q/K/V widths, $m=3d$. For grouped-query attention:

\[m=n_qd_h+2n_{kv}d_h.\]

The released LlamaCanon code computes exactly this total dimension before constructing canonB.

Canon-C: before the MLP

\[r = \operatorname{Canon}_C(\operatorname{Norm}(x^{(\ell+\frac12)})).\]

The MLP receives a locally enriched representation.

Canon-D: inside the MLP

For a gated MLP:

\[\operatorname{MLP}(r) = W_{\mathrm{down}} \left( \phi(W_{\mathrm{gate}}r) \odot W_{\mathrm{up}}r \right).\]

LlamaCanon concatenates the gate and up projections:

\[z_t=[g_t;u_t], \qquad z'_t=\operatorname{Canon}_D(z)_t, \qquad [g'_t;u'_t]=z'_t,\]

then computes:

\[W_{\mathrm{down}}\left(\phi(g'_t)\odot u'_t\right).\]

For a Llama-style gated MLP with intermediate width $\frac{8}{3}d$, Canon-D has width:

\[m=2\cdot\frac{8}{3}d=\frac{16}{3}d.\]

7. Canon-ABCD pseudocode

The same residual local mixer appears at different internal representations:

def canon_residual(x, canon_conv):
    # x: [batch, seq, channels]
    # canon_conv: causal depthwise Conv1d over the sequence dimension
    return x + canon_conv(x)

For a Llama-style pre-norm block with grouped-query attention and a gated MLP:

def llama_block_with_canon(x, mask=None, cache=None):
    # x: [B, T, d]

    residual = x
    h = rmsnorm_attn(x)

    if canonA is not None:
        h = canon_residual(h, canonA)          # [B, T, d]

    q = q_proj(h)
    k = k_proj(h)
    v = v_proj(h)

    if canonB is not None:
        qkv = concat([q, k, v], dim=-1)
        qkv = canon_residual(qkv, canonB)
        q, k, v = split(qkv, [q_dim, k_dim, v_dim], dim=-1)

    q = apply_rope(q)
    k = apply_rope(k)
    a = causal_attention(q, k, v, mask=mask, cache=cache)
    x = residual + o_proj(a)

    residual = x
    h = rmsnorm_mlp(x)

    if canonC is not None:
        h = canon_residual(h, canonC)          # [B, T, d]

    gate = gate_proj(h)
    up = up_proj(h)

    if canonD is not None:
        z = concat([gate, up], dim=-1)
        z = canon_residual(z, canonD)
        gate, up = z.chunk(2, dim=-1)

    x = residual + down_proj(silu(gate) * up)
    return x

Partial variants such as Canon-AC, Canon-ACD, or Canon-ABC are also meaningful. The paper’s ablations find that the benefits are cumulative, and that Canon-ACD can help even without modifying the attention projections.

8. Tensor shapes for the core mixer

The minimal PyTorch version for a [B,T,D] tensor is:

import torch
import torch.nn as nn
import torch.nn.functional as F


class CanonResidualMixer(nn.Module):
    def __init__(self, channels: int, kernel_size: int = 4):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            groups=channels,
            bias=False,
        )

    def forward(self, x):
        # x: [B, T, D]
        xt = x.transpose(1, 2)                    # [B, D, T]
        xt = F.pad(xt, (self.kernel_size - 1, 0)) # [B, D, T + K - 1]
        mixed = self.conv(xt)                     # [B, D, T]
        mixed = mixed.transpose(1, 2)             # [B, T, D]
        return x + mixed

Shape summary:

Step Shape Meaning
input x [B,T,D] Transformer layout
transpose [B,D,T] Conv1d layout
left pad by K-1 [B,D,T+K-1] causal boundary handling
depthwise conv [B,D,T] local sequence mixing
transpose back [B,T,D] Transformer layout
residual add [B,T,D] unchanged external shape

For $K=4$, one channel computes:

\[\operatorname{mixed}_{b,t,c} = a_{c,0}x_{b,t-3,c} +a_{c,1}x_{b,t-2,c} +a_{c,2}x_{b,t-1,c} +a_{c,3}x_{b,t,c}.\]

Then:

\[y_{b,t,c}=x_{b,t,c}+\operatorname{mixed}_{b,t,c}.\]

9. Why groups=channels matters

With depthwise convolution:

nn.Conv1d(D, D, K, groups=D)

the parameter tensor has shape:

\[[D,1,K],\]

so parameters scale as:

\[DK.\]

With full convolution:

nn.Conv1d(D, D, K, groups=1)

the parameter tensor has shape:

\[[D,D,K],\]

so parameters scale as:

\[D^2K.\]

For $D=4096$ and $K=4$:

\[DK=16{,}384, \qquad D^2K\approx 67\text{ million}.\]

That gap is the reason Canon isolates local sequence transport from channel mixing. Channel mixing remains in the projections and MLP, where it already exists.

Depthwise Canon uses one short causal filter per channel. A full local convolution would mix every channel into every output channel.
Depthwise Canon uses one short causal filter per channel. A full local convolution would mix every channel into every output channel.

10. Complexity and runtime

For batch $B$, sequence length $T$, hidden width $D$, and small kernel $K$:

\[\operatorname{cost}_{\mathrm{Canon}} = O(BTDK).\]

The attention matrix/value aggregation term is roughly:

\[\operatorname{cost}_{\mathrm{attention}} = O(BT^2D),\]

plus projection costs.

Asymptotically, Canon is tiny. Practically, it is not free: every additional operator can add memory movement and kernel-launch overhead. The Part 4.1 paper reports that Canon-ABCD adds fewer than $0.45\%$ parameters for GPT-2-small, and for a 1.3B Llama-style model it adds about $0.0063\%$ parameters. The same footnote reports nonzero naive H100 runtime overheads, with Canon-AC cheaper than Canon-ABCD.

The released code uses a ShortConvolution wrapper with causal_conv1d when available and when the kernel is in ${2,3,4}$. During generation, the convolution cache stores only the last $K$ states per channel:

\[\text{cache shape}=[B,D,K].\]

11. The synthetic playground

The paper argues that academic-scale real-data pretraining can be too noisy for architecture science. Perplexity mixes many skills together; benchmark swings can hide whether an architecture improved reasoning, knowledge storage, local composition, or something else.

The Part 4.1 experiments therefore use five controlled synthetic pretraining tasks:

Task Capability Core requirement
Depo reasoning depth follow a directed permutation for $k$ hops
Brevo reasoning breadth process recursive dependencies in a DAG
Capo knowledge capacity store synthetic facts in parameters
Mano knowledge manipulation retrieve learned facts and compute over them
Lano hierarchical structure learn CFG-like recursive constraints

The point is not that synthetic tasks are the final benchmark. They isolate mechanisms. If a change improves Depo but not Capo, or helps NoPE but not RoPE, the result is easier to interpret than a single mixed-corpus loss number.

Depo: depth

Depo builds a directed permutation from key-value pairs:

<bos> x1 y1 x2 y2 ... xn yn <query_k> q <ans> a <eos>

If the pairs define $f(x_i)=y_i$, the target is:

\[a=f^{(k)}(q).\]

The model must compute the $k$-hop successor internally, without writing intermediate chain-of-thought tokens. Depo2 makes each node span multiple tokens, so a 4-token Canon window cannot solve the task by direct copying. The local mixer must improve segment representations that attention can later chain globally.

Brevo: breadth

Brevo gives the model a directed acyclic graph and asks for recursive dependencies in topological order. The hard part is not one long chain; it is parallel dependency processing across branches.

Capo: capacity

Capo measures reliable storage of synthetic facts, often as bits per parameter. Limited-exposure regimes are important because overtraining can hide architectural differences.

Mano: manipulation

Mano uses modular arithmetic expressions. The model must retrieve learned operation tables and compose them internally. This tests manipulation of knowledge stored in weights rather than only information present in the prompt.

Lano: structure

Lano uses CFG-like sequences with local ambiguity. Correct prediction can require maintaining recursive global structure rather than memorizing nearby tokens.

12. What the results imply

For Transformer-style models, the Part 4.1 paper reports that Canon-ABCD improves reasoning depth by roughly $2$-$4\times$ in the controlled setup, reasoning breadth by about $30\%$, knowledge manipulation length by about $30\%$, and knowledge capacity in limited-exposure factual-storage regimes.

The strongest interpretation is not that Canon solves every task inside a four-token window. It is that better local representations make later global routing easier.

The NoPE result is especially interesting. NoPE means no positional embedding. Without positional encoding, a Transformer has weak order information. With Canon, NoPE becomes far stronger, often competitive with RoPE+Canon in the reported synthetic setup. A causal convolution injects order-sensitive local structure:

\[h_t \leftarrow h_t+f(h_t,h_{t-1},h_{t-2},h_{t-3}).\]

The paper also studies partial RoPE. With Canon present, reduced-RoPE variants can work well, which matters because heavy RoPE usage can hurt length generalization.

13. Linear models and SSMs

The paper compares Transformers, GLA, Mamba2, and GDN under the same synthetic tasks. A useful takeaway is that local convolution-like components inside some linear/SSM architectures already explain a lot of their behavior.

In the paper’s terminology:

  • Mamba2’s internal conv1d resembles a partial non-residual Canon-B;
  • GLA and GDN implementations also contain conv-like local components;
  • adding Canon systematically makes comparisons fairer because every model receives the same local-transport primitive.

After adding Canon broadly, linear models still tend to lag full-attention Transformers on deep retrieval-heavy reasoning. The diagnosis is not only state size. The harder problem is memory dynamics: compressed recurrent state must preserve and retrieve fine-grained facts across multiple hops without compounding errors.

The Part 4.2 code release extends the story to real-world pretraining recipes and released model families, including LlamaCanon, GLA, GDN, and Mamba2 variants.

Primer

Primer introduced squared ReLU and a depthwise convolution after Q/K/V projection. The Q/K/V convolution part is closest to Canon-B without the residual path:

\[q'=\operatorname{DWConv}(W_qx), \qquad k'=\operatorname{DWConv}(W_kx), \qquad v'=\operatorname{DWConv}(W_vx).\]

Canon generalizes the idea in three ways:

  1. it adds an explicit residual around the local mixer;
  2. it applies the primitive at A/B/C/D, not only Q/K/V;
  3. it studies the primitive across Transformers, linear attention, and SSM-style models.

Longformer-style local attention

Longformer sparsifies attention with sliding windows and task-specific global attention. Canon works on a different axis.

Local attention asks:

\[\text{which nearby tokens should I retrieve from?}\]

Canon asks:

\[\text{what nearby hidden signal should be cheaply propagated?}\]

They can coexist.

Mamba2

Mamba2 is built around state-space duality and selective SSM computation. Its local convolution is a frontend to a recurrent/SSM memory system:

\[x'_t=\operatorname{Conv1D}(x_{t-K+1:t}), \qquad h_t=A_th_{t-1}+B_tx'_t, \qquad y_t=C_t^\top h_t.\]

Canon isolates the local convolutional part as a reusable residual primitive that can be applied outside a specific SSM block.

Uniform attention

Earlier Physics of Language Models work found that uniform averaging over recent tokens could help CFG-style tasks. Canon can be viewed as a learned, channelwise, modular version of that local averaging:

\[\text{uniform local average} \quad\rightarrow\quad \text{learned channelwise local residual convolution}.\]

15. Implementation details from LlamaCanon

The released LlamaCanon helper uses a ShortConvolution module:

ShortConvolution(
    hidden_size=dim,
    kernel_size=config.canon_kernel,
    bias=config.canon_bias,
    activation="silu" if config.canon_activation else None,
    use_fast_conv1d=causal_conv1d_available and config.canon_kernel in [2, 3, 4],
)

It is dimension-last at the interface:

\[x\in\mathbb{R}^{B\times T\times D},\]

then rearranges internally to Conv1d layout:

\[x\in\mathbb{R}^{B\times D\times T}.\]

The helper masks padded positions, uses the fast causal_conv1d kernel when available, and supports decode-time cache updates through a [B,D,K] state.

The code exposes:

  • canon_set, selecting any subset of A, B, C, D;
  • canon_kernel, usually $4$;
  • canon_residual, controlling whether the output is hidden_states + hidden_states2;
  • canon_activation, available but not recommended by the paper for Transformer Canon layers;
  • canon_bias, generally avoided.

For packed or padded batches, Canon must respect the same valid-token mask as attention. Otherwise, a causal convolution can propagate padding artifacts into valid positions.

16. Practical choices

Initialization

There are several reasonable options:

  1. Default initialization. This matches the released implementation path.
  2. Zero initialization. This makes Canon an exact identity at step zero:

    \[y=x+0=x.\]

    That is useful when retrofitting Canon into an already trained model.

  3. Past-average initialization. For $K=4$, initialize previous offsets to $\frac13$ and current offset to $0$:

    \[y_t=x_t+\frac13(x_{t-1}+x_{t-2}+x_{t-3}).\]

    This tests the local-context hypothesis directly, but it is a design choice rather than the default released setup.

Causal padding

Use left padding:

F.pad(x, (K - 1, 0))

Right padding would either shift outputs incorrectly or leak future information.

Optimized kernels

torch.nn.Conv1d is the generic API. It is not automatically the optimized Dao-AILab causal-conv1d path. The implementation must call that package explicitly, as LlamaCanon’s helper does.

17. Open engineering questions

Runtime overhead

Parameter overhead is tiny, but runtime overhead is still real. Multiple small convolutions can add memory traffic and kernel launches. A production implementation would likely fuse Canon with adjacent projections or batch several Canon calls together.

Dynamic Canon

The studied operator uses fixed learned weights. A dynamic version could use input-conditioned local weights:

\[y_t = x_t + \sum_{r=0}^{K-1}a_r(x_t)\odot x_{t-r}.\]

That moves Canon closer to lightweight local attention. It may improve expressivity, but it also changes the clean cost and interpretation.

MoE interaction

Canon-D inside a mixture-of-experts MLP is awkward because neighboring tokens may be routed to different experts. Canon-ABC is easier; Canon-D requires a more careful dispatch design.

Long-range compression

Canon improves local flow. It does not remove the hard problem of preserving high-fidelity information through compressed recurrent state or across very long contexts. The paper’s linear-model results still suggest that full attention remains stronger for some deep in-context reasoning tasks.

18. Summary

Canon Layers are lightweight residual causal convolutions over neighboring token representations:

\[h'_t = h_t + \sum_{r=0}^{K-1}w_r\odot h_{t-r}.\]

The architecture split is clean:

  • attention handles content-addressed global routing;
  • MLPs handle channelwise nonlinear transformation;
  • Canon handles cheap local token-to-token propagation.

The empirical claim from the Canon paper is that this small primitive improves controlled measures of reasoning depth, reasoning breadth, knowledge manipulation, NoPE viability, and several linear/SSM architectures. The implementation claim is equally simple: Canon is depthwise causal Conv1D with a residual path, placed at selected A/B/C/D points inside a block.

Canon is not interesting because “convolution is back.” It is interesting because local horizontal flow is useful enough to deserve its own architectural slot.

References