示例#1
0
 def __call__(self, img):
     vassert(type(img) is Image.Image, 'Input is not a PIL.Image')
     width, height = img.size
     img = torch.ByteTensor(torch.ByteStorage.from_buffer(
         img.tobytes())).view(height, width, 3)
     img = img.permute(2, 0, 1)
     return img
示例#2
0
def glob_samples_paths(path, samples_find_deep, samples_find_ext, samples_ext_lossy=None, verbose=True):
    vassert(type(samples_find_ext) is str and samples_find_ext != '', 'Sample extensions not specified')
    vassert(
        samples_ext_lossy is None or type(samples_ext_lossy) is str, 'Lossy sample extensions can be None or string'
    )
    vprint(verbose, f'Looking for samples {"recursively" if samples_find_deep else "non-recursivelty"} in "{path}" '
                    f'with extensions {samples_find_ext}')
    samples_find_ext = [a.strip() for a in samples_find_ext.split(',') if a.strip() != '']
    if samples_ext_lossy is not None:
        samples_ext_lossy = [a.strip() for a in samples_ext_lossy.split(',') if a.strip() != '']
    have_lossy = False
    files = []
    for r, d, ff in os.walk(path):
        if not samples_find_deep and os.path.realpath(r) != os.path.realpath(path):
            continue
        for f in ff:
            ext = os.path.splitext(f)[1].lower()
            if len(ext) > 0 and ext[0] == '.':
                ext = ext[1:]
            if ext not in samples_find_ext:
                continue
            if samples_ext_lossy is not None and ext in samples_ext_lossy:
                have_lossy = True
            files.append(os.path.realpath(os.path.join(r, f)))
    files = sorted(files)
    vprint(verbose, f'Found {len(files)} samples'
                    f'{", some are lossy-compressed - this may affect metrics" if have_lossy else ""}')
    return files
示例#3
0
 def convert_features_tuple_to_dict(self, features):
     # The only compound return type of the forward function amenable to JIT tracing is tuple.
     # This function simply helps to recover the mapping.
     vassert(
         type(features) is tuple and len(features) == len(self.features_list),
         'Features must be the output of forward function'
     )
     return dict(((name, feature) for name, feature  in zip(self.features_list, features)))
