Source code for torch_numopt.utils

""" """

import torch
import torch.linalg


[docs] def param_sizes(params: list): """ Obtains the shape of every matrix in the list of parameters provided. Parameters ---------- params: list List of matrices containing a list of parameters. """ return [i.shape for i in params]
[docs] def param_reshape_like(params_flat: torch.Tensor, params: list): """ Reshapes a vector into a list of matrices with the same shapes as the `params` parameter. Parameters ---------- params_flat: Tensor Vector with the parameters to reshape. params: list List of matrices with the desired shape. Returns ------- reshaped_params: Tensor """ result = [] acc1 = 0 acc2 = 0 for p in params: flat_size = int(p.flatten().shape[0]) acc2 += flat_size result.append(params_flat[acc1:acc2].reshape(p.shape)) acc1 += flat_size return result
[docs] def param_flatten(params: list): return torch.hstack(_param_flatten_rec(params))
def _param_flatten_rec(params: list): all_params = [] for i in params: if isinstance(i, torch.Tensor): all_params.append(i.flatten()) else: all_params += param_flatten(i) return all_params
[docs] def fix_stability(mat: torch.Tensor): """ Procedure to adjust a matrix by adding a very small value to the diagonal to avoid numerical instability problems. Parameters ---------- mat: torch.Tensor Ill conditioned matrix. Returns ------- fixed_mat: torch.Tensor (Hopefully) Well conditioned matrix. """ return mat + torch.eye(mat.shape[0], device=mat.device) * torch.finfo(mat.dtype).eps
[docs] def pinv_svd_trunc(mat: torch.Tensor, thresh: float = 1e-4): """ Procedure to calculate the pseudoinverse of a matrix by using truncated SVD in order to maintain numerical stability. Parameters ---------- mat: torch.Tensor Problematic matrix that we want to invert. thresh: float Threshold applied to the S matrix in the SVD procedure. Returns ------- inverted_mat: torch.Tensor Pseudoinverse of the input matrix. """ U, S, Vt = torch.linalg.svd(mat) # max_val = torch.max(S) # S_tresh = S < thresh * max_val S_tresh = S < thresh S_inv_trunc = 1.0 / S S_inv_trunc[S_tresh] = 0 return Vt.T @ torch.diag(S_inv_trunc) @ U.T