Source code for qulearn.fat

import logging

import torch

from .datagen import DataGenFat as Datagen
from .trainer import SupervisedTrainer as Trainer
from .types import QModel, Tensor


[docs] def fat_shattering_dim( model: QModel, datagen: Datagen, trainer: Trainer, dmin: int, dmax: int, gamma: float = 0.0, dstep: int = 1, ) -> int: """ Estimate the fat-shattering dimension for a model with a given architecture. :param model: The model. :type model: QModel :param datagen: The (synthetic) data generator. :type datagen: Datagen :param trainer: The trainer. :type trainer: Trainer :param dmin: Iteration start for dimension check. :type dmin: int :param dmax: Iteration stop for dimension check (including). :type dmax: int :param gamma: The margin value. Defaults to 0.0 (pseudo-dim). :type gamma: float, optional :param dstep: Dimension iteration step size. Defaults to 1. :type dstep: int :return: The estimated fat-shattering dimension. :rtype: int """ for d in range(dmin, dmax + 1, dstep): shattered = check_shattering(model, datagen, trainer, d, gamma) if not shattered: if d == dmin: logging.basicConfig(level=logging.WARNING) logging.warning(f"Stopped at dmin = {dmin}.") return dmin return d - dstep logging.basicConfig(level=logging.WARNING) logging.warning(f"Reached dmax = {dmax}.") return dmax
[docs] def check_shattering( model: QModel, datagen: Datagen, trainer: Trainer, d: int, gamma: float ) -> bool: """ Check if the model shatters a given dimension d with margin value gamma. :param model: The model. :type model: QModel :param datagen: The (synthetic) data generator. :type datagen: Datagen :param trainer: The trainer. :type trainer: Trainer :param d: Size of data set to shatter. :type d: int :param gamma: The margin value. :type gamma: float :return: True if the model shatters a random data set of size d, False otherwise. :rtype: bool """ data = datagen.gen_data(d) X = data["X"] b = data["b"] r = data["r"] for sr in range(len(r)): shattered = True for sb in range(len(b)): loader = datagen.data_to_loader(data, sr, sb) trainer.train(model, loader, loader) predictions = model(X) for i, pred in enumerate(predictions): if b[sb, i] == 1 and not (pred >= r[sr, i] + gamma): shattered = False break if b[sb, i] == 0 and not (pred <= r[sr, i] - gamma): shattered = False break if not shattered: break if shattered: return True return False
[docs] def normalize_const(weights: Tensor, gamma: float, Rx: float) -> float: """ Compute a normalization constant given a tensor of weights and the margin parameter gamma. Rationale: the fat-shattering dimension of a linear classifier, with weights bounded by Rw and data bounded by Rx, is bounded by <= Rw^2*Rx^2/gamma^2. Hence, normalizing the fat-shattering dimension of a model with unbounded weights compares it to the best linear classifier with the same weight norm. :param weights: Tensor of weights :type weights: Tensor :param gamma: Margin parameter. :type gamma: float :param Rx: Estimated 2-radius of input data. :type Rx: float :return: A positive real-valued normalization constant. :rtype: float """ V = torch.norm(weights, p=2) V = V.item() C = V**2 * Rx**2 / gamma**2 return C