示例#4
0
def create_sample_similarity(name, cuda=True, **kwargs):
    vassert(name in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" not registered')
    vprint(get_kwarg('verbose', kwargs), f'Creating sample similarity "{name}"')
    cls = SAMPLE_SIMILARITY_REGISTRY[name]
    sample_similarity = cls(name, **kwargs)
    sample_similarity.eval()
    if cuda:
        sample_similarity.cuda()
    return sample_similarity
示例#5
0
def prepare_input_descriptor_from_input_id(input_id, **kwargs):
    vassert(type(input_id) is int or type(input_id) is str and input_id in DATASETS_REGISTRY,
            'Input can be either integer (1 or 2) specifying the first or the second set of kwargs, or a string as a '
            'shortcut for registered datasets')
    if type(input_id) is int:
        input_desc = make_input_descriptor_from_int(input_id, **kwargs)
    else:
        input_desc = make_input_descriptor_from_str(input_id)
    return input_desc
示例#6
0
def create_feature_extractor(name, list_features, cuda=True, **kwargs):
    vassert(name in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" not registered')
    vprint(get_kwarg('verbose', kwargs), f'Creating feature extractor "{name}" with features {list_features}')
    cls = FEATURE_EXTRACTORS_REGISTRY[name]
    feat_extractor = cls(name, list_features, **kwargs)
    feat_extractor.eval()
    if cuda:
        feat_extractor.cuda()
    return feat_extractor
示例#7
0
 def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021):
     vassert(dtype == torch.uint8, 'Unsupported dtype')
     rng_stash = torch.get_rng_state()
     try:
         torch.manual_seed(seed)
         self.imgs = torch.randint(0,
                                   255, (num_samples, *dimensions),
                                   dtype=dtype)
     finally:
         torch.set_rng_state(rng_stash)
示例#8
0
def make_input_descriptor_from_str(input_str):
    vassert(type(input_str) is str and input_str in DATASETS_REGISTRY,
            f'Supported input str: {list(DATASETS_REGISTRY.keys())}')
    return {
        'input': input_str,
        'input_cache_name': input_str,
        'input_model_z_type': DEFAULTS['input1_model_z_type'],
        'input_model_z_size': DEFAULTS['input1_model_z_size'],
        'input_model_num_classes': DEFAULTS['input1_model_num_classes'],
        'input_model_num_samples': DEFAULTS['input1_model_num_samples'],
    }
示例#9
0
    def __init__(self, name):
        """
        Base class for samples similarity measures that can be used in :func:`calculate_metrics`.

        Args:

            name (str): Unique name of the subclassed sample similarity measure, must be the same as used in
                :func:`register_sample_similarity`.
        """
        super(SampleSimilarityBase, self).__init__()
        vassert(type(name) is str, 'Sample similarity name must be a string')
        self.name = name
示例#10
0
def kid_features_to_metric(features_1, features_2, **kwargs):
    assert torch.is_tensor(features_1) and features_1.dim() == 2
    assert torch.is_tensor(features_2) and features_2.dim() == 2
    assert features_1.shape[1] == features_2.shape[1]

    kid_subsets = get_kwarg('kid_subsets', kwargs)
    kid_subset_size = get_kwarg('kid_subset_size', kwargs)
    verbose = get_kwarg('verbose', kwargs)

    n_samples_1, n_samples_2 = len(features_1), len(features_2)
    vassert(
        n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size,
        f'KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, '
        f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to '
        f'proceed.')

    features_1 = features_1.cpu().numpy()
    features_2 = features_2.cpu().numpy()

    mmds = np.zeros(kid_subsets)
    rng = np.random.RandomState(get_kwarg('rng_seed', kwargs))

    for i in tqdm(range(kid_subsets),
                  disable=not verbose,
                  leave=False,
                  unit='subsets',
                  desc='Kernel Inception Distance'):
        f1 = features_1[rng.choice(n_samples_1, kid_subset_size,
                                   replace=False)]
        f2 = features_2[rng.choice(n_samples_2, kid_subset_size,
                                   replace=False)]
        o = polynomial_mmd(
            f1,
            f2,
            get_kwarg('kid_degree', kwargs),
            get_kwarg('kid_gamma', kwargs),
            get_kwarg('kid_coef0', kwargs),
        )
        mmds[i] = o

    out = {
        KEY_METRIC_KID_MEAN: float(np.mean(mmds)),
        KEY_METRIC_KID_STD: float(np.std(mmds)),
    }

    vprint(
        verbose,
        f'Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]} ± {out[KEY_METRIC_KID_STD]}'
    )

    return out
    def forward(self, in0, in1):
        vassert(torch.is_tensor(in0) and torch.is_tensor(in1), 'Inputs must be torch tensors')
        vassert(in0.dim() == 4 and in0.shape[1] == 3, 'Input 0 is not Bx3xHxW')
        vassert(in1.dim() == 4 and in1.shape[1] == 3, 'Input 1 is not Bx3xHxW')
        if self.sample_similarity_dtype is not None:
            dtype = self.SUPPORTED_DTYPES.get(self.sample_similarity_dtype, None)
            vassert(dtype is not None and in0.dtype == dtype and in1.dtype == dtype,
                    f'Unexpected input dtype ({in0.dtype})')
        in0_input = self.normalize(in0)
        in1_input = self.normalize(in1)

        if self.sample_similarity_resize is not None:
            in0_input = self.resize(in0_input, self.sample_similarity_resize)
            in1_input = self.resize(in1_input, self.sample_similarity_resize)

        outs0 = self.net.forward(in0_input)
        outs1 = self.net.forward(in1_input)

        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
        val = sum(res)
        return val
示例#12
0
def make_input_descriptor_from_int(input_int, **kwargs):
    vassert(input_int in (1, 2), 'Supported input slots: 1, 2')
    inputX = f'input{input_int}'
    input = get_kwarg(inputX, kwargs)
    input_desc = {
        'input': input,
        'input_cache_name': get_kwarg(f'{inputX}_cache_name', kwargs),
        'input_model_z_type': get_kwarg(f'{inputX}_model_z_type', kwargs),
        'input_model_z_size': get_kwarg(f'{inputX}_model_z_size', kwargs),
        'input_model_num_classes': get_kwarg(f'{inputX}_model_num_classes', kwargs),
        'input_model_num_samples': get_kwarg(f'{inputX}_model_num_samples', kwargs),
    }
    if type(input) is str and input in DATASETS_REGISTRY:
        input_desc['input_cache_name'] = input
    return input_desc
示例#13
0
def get_featuresdict_from_generative_model(gen_model, feat_extractor, num_samples, batch_size, cuda, rng_seed, verbose):
    vassert(isinstance(gen_model, GenerativeModelBase), 'Input can only be a GenerativeModel instance')
    vassert(
        isinstance(feat_extractor, FeatureExtractorBase), 'Feature extractor is not a subclass of FeatureExtractorBase'
    )

    if batch_size > num_samples:
        batch_size = num_samples

    out = None

    rng = np.random.RandomState(rng_seed)

    if cuda:
        gen_model.cuda()

    with tqdm(disable=not verbose, leave=False, unit='samples', total=num_samples, desc='Processing samples') as t, \
            torch.no_grad():
        for sample_start in range(0, num_samples, batch_size):
            sample_end = min(sample_start + batch_size, num_samples)
            sz = sample_end - sample_start

            noise = NOISE_SOURCE_REGISTRY[gen_model.z_type](rng, (sz, gen_model.z_size))
            if cuda:
                noise = noise.cuda(non_blocking=True)
            gen_args = [noise]
            if gen_model.num_classes > 0:
                cond_labels = torch.from_numpy(rng.randint(low=0, high=gen_model.num_classes, size=(sz,), dtype=np.int))
                if cuda:
                    cond_labels = cond_labels.cuda(non_blocking=True)
                gen_args.append(cond_labels)

            fakes = gen_model(*gen_args)
            features = feat_extractor(fakes)
            featuresdict = feat_extractor.convert_features_tuple_to_dict(features)
            featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()}

            if out is None:
                out = featuresdict
            else:
                out = {k: out[k] + featuresdict[k] for k in out.keys()}
            t.update(sz)

    vprint(verbose, 'Processing samples')

    out = {k: torch.cat(v, dim=0) for k, v in out.items()}

    return out
