Source code for torch_numopt.second_order_optimizer

from __future__ import annotations
from abc import ABC
import torch
from functools import reduce
from .utils import param_reshape_like, param_flatten
from .line_search_optimizer import LineSearchOptimizer
from torch.autograd.functional import hessian
from torch.func import functional_call
from copy import copy


[docs] class SecondOrderOptimizer(LineSearchOptimizer, ABC): """ Class for Optimization methods using second derivative information. Parameters ---------- model: nn.Module The model to be optimized lr_init: float Maximum learning rate in backtracking line search, if the learning rate is set as constant, this will be the value used. lr_method: str Method to use to initialize the learning rate before applying line search. batch_size: int Size of the amount of data to use at a time to calculate the hessian matrix. """ def __init__( self, model: nn.Module, lr_init: float, lr_method: str = None, line_search_cond="armijo", line_search_method="const", c1: float = 1e-4, c2: float = 0.9, tau: float = 0.1, batch_size: int = None, ): super().__init__( model, lr_init=lr_init, lr_method=lr_method, line_search_cond=line_search_cond, line_search_method=line_search_method, c1=c1, c2=c2, tau=tau, ) self.batch_size = batch_size
[docs] def exact_hessian(self, x, y, loss_fn, vectorize=True): """ Calculation of the exact hessian of the Neural network given a dataset. Parameters ---------- x: torch.Tensor Input dataset for calculting the loss. y: torch.Tensor Target dataset for calculting the loss. loss_fn: torch.Module Loss function for which to calculate the hessian. vectorize: boolean Use vectorization in pytorch's implementation of the hessian calculation. """ model_params = tuple(self._model.parameters()) loss_fn = copy(loss_fn) is_mean = loss_fn.reduction == "mean" if is_mean: loss_fn.reduction = "sum" scale = 1 / len(x) if is_mean else 1 def eval_model_batch(x, y, *input_params): out = functional_call(self._model, dict(zip(self._param_keys, input_params)), x) return loss_fn(out, y) # Calculate exact Hessian matrix if self.batch_size is None or self.batch_size >= len(x): # Calculate hessian with every sample in the dataset eval_model = lambda *p: eval_model_batch(x, y, *p) h_list = torch.autograd.functional.hessian(eval_model, model_params, create_graph=True, vectorize=vectorize) h_list = [self._reshape_hessian(h_list[i][i] * scale) for i, _ in enumerate(h_list)] else: # Calculate hessian for each batch and add the results batch_start = torch.arange(0, len(x), self.batch_size) h_list = None for start in batch_start: # Prepare batch x_batch = x[start : start + self.batch_size] y_batch = y[start : start + self.batch_size] # Calculate hessian of the batch eval_model = lambda *p: eval_model_batch(x_batch, y_batch, *p) h_list_batch = torch.autograd.functional.hessian(eval_model, model_params, create_graph=True, vectorize=vectorize) h_list_batch = [self._reshape_hessian(h_list_batch[i][i]) * scale for i, _ in enumerate(h_list_batch)] # Agreggate result if h_list is None: h_list = h_list_batch else: h_list = [batch_h + prev_h for batch_h, prev_h in zip(h_list, h_list_batch)] return h_list
[docs] def approx_hessian_gn(self, x, y, loss_fn, vectorize=True): """ Calculation of the an approximate hessian of the Neural network given a dataset as in the Gauss-Newton algorithm. The approximate Hessian is calculated as the square of the Jacobian of the residual of every data point with respect to the parameters. Let the loss function be, for example the MSE: :math:`\mathcal{L}(x,y;\\theta) = \sum^{N}_{i=1} (f(x_i; \\theta) - y_i)^2 = \sum^{N}_{i=1} r_i` Then the Jacobian of the residuals will be the matrix: :math:`(J_{\\theta}[\mathcal{L}])_{i,j} = \dfrac{\partial r_i}{\partial \\theta_j}` Then, we will approximate the hessian as the product of the Jacobian with it's transpose, noting that the result will be a square matrix with size :math:`p\\times p` with :math:`p` being the number of parameters of the model: :math:`H_{\\theta}[\mathcal{L}] \\approx J_{\\theta}[\mathcal{L}]^{\intercal} \cdot J_{\\theta}[\mathcal{L}]` Parameters ---------- x: torch.Tensor Input dataset for calculting the loss. y: torch.Tensor Target dataset for calculting the loss. loss_fn: torch.Module Loss function for which to calculate the hessian. vectorize: boolean Use vectorization in pytorch's implementation of the hessian calculation. """ model_params = tuple(self._model.parameters()) scale = 1 / len(x) if loss_fn.reduction == "mean" else 1 residual_fn = copy(loss_fn) residual_fn.reduction = "none" def get_residuals_batch(x, y, *input_params): out = functional_call(self._model, dict(zip(self._param_keys, input_params)), x) return residual_fn(out, y) # Calculate approximate Hessian matrix if self.batch_size is None or self.batch_size >= len(x): get_residuals = lambda *p: get_residuals_batch(x, y, *p) j_list = torch.autograd.functional.jacobian(get_residuals, model_params, create_graph=False, vectorize=vectorize) h_list = [None] * len(j_list) for j_idx, j in enumerate(j_list): j = j.flatten(start_dim=1) h_list[j_idx] = self._reshape_hessian(j.T @ j) * scale else: # Calculate hessian for each batch and add the results batch_start = torch.arange(0, len(x), self.batch_size) h_list = None for start in batch_start: # Prepare batch x_batch = x[start : start + self.batch_size] y_batch = y[start : start + self.batch_size] # Calculate appeoximate hessian of the batch get_residuals = lambda *p: get_residuals_batch(x_batch, y_batch, *p) j_list = torch.autograd.functional.jacobian(get_residuals, model_params, create_graph=False, vectorize=vectorize) h_list_batch = [None] * len(j_list) for j_idx, j in enumerate(j_list): j = j.flatten(start_dim=1) h_list_batch[j_idx] = self._reshape_hessian(j.T @ j) * scale # Agreggate result if h_list is None: h_list = h_list_batch else: h_list = [batch_h + prev_h for batch_h, prev_h in zip(h_list, h_list_batch)] return h_list
[docs] def hutchinson_diagonal(self, x, y, loss_fn, vectorize=True, n_samples=1, as_matrix=False): model_params = tuple(self._model.parameters()) params_flat = torch.hstack([i.flatten() for i in model_params]) def eval_model(*input_params): out = functional_call(self._model, dict(zip(self._param_keys, input_params)), x) return loss_fn(out, y) h_diag_flat = torch.zeros_like(params_flat) for _ in range(n_samples): # Rademacher sample z_flat = 2*torch.bernoulli(torch.full_like(params_flat, 0.5, device=params_flat.device)) - 1 z = tuple(param_reshape_like(z_flat, model_params)) # Pytorch documentation recommends doing (vH)^T instead of Hv directly _, Hz = torch.autograd.functional.vhp(eval_model, model_params, v=z, create_graph=False) Hz_flat = torch.hstack([i.flatten() for i in Hz]) h_diag_flat += z_flat*Hz_flat h_diag_flat /= n_samples h_diag = param_reshape_like(h_diag_flat, model_params) return h_diag
@staticmethod def _reshape_hessian(hess: torch.Tensor): """ Procedure to reshape a misshapen hessian matrix. The input is expected to be an array of size :math:`(X,Y,...,X,Y,...)` and the output will be a square matrix of size :math:`(X \cdot Y \cdots, X \cdot Y \cdots)`. Parameters ---------- hess: torch.Tensor Misshapen hessian matrix. """ if hess.dim() == 2: return hess if hess.dim() % 2 != 0: raise ValueError("Hessian has a weird ass shape.") # Divide shape in two halves, multiply each half to get total size new_shape = ( reduce(lambda x, y: x * y, hess.size()[hess.dim() // 2 :]), reduce(lambda x, y: x * y, hess.size()[: hess.dim() // 2]), ) assert new_shape[0] == new_shape[1], "Something weird happened with the hessian size" return hess.reshape(new_shape)