def __call__(self, img): vassert(type(img) is Image.Image, 'Input is not a PIL.Image') width, height = img.size img = torch.ByteTensor(torch.ByteStorage.from_buffer( img.tobytes())).view(height, width, 3) img = img.permute(2, 0, 1) return img
def glob_samples_paths(path, samples_find_deep, samples_find_ext, samples_ext_lossy=None, verbose=True): vassert(type(samples_find_ext) is str and samples_find_ext != '', 'Sample extensions not specified') vassert( samples_ext_lossy is None or type(samples_ext_lossy) is str, 'Lossy sample extensions can be None or string' ) vprint(verbose, f'Looking for samples {"recursively" if samples_find_deep else "non-recursivelty"} in "{path}" ' f'with extensions {samples_find_ext}') samples_find_ext = [a.strip() for a in samples_find_ext.split(',') if a.strip() != ''] if samples_ext_lossy is not None: samples_ext_lossy = [a.strip() for a in samples_ext_lossy.split(',') if a.strip() != ''] have_lossy = False files = [] for r, d, ff in os.walk(path): if not samples_find_deep and os.path.realpath(r) != os.path.realpath(path): continue for f in ff: ext = os.path.splitext(f)[1].lower() if len(ext) > 0 and ext[0] == '.': ext = ext[1:] if ext not in samples_find_ext: continue if samples_ext_lossy is not None and ext in samples_ext_lossy: have_lossy = True files.append(os.path.realpath(os.path.join(r, f))) files = sorted(files) vprint(verbose, f'Found {len(files)} samples' f'{", some are lossy-compressed - this may affect metrics" if have_lossy else ""}') return files
def convert_features_tuple_to_dict(self, features): # The only compound return type of the forward function amenable to JIT tracing is tuple. # This function simply helps to recover the mapping. vassert( type(features) is tuple and len(features) == len(self.features_list), 'Features must be the output of forward function' ) return dict(((name, feature) for name, feature in zip(self.features_list, features)))
def create_sample_similarity(name, cuda=True, **kwargs): vassert(name in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" not registered') vprint(get_kwarg('verbose', kwargs), f'Creating sample similarity "{name}"') cls = SAMPLE_SIMILARITY_REGISTRY[name] sample_similarity = cls(name, **kwargs) sample_similarity.eval() if cuda: sample_similarity.cuda() return sample_similarity
def prepare_input_descriptor_from_input_id(input_id, **kwargs): vassert(type(input_id) is int or type(input_id) is str and input_id in DATASETS_REGISTRY, 'Input can be either integer (1 or 2) specifying the first or the second set of kwargs, or a string as a ' 'shortcut for registered datasets') if type(input_id) is int: input_desc = make_input_descriptor_from_int(input_id, **kwargs) else: input_desc = make_input_descriptor_from_str(input_id) return input_desc
def create_feature_extractor(name, list_features, cuda=True, **kwargs): vassert(name in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" not registered') vprint(get_kwarg('verbose', kwargs), f'Creating feature extractor "{name}" with features {list_features}') cls = FEATURE_EXTRACTORS_REGISTRY[name] feat_extractor = cls(name, list_features, **kwargs) feat_extractor.eval() if cuda: feat_extractor.cuda() return feat_extractor
def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021): vassert(dtype == torch.uint8, 'Unsupported dtype') rng_stash = torch.get_rng_state() try: torch.manual_seed(seed) self.imgs = torch.randint(0, 255, (num_samples, *dimensions), dtype=dtype) finally: torch.set_rng_state(rng_stash)
def make_input_descriptor_from_str(input_str): vassert(type(input_str) is str and input_str in DATASETS_REGISTRY, f'Supported input str: {list(DATASETS_REGISTRY.keys())}') return { 'input': input_str, 'input_cache_name': input_str, 'input_model_z_type': DEFAULTS['input1_model_z_type'], 'input_model_z_size': DEFAULTS['input1_model_z_size'], 'input_model_num_classes': DEFAULTS['input1_model_num_classes'], 'input_model_num_samples': DEFAULTS['input1_model_num_samples'], }
def __init__(self, name): """ Base class for samples similarity measures that can be used in :func:`calculate_metrics`. Args: name (str): Unique name of the subclassed sample similarity measure, must be the same as used in :func:`register_sample_similarity`. """ super(SampleSimilarityBase, self).__init__() vassert(type(name) is str, 'Sample similarity name must be a string') self.name = name
def kid_features_to_metric(features_1, features_2, **kwargs): assert torch.is_tensor(features_1) and features_1.dim() == 2 assert torch.is_tensor(features_2) and features_2.dim() == 2 assert features_1.shape[1] == features_2.shape[1] kid_subsets = get_kwarg('kid_subsets', kwargs) kid_subset_size = get_kwarg('kid_subset_size', kwargs) verbose = get_kwarg('verbose', kwargs) n_samples_1, n_samples_2 = len(features_1), len(features_2) vassert( n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size, f'KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, ' f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to ' f'proceed.') features_1 = features_1.cpu().numpy() features_2 = features_2.cpu().numpy() mmds = np.zeros(kid_subsets) rng = np.random.RandomState(get_kwarg('rng_seed', kwargs)) for i in tqdm(range(kid_subsets), disable=not verbose, leave=False, unit='subsets', desc='Kernel Inception Distance'): f1 = features_1[rng.choice(n_samples_1, kid_subset_size, replace=False)] f2 = features_2[rng.choice(n_samples_2, kid_subset_size, replace=False)] o = polynomial_mmd( f1, f2, get_kwarg('kid_degree', kwargs), get_kwarg('kid_gamma', kwargs), get_kwarg('kid_coef0', kwargs), ) mmds[i] = o out = { KEY_METRIC_KID_MEAN: float(np.mean(mmds)), KEY_METRIC_KID_STD: float(np.std(mmds)), } vprint( verbose, f'Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]} ± {out[KEY_METRIC_KID_STD]}' ) return out
def forward(self, in0, in1): vassert(torch.is_tensor(in0) and torch.is_tensor(in1), 'Inputs must be torch tensors') vassert(in0.dim() == 4 and in0.shape[1] == 3, 'Input 0 is not Bx3xHxW') vassert(in1.dim() == 4 and in1.shape[1] == 3, 'Input 1 is not Bx3xHxW') if self.sample_similarity_dtype is not None: dtype = self.SUPPORTED_DTYPES.get(self.sample_similarity_dtype, None) vassert(dtype is not None and in0.dtype == dtype and in1.dtype == dtype, f'Unexpected input dtype ({in0.dtype})') in0_input = self.normalize(in0) in1_input = self.normalize(in1) if self.sample_similarity_resize is not None: in0_input = self.resize(in0_input, self.sample_similarity_resize) in1_input = self.resize(in1_input, self.sample_similarity_resize) outs0 = self.net.forward(in0_input) outs1 = self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] val = sum(res) return val
def make_input_descriptor_from_int(input_int, **kwargs): vassert(input_int in (1, 2), 'Supported input slots: 1, 2') inputX = f'input{input_int}' input = get_kwarg(inputX, kwargs) input_desc = { 'input': input, 'input_cache_name': get_kwarg(f'{inputX}_cache_name', kwargs), 'input_model_z_type': get_kwarg(f'{inputX}_model_z_type', kwargs), 'input_model_z_size': get_kwarg(f'{inputX}_model_z_size', kwargs), 'input_model_num_classes': get_kwarg(f'{inputX}_model_num_classes', kwargs), 'input_model_num_samples': get_kwarg(f'{inputX}_model_num_samples', kwargs), } if type(input) is str and input in DATASETS_REGISTRY: input_desc['input_cache_name'] = input return input_desc
def get_featuresdict_from_generative_model(gen_model, feat_extractor, num_samples, batch_size, cuda, rng_seed, verbose): vassert(isinstance(gen_model, GenerativeModelBase), 'Input can only be a GenerativeModel instance') vassert( isinstance(feat_extractor, FeatureExtractorBase), 'Feature extractor is not a subclass of FeatureExtractorBase' ) if batch_size > num_samples: batch_size = num_samples out = None rng = np.random.RandomState(rng_seed) if cuda: gen_model.cuda() with tqdm(disable=not verbose, leave=False, unit='samples', total=num_samples, desc='Processing samples') as t, \ torch.no_grad(): for sample_start in range(0, num_samples, batch_size): sample_end = min(sample_start + batch_size, num_samples) sz = sample_end - sample_start noise = NOISE_SOURCE_REGISTRY[gen_model.z_type](rng, (sz, gen_model.z_size)) if cuda: noise = noise.cuda(non_blocking=True) gen_args = [noise] if gen_model.num_classes > 0: cond_labels = torch.from_numpy(rng.randint(low=0, high=gen_model.num_classes, size=(sz,), dtype=np.int)) if cuda: cond_labels = cond_labels.cuda(non_blocking=True) gen_args.append(cond_labels) fakes = gen_model(*gen_args) features = feat_extractor(fakes) featuresdict = feat_extractor.convert_features_tuple_to_dict(features) featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()} if out is None: out = featuresdict else: out = {k: out[k] + featuresdict[k] for k in out.keys()} t.update(sz) vprint(verbose, 'Processing samples') out = {k: torch.cat(v, dim=0) for k, v in out.items()} return out
def prepare_input_from_descriptor(input_desc, **kwargs): bad_input = False input = input_desc['input'] if type(input) is str: if input in DATASETS_REGISTRY: datasets_root = get_kwarg('datasets_root', kwargs) datasets_download = get_kwarg('datasets_download', kwargs) fn_instantiate = DATASETS_REGISTRY[input] if datasets_root is None: datasets_root = os.path.join(torch.hub._get_torch_home(), 'fidelity_datasets') os.makedirs(datasets_root, exist_ok=True) input = fn_instantiate(datasets_root, datasets_download) elif os.path.isdir(input): samples_find_deep = get_kwarg('samples_find_deep', kwargs) samples_find_ext = get_kwarg('samples_find_ext', kwargs) samples_ext_lossy = get_kwarg('samples_ext_lossy', kwargs) verbose = get_kwarg('verbose', kwargs) input = glob_samples_paths(input, samples_find_deep, samples_find_ext, samples_ext_lossy, verbose) vassert(len(input) > 0, f'No samples found in {input} with samples_find_deep={samples_find_deep}') input = ImagesPathDataset(input) elif os.path.isfile(input) and input.endswith('.onnx'): input = GenerativeModelONNX( input, input_desc['input_model_z_size'], input_desc['input_model_z_type'], input_desc['input_model_num_classes'] ) elif os.path.isfile(input) and input.endswith('.pth'): input = torch.jit.load(input, map_location='cpu') input = GenerativeModelModuleWrapper( input, input_desc['input_model_z_size'], input_desc['input_model_z_type'], input_desc['input_model_num_classes'] ) else: bad_input = True elif isinstance(input, Dataset) or isinstance(input, GenerativeModelBase): pass else: bad_input = True vassert( not bad_input, f'Input descriptor "input" field can be either an instance of Dataset, GenerativeModelBase class, or a string, ' f'such as a path to a name of a registered dataset ({", ".join(DATASETS_REGISTRY.keys())}), a directory with ' f'file samples, or a path to an ONNX or PTH (JIT) module' ) return input
def __init__(self, module, z_size, z_type, num_classes, make_copy=False, make_eval=True, cuda=None): """ Wraps any generative model :class:`torch.nn.Module`, implements the :class:`GenerativeModelBase` interface, and provides a few convenience functions. Args: module (torch.nn.Module): A generative model module, taking a batch of noise samples, and producing generative samples. z_size (int): Size of the noise dimension of the generative model (positive integer). z_type (str): Type of the noise used by the generative model (see :ref:`registry <Registry>` for a list of preregistered noise types, see :func:`register_noise_source` for registering a new noise type). num_classes (int): Number of classes used by a conditional generative model. Must return zero for unconditional models. make_copy (bool): Makes a copy of the model weights if `True`. Default: `False`. make_eval (bool): Switches to :class:`torch.nn.Module` evaluation mode upon construction if `True`. Default: `True`. cuda (bool): Moves the module on a CUDA device if `True`, moves to CPU if `False`, does nothing if `None`. Default: `None`. """ super().__init__() vassert(isinstance(module, torch.nn.Module), 'Not an instance of torch.nn.Module') vassert( type(z_size) is int and z_size > 0, 'z_size must be a positive integer') vassert(z_type in ('normal', 'unit', 'uniform_0_1'), f'z_type={z_type} not implemented') vassert( type(num_classes) is int and num_classes >= 0, 'num_classes must be a non-negative integer') self.module = module if make_copy: self.module = copy.deepcopy(self.module) if make_eval: self.module.eval() if cuda is not None: if cuda: self.module = self.module.cuda() else: self.module = self.module.cpu() self._z_size = z_size self._z_type = z_type self._num_classes = num_classes
def extract_featuresdict_from_input_id(input_id, feat_extractor, **kwargs): batch_size = get_kwarg('batch_size', kwargs) cuda = get_kwarg('cuda', kwargs) rng_seed = get_kwarg('rng_seed', kwargs) verbose = get_kwarg('verbose', kwargs) input = prepare_input_from_id(input_id, **kwargs) if isinstance(input, Dataset): save_cpu_ram = get_kwarg('save_cpu_ram', kwargs) featuresdict = get_featuresdict_from_dataset(input, feat_extractor, batch_size, cuda, save_cpu_ram, verbose) else: input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs) num_samples = input_desc['input_model_num_samples'] vassert(type(num_samples) is int and num_samples > 0, 'Number of samples must be positive') featuresdict = get_featuresdict_from_generative_model( input, feat_extractor, num_samples, batch_size, cuda, rng_seed, verbose ) return featuresdict
def register_feature_extractor(name, cls): """ Registers a new feature extractor. Args: name (str): Unique name of the feature extractor. cls (FeatureExtractorBase): Instance of :class:`FeatureExtractorBase`, implementing a new feature extractor. """ vassert(type(name) is str, 'Feature extractor must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" is already registered') vassert( issubclass(cls, FeatureExtractorBase), 'Feature extractor class must be subclassed from FeatureExtractorBase' ) FEATURE_EXTRACTORS_REGISTRY[name] = cls
def register_sample_similarity(name, cls): """ Registers a new sample similarity measure. Args: name (str): Unique name of the sample similarity measure. cls (SampleSimilarityBase): Instance of :class:`SampleSimilarityBase`, implementing a new sample similarity measure. """ vassert(type(name) is str, 'Sample similarity must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" is already registered') vassert( issubclass(cls, SampleSimilarityBase), 'Sample similarity class must be subclassed from SampleSimilarityBase' ) SAMPLE_SIMILARITY_REGISTRY[name] = cls
def prepare_inputs_as_datasets( input, samples_find_deep=False, samples_find_ext=DEFAULTS['samples_find_ext'], samples_ext_lossy=DEFAULTS['samples_ext_lossy'], datasets_root=None, datasets_download=True, verbose=True ): check_input(input) if type(input) is str: if input in DATASETS_REGISTRY: fn_instantiate = DATASETS_REGISTRY[input] if datasets_root is None: datasets_root = os.path.join(torch.hub._get_torch_home(), 'fidelity_datasets') os.makedirs(datasets_root, exist_ok=True) input = fn_instantiate(datasets_root, datasets_download) elif os.path.isdir(input): input = glob_samples_paths(input, samples_find_deep, samples_find_ext, samples_ext_lossy, verbose) vassert(len(input) > 0, f'No samples found in {input} with samples_find_deep={samples_find_deep}') input = ImagesPathDataset(input) else: raise ValueError(f'Unknown format of input string "{input}"') return input
def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'): # based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py # changed to not compute the full kernel matrix at once vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'), 'Invalid value of mmd_est') m = K_XX.shape[0] assert K_XX.shape == (m, m) assert K_XY.shape == (m, m) assert K_YY.shape == (m, m) # Get the various sums of kernels that we'll use # Kts drop the diagonal, but we don't need to compute them explicitly if unit_diagonal: diag_X = diag_Y = 1 sum_diag_X = sum_diag_Y = m else: diag_X = np.diagonal(K_XX) diag_Y = np.diagonal(K_YY) sum_diag_X = diag_X.sum() sum_diag_Y = diag_Y.sum() Kt_XX_sums = K_XX.sum(axis=1) - diag_X Kt_YY_sums = K_YY.sum(axis=1) - diag_Y K_XY_sums_0 = K_XY.sum(axis=0) Kt_XX_sum = Kt_XX_sums.sum() Kt_YY_sum = Kt_YY_sums.sum() K_XY_sum = K_XY_sums_0.sum() if mmd_est == 'biased': mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) / (m * m) - 2 * K_XY_sum / (m * m)) else: mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) if mmd_est == 'unbiased': mmd2 -= 2 * K_XY_sum / (m * m) else: mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1)) return mmd2
def get_featuresdict_from_dataset(input, feat_extractor, batch_size, cuda, save_cpu_ram, verbose): vassert(isinstance(input, Dataset), 'Input can only be a Dataset instance') vassert( isinstance(feat_extractor, FeatureExtractorBase), 'Feature extractor is not a subclass of FeatureExtractorBase' ) if batch_size > len(input): batch_size = len(input) num_workers = 0 if save_cpu_ram else min(4, 2 * multiprocessing.cpu_count()) dataloader = DataLoader( input, batch_size=batch_size, drop_last=False, num_workers=num_workers, pin_memory=cuda, ) out = None with tqdm(disable=not verbose, leave=False, unit='samples', total=len(input), desc='Processing samples') as t: for bid, batch in enumerate(dataloader): if cuda: batch = batch.cuda(non_blocking=True) with torch.no_grad(): features = feat_extractor(batch) featuresdict = feat_extractor.convert_features_tuple_to_dict(features) featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()} if out is None: out = featuresdict else: out = {k: out[k] + featuresdict[k] for k in out.keys()} t.update(batch_size) vprint(verbose, 'Processing samples') out = {k: torch.cat(v, dim=0) for k, v in out.items()} return out
def register_feature_extractor(name, cls): r""" Register a new feature extractor (useful for extending metrics beyond Inception 2D feature extractor). Args: name: str A unique name of the feature extractor, which will be available for use as a value of the "feature_extractor" argument. See calculate_metrics function. cls: subclass(FeatureExtractorBase) Name of a class subclassed from FeatureExtractorBase, implementing a new feature extractor. """ vassert(type(name) is str, 'Feature extractor must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" is already registered') vassert( issubclass(cls, FeatureExtractorBase), 'Feature extractor class must be subclassed from FeatureExtractorBase') FEATURE_EXTRACTORS_REGISTRY[name] = cls
def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'): vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'), 'Invalid value of mmd_est') m = K_XX.shape[0] assert K_XX.shape == (m, m) assert K_XY.shape == (m, m) assert K_YY.shape == (m, m) # Get the various sums of kernels that we'll use # Kts drop the diagonal, but we don't need to compute them explicitly if unit_diagonal: diag_X = diag_Y = 1 sum_diag_X = sum_diag_Y = m else: diag_X = np.diagonal(K_XX) diag_Y = np.diagonal(K_YY) sum_diag_X = diag_X.sum() sum_diag_Y = diag_Y.sum() Kt_XX_sums = K_XX.sum(axis=1) - diag_X Kt_YY_sums = K_YY.sum(axis=1) - diag_Y K_XY_sums_0 = K_XY.sum(axis=0) Kt_XX_sum = Kt_XX_sums.sum() Kt_YY_sum = Kt_YY_sums.sum() K_XY_sum = K_XY_sums_0.sum() if mmd_est == 'biased': mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) / (m * m) - 2 * K_XY_sum / (m * m)) else: mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) if mmd_est == 'unbiased': mmd2 -= 2 * K_XY_sum / (m * m) else: mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1)) return mmd2
def __init__(self, name, features_list): super(FeatureExtractorBase, self).__init__() vassert(type(name) is str, 'Feature extractor name must be a string') vassert( type(features_list) in (list, tuple), 'Wrong features list type') vassert( all((a in self.get_provided_features_list() for a in features_list)), 'Requested features are not on the list of provided') vassert( len(features_list) == len(set(features_list)), 'Duplicate features requested') self.name = name self.features_list = features_list
def register_dataset(name, fn_create): """ Registers a new input source. Args: name (str): Unique name of the input source. fn_create (callable): A constructor of a :class:`~torch:torch.utils.data.Dataset` instance. Callable arguments: - `root` (str): Location where the dataset files may be downloaded. - `download` (bool): Whether to perform downloading or rely on the cached version. """ vassert(type(name) is str, 'Dataset must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in DATASETS_REGISTRY, f'Dataset "{name}" is already registered') vassert( callable(fn_create), 'Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download' ) DATASETS_REGISTRY[name] = fn_create
def register_interpolation(name, fn_interpolate): """ Registers a new sample interpolation method. Args: name (str): Unique name of the interpolation method. fn_interpolate (callable): Sample interpolation function. Callable arguments: - `a` (torch.Tensor): batch of the first endpoint samples. - `b` (torch.Tensor): batch of the second endpoint samples. - `t` (float): interpolation coefficient in the range [0,1]. """ vassert(type(name) is str, 'Interpolation must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in INTERPOLATION_REGISTRY, f'Interpolation "{name}" is already registered') vassert( callable(fn_interpolate), 'Interpolation must be provided as a callable (function, lambda) with 3 arguments: a, b, t' ) INTERPOLATION_REGISTRY[name] = fn_interpolate
def register_noise_source(name, fn_generate): """ Registers a new noise source, which can generate samples to be used as inputs to generative models. Args: name (str): Unique name of the noise source. fn_generate (callable): Generator of a random samples of specified type and shape. Callable arguments: - `rng` (numpy.random.RandomState): random number generator state, initialized with \ :paramref:`~calculate_metrics.seed`. - `shape` (torch.Size): shape of the tensor of random samples. """ vassert(type(name) is str, 'Noise source must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in NOISE_SOURCE_REGISTRY, f'Noise source "{name}" is already registered') vassert( callable(fn_generate), 'Noise source must be provided as a callable (function, lambda) with 2 arguments: rng, shape' ) NOISE_SOURCE_REGISTRY[name] = fn_generate
def register_dataset(name, fn_create): r""" Register a new input source (useful for ground truth or reference datasets). Args: name: str A unique name of the input source, which will be available for use as a positional input argument. See calculate_metrics function. fn_create: callable(root, download) A constructor of torch.util.data.Dataset instance. The passed arguments denote a possible root where the dataset may be downloaded. """ vassert(type(name) is str, 'Dataset must be given a name') vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces') vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)') vassert(name not in DATASETS_REGISTRY, f'Dataset "{name}" is already registered') vassert( callable(fn_create), 'Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download' ) DATASETS_REGISTRY[name] = fn_create
def __init__(self, path_onnx, z_size, z_type, num_classes): """ Wraps :obj:`ONNX<torch:torch.onnx>` generative model, implements the :class:`GenerativeModelBase` interface. Args: path_onnx (str): Path to a generative model in :obj:`ONNX<torch:torch.onnx>` format. z_size (int): Size of the noise dimension of the generative model (positive integer). z_type (str): Type of the noise used by the generative model (see :ref:`registry <Registry>` for a list of preregistered noise types, see :func:`register_noise_source` for registering a new noise type). num_classes (int): Number of classes used by a conditional generative model. Must return zero for unconditional models. """ super().__init__() vassert(os.path.isfile(path_onnx), f'Model file not found at "{path_onnx}"') vassert(type(z_size) is int and z_size > 0, 'z_size must be a positive integer') vassert(z_type in ('normal', 'unit', 'uniform_0_1'), f'z_type={z_type} not implemented') vassert(type(num_classes) is int and num_classes >= 0, 'num_classes must be a non-negative integer') try: import onnxruntime except ImportError as e: # This message may be removed if onnxruntime becomes a unified package with embedded CUDA dependencies, # like for example pytorch print( '====================================================================================================\n' 'Loading ONNX models in PyTorch requires ONNX runtime package, which we did not want to include in\n' 'torch_fidelity package requirements.txt. The two relevant pip packages are:\n' ' - onnxruntime (pip install onnxruntime), or\n' ' - onnxruntime-gpu (pip install onnxruntime-gpu).\n' 'If you choose to install "onnxruntime", you will be able to run inference on CPU only - this may be\n' 'slow. With "onnxruntime-gpu" speed is not an issue, but at run time you might face CUDA toolkit\n' 'versions incompatibility, which can only be resolved by recompiling onnxruntime-gpu from source.\n' 'Alternatively, use calculate_metrics API and pass an instance of GenerativeModelBase as an input.\n' '====================================================================================================' ) raise e self.ort_session = onnxruntime.InferenceSession(path_onnx) self.input_names = [a.name for a in self.ort_session.get_inputs()] self._z_size = z_size self._z_type = z_type self._num_classes = num_classes
def forward(self, *args): vassert( len(args) == len(self.input_names), f'Number of input arguments {len(args)} does not match ONNX model: {self.input_names}' ) vassert(all(torch.is_tensor(a) for a in args), 'All model inputs must be tensors') ort_input = {self.input_names[i]: self.to_numpy(args[i]) for i in range(len(args))} ort_output = self.ort_session.run(None, ort_input) ort_output = ort_output[0] vassert(isinstance(ort_output, np.ndarray), 'Invalid output of ONNX model') out = torch.from_numpy(ort_output).to(device=args[0].device) return out