Exemple #1
0
    def loss(self, targets, logits, train_config, training, aux_output=None):
        """Converts targets to coarse colors and computes log-likelihood."""
        downsample = train_config.get('downsample', False)
        downsample_res = train_config.get('downsample_res', 64)
        if downsample:
            labels = targets['targets_%d' % downsample_res]
        else:
            labels = targets['targets']

        if aux_output is None:
            aux_output = {}

        # quantize labels.
        labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3)

        # bin each channel triplet.
        labels = base_utils.labels_to_bins(labels,
                                           self.num_symbols_per_channel)

        loss = self.image_loss(logits, labels)
        enc_logits = aux_output.get('encoder_logits')
        if enc_logits is None:
            return loss, {}

        enc_loss = self.image_loss(enc_logits, labels)
        return loss, {'encoder': enc_loss}
Exemple #2
0
  def test_bins_to_labels_random(self):
    labels_t = tf.random.uniform(shape=(1, 64, 64, 3), minval=0, maxval=8,
                                 dtype=tf.int32)
    labels_np = labels_t.numpy()
    bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8)

    inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8)
    inv_labels_np = inv_labels_t.numpy()
    self.assertTrue(np.allclose(inv_labels_np, labels_np))
Exemple #3
0
  def test_labels_to_bins(self):
    n_bits = 3
    bins = np.arange(2**n_bits)
    triplets = itertools.product(bins, bins, bins)

    labels = np.array(list(triplets))
    labels_t = tf.convert_to_tensor(labels, dtype=tf.float32)
    bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8)
    bins_np = bins_t.numpy()
    self.assertTrue(np.allclose(bins_np, np.arange(512)))

    inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8)
    inv_labels_np = inv_labels_t.numpy()
    self.assertTrue(np.allclose(labels, inv_labels_np))
Exemple #4
0
    def decoder(self, inputs, z, training):
        """Decodes grayscale representation and masked colors into logits."""
        # (H, W, 512) preprocessing.
        # quantize to 3 bits.
        labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3)

        # bin each channel triplet -> (H, W, 3) with 8 possible symbols
        # (H, W, 512)
        labels = base_utils.labels_to_bins(labels,
                                           self.num_symbols_per_channel)

        # (H, W) with 512 symbols to (H, W, 512)
        labels = tf.one_hot(labels, depth=self.num_symbols)

        h_dec = self.pixel_embed_layer(labels)
        h_upper = self.outer_decoder((h_dec, z), training=training)
        h_inner = self.inner_decoder((h_dec, h_upper, z), training=training)

        activations = self.final_norm(h_inner)
        logits = self.final_dense(activations)
        return tf.expand_dims(logits, axis=-2)