Source code for qulearn.memory

import logging

import numpy as np
import torch

from .datagen import DataGenCapacity as Datagen
from .trainer import SupervisedTrainer as Trainer
from .types import Capacity, Model


[docs] def memory( model: Model, datagen: Datagen, trainer: Trainer, Nmin: int, Nmax: int, Nstep: int = 1, early_stop: bool = True, stop_count: int = 1, ) -> Capacity: """ Estimates the memory capacity of a model over a range of values of N. :param model: The model. :type model: Model :param datagen: The (synthetic) data generator. :type datagen: Datagen :param trainer: The trainer. :type trainer: Trainer :param Nmin: The minimum value of N. :type Nmin: int :param Nmax: The maximum value of N, included. :type Nmax: int :param Nstep: Step size for N. Defaults to 1. :type Nstep: int, optional :param early_stop: Stops early if previous stop_count iterations capacity at least as large. Defaults to True. :type early_stop: bool, optional :param stop_count: See early_stop. Defaults to 1. :type stop_count: int, optional :return: List of tuples (N, mre=2^(-m), m, N*m). :rtype: Capacity .. seealso:: arXiv:1908.01364. """ capacities = [] Cprev = 0 count = 0 for N in range(Nmin, Nmax + 1, Nstep): mre = fit_rand_labels(model, datagen, trainer, N) m = max(int(np.log2(1.0 / mre)), 0) C = N * m capacities.append((N, mre, m, C)) if C <= Cprev and N != Nmax and early_stop: count += 1 if count >= stop_count: logging.basicConfig(level=logging.WARNING) logging.warning("Stopping early, capacity not improving.") break else: count = 0 Cprev = C return capacities
[docs] def fit_rand_labels(model: Model, datagen: Datagen, trainer: Trainer, N: int) -> float: """ Fits random labels to a model and returns the mean relative error. :param model: The model. :type model: Model :param datagen: The data generation class. :type datagen: Datagen :param trainer: The trainer. :type trainer: Trainer :param N: The number of inputs. :type N: int :return: The mean relative error. :rtype: float """ data = datagen.gen_data(N) X = data["X"] Y = data["Y"] mre_sample = [] for s in range(datagen.num_samples): loader = datagen.data_to_loader(data, s) trainer.train(model, loader, loader) y_pred = model(X) mre = torch.mean(torch.abs((Y[s] - y_pred) / y_pred)) mre_sample.append(mre.item()) mre_N = torch.mean(torch.tensor(mre_sample)) return mre_N.item()