from __future__ import annotations
import torch
import torch.nn as nn
from ..line_search import create_line_search_solver
from ..numerical_optimizer import NumericalOptimizer, LineSearchOptimizer
from ..scaling_matrix_calculator import GaussNewtonBlockApproximation
import warnings
class LevenbergMarquardt(NumericalOptimizer):
"""
Heavily inspired by https://github.com/hahnec/torchimize/blob/master/torchimize/optimizer/gna_opt.py
and the matlab implementation of 'learnlm' https://es.mathworks.com/help/deeplearning/ref/trainlm.html#d126e69092
Parameters
----------
model: nn.Module
The model to be optimized
lr_init: float
Maximum learning rate in backtracking line search, if the learning rate is set as constant, this will be the value used.
lr_method: str
Method to use to initialize the learning rate before applying line search.
mu: float
Initial value for the coefficient used when adding a diagonal matrix to the Hessian approximation.
mu_dec: float
Factor with which to decrease the coefficient of the diagonal matrix if the previous iteration didn't improve the model.
mu_max: float
Factor with which to increase the coefficient of the diagonal matrix if the previous iteration improved the model.
use_diagonal: bool
Whether to use the diagonal of the Hessian approximation instead of an identity matrix to adjust the Hessian matrix.
c1: float
Coefficient of the sufficient increase condition in backtracking line search.
c2: float
Coefficient used in the second condition for wolfe conditions.
tau: float
Factor used to reduce the step size in each step of the backtracking line search.
line_search_method: str
Method used for line search, options are "backtrack" and "constant".
line_search_cond: str
Condition to be used in backtracking line search, options are "armijo", "wolfe", "strong-wolfe" and "goldstein".
solver: str
Method to use to invert the hessian.
batch_size: int
Size of the amount of data to use at a time to calculate the hessian matrix.
"""
def __init__(
self,
model: nn.Module,
lr_init: float = 1,
lr_method: str | None = None,
mu: float = 0.001,
mu_dec: float = 0.1,
mu_max: float = 1e10,
fletcher: bool = False,
solver: str = "solve",
batch_size: int | None = None,
):
self.fletcher = fletcher
damping = "fletcher" if fletcher else "identity"
super().__init__(
model,
scaling_matrix=GaussNewtonBlockApproximation(model=model, batch_size=batch_size, damping=damping, mu=mu),
lr_init=lr_init,
lr_method=lr_method,
solver=solver,
)
self.mu = mu
self.mu_dec = mu_dec
self.mu_max = mu_max
self.prev_loss = None
if fletcher and solver == "solve":
warnings.warn("Using 'solve' with fletcher's method usually doesn't work very well. Try using 'pinv' instead.")
def step(self, x: torch.Tensor, y: torch.Tensor, loss_fn: nn.Module):
super().step(x, y, loss_fn)
with torch.inference_mode():
pred_y = self._model(x)
loss = loss_fn(pred_y, y)
self.update(loss)
def update(self, loss: torch.Tensor):
loss_val = loss.detach().item()
if self.prev_loss is None:
self.prev_loss = loss_val
self._prev_params = [p.detach().clone() for p in self._params]
elif loss_val <= self.prev_loss:
self.prev_loss = loss_val
self._prev_params = [p.detach().clone() for p in self._params]
self.mu *= self.mu_dec
else:
self._params = self._prev_params
self.mu /= self.mu_dec
if self.mu >= self.mu_max:
self.mu = self.mu_max
self.scaling_matrix.mu = self.mu
[docs]
class LevenbergMarquardtLS(LineSearchOptimizer):
"""
Heavily inspired by https://github.com/hahnec/torchimize/blob/master/torchimize/optimizer/gna_opt.py
and the matlab implementation of 'learnlm' https://es.mathworks.com/help/deeplearning/ref/trainlm.html#d126e69092
Parameters
----------
model: nn.Module
The model to be optimized
lr_init: float
Maximum learning rate in backtracking line search, if the learning rate is set as constant, this will be the value used.
lr_method: str
Method to use to initialize the learning rate before applying line search.
mu: float
Initial value for the coefficient used when adding a diagonal matrix to the Hessian approximation.
mu_dec: float
Factor with which to decrease the coefficient of the diagonal matrix if the previous iteration didn't improve the model.
mu_max: float
Factor with which to increase the coefficient of the diagonal matrix if the previous iteration improved the model.
use_diagonal: bool
Whether to use the diagonal of the Hessian approximation instead of an identity matrix to adjust the Hessian matrix.
c1: float
Coefficient of the sufficient increase condition in backtracking line search.
c2: float
Coefficient used in the second condition for wolfe conditions.
tau: float
Factor used to reduce the step size in each step of the backtracking line search.
line_search_method: str
Method used for line search, options are "backtrack" and "constant".
line_search_cond: str
Condition to be used in backtracking line search, options are "armijo", "wolfe", "strong-wolfe" and "goldstein".
solver: str
Method to use to invert the hessian.
batch_size: int
Size of the amount of data to use at a time to calculate the hessian matrix.
"""
def __init__(
self,
model: nn.Module,
lr_init: float = 1,
lr_method: str | None = None,
mu: float = 0.001,
mu_dec: float = 0.1,
mu_max: float = 1e10,
fletcher: bool = False,
c1: float = 1e-4,
c2: float = 0.9,
tau: float = 0.1,
line_search_method: str = "backtrack",
line_search_cond: str = "armijo",
solver: str = "solve",
batch_size: int | None = None,
):
self.fletcher = fletcher
damping = "fletcher" if fletcher else "identity"
super().__init__(
model,
scaling_matrix=GaussNewtonBlockApproximation(model=model, batch_size=batch_size, damping=damping, mu=mu),
lr_init=lr_init,
lr_method=lr_method,
line_search=create_line_search_solver(method=line_search_method, condition=line_search_cond, c1=c1, c2=c2, tau=tau),
solver=solver,
)
self.mu = mu
self.mu_dec = mu_dec
self.mu_max = mu_max
self.prev_loss = None
if fletcher and solver == "solve":
warnings.warn("Using 'solve' with fletcher's method usually doesn't work very well. Try using 'pinv' instead.")
[docs]
def step(self, x: torch.Tensor, y: torch.Tensor, loss_fn: nn.Module):
super().step(x, y, loss_fn)
with torch.inference_mode():
pred_y = self._model(x)
loss = loss_fn(pred_y, y)
self.update(loss)
[docs]
def update(self, loss: torch.Tensor):
loss_val = loss.detach().item()
if self.prev_loss is None:
self.prev_loss = loss_val
self._prev_params = [p.detach().clone() for p in self._params]
elif loss_val <= self.prev_loss:
self.prev_loss = loss_val
self._prev_params = [p.detach().clone() for p in self._params]
self.mu *= self.mu_dec
else:
self._params = self._prev_params
self.mu /= self.mu_dec
if self.mu >= self.mu_max:
self.mu = self.mu_max
self.scaling_matrix.mu = self.mu