Source code for ttfemesh.tt_tools.tensor_cross

from typing import Callable, List, Optional

import numpy as np
import teneva


[docs] class TTCrossConfig: """ Configuration class for the tensor train cross approximation algorithm. Args: cache (dict, optional): Cache for storing requested function values. info (dict, optional): Stores TTCross run information. num_sweeps (int, optional): Number of sweeps for DMRG. Defaults to 10. rel_stagnation_tol (float, optional): Relative stagnation tolerance. Defaults to 1e-4. max_func_calls (Optional[int], optional): Maximum number of function calls. Defaults to None. cache_calls_factor (int, optional): If the number of calls to cache is this factor times larger than number of function calls, TTCross stops. Defaults to 20. num_anova_init (int, optional): Number of training indices for ANOVA initialization. Defaults to 1000. anova_order (int, optional): Order of the ANOVA decomposition. Defaults to 2. verbose (bool, optional): Verbose output. Defaults to False. """
[docs] def __init__( self, cache: Optional[dict] = None, info: Optional[dict] = None, num_sweeps: int = 10, rel_stagnation_tol: float = 1e-4, max_func_calls: Optional[int] = None, cache_calls_factor: int = 5, num_anova_init: int = 1000, anova_order: int = 2, verbose: bool = False, ): self.cache = cache self.info = info self.num_sweeps = num_sweeps self.rel_stagnation_tol = rel_stagnation_tol self.max_func_calls = max_func_calls self.cache_calls_factor = cache_calls_factor self.num_anova_init = num_anova_init self.anova_order = anova_order self.verbose = verbose
[docs] def to_dict(self): """ Convert all attributes of the configuration to a dictionary that is passed to teneva Returns: dict: A dictionary representation of the configuration. """ kwargs = { "cache": self.cache, "info": self.info, "nswp": self.num_sweeps, "e": self.rel_stagnation_tol, "log": self.verbose, "m": self.max_func_calls, "m_cache_scale": self.cache_calls_factor, "num_anova_init": self.num_anova_init, "anova_order": self.anova_order, } return kwargs
[docs] def gen_teneva_indices(num_indices: int, tensor_shape: List[int]) -> np.ndarray: """ Generate random indices for a tensor of shape tensor_shape. Args: num_indices (int): Number of indices to generate. tensor_shape (List[int]): Shape of the tensor. Returns: np.ndarray: Random indices of shape (num_indices, len(tensor_shape)). """ idxs = np.vstack([np.random.choice(k, num_indices) for k in tensor_shape]).T return idxs
[docs] def anova_init_tensor_train( oracle: Callable[[np.ndarray], np.ndarray], train_indices: np.ndarray, order: int = 2 ) -> List[np.ndarray]: """ Initialize the tensor train with the ANOVA decomposition of the training data. Args: oracle (Callable[[np.ndarray], np.ndarray]): Oracle function. train_indices (np.ndarray): Training indices. order (int, optional): Order of the ANOVA decomposition. Defaults to 2. Returns: List[np.ndarray]: List of TT-cores for the ANOVA decomposition of the training data. """ ytrain = oracle(train_indices) yanova = teneva.anova(train_indices, ytrain, order=order) return yanova
[docs] def tensor_train_cross_approximation( oracle: Callable[[np.ndarray], np.ndarray], tt_init: List[np.ndarray], **kwargs ) -> List[np.ndarray]: """ Approximate the tensor train with the cross approximation algorithm. Args: oracle (Callable[[np.ndarray], np.ndarray]): Oracle function. tt_init (List[np.ndarray]): Initial tensor train. **kwargs: Additional keyword arguments for the cross approximation algorithm. Returns: List[np.ndarray]: List of TT-cores for the approximated tensor train. """ return teneva.cross(oracle, tt_init, **kwargs)
[docs] def error_on_indices( oracle: Callable[[np.ndarray], np.ndarray], approx_tt: List[np.ndarray], test_indices: np.ndarray, ) -> float: """ Test the accuracy of the approximated tensor train. Args: oracle (Callable[[np.ndarray], np.ndarray]): Oracle function. approx_tt (List[np.ndarray]): Approximated tensor train cores. test_indices (np.ndarray): Test indices. Returns: float: Relative error of the approximated tensor train. """ ytest = oracle(test_indices) error = teneva.accuracy_on_data(approx_tt, test_indices, ytest) return error
[docs] def error_on_random_indices( oracle: Callable[[np.ndarray], np.ndarray], approx_tt: List[np.ndarray], num_test_indices: int, tensor_shape: List[int], ) -> float: """ Test the accuracy of the approximated tensor train with random test indices. Args: oracle (Callable[[np.ndarray], np.ndarray]): Oracle function. approx_tt (List[np.ndarray]): Approximated tensor train cores. num_test_indices (int): Number of test indices. tensor_shape (List[int]): Shape of the tensor. Returns: float: Relative error of the approximated tensor train. """ test_indices = gen_teneva_indices(num_test_indices, tensor_shape) return error_on_indices(oracle, approx_tt, test_indices)