Esempio n. 1
0
    def z_generator(self,
                    shape,
                    distribution_fn=tf.random.uniform,
                    minval=-1.0,
                    maxval=1.0,
                    stddev=1.0,
                    name=None):
        """Random noise distributions as TF op.

    Args:
      shape: A 1-D integer Tensor or Python array.
      distribution_fn: Function that create a Tensor. If the function has any
        of the arguments 'minval', 'maxval' or 'stddev' these are passed to it.
      minval: The lower bound on the range of random values to generate.
      maxval: The upper bound on the range of random values to generate.
      stddev: The standard deviation of a normal distribution.
      name: A name for the operation.

    Returns:
      Tensor with the given shape and dtype tf.float32.
    """
        return utils.call_with_accepted_args(distribution_fn,
                                             shape=shape,
                                             minval=minval,
                                             maxval=maxval,
                                             stddev=stddev,
                                             name=name)
Esempio n. 2
0
 def batch_norm(self, inputs, **kwargs):
   if self._batch_norm_fn is None:
     return inputs
   args = kwargs.copy()
   args["inputs"] = inputs
   if "use_sn" not in args:
     args["use_sn"] = self._spectral_norm
   return utils.call_with_accepted_args(self._batch_norm_fn, **args)
Esempio n. 3
0
def get_penalty_loss(fn=no_penalty, **kwargs):
  """Returns the penalty loss."""
  return utils.call_with_accepted_args(fn, **kwargs)
Esempio n. 4
0
def get_losses(fn=non_saturating, **kwargs):
    """Returns the losses for the discriminator and generator."""
    return utils.call_with_accepted_args(fn, **kwargs)