Source code for grilly.optim.hypergradient

"""
Hypergradient Descent Optimizers

Implements online learning rate adaptation via hypergradient descent.

HypergradientAdamW: Basic hypergradient (Baydin et al. 2018).
    Fixed beta_hyper. Simple but requires tuning beta_hyper.

AutoHypergradientAdamW: OSGM-style auto adjustment (arXiv:2502.11229).
    Self-tuning via AdaGrad-stabilized hypergradients with gradient-norm
    normalization. No manual hypergradient LR tuning needed.
    Optional surprise signal: gradient prediction error exposed as
    `current_surprise` for input-level gain modulation. The model
    scales inputs by (1 + gain * surprise), amplifying signals when
    the optimization landscape shifts (e.g., SNN phase transitions).

The core idea: the learning rate is treated as a learnable parameter.
At each step, the hypergradient h = -g_k . d_{k-1} / ||g_{k-1}||^2
tells us whether to increase or decrease the learning rate based on
gradient agreement with the previous update direction.

References:
    [1] Baydin et al. "Online Learning Rate Adaptation with Hypergradient
        Descent" (ICLR 2018)
    [2] "Provable and Practical Online Learning Rate Adaptation with
        Hypergradient Descent" (arXiv:2502.11229)
    [3] "Gradient Methods with Online Scaling" (arXiv:2505.23081, 2509.11007)

Uses: adamw-update.glsl (via AdamW base class)
"""

from collections.abc import Iterator

import numpy as np

from .adamw import AdamW


