def testDiscretizedMixLogisticLoss(self): batch = 2 height = 4 width = 4 channels = 3 num_mixtures = 5 logits = tf.concat( # assign all probability mass to first component [tf.ones([batch, height, width, 1]) * 1e8, tf.zeros([batch, height, width, num_mixtures - 1])], axis=-1) locs = tf.random_uniform([batch, height, width, num_mixtures * 3], minval=-.9, maxval=.9) log_scales = tf.random_uniform([batch, height, width, num_mixtures * 3], minval=-1., maxval=1.) coeffs = tf.atanh(tf.zeros([batch, height, width, num_mixtures * 3])) pred = tf.concat([logits, locs, log_scales, coeffs], axis=-1) # Test labels that don't satisfy edge cases where 8-bit value is 0 or 255. labels = tf.random_uniform([batch, height, width, channels], minval=-.9, maxval=.9) locs_0 = locs[..., :3] log_scales_0 = log_scales[..., :3] centered_labels = labels - locs_0 inv_stdv = tf.exp(-log_scales_0) plus_in = inv_stdv * (centered_labels + 1. / 255.) min_in = inv_stdv * (centered_labels - 1. / 255.) cdf_plus = tf.nn.sigmoid(plus_in) cdf_min = tf.nn.sigmoid(min_in) expected_loss = -tf.reduce_sum(tf.log(cdf_plus - cdf_min), axis=-1) actual_loss = common_layers.discretized_mix_logistic_loss( pred=pred, labels=labels) actual_loss_val, expected_loss_val = self.evaluate( [actual_loss, expected_loss]) self.assertAllClose(actual_loss_val, expected_loss_val, rtol=1e-5)
def testDmlLoss(self, batch, height, width, num_mixtures, reduce_sum): channels = 3 pred = tf.random_normal([batch, height, width, num_mixtures * 10]) labels = tf.random_uniform([batch, height, width, channels], minval=0, maxval=256, dtype=tf.int32) actual_loss_num, actual_loss_den = common_layers.dml_loss( pred=pred, labels=labels, reduce_sum=reduce_sum) actual_loss = actual_loss_num / actual_loss_den real_labels = common_layers.convert_rgb_to_symmetric_real(labels) expected_loss = common_layers.discretized_mix_logistic_loss( pred=pred, labels=real_labels) / channels if reduce_sum: expected_loss = tf.reduce_mean(expected_loss) actual_loss_val, expected_loss_val = self.evaluate( [actual_loss, expected_loss]) self.assertAllClose(actual_loss_val, expected_loss_val)