Beispiel #1
0
    def z_generator(self,
                    shape,
                    distribution_fn=tf.compat.v1.random.uniform,
                    minval=-1.0,
                    maxval=1.0,
                    stddev=1.0,
                    name=None):
        """Random noise distributions as TF op.

        Args:
          shape: A 1-D integer Tensor or Python array.
          distribution_fn: Function that create a Tensor. If the function has any
            of the arguments 'minval', 'maxval' or 'stddev' these are passed to it.
          minval: The lower bound on the range of random values to generate.
          maxval: The upper bound on the range of random values to generate.
          stddev: The standard deviation of a normal distribution.
          name: A name for the operation.

        Returns:
          Tensor with the given shape and dtype tf.float32.
        """
        return utils.call_with_accepted_args(distribution_fn,
                                             shape=shape,
                                             minval=minval,
                                             maxval=maxval,
                                             stddev=stddev,
                                             name=name)
Beispiel #2
0
def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    from utils import call_with_accepted_args
    inception = call_with_accepted_args(models.Inception3,
                                        num_classes=1008, aux_logits=False, init_weights=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    if os.path.exists(FID_WEIGHTS_LOCAL):
        state_dict = torch.load(FID_WEIGHTS_LOCAL)
    else:
        state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception
Beispiel #3
0
 def batch_norm(self, inputs, **kwargs):
     if self._batch_norm_fn is None:
         return inputs
     args = kwargs.copy()
     args["inputs"] = inputs
     if "use_sn" not in args:
         args["use_sn"] = self._spectral_norm
     return utils.call_with_accepted_args(self._batch_norm_fn, **args)
Beispiel #4
0
def compute_penalty(mode='none', **kwargs):
    _mapping = {
        'none': no_penalty,
        'gp': gradient_penalty,
        'cr': consistency,
        'bcr': balanced_consistency
    }
    fn = _mapping[mode]
    return call_with_accepted_args(fn, **kwargs)
Beispiel #5
0
def get_augment(mode='none', **kwargs):
    _mapping = {
        'none': NoAugment,
        'gaussian': Gaussian,
        'hflip': HorizontalFlipLayer,
        'hfrt': HorizontalFlipRandomCrop,
        'color_jitter': ColorJitterLayer,
        'cutout': CutOut,
        'simclr': simclr,
        'simclr_hq': simclr_hq,
        'simclr_hq_cutout': simclr_hq_cutout,
        'diffaug': diffaug,
    }
    fn = _mapping[mode]
    return call_with_accepted_args(fn, **kwargs)
Beispiel #6
0
def get_losses(fn=non_saturating, **kwargs):
    """Returns the losses for the discriminator and generator."""
    return utils.call_with_accepted_args(fn, **kwargs)
Beispiel #7
0
def get_penalty_loss(fn=no_penalty, **kwargs):
    """Returns the penalty loss."""
    return utils.call_with_accepted_args(fn, **kwargs)