示例#14
0
def prepare_input_from_descriptor(input_desc, **kwargs):
    bad_input = False
    input = input_desc['input']
    if type(input) is str:
        if input in DATASETS_REGISTRY:
            datasets_root = get_kwarg('datasets_root', kwargs)
            datasets_download = get_kwarg('datasets_download', kwargs)
            fn_instantiate = DATASETS_REGISTRY[input]
            if datasets_root is None:
                datasets_root = os.path.join(torch.hub._get_torch_home(), 'fidelity_datasets')
            os.makedirs(datasets_root, exist_ok=True)
            input = fn_instantiate(datasets_root, datasets_download)
        elif os.path.isdir(input):
            samples_find_deep = get_kwarg('samples_find_deep', kwargs)
            samples_find_ext = get_kwarg('samples_find_ext', kwargs)
            samples_ext_lossy = get_kwarg('samples_ext_lossy', kwargs)
            verbose = get_kwarg('verbose', kwargs)
            input = glob_samples_paths(input, samples_find_deep, samples_find_ext, samples_ext_lossy, verbose)
            vassert(len(input) > 0, f'No samples found in {input} with samples_find_deep={samples_find_deep}')
            input = ImagesPathDataset(input)
        elif os.path.isfile(input) and input.endswith('.onnx'):
            input = GenerativeModelONNX(
                input,
                input_desc['input_model_z_size'],
                input_desc['input_model_z_type'],
                input_desc['input_model_num_classes']
            )
        elif os.path.isfile(input) and input.endswith('.pth'):
            input = torch.jit.load(input, map_location='cpu')
            input = GenerativeModelModuleWrapper(
                input,
                input_desc['input_model_z_size'],
                input_desc['input_model_z_type'],
                input_desc['input_model_num_classes']
            )
        else:
            bad_input = True
    elif isinstance(input, Dataset) or isinstance(input, GenerativeModelBase):
        pass
    else:
        bad_input = True
    vassert(
        not bad_input,
        f'Input descriptor "input" field can be either an instance of Dataset, GenerativeModelBase class, or a string, '
        f'such as a path to a name of a registered dataset ({", ".join(DATASETS_REGISTRY.keys())}), a directory with '
        f'file samples, or a path to an ONNX or PTH (JIT) module'
    )
    return input
