Source code for grilly.optim.natural_gradient

"""
Natural Gradient Optimizer

Uses: fisher-natural-gradient.glsl

Implements natural gradient descent using Fisher information matrix.
"""

from collections.abc import Iterator

import numpy as np

from .base import Optimizer


[docs]class NaturalGradient(Optimizer): """ Natural Gradient optimizer using Fisher information matrix. Uses: fisher-natural-gradient.glsl Implements natural gradient descent: - F = Fisher information matrix - param = param - lr * F^(-1) * grad Reference: grilly/backend/learning.py natural_gradient """
[docs] def __init__( self, params: Iterator[np.ndarray], lr: float = 1e-3, fisher_momentum: float = 0.9, use_gpu: bool = True, ): """ Initialize Natural Gradient optimizer. Args: params: Iterator of parameter arrays to optimize lr: Learning rate (default: 1e-3) fisher_momentum: Momentum for Fisher information estimate (default: 0.9) use_gpu: Whether to use GPU acceleration (default: True) """ defaults = { "lr": lr, "fisher_momentum": fisher_momentum, } super().__init__(params, defaults) self.use_gpu = use_gpu self._backend = None
[docs] def _get_backend(self): """Get or create backend instance""" if self._backend is None: try: from grilly import Compute self._backend = Compute() except Exception: self._backend = None return self._backend
[docs] def step(self, closure=None): """ Perform a single optimization step. Args: closure: Optional closure that reevaluates the model and returns loss """ loss = None if closure is not None: loss = closure() backend = self._get_backend() use_gpu = self.use_gpu and backend is not None lr = self.defaults["lr"] fisher_momentum = self.defaults["fisher_momentum"] for group in self.param_groups: for p in group["params"]: if p is None: continue param_id = id(p) state = self.state[param_id] # Initialize Fisher information if needed if "fisher" not in state: state["fisher"] = np.ones_like(p, dtype=np.float32) * 1e-6 fisher = state["fisher"] # Get gradients grad = getattr(p, "grad", None) if grad is None: continue # Get parameter data p_data = p.data if hasattr(p, "data") and not isinstance(p, np.ndarray) else p # Ensure numpy array if not isinstance(p_data, np.ndarray): p_data = np.array(p_data, dtype=np.float32) # Try GPU update if available if use_gpu and backend is not None and hasattr(backend, "learning"): try: # Update Fisher information if hasattr(backend.learning, "fisher_info_update"): fisher = backend.learning.fisher_info_update( grad.flatten(), fisher.flatten(), momentum=fisher_momentum ) fisher = fisher.reshape(p.shape) state["fisher"] = fisher # Apply natural gradient if hasattr(backend.learning, "natural_gradient"): natural_grad = backend.learning.natural_gradient( grad.flatten(), fisher.flatten(), learning_rate=lr, epsilon=1e-8 ) natural_grad = natural_grad.reshape(p.shape) p_data -= natural_grad # natural_gradient already includes lr # Update parameter (handle wrapper or direct numpy array) if hasattr(p, "data") and not isinstance(p, np.ndarray): # Parameter wrapper or custom class p.data = p_data else: # Direct numpy array - update in-place p[:] = p_data # Clear gradient after update if hasattr(p, "grad") and p.grad is not None: if hasattr(p, "zero_grad"): p.zero_grad() else: p.grad = None continue except Exception as e: import logging logger = logging.getLogger(__name__) logger.debug( f"GPU Natural Gradient update failed: {e}, falling back to CPU" ) pass # Fall back to CPU # CPU fallback # Update Fisher information (simplified - use gradient squared) fisher = fisher_momentum * fisher + (1 - fisher_momentum) * (grad**2) state["fisher"] = fisher # Apply natural gradient: F^(-1) * grad ≈ grad / (fisher + eps) eps = 1e-8 natural_grad = grad / (fisher + eps) p_data -= lr * natural_grad # Update parameter (handle wrapper or direct numpy array) if hasattr(p, "data") and not isinstance(p, np.ndarray): # Parameter wrapper or custom class p.data = p_data else: # Direct numpy array - update in-place p[:] = p_data # Clear gradient after update if hasattr(p, "grad") and p.grad is not None: if hasattr(p, "zero_grad"): p.zero_grad() else: p.grad = None return loss