[docs]def _collect_grads(param_groups, gradients=None): """Collect gradients from param groups into a dict keyed by param id.""" grads = {} for group in param_groups: for p in group["params"]: if p is None: continue param_id = id(p) grad = None if gradients is not None: grad = gradients.get(param_id, None) if grad is None: grad = getattr(p, "grad", None) if grad is None: continue if hasattr(grad, "data"): grad = grad.data if not isinstance(grad, np.ndarray): grad = np.array(grad, dtype=np.float32) grads[param_id] = grad return grads
[docs]def _compute_update_directions(param_groups, state, step_count, betas, eps): """Compute Adam update directions d = m_hat / (sqrt(v_hat) + eps).""" beta1, beta2 = betas directions = {} for group in param_groups: for p in group["params"]: if p is None: continue param_id = id(p) s = state.get(param_id, {}) if "exp_avg" not in s or "exp_avg_sq" not in s: continue sc = s.get("step", step_count) if sc == 0: continue m_hat = s["exp_avg"] / (1.0 - beta1**sc) v_hat = s["exp_avg_sq"] / (1.0 - beta2**sc) directions[param_id] = m_hat / (np.sqrt(v_hat) + eps) return directions
[docs]class HypergradientAdamW(AdamW): """AdamW with hypergradient-based online learning rate adaptation. Basic version from Baydin et al. (2018). Uses a fixed hypergradient learning rate beta_hyper. Simple but requires manual tuning of beta_hyper. For a self-tuning version, use AutoHypergradientAdamW. Update rule: alpha_{t+1} = alpha_t + beta_hyper * sum(g_t * d_{t-1}) Args: params: Iterator of parameter arrays to optimize lr: Initial learning rate (default: 1e-3) betas: Coefficients for running averages (default: (0.9, 0.999)) eps: Numerical stability term (default: 1e-8) weight_decay: Decoupled weight decay (default: 0.01) beta_hyper: Hypergradient learning rate (default: 1e-7) lr_min: Minimum learning rate clamp (default: 1e-6) lr_max: Maximum learning rate clamp (default: 1.0) log_scale: If True, adapt log(lr) instead of lr (default: False) use_gpu: Whether to use GPU acceleration (default: True) """
[docs] def __init__( self, params: Iterator[np.ndarray], lr: float = 1e-3, betas: tuple = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01, beta_hyper: float = 1e-7, lr_min: float = 1e-6, lr_max: float = 1.0, log_scale: bool = False, use_gpu: bool = True, ): super().__init__( params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, use_gpu=use_gpu, ) self.beta_hyper = beta_hyper self.lr_min = lr_min self.lr_max = lr_max self.log_scale = log_scale self._prev_directions = {} self._lr_history = [lr]
@property def current_lr(self): return self.defaults["lr"] @property def lr_history(self): return self._lr_history
[docs] def step(self, closure=None, gradients=None): current_grads = _collect_grads(self.param_groups, gradients) # Compute hypergradient: sum(g_t . d_{t-1}) hypergradient = 0.0 n_active = 0 for pid, grad in current_grads.items(): if pid in self._prev_directions: hypergradient += np.sum(grad * self._prev_directions[pid]) n_active += 1 if n_active > 0: if self.log_scale: log_lr = np.log(self.defaults["lr"]) log_lr += self.beta_hyper * hypergradient new_lr = float(np.exp(np.clip(log_lr, np.log(self.lr_min), np.log(self.lr_max)))) else: new_lr = self.defaults["lr"] + self.beta_hyper * hypergradient new_lr = float(np.clip(new_lr, self.lr_min, self.lr_max)) self.defaults["lr"] = new_lr for group in self.param_groups: group["lr"] = new_lr self._lr_history.append(new_lr) loss = super().step(closure=closure, gradients=gradients) self._prev_directions = _compute_update_directions( self.param_groups, self.state, self._step_count, self.defaults["betas"], self.defaults["eps"], ) return loss
def __repr__(self): return ( f"HypergradientAdamW(lr={self.defaults['lr']:.6f}, " f"beta_hyper={self.beta_hyper}, " f"lr_range=[{self.lr_min}, {self.lr_max}])" )
[docs]class AutoHypergradientAdamW(AdamW): """AdamW with OSGM-style auto hypergradient adjustment. Self-tuning optimizer that automatically adapts the learning rate (and optionally momentum beta1) using online hypergradient descent with AdaGrad-stabilized updates. No manual hypergradient LR tuning needed — the AdaGrad accumulator self-adjusts the meta-learning rate. Based on the OSGM/HDM algorithm: Step size hypergradient (how lr should change): h_lr = -g_k . d_{k-1} / (||g_{k-1}||^2 + eps) G_lr += h_lr^2 lr -= hyper_lr * h_lr / (sqrt(G_lr) + eps) Momentum hypergradient (how beta1 should change): h_beta = g_k . m_{k-1} / (||g_{k-1}||^2 + eps) G_beta += h_beta^2 beta1 -= hyper_lr_beta * h_beta / (sqrt(G_beta) + eps) The gradient-norm normalization (/ ||g||^2) makes the algorithm scale-invariant, and the AdaGrad accumulator makes the meta-LR self-adjusting — larger past hypergradients automatically slow down future adaptation, preventing oscillation. Particularly effective for SNN training where surrogate gradients are noisy and the optimal learning rate shifts during training. Surprise signal (optional, input-level): Tracks gradient prediction error as a "surprise" signal and exposes it for the model to use as input gain modulation. Unlike backprop-level momentum changes, this acts at the forward-pass level — amplifying input signals when the optimization landscape shifts unexpectedly. Instant surprise (gradient prediction error): S_instant = tanh(||g_k - EMA(g)||^2 / (EMA(||g||^2) + eps)) Accumulated surprise (biological momentum / S_bar): S_bar = alpha * S_instant + (1-alpha) * S_bar_prev Inverted-U gain (Yerkes-Dodson / trauma protection): gain = S_bar * exp(-S_bar / trauma_threshold) The inverted-U curve implements the biological stress response: - Low S_bar → low gain (nothing interesting) - Moderate S_bar → peak gain (optimal learning zone) - High S_bar → gain drops (trauma protection) This prevents "unerasable events" — if surprise stays high for many consecutive steps (chronic stress), the gain suppresses instead of amplifying, protecting the model from fixating on a single extreme event. Mirrors the HPA axis: acute stress enhances encoding, chronic stress impairs plasticity. The model reads `current_surprise_gain` for input scaling: x_effective = x * (1 + scale * optimizer.current_surprise_gain) Args: params: Iterator of parameter arrays to optimize lr: Initial learning rate (default: 1e-3) betas: Coefficients for running averages (default: (0.9, 0.999)) eps: Numerical stability term (default: 1e-8) weight_decay: Decoupled weight decay (default: 0.01) hyper_lr: Meta-learning rate for step size adaptation (default: 0.01). This is automatically modulated by the AdaGrad accumulator, so it's much less sensitive than HypergradientAdamW's beta_hyper. hyper_lr_beta: Meta-learning rate for momentum adaptation (default: 1.0). Only used when adapt_momentum=True. lr_min: Minimum learning rate clamp (default: 1e-6) lr_max: Maximum learning rate clamp (default: 1.0) adapt_momentum: If True, also adapt beta1 via hypergradient (default: False) track_surprise: If True, compute and expose gradient surprise signal via `current_surprise_gain` (default: False). The model's forward pass should read this to modulate input gain. surprise_gamma: EMA decay for gradient tracking (default: 0.9). Higher = smoother baseline, slower to detect change. surprise_alpha: EMA decay for surprise accumulation S_bar (default: 0.1). Controls how fast accumulated surprise builds up and decays. Lower = longer memory of surprise. trauma_threshold: S_bar level where gain peaks before suppression (default: 0.5). The inverted-U gain = S_bar * exp(-S_bar/T) peaks at S_bar = T. Above this, gain decreases (protection). beta_min: Minimum beta1 clamp (default: 0.5) beta_max: Maximum beta1 clamp (default: 0.9995) warmup_steps: Steps before starting adaptation (default: 10). Lets Adam moments initialize before adapting LR. use_gpu: Whether to use GPU acceleration (default: True) """
[docs] def __init__( self, params: Iterator[np.ndarray], lr: float = 1e-3, betas: tuple = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01, hyper_lr: float = 0.01, hyper_lr_beta: float = 1.0, lr_min: float = 1e-6, lr_max: float = 1.0, adapt_momentum: bool = False, track_surprise: bool = False, surprise_gamma: float = 0.9, surprise_alpha: float = 0.1, trauma_threshold: float = 0.5, beta_min: float = 0.5, beta_max: float = 0.9995, warmup_steps: int = 10, use_gpu: bool = True, ): super().__init__( params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, use_gpu=use_gpu, ) self.hyper_lr = hyper_lr self.hyper_lr_beta = hyper_lr_beta self.lr_min = lr_min self.lr_max = lr_max self.adapt_momentum = adapt_momentum self.track_surprise = track_surprise self.surprise_gamma = surprise_gamma self.surprise_alpha = surprise_alpha self.trauma_threshold = trauma_threshold self.beta_min = beta_min self.beta_max = beta_max self.warmup_steps = warmup_steps # AdaGrad accumulators for hypergradient stabilization. # Seeded to 1.0 (not 0.0) so the first adaptive step doesn't apply # the full hyper_lr as lr_delta. With G_lr=0 and hyper_lr=0.01, # the first step would be lr_delta=±0.01, a 10x jump when lr=0.001. # Seeding to 1.0 bounds the first step to hyper_lr*h/(sqrt(1+h^2)). self._G_lr = 1.0 self._G_beta = 1.0 self._adagrad_eps = 1e-12 self._meta_decay = 0.99 # RMSProp-style decay for meta-accumulator # Previous step state for hypergradient computation self._prev_directions = {} # d_{k-1}: Adam update directions self._prev_grad_norm_sq = 0.0 # ||g_{k-1}||^2 self._prev_first_moments = {} # m_{k-1}: for momentum adaptation # Surprise signal state (input-level, not backprop-level). # Tracks gradient prediction error as a neuromodulatory signal. # The optimizer computes surprise; the model reads current_surprise_gain # for input scaling via the inverted-U (Yerkes-Dodson) curve. self._grad_ema = {} # EMA of gradients (per-param) self._grad_var_ema = 0.0 # EMA of ||g||^2 (scalar) self._current_surprise = 0.0 # instant surprise [0, 1] self._s_bar = 0.0 # accumulated surprise (biological momentum) self._current_gain = 0.0 # inverted-U modulated gain self._surprise_history = [] # instant surprise history self._s_bar_history = [] # accumulated surprise history # History for monitoring / plotting self._lr_history = [lr] self._beta1_history = [betas[0]]
@property def current_lr(self): return self.defaults["lr"] @property def current_surprise(self): """Instant surprise signal [0, 1]. Raw gradient prediction error.""" return self._current_surprise @property def accumulated_surprise(self): """Accumulated surprise S_bar. Biological momentum of surprise.""" return self._s_bar @property def current_surprise_gain(self): """Inverted-U gain signal for input-level modulation. Implements the Yerkes-Dodson curve / trauma protection: gain = S_bar * exp(-S_bar / trauma_threshold) - Low S_bar → low gain (nothing interesting happening) - Moderate S_bar → peak gain (optimal learning zone) - High S_bar → gain drops (trauma protection, don't fixate) Read this after each optimizer step and pass to the model: x_effective = x * (1 + scale * optimizer.current_surprise_gain) Returns 0.0 when surprise tracking is off or during warmup. """ return self._current_gain @property def lr_history(self): return self._lr_history @property def beta1_history(self): return self._beta1_history @property def surprise_history(self): return self._surprise_history @property def s_bar_history(self): return self._s_bar_history
[docs] def step(self, closure=None, gradients=None): """Perform optimization step with OSGM-style auto LR adaptation. 1. Collect current gradients g_k 2. Compute surprise signal (if track_surprise=True) 3. Compute normalized hypergradients (after warmup): h_lr = -g_k . d_{k-1} / ||g_{k-1}||^2 h_beta = g_k . m_{k-1} / ||g_{k-1}||^2 4. Update AdaGrad accumulators and adjust lr (and beta1) 5. Run standard AdamW step with adapted hyperparameters 6. Store d_k, ||g_k||^2, m_k for next step """ current_grads = _collect_grads(self.param_groups, gradients) # --- Surprise signal computation (input-level, not backprop) --- # Computes gradient prediction error and exposes it via # current_surprise. The model reads this to scale inputs. if self.track_surprise and self._step_count >= self.warmup_steps: gamma = self.surprise_gamma # Compute current gradient norm squared current_norm_sq = 0.0 for grad in current_grads.values(): current_norm_sq += np.sum(grad * grad) # Update gradient variance EMA: EMA(||g||^2) if self._grad_var_ema == 0.0 and self._step_count == self.warmup_steps: self._grad_var_ema = current_norm_sq else: self._grad_var_ema = gamma * self._grad_var_ema + (1.0 - gamma) * current_norm_sq # Compute surprise: ||g_k - EMA(g)||^2 / (EMA(||g||^2) + eps) prediction_error_sq = 0.0 for pid, grad in current_grads.items(): if pid in self._grad_ema: diff = grad - self._grad_ema[pid] prediction_error_sq += np.sum(diff * diff) else: prediction_error_sq += np.sum(grad * grad) raw_surprise = prediction_error_sq / (self._grad_var_ema + self._adagrad_eps) # Squash instant surprise to [0, 1] via tanh self._current_surprise = float(np.tanh(raw_surprise)) self._surprise_history.append(self._current_surprise) # Accumulate surprise: S_bar = alpha * S_instant + (1-alpha) * S_bar # This is the biological momentum — tracks sustained surprise. alpha = self.surprise_alpha self._s_bar = alpha * self._current_surprise + (1.0 - alpha) * self._s_bar self._s_bar_history.append(self._s_bar) # Inverted-U gain (Yerkes-Dodson / trauma protection): # gain = S_bar * exp(-S_bar / trauma_threshold) # Peaks at S_bar = trauma_threshold, suppresses above. # This prevents "unerasable events" — chronic high surprise # (trauma) reduces gain instead of amplifying it. T = self.trauma_threshold self._current_gain = float(self._s_bar * np.exp(-self._s_bar / T)) # Update gradient EMA for pid, grad in current_grads.items(): if pid in self._grad_ema: self._grad_ema[pid] = gamma * self._grad_ema[pid] + (1.0 - gamma) * grad.copy() else: self._grad_ema[pid] = grad.copy() # --- Hypergradient-based adaptation (after warmup) --- if ( self._step_count >= self.warmup_steps and self._prev_grad_norm_sq > self._adagrad_eps and self._prev_directions ): norm_sq = self._prev_grad_norm_sq # Hypergradient for learning rate: # h_lr = -sum(g_k * d_{k-1}) / ||g_{k-1}||^2 h_lr = 0.0 for pid, grad in current_grads.items(): if pid in self._prev_directions: h_lr -= np.sum(grad * self._prev_directions[pid]) h_lr /= norm_sq # Clip hypergradient to prevent outliers from poisoning accumulator. # Without clipping, a single step with tiny ||g_{k-1}||^2 (common # in LIF silent phases) can produce h_lr >> 1, permanently inflating # G_lr and freezing the meta-LR near zero. h_lr = float(np.clip(h_lr, -1.0, 1.0)) # RMSProp-style decay for meta-accumulator (replaces pure AdaGrad). # AdaGrad's "never forget" property means one outlier permanently # poisons G_lr. Decaying lets the optimizer recover from early # instability (especially important for LIF's noisy gradient start). self._G_lr = self._meta_decay * self._G_lr + (1.0 - self._meta_decay) * h_lr * h_lr lr_delta = self.hyper_lr * h_lr / (np.sqrt(self._G_lr) + self._adagrad_eps) # Rate-limit LR changes: max 10% relative change per step. # 50% allows 0.0005 to reach 0.01 in 7 steps (too fast for LIF). # 10% means 0.0005 reaches 0.005 in ~24 steps — gradual enough # for the model to signal back if the LR is getting too high. current_lr = self.defaults["lr"] target_lr = current_lr - lr_delta max_change = 0.1 * current_lr if abs(target_lr - current_lr) > max_change: direction = 1.0 if target_lr > current_lr else -1.0 target_lr = current_lr + direction * max_change new_lr = float(np.clip(target_lr, self.lr_min, self.lr_max)) self.defaults["lr"] = new_lr for group in self.param_groups: group["lr"] = new_lr self._lr_history.append(new_lr) # Hypergradient for momentum (beta1): if self.adapt_momentum and self._prev_first_moments: h_beta = 0.0 for pid, grad in current_grads.items(): if pid in self._prev_first_moments: h_beta += np.sum(grad * self._prev_first_moments[pid]) h_beta /= norm_sq h_beta = float(np.clip(h_beta, -1.0, 1.0)) self._G_beta = ( self._meta_decay * self._G_beta + (1.0 - self._meta_decay) * h_beta * h_beta ) beta_delta = ( self.hyper_lr_beta * h_beta / (np.sqrt(self._G_beta) + self._adagrad_eps) ) beta1, beta2 = self.defaults["betas"] new_beta1 = float( np.clip( beta1 - beta_delta, self.beta_min, self.beta_max, ) ) self.defaults["betas"] = (new_beta1, beta2) for group in self.param_groups: group["betas"] = (new_beta1, beta2) self._beta1_history.append(new_beta1) # --- Compute current gradient norm for next step --- grad_norm_sq = 0.0 for grad in current_grads.values(): grad_norm_sq += np.sum(grad * grad) # --- Run standard AdamW step --- loss = super().step(closure=closure, gradients=gradients) # --- Store state for next step's hypergradient --- self._prev_directions = _compute_update_directions( self.param_groups, self.state, self._step_count, self.defaults["betas"], self.defaults["eps"], ) self._prev_grad_norm_sq = grad_norm_sq if self.adapt_momentum: self._prev_first_moments = {} for group in self.param_groups: for p in group["params"]: if p is None: continue pid = id(p) s = self.state.get(pid, {}) if "exp_avg" in s: self._prev_first_moments[pid] = s["exp_avg"].copy() return loss
def __repr__(self): beta1 = self.defaults["betas"][0] parts = [ f"AutoHypergradientAdamW(lr={self.defaults['lr']:.6f}", f"beta1={beta1:.4f}", f"hyper_lr={self.hyper_lr}", f"lr_range=[{self.lr_min}, {self.lr_max}]", f"adapt_momentum={self.adapt_momentum}", f"track_surprise={self.track_surprise}", ] return ", ".join(parts) + ")"