"""
NLMS (Normalized Least Mean Squares) Optimizer
Uses: nlms-update.glsl
Reference: ref/brain/specialist.py NLMSExpertHead
"""
from collections.abc import Iterator
import numpy as np
from .base import Optimizer
[docs]class NLMS(Optimizer):
"""
NLMS (Normalized Least Mean Squares) optimizer.
Uses: nlms-update.glsl
Implements adaptive filtering with normalized learning rate:
- w = w + mu * error * x / (||x||^2 + eps)
Reference: ref/brain/specialist.py NLMSExpertHead
"""
[docs] def __init__(
self,
params: Iterator[np.ndarray],
lr: float = 0.5,
lr_decay: float = 0.99995,
lr_min: float = 0.1,
eps: float = 1e-6,
use_gpu: bool = True,
):
"""
Initialize NLMS optimizer.
Args:
params: Iterator of parameter arrays to optimize
lr: Initial learning rate (mu) (default: 0.5)
lr_decay: Learning rate decay factor (default: 0.99995)
lr_min: Minimum learning rate (default: 0.1)
eps: Small constant for numerical stability (default: 1e-6)
use_gpu: Whether to use GPU acceleration (default: True)
"""
defaults = {
"lr": lr,
"lr_decay": lr_decay,
"lr_min": lr_min,
"eps": eps,
}
super().__init__(params, defaults)
self.use_gpu = use_gpu
self._backend = None
[docs] def _get_backend(self):
"""Get or create backend instance"""
if self._backend is None:
try:
from grilly import Compute
self._backend = Compute()
except Exception:
self._backend = None
return self._backend
[docs] def step(self, closure=None):
"""
Perform a single optimization step.
Args:
closure: Optional closure that reevaluates the model and returns loss
"""
loss = None
if closure is not None:
loss = closure()
backend = self._get_backend()
use_gpu = self.use_gpu and backend is not None
for group in self.param_groups:
lr = group.get("lr", self.defaults["lr"])
lr_decay = self.defaults["lr_decay"]
lr_min = self.defaults["lr_min"]
eps = self.defaults["eps"]
for p in group["params"]:
if p is None:
continue
param_id = id(p)
state = self.state[param_id]
# Initialize state if needed
if len(state) == 0:
state["mu"] = lr
state["mu_initial"] = lr
state["update_count"] = 0
mu = state["mu"]
# Get gradients (assumed to be stored in p.grad)
# For NLMS, we need both the gradient and the input
# In practice, this would come from the forward pass
grad = getattr(p, "grad", None)
if grad is None:
continue
# Get parameter data
p_data = p.data if hasattr(p, "data") and not isinstance(p, np.ndarray) else p
# Ensure numpy array
if not isinstance(p_data, np.ndarray):
p_data = np.array(p_data, dtype=np.float32)
# Try GPU update if available
if use_gpu and backend is not None and hasattr(backend, "learning"):
try:
# NLMS requires features and target, not just gradients
# For standard optimizer use, we approximate:
# - grad ≈ error * x (gradient already contains error-weighted input)
# - We normalize by ||grad||^2 to approximate ||x||^2
# Check if nlms-update shader is available
if hasattr(backend.learning, "nlms_update"):
# For NLMS, we need to extract features from gradient
# In practice, NLMS is used with explicit features and targets
# For optimizer use, we'll use a simplified GPU-accelerated version
# that normalizes by gradient norm
# Compute gradient norm for normalization
grad_flat = grad.flatten()
norm_sq = np.dot(grad_flat, grad_flat) + eps
# Use GPU for the update computation if possible
# For now, use CPU fallback with efficient NumPy operations
pass
except Exception:
pass
# CPU fallback (simplified - assumes grad is the error-weighted input)
# In real NLMS: w = w + mu * error * x / (||x||^2 + eps)
# Here we approximate: grad ≈ error * x, so we normalize by ||grad||^2
grad_flat = grad.flatten()
norm_sq = np.dot(grad_flat, grad_flat) + eps
step = mu / norm_sq
p_data -= step * grad
# Update parameter (handle wrapper or direct numpy array)
if hasattr(p, "data") and not isinstance(p, np.ndarray):
# Parameter wrapper or custom class
p.data = p_data
else:
# Direct numpy array - update in-place
p[:] = p_data
# Decay learning rate
if mu > lr_min:
state["mu"] = mu * lr_decay
state["update_count"] += 1
# Clear gradient after update
if hasattr(p, "grad") and p.grad is not None:
if hasattr(p, "zero_grad"):
p.zero_grad()
else:
p.grad = None
return loss