def loss(self, targets, logits, train_config, training, aux_output=None): """Converts targets to coarse colors and computes log-likelihood.""" downsample = train_config.get('downsample', False) downsample_res = train_config.get('downsample_res', 64) if downsample: labels = targets['targets_%d' % downsample_res] else: labels = targets['targets'] if aux_output is None: aux_output = {} # quantize labels. labels = base_utils.convert_bits(labels, n_bits_in=8, n_bits_out=3) # bin each channel triplet. labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel) loss = self.image_loss(logits, labels) enc_logits = aux_output.get('encoder_logits') if enc_logits is None: return loss, {} enc_loss = self.image_loss(enc_logits, labels) return loss, {'encoder': enc_loss}
def test_bins_to_labels_random(self): labels_t = tf.random.uniform(shape=(1, 64, 64, 3), minval=0, maxval=8, dtype=tf.int32) labels_np = labels_t.numpy() bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8) inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8) inv_labels_np = inv_labels_t.numpy() self.assertTrue(np.allclose(inv_labels_np, labels_np))
def test_labels_to_bins(self): n_bits = 3 bins = np.arange(2**n_bits) triplets = itertools.product(bins, bins, bins) labels = np.array(list(triplets)) labels_t = tf.convert_to_tensor(labels, dtype=tf.float32) bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8) bins_np = bins_t.numpy() self.assertTrue(np.allclose(bins_np, np.arange(512))) inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8) inv_labels_np = inv_labels_t.numpy() self.assertTrue(np.allclose(labels, inv_labels_np))
def decoder(self, inputs, z, training): """Decodes grayscale representation and masked colors into logits.""" # (H, W, 512) preprocessing. # quantize to 3 bits. labels = base_utils.convert_bits(inputs, n_bits_in=8, n_bits_out=3) # bin each channel triplet -> (H, W, 3) with 8 possible symbols # (H, W, 512) labels = base_utils.labels_to_bins(labels, self.num_symbols_per_channel) # (H, W) with 512 symbols to (H, W, 512) labels = tf.one_hot(labels, depth=self.num_symbols) h_dec = self.pixel_embed_layer(labels) h_upper = self.outer_decoder((h_dec, z), training=training) h_inner = self.inner_decoder((h_dec, h_upper, z), training=training) activations = self.final_norm(h_inner) logits = self.final_dense(activations) return tf.expand_dims(logits, axis=-2)