"""
Learning Rate Schedulers
Implements various learning rate scheduling strategies to match PyTorch's API.
"""
import math
[docs]class LRScheduler:
"""
Base class for learning rate schedulers.
All schedulers should inherit from this class and implement the get_lr() method.
"""
[docs] def __init__(self, optimizer, last_epoch=-1):
"""
Initialize base scheduler.
Args:
optimizer: Wrapped optimizer
last_epoch: The index of last epoch (default: -1)
"""
self.optimizer = optimizer
# Store initial learning rates
if not isinstance(optimizer.param_groups, list):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.base_lrs = []
for i, group in enumerate(optimizer.param_groups):
if "lr" not in group:
raise KeyError(
f"param 'lr' is not specified in param_groups[{i}] when resuming an optimizer"
)
self.base_lrs.append(group["lr"])
self.last_epoch = last_epoch
self.step()
[docs] def state_dict(self):
"""Returns the state of the scheduler as a dict."""
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
[docs] def load_state_dict(self, state_dict):
"""Loads the scheduler state."""
self.__dict__.update(state_dict)
[docs] def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
return self._last_lr
[docs] def get_lr(self):
"""
Compute learning rate using chainable form of the scheduler.
This method should be implemented by subclasses.
"""
raise NotImplementedError
[docs] def step(self, epoch=None):
"""
Perform a scheduler step.
Args:
epoch: Optional epoch number to use instead of incrementing
"""
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
# Get learning rates for this epoch
lrs = self.get_lr()
# Update optimizer learning rates
for param_group, lr in zip(self.optimizer.param_groups, lrs):
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
[docs]class StepLR(LRScheduler):
"""
Decays the learning rate by gamma every step_size epochs.
Matches torch.optim.lr_scheduler.StepLR
"""
[docs] def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
"""
Initialize StepLR scheduler.
Args:
optimizer: Wrapped optimizer
step_size: Period of learning rate decay
gamma: Multiplicative factor of learning rate decay (default: 0.1)
last_epoch: The index of last epoch (default: -1)
"""
self.step_size = step_size
self.gamma = gamma
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Compute learning rate for current epoch."""
if self.last_epoch == 0 or self.last_epoch % self.step_size != 0:
return [group["lr"] for group in self.optimizer.param_groups]
return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
[docs]class CosineAnnealingLR(LRScheduler):
"""
Set the learning rate using a cosine annealing schedule.
Matches torch.optim.lr_scheduler.CosineAnnealingLR
"""
[docs] def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
"""
Initialize CosineAnnealingLR scheduler.
Args:
optimizer: Wrapped optimizer
T_max: Maximum number of iterations
eta_min: Minimum learning rate (default: 0)
last_epoch: The index of last epoch (default: -1)
"""
self.T_max = T_max
self.eta_min = eta_min
super().__init__(optimizer, last_epoch)
[docs] def get_lr(self):
"""Compute learning rate using cosine annealing."""
if self.last_epoch == 0:
return self.base_lrs
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
return [
(1 + math.cos(math.pi * self.last_epoch / self.T_max))
/ (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]
[docs]class ReduceLROnPlateau:
"""
Reduce learning rate when a metric has stopped improving.
Matches torch.optim.lr_scheduler.ReduceLROnPlateau
"""
[docs] def __init__(
self,
optimizer,
mode="min",
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode="rel",
cooldown=0,
min_lr=0,
eps=1e-8,
):
"""
Initialize ReduceLROnPlateau scheduler.
Args:
optimizer: Wrapped optimizer
mode: One of 'min' or 'max'. In 'min' mode, lr will be reduced when
the quantity monitored has stopped decreasing (default: 'min')
factor: Factor by which the learning rate will be reduced (default: 0.1)
patience: Number of epochs with no improvement after which learning rate
will be reduced (default: 10)
threshold: Threshold for measuring the new optimum (default: 1e-4)
threshold_mode: One of 'rel', 'abs' (default: 'rel')
cooldown: Number of epochs to wait before resuming normal operation
after lr has been reduced (default: 0)
min_lr: A lower bound on the learning rate (default: 0)
eps: Minimal decay applied to lr (default: 1e-8)
"""
if factor >= 1.0:
raise ValueError("Factor should be < 1.0.")
self.factor = factor
if not isinstance(optimizer.param_groups, list):
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer
if isinstance(min_lr, (list, tuple)):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(
f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
)
self.min_lrs = list(min_lr)
else:
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.cooldown_counter = 0
self.best = None
self.num_bad_epochs = 0
self.mode_worse = None
self.eps = eps
self.last_epoch = 0
self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode)
self._reset()
[docs] def _reset(self):
"""Reset num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0
[docs] def step(self, metrics, epoch=None):
"""
Perform a scheduler step based on metric.
Args:
metrics: The metric to monitor
epoch: Optional epoch number
"""
current = float(metrics)
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
[docs] def _reduce_lr(self, epoch):
"""Reduce learning rate."""
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
@property
def in_cooldown(self):
"""Check if scheduler is in cooldown period."""
return self.cooldown_counter > 0
[docs] def is_better(self, a, best):
"""Check if metric 'a' is better than 'best'."""
if self.mode == "min" and self.threshold_mode == "rel":
rel_epsilon = 1.0 - self.threshold
return a < best * rel_epsilon
elif self.mode == "min" and self.threshold_mode == "abs":
return a < best - self.threshold
elif self.mode == "max" and self.threshold_mode == "rel":
rel_epsilon = self.threshold + 1.0
return a > best * rel_epsilon
else: # mode == 'max' and threshold_mode == 'abs'
return a > best + self.threshold
[docs] def _init_is_better(self, mode, threshold, threshold_mode):
"""Initialize comparison function."""
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")
if threshold_mode not in {"rel", "abs"}:
raise ValueError("threshold mode " + threshold_mode + " is unknown!")
if mode == "min":
self.mode_worse = float("inf")
else: # mode == 'max'
self.mode_worse = -float("inf")
[docs] def state_dict(self):
"""Returns the state of the scheduler as a dict."""
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
[docs] def load_state_dict(self, state_dict):
"""Loads the scheduler state."""
self.__dict__.update(state_dict)
self._init_is_better(
mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
)
[docs]class OneCycleLR(LRScheduler):
"""
Sets the learning rate according to the 1cycle learning rate policy.
Matches torch.optim.lr_scheduler.OneCycleLR
"""
[docs] def __init__(
self,
optimizer,
max_lr,
total_steps=None,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy="cos",
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.0,
final_div_factor=1e4,
last_epoch=-1,
):
"""
Initialize OneCycleLR scheduler.
Args:
optimizer: Wrapped optimizer
max_lr: Upper learning rate boundary in the cycle
total_steps: Total number of steps in the cycle (optional)
epochs: Number of epochs to train for (optional)
steps_per_epoch: Number of steps per epoch (optional)
pct_start: Percentage of the cycle spent increasing the learning rate (default: 0.3)
anneal_strategy: Specifies the annealing strategy: 'cos' or 'linear' (default: 'cos')
cycle_momentum: If True, momentum is cycled inversely (default: True)
base_momentum: Lower momentum boundary in the cycle (default: 0.85)
max_momentum: Upper momentum boundary in the cycle (default: 0.95)
div_factor: Determines the initial learning rate via initial_lr = max_lr/div_factor (default: 25)
final_div_factor: Determines the minimum learning rate via min_lr = initial_lr/final_div_factor (default: 1e4)
last_epoch: The index of last epoch (default: -1)
"""
# Validate total_steps
if total_steps is None and epochs is None and steps_per_epoch is None:
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
elif total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError(f"Expected positive integer total_steps, but got {total_steps}")
self.total_steps = total_steps
else:
if epochs is None or steps_per_epoch is None:
raise ValueError("You must define both epochs and steps_per_epoch")
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(f"Expected positive integer epochs, but got {epochs}")
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(
f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}"
)
self.total_steps = epochs * steps_per_epoch
self.step_size_up = float(pct_start * self.total_steps) - 1
self.step_size_down = float(self.total_steps - self.step_size_up) - 1
# Validate pct_start
if pct_start < 0 or pct_start > 1:
raise ValueError(f"Expected float between 0 and 1 pct_start, but got {pct_start}")
# Validate anneal_strategy
if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}"
)
elif anneal_strategy == "cos":
self.anneal_func = self._annealing_cos
elif anneal_strategy == "linear":
self.anneal_func = self._annealing_linear
# Initialize learning rate variables
self.max_lrs = self._format_param("max_lr", optimizer, max_lr)
self.initial_lrs = [max_lr / div_factor for max_lr in self.max_lrs]
self.min_lrs = [initial_lr / final_div_factor for initial_lr in self.initial_lrs]
self.cycle_momentum = cycle_momentum
if cycle_momentum:
if "momentum" not in optimizer.defaults and "betas" not in optimizer.defaults:
raise ValueError(
"optimizer must support momentum or betas with cycle_momentum option enabled"
)
self.use_beta1 = "betas" in optimizer.defaults
self.max_momentums = self._format_param("max_momentum", optimizer, max_momentum)
self.base_momentums = self._format_param("base_momentum", optimizer, base_momentum)
super().__init__(optimizer, last_epoch)
[docs] def _annealing_cos(self, start, end, pct):
"""Cosine annealing from start to end as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
[docs] def _annealing_linear(self, start, end, pct):
"""Linear annealing from start to end as pct goes from 0.0 to 1.0."""
return (end - start) * pct + start
[docs] def get_lr(self):
"""Compute learning rate at current step."""
lrs = []
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError(
f"Tried to step {step_num + 1} times. The specified number of total steps is {self.total_steps}"
)
for initial_lr, max_lr, min_lr in zip(self.initial_lrs, self.max_lrs, self.min_lrs):
if step_num <= self.step_size_up:
# Annealing from initial_lr to max_lr
pct = step_num / self.step_size_up
lr = self.anneal_func(initial_lr, max_lr, pct)
else:
# Annealing from max_lr to min_lr
pct = (step_num - self.step_size_up) / self.step_size_down
lr = self.anneal_func(max_lr, min_lr, pct)
lrs.append(lr)
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
if step_num <= self.step_size_up:
# Annealing from max_momentum to base_momentum (inverse of lr)
pct = step_num / self.step_size_up
momentum = self.anneal_func(max_momentum, base_momentum, pct)
else:
# Annealing from base_momentum to max_momentum
pct = (step_num - self.step_size_up) / self.step_size_down
momentum = self.anneal_func(base_momentum, max_momentum, pct)
momentums.append(momentum)
# Update momentum in optimizer
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
if self.use_beta1:
# For Adam-style optimizers, update beta1
betas = param_group["betas"]
param_group["betas"] = (momentum, betas[1])
else:
# For SGD-style optimizers
param_group["momentum"] = momentum
return lrs