Source code for torch_numopt.first_order.conjugate_gradient

from __future__ import annotations
from typing import Iterable
import torch
import torch.nn as nn
from ..line_search import create_line_search_solver
from ..numerical_optimizer import NumericalOptimizer, LineSearchOptimizer
from ..scaling_matrix_calculator import *
from ..utils import param_reshape_like


class ConjugateGradient(NumericalOptimizer):
    """
    Heavily inspired by https://github.com/hahnec/torchimize/blob/master/torchimize/optimizer/gna_opt.py
    https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf
    https://arxiv.org/abs/2201.08568

    Parameters
    ----------

    model: nn.Module
        The model to be optimized
    lr_init: float
        Maximum learning rate in backtracking line search, if the learning rate is set as constant, this will be the value used.
    lr_method: str
        Method to use to initialize the learning rate before applying line search.
    c1: float
        Coefficient of the sufficient increase condition in backtracking line search.
    c2: float
        Coefficient used in the second condition for wolfe conditions.
    tau: float
        Factor used to reduce the step size in each step of the backtracking line search.
    line_search_method: str
        Method used for line search, options are "backtrack" and "constant".
    line_search_cond: str
        Condition to be used in backtracking line search, options are "armijo", "wolfe", "strong-wolfe" and "goldstein".
    cg_method: str
        Formula used to calculate the conjugate gradient, options are "FR", "PR" and "PRP+".
    """

    def __init__(
        self,
        model: nn.Module,
        lr_init: float = 1,
        lr_method: str | None = None,
        cg_method: str = "PRP+",
    ):
        super().__init__(
            model,
            scaling_matrix=NaiveIdentityCalculator(model=model),
            lr_init=lr_init,
            lr_method=lr_method,
        )

        # Conjugate gradient memory
        self.cg_method = cg_method

    def get_step_direction(self, d_p_list, _):
        """ """

        if self.prev_grad is None:
            return d_p_list

        grad = torch.hstack([i.flatten() for i in d_p_list])
        prev_grad = torch.hstack([i.flatten() for i in self.prev_grad])
        prev_step = torch.hstack([i.flatten() for i in self.prev_step_dir])

        res = -grad
        prev_res = -prev_grad

        eps = torch.finfo(res.dtype).eps
        match self.cg_method:
            case "FR":
                beta = torch.dot(res, res) / (torch.dot(prev_res, prev_res) + eps)
            case "PR":
                beta = torch.dot(res, res - prev_res) / (torch.dot(prev_res, prev_res) + eps)
            case "PRP+":
                beta = torch.dot(res, res - prev_res) / (torch.dot(prev_res, prev_res) + eps)
                beta = torch.relu(beta)
            case "HS":
                beta = torch.dot(res, res - prev_res) / (torch.dot(prev_step, res - prev_res) + eps)
            case "DY":
                beta = torch.dot(res, res) / (torch.dot(-prev_step, res - prev_res) + eps)
            case _:
                raise ValueError("Incorrect conjugate gradient method, try 'FR', 'PR' or 'PRP+', 'HS', 'DY'.")

        # Invert sign since we update the weights like x - lr*step
        next_dir = param_reshape_like(grad - beta * res, d_p_list)
        return next_dir


[docs] class ConjugateGradientLS(LineSearchOptimizer): """ Heavily inspired by https://github.com/hahnec/torchimize/blob/master/torchimize/optimizer/gna_opt.py https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf https://arxiv.org/abs/2201.08568 Parameters ---------- model: nn.Module The model to be optimized lr_init: float Maximum learning rate in backtracking line search, if the learning rate is set as constant, this will be the value used. lr_method: str Method to use to initialize the learning rate before applying line search. c1: float Coefficient of the sufficient increase condition in backtracking line search. c2: float Coefficient used in the second condition for wolfe conditions. tau: float Factor used to reduce the step size in each step of the backtracking line search. line_search_method: str Method used for line search, options are "backtrack" and "constant". line_search_cond: str Condition to be used in backtracking line search, options are "armijo", "wolfe", "strong-wolfe" and "goldstein". cg_method: str Formula used to calculate the conjugate gradient, options are "FR", "PR" and "PRP+". """ def __init__( self, model: nn.Module, lr_init: float = 1, lr_method: str = None, c1: float = 1e-4, c2: float = 0.9, tau: float = 0.1, line_search_method: str = "backtrack", line_search_cond: str = "armijo", cg_method: str = "PRP+", **kwargs, ): super().__init__( model, scaling_matrix=NaiveIdentityCalculator(model=model), lr_init=lr_init, lr_method=lr_method, line_search=create_line_search_solver(method=line_search_method, condition=line_search_cond, c1=c1, c2=c2, tau=tau), ) # Conjugate gradient memory self.cg_method = cg_method
[docs] def get_step_direction(self, d_p_list, _): """ """ if self.prev_grad is None: return d_p_list grad = torch.hstack([i.flatten() for i in d_p_list]) prev_grad = torch.hstack([i.flatten() for i in self.prev_grad]) prev_step = torch.hstack([i.flatten() for i in self.prev_step_dir]) res = -grad prev_res = -prev_grad eps = torch.finfo(res.dtype).eps match self.cg_method: case "FR": beta = torch.dot(res, res) / (torch.dot(prev_res, prev_res) + eps) case "PR": beta = torch.dot(res, res - prev_res) / (torch.dot(prev_res, prev_res) + eps) case "PRP+": beta = torch.dot(res, res - prev_res) / (torch.dot(prev_res, prev_res) + eps) beta = torch.relu(beta) case "HS": beta = torch.dot(res, res - prev_res) / (torch.dot(prev_step, res - prev_res) + eps) case "DY": beta = torch.dot(res, res) / (torch.dot(-prev_step, res - prev_res) + eps) case _: raise ValueError("Incorrect conjugate gradient method, try 'FR', 'PR' or 'PRP+', 'HS', 'DY'.") # Invert sign since we update the weights like x - lr*step next_dir = param_reshape_like(grad - beta * res, d_p_list) return next_dir