def testSampleFromDiscretizedMixLogistic(self):
    batch = 2
    height = 4
    width = 4
    num_mixtures = 5
    seed = 42
    logits = tf.concat(  # assign all probability mass to first component
        [tf.ones([batch, height, width, 1]) * 1e8,
         tf.zeros([batch, height, width, num_mixtures - 1])],
        axis=-1)
    locs = tf.random_uniform([batch, height, width, num_mixtures * 3],
                             minval=-.9, maxval=.9)
    log_scales = tf.ones([batch, height, width, num_mixtures * 3]) * -1e8
    coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3]))
    pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1)

    locs_0 = locs[..., :3]
    expected_sample = tf.clip_by_value(locs_0, -1., 1.)

    actual_sample = common_layers.sample_from_discretized_mix_logistic(
        pred, seed=seed)
    actual_sample_val, expected_sample_val = self.evaluate(
        [actual_sample, expected_sample])
    # Use a low tolerance: samples numerically differ, as the actual
    # implementation clips log-scales so they always contribute to sampling.
    self.assertAllClose(actual_sample_val, expected_sample_val, atol=1e-2)
示例#2
0
    def testSampleFromDiscretizedMixLogistic(self):
        batch = 2
        height = 4
        width = 4
        num_mixtures = 5
        seed = 42
        logits = tf.concat(  # assign all probability mass to first component
            [
                tf.ones([batch, height, width, 1]) * 1e8,
                tf.zeros([batch, height, width, num_mixtures - 1])
            ],
            axis=-1)
        locs = tf.random_uniform([batch, height, width, num_mixtures * 3],
                                 minval=-.9,
                                 maxval=.9)
        log_scales = tf.ones([batch, height, width, num_mixtures * 3]) * -1e8
        coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3]))
        pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1)

        locs_0 = locs[..., :3]
        expected_sample = tf.clip_by_value(locs_0, -1., 1.)

        actual_sample = common_layers.sample_from_discretized_mix_logistic(
            pred, seed=seed)
        actual_sample_val, expected_sample_val = self.evaluate(
            [actual_sample, expected_sample])
        # Use a low tolerance: samples numerically differ, as the actual
        # implementation clips log-scales so they always contribute to sampling.
        self.assertAllClose(actual_sample_val, expected_sample_val, atol=1e-2)
    def sample(self, features):
        """Run the model and extract samples.

    Args:
      features: an map of string to `Tensor`.

    Returns:
       samples: an integer `Tensor`.
       logits: a list of `Tensor`s, one per datashard.
       losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
        logits, losses = self(features)  # pylint: disable=not-callable

        samples = common_layers.sample_from_discretized_mix_logistic(logits,
                                                                     seed=None)
        return samples, logits, losses
    def sample(self, features):
        """Run the model and extract samples.

    Args:
      features: an map of string to `Tensor`.

    Returns:
       samples: an integer `Tensor`.
       logits: a list of `Tensor`s, one per datashard.
       losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
        if self._hparams.likelihood == cia.DistributionType.DMOL:
            logits, losses = self(features)  # pylint: disable=not-callable
            samples = common_layers.sample_from_discretized_mix_logistic(
                logits, seed=None)
            return samples, logits, losses

        return super(Imagetransformer, self).sample(features)
示例#5
0
  def sample(self, features):
    """Run the model and extract samples.

    Args:
      features: an map of string to `Tensor`.

    Returns:
       samples: an integer `Tensor`.
       logits: a list of `Tensor`s, one per datashard.
       losses: a dictionary: {loss-name (string): floating point `Scalar`}.
    """
    if self._hparams.likelihood == cia.DistributionType.DMOL:
      logits, losses = self(features)  # pylint: disable=not-callable
      samples = common_layers.sample_from_discretized_mix_logistic(
          logits, seed=None)
      return samples, logits, losses

    return super(Imagetransformer, self).sample(features)