示例#15
0
    def __init__(self,
                 module,
                 z_size,
                 z_type,
                 num_classes,
                 make_copy=False,
                 make_eval=True,
                 cuda=None):
        """
        Wraps any generative model :class:`torch.nn.Module`, implements the :class:`GenerativeModelBase` interface, and
        provides a few convenience functions.

        Args:

            module (torch.nn.Module): A generative model module, taking a batch of noise samples, and producing
                generative samples.

            z_size (int): Size of the noise dimension of the generative model (positive integer).

            z_type (str): Type of the noise used by the generative model (see :ref:`registry <Registry>` for a list of
                preregistered noise types, see :func:`register_noise_source` for registering a new noise type).

            num_classes (int): Number of classes used by a conditional generative model. Must return zero for
                unconditional models.

            make_copy (bool): Makes a copy of the model weights if `True`. Default: `False`.

            make_eval (bool): Switches to :class:`torch.nn.Module` evaluation mode upon construction if `True`. Default:
                `True`.

            cuda (bool): Moves the module on a CUDA device if `True`, moves to CPU if `False`, does nothing if `None`.
                Default: `None`.
        """
        super().__init__()
        vassert(isinstance(module, torch.nn.Module),
                'Not an instance of torch.nn.Module')
        vassert(
            type(z_size) is int and z_size > 0,
            'z_size must be a positive integer')
        vassert(z_type in ('normal', 'unit', 'uniform_0_1'),
                f'z_type={z_type} not implemented')
        vassert(
            type(num_classes) is int and num_classes >= 0,
            'num_classes must be a non-negative integer')
        self.module = module
        if make_copy:
            self.module = copy.deepcopy(self.module)
        if make_eval:
            self.module.eval()
        if cuda is not None:
            if cuda:
                self.module = self.module.cuda()
            else:
                self.module = self.module.cpu()
        self._z_size = z_size
        self._z_type = z_type
        self._num_classes = num_classes
示例#16
0
def extract_featuresdict_from_input_id(input_id, feat_extractor, **kwargs):
    batch_size = get_kwarg('batch_size', kwargs)
    cuda = get_kwarg('cuda', kwargs)
    rng_seed = get_kwarg('rng_seed', kwargs)
    verbose = get_kwarg('verbose', kwargs)
    input = prepare_input_from_id(input_id, **kwargs)
    if isinstance(input, Dataset):
        save_cpu_ram = get_kwarg('save_cpu_ram', kwargs)
        featuresdict = get_featuresdict_from_dataset(input, feat_extractor, batch_size, cuda, save_cpu_ram, verbose)
    else:
        input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs)
        num_samples = input_desc['input_model_num_samples']
        vassert(type(num_samples) is int and num_samples > 0, 'Number of samples must be positive')
        featuresdict = get_featuresdict_from_generative_model(
            input, feat_extractor, num_samples, batch_size, cuda, rng_seed, verbose
        )
    return featuresdict
示例#17
0
def register_feature_extractor(name, cls):
    """
    Registers a new feature extractor.

    Args:

        name (str): Unique name of the feature extractor.

        cls (FeatureExtractorBase): Instance of :class:`FeatureExtractorBase`, implementing a new feature extractor.
    """
    vassert(type(name) is str, 'Feature extractor must be given a name')
    vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" is already registered')
    vassert(
        issubclass(cls, FeatureExtractorBase), 'Feature extractor class must be subclassed from FeatureExtractorBase'
    )
    FEATURE_EXTRACTORS_REGISTRY[name] = cls
示例#18
0
def register_sample_similarity(name, cls):
    """
    Registers a new sample similarity measure.

    Args:

        name (str): Unique name of the sample similarity measure.

        cls (SampleSimilarityBase): Instance of :class:`SampleSimilarityBase`, implementing a new sample similarity
            measure.
    """
    vassert(type(name) is str, 'Sample similarity must be given a name')
    vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" is already registered')
    vassert(
        issubclass(cls, SampleSimilarityBase), 'Sample similarity class must be subclassed from SampleSimilarityBase'
    )
    SAMPLE_SIMILARITY_REGISTRY[name] = cls
示例#19
0
def prepare_inputs_as_datasets(
        input, samples_find_deep=False, samples_find_ext=DEFAULTS['samples_find_ext'],
        samples_ext_lossy=DEFAULTS['samples_ext_lossy'], datasets_root=None, datasets_download=True, verbose=True
):
    check_input(input)
    if type(input) is str:
        if input in DATASETS_REGISTRY:
            fn_instantiate = DATASETS_REGISTRY[input]
            if datasets_root is None:
                datasets_root = os.path.join(torch.hub._get_torch_home(), 'fidelity_datasets')
            os.makedirs(datasets_root, exist_ok=True)
            input = fn_instantiate(datasets_root, datasets_download)
        elif os.path.isdir(input):
            input = glob_samples_paths(input, samples_find_deep, samples_find_ext, samples_ext_lossy, verbose)
            vassert(len(input) > 0, f'No samples found in {input} with samples_find_deep={samples_find_deep}')
            input = ImagesPathDataset(input)
        else:
            raise ValueError(f'Unknown format of input string "{input}"')
    return input
示例#20
0
def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'):
    # based on https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd.py
    # changed to not compute the full kernel matrix at once
    vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'),
            'Invalid value of mmd_est')

    m = K_XX.shape[0]
    assert K_XX.shape == (m, m)
    assert K_XY.shape == (m, m)
    assert K_YY.shape == (m, m)

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
    else:
        diag_X = np.diagonal(K_XX)
        diag_Y = np.diagonal(K_YY)

        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()

    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    if mmd_est == 'biased':
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) /
                (m * m) - 2 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
        if mmd_est == 'unbiased':
            mmd2 -= 2 * K_XY_sum / (m * m)
        else:
            mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1))

    return mmd2
示例#21
0
def get_featuresdict_from_dataset(input, feat_extractor, batch_size, cuda, save_cpu_ram, verbose):
    vassert(isinstance(input, Dataset), 'Input can only be a Dataset instance')
    vassert(
        isinstance(feat_extractor, FeatureExtractorBase), 'Feature extractor is not a subclass of FeatureExtractorBase'
    )

    if batch_size > len(input):
        batch_size = len(input)

    num_workers = 0 if save_cpu_ram else min(4, 2 * multiprocessing.cpu_count())

    dataloader = DataLoader(
        input,
        batch_size=batch_size,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=cuda,
    )

    out = None

    with tqdm(disable=not verbose, leave=False, unit='samples', total=len(input), desc='Processing samples') as t:
        for bid, batch in enumerate(dataloader):
            if cuda:
                batch = batch.cuda(non_blocking=True)

            with torch.no_grad():
                features = feat_extractor(batch)
            featuresdict = feat_extractor.convert_features_tuple_to_dict(features)
            featuresdict = {k: [v.cpu()] for k, v in featuresdict.items()}

            if out is None:
                out = featuresdict
            else:
                out = {k: out[k] + featuresdict[k] for k in out.keys()}
            t.update(batch_size)

    vprint(verbose, 'Processing samples')

    out = {k: torch.cat(v, dim=0) for k, v in out.items()}

    return out
示例#22
0
def register_feature_extractor(name, cls):
    r"""
    Register a new feature extractor (useful for extending metrics beyond Inception 2D feature extractor).
    Args:
        name: str
            A unique name of the feature extractor, which will be available for use as a value of the
            "feature_extractor" argument. See calculate_metrics function.
        cls: subclass(FeatureExtractorBase)
            Name of a class subclassed from FeatureExtractorBase, implementing a new feature extractor.
    """
    vassert(type(name) is str, 'Feature extractor must be given a name')
    vassert(name.strip() == name,
            'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name,
            'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in FEATURE_EXTRACTORS_REGISTRY,
            f'Feature extractor "{name}" is already registered')
    vassert(
        issubclass(cls, FeatureExtractorBase),
        'Feature extractor class must be subclassed from FeatureExtractorBase')
    FEATURE_EXTRACTORS_REGISTRY[name] = cls
示例#23
0
def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est='unbiased'):
    vassert(mmd_est in ('biased', 'unbiased', 'u-statistic'),
            'Invalid value of mmd_est')

    m = K_XX.shape[0]
    assert K_XX.shape == (m, m)
    assert K_XY.shape == (m, m)
    assert K_YY.shape == (m, m)

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
    else:
        diag_X = np.diagonal(K_XX)
        diag_Y = np.diagonal(K_YY)

        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()

    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    if mmd_est == 'biased':
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) + (Kt_YY_sum + sum_diag_Y) /
                (m * m) - 2 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
        if mmd_est == 'unbiased':
            mmd2 -= 2 * K_XY_sum / (m * m)
        else:
            mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1))

    return mmd2
 def __init__(self, name, features_list):
     super(FeatureExtractorBase, self).__init__()
     vassert(type(name) is str, 'Feature extractor name must be a string')
     vassert(
         type(features_list) in (list, tuple), 'Wrong features list type')
     vassert(
         all((a in self.get_provided_features_list()
              for a in features_list)),
         'Requested features are not on the list of provided')
     vassert(
         len(features_list) == len(set(features_list)),
         'Duplicate features requested')
     self.name = name
     self.features_list = features_list
示例#25
0
def register_dataset(name, fn_create):
    """
    Registers a new input source.

    Args:

        name (str): Unique name of the input source.

        fn_create (callable): A constructor of a :class:`~torch:torch.utils.data.Dataset` instance. Callable arguments:

            - `root` (str): Location where the dataset files may be downloaded.
            - `download` (bool): Whether to perform downloading or rely on the cached version.
    """
    vassert(type(name) is str, 'Dataset must be given a name')
    vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in DATASETS_REGISTRY, f'Dataset "{name}" is already registered')
    vassert(
        callable(fn_create),
        'Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download'
    )
    DATASETS_REGISTRY[name] = fn_create
示例#26
0
def register_interpolation(name, fn_interpolate):
    """
    Registers a new sample interpolation method.

    Args:

        name (str): Unique name of the interpolation method.

        fn_interpolate (callable): Sample interpolation function. Callable arguments:

            - `a` (torch.Tensor): batch of the first endpoint samples.
            - `b` (torch.Tensor): batch of the second endpoint samples.
            - `t` (float): interpolation coefficient in the range [0,1].
    """
    vassert(type(name) is str, 'Interpolation must be given a name')
    vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in INTERPOLATION_REGISTRY, f'Interpolation "{name}" is already registered')
    vassert(
        callable(fn_interpolate),
        'Interpolation must be provided as a callable (function, lambda) with 3 arguments: a, b, t'
    )
    INTERPOLATION_REGISTRY[name] = fn_interpolate
示例#27
0
def register_noise_source(name, fn_generate):
    """
    Registers a new noise source, which can generate samples to be used as inputs to generative models.

    Args:

        name (str): Unique name of the noise source.

        fn_generate (callable): Generator of a random samples of specified type and shape. Callable arguments:

            - `rng` (numpy.random.RandomState): random number generator state, initialized with \
                :paramref:`~calculate_metrics.seed`.
            - `shape` (torch.Size): shape of the tensor of random samples.
    """
    vassert(type(name) is str, 'Noise source must be given a name')
    vassert(name.strip() == name, 'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name, 'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in NOISE_SOURCE_REGISTRY, f'Noise source "{name}" is already registered')
    vassert(
        callable(fn_generate),
        'Noise source must be provided as a callable (function, lambda) with 2 arguments: rng, shape'
    )
    NOISE_SOURCE_REGISTRY[name] = fn_generate
示例#28
0
def register_dataset(name, fn_create):
    r"""
    Register a new input source (useful for ground truth or reference datasets).
    Args:
        name: str
            A unique name of the input source, which will be available for use as a positional input argument. See
            calculate_metrics function.
        fn_create: callable(root, download)
            A constructor of torch.util.data.Dataset instance. The passed arguments denote a possible root where the
            dataset may be downloaded.
    """
    vassert(type(name) is str, 'Dataset must be given a name')
    vassert(name.strip() == name,
            'Name must not have leading or trailing whitespaces')
    vassert(os.path.sep not in name,
            'Name must not contain path delimiters (slash/backslash)')
    vassert(name not in DATASETS_REGISTRY,
            f'Dataset "{name}" is already registered')
    vassert(
        callable(fn_create),
        'Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download'
    )
    DATASETS_REGISTRY[name] = fn_create
    def __init__(self, path_onnx, z_size, z_type, num_classes):
        """
        Wraps :obj:`ONNX<torch:torch.onnx>` generative model, implements the :class:`GenerativeModelBase` interface.

        Args:

            path_onnx (str): Path to a generative model in :obj:`ONNX<torch:torch.onnx>` format.

            z_size (int): Size of the noise dimension of the generative model (positive integer).

            z_type (str): Type of the noise used by the generative model (see :ref:`registry <Registry>` for a list of
                preregistered noise types, see :func:`register_noise_source` for registering a new noise type).

            num_classes (int): Number of classes used by a conditional generative model. Must return zero for
                unconditional models.
        """
        super().__init__()
        vassert(os.path.isfile(path_onnx), f'Model file not found at "{path_onnx}"')
        vassert(type(z_size) is int and z_size > 0, 'z_size must be a positive integer')
        vassert(z_type in ('normal', 'unit', 'uniform_0_1'), f'z_type={z_type} not implemented')
        vassert(type(num_classes) is int and num_classes >= 0, 'num_classes must be a non-negative integer')
        try:
            import onnxruntime
        except ImportError as e:
            # This message may be removed if onnxruntime becomes a unified package with embedded CUDA dependencies,
            # like for example pytorch
            print(
                '====================================================================================================\n'
                'Loading ONNX models in PyTorch requires ONNX runtime package, which we did not want to include in\n'
                'torch_fidelity package requirements.txt. The two relevant pip packages are:\n'
                ' - onnxruntime       (pip install onnxruntime), or\n'
                ' - onnxruntime-gpu   (pip install onnxruntime-gpu).\n'
                'If you choose to install "onnxruntime", you will be able to run inference on CPU only - this may be\n'
                'slow. With "onnxruntime-gpu" speed is not an issue, but at run time you might face CUDA toolkit\n'
                'versions incompatibility, which can only be resolved by recompiling onnxruntime-gpu from source.\n'
                'Alternatively, use calculate_metrics API and pass an instance of GenerativeModelBase as an input.\n'
                '===================================================================================================='
            )
            raise e
        self.ort_session = onnxruntime.InferenceSession(path_onnx)
        self.input_names = [a.name for a in self.ort_session.get_inputs()]
        self._z_size = z_size
        self._z_type = z_type
        self._num_classes = num_classes
 def forward(self, *args):
     vassert(
         len(args) == len(self.input_names),
         f'Number of input arguments {len(args)} does not match ONNX model: {self.input_names}'
     )
     vassert(all(torch.is_tensor(a) for a in args), 'All model inputs must be tensors')
     ort_input = {self.input_names[i]: self.to_numpy(args[i]) for i in range(len(args))}
     ort_output = self.ort_session.run(None, ort_input)
     ort_output = ort_output[0]
     vassert(isinstance(ort_output, np.ndarray), 'Invalid output of ONNX model')
     out = torch.from_numpy(ort_output).to(device=args[0].device)
     return out