def testSampleFromDiscretizedMixLogistic(self): batch = 2 height = 4 width = 4 num_mixtures = 5 seed = 42 logits = tf.concat( # assign all probability mass to first component [tf.ones([batch, height, width, 1]) * 1e8, tf.zeros([batch, height, width, num_mixtures - 1])], axis=-1) locs = tf.random_uniform([batch, height, width, num_mixtures * 3], minval=-.9, maxval=.9) log_scales = tf.ones([batch, height, width, num_mixtures * 3]) * -1e8 coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3])) pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1) locs_0 = locs[..., :3] expected_sample = tf.clip_by_value(locs_0, -1., 1.) actual_sample = common_layers.sample_from_discretized_mix_logistic( pred, seed=seed) actual_sample_val, expected_sample_val = self.evaluate( [actual_sample, expected_sample]) # Use a low tolerance: samples numerically differ, as the actual # implementation clips log-scales so they always contribute to sampling. self.assertAllClose(actual_sample_val, expected_sample_val, atol=1e-2)
def testSampleFromDiscretizedMixLogistic(self): batch = 2 height = 4 width = 4 num_mixtures = 5 seed = 42 logits = tf.concat( # assign all probability mass to first component [ tf.ones([batch, height, width, 1]) * 1e8, tf.zeros([batch, height, width, num_mixtures - 1]) ], axis=-1) locs = tf.random_uniform([batch, height, width, num_mixtures * 3], minval=-.9, maxval=.9) log_scales = tf.ones([batch, height, width, num_mixtures * 3]) * -1e8 coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3])) pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1) locs_0 = locs[..., :3] expected_sample = tf.clip_by_value(locs_0, -1., 1.) actual_sample = common_layers.sample_from_discretized_mix_logistic( pred, seed=seed) actual_sample_val, expected_sample_val = self.evaluate( [actual_sample, expected_sample]) # Use a low tolerance: samples numerically differ, as the actual # implementation clips log-scales so they always contribute to sampling. self.assertAllClose(actual_sample_val, expected_sample_val, atol=1e-2)
def sample(self, features): """Run the model and extract samples. Args: features: an map of string to `Tensor`. Returns: samples: an integer `Tensor`. logits: a list of `Tensor`s, one per datashard. losses: a dictionary: {loss-name (string): floating point `Scalar`}. """ logits, losses = self(features) # pylint: disable=not-callable samples = common_layers.sample_from_discretized_mix_logistic(logits, seed=None) return samples, logits, losses
def sample(self, features): """Run the model and extract samples. Args: features: an map of string to `Tensor`. Returns: samples: an integer `Tensor`. logits: a list of `Tensor`s, one per datashard. losses: a dictionary: {loss-name (string): floating point `Scalar`}. """ if self._hparams.likelihood == cia.DistributionType.DMOL: logits, losses = self(features) # pylint: disable=not-callable samples = common_layers.sample_from_discretized_mix_logistic( logits, seed=None) return samples, logits, losses return super(Imagetransformer, self).sample(features)