def post_process_image(self, image): """Post process image of size (H, W, 512) to a coarse RGB image.""" image = base_utils.bins_to_labels( image, num_symbols_per_channel=self.num_symbols_per_channel) image = base_utils.convert_bits(image, n_bits_in=3, n_bits_out=8) image = tf.cast(image, dtype=tf.uint8) return image
def test_bins_to_labels_random(self): labels_t = tf.random.uniform(shape=(1, 64, 64, 3), minval=0, maxval=8, dtype=tf.int32) labels_np = labels_t.numpy() bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8) inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8) inv_labels_np = inv_labels_t.numpy() self.assertTrue(np.allclose(inv_labels_np, labels_np))
def test_labels_to_bins(self): n_bits = 3 bins = np.arange(2**n_bits) triplets = itertools.product(bins, bins, bins) labels = np.array(list(triplets)) labels_t = tf.convert_to_tensor(labels, dtype=tf.float32) bins_t = base_utils.labels_to_bins(labels_t, num_symbols_per_channel=8) bins_np = bins_t.numpy() self.assertTrue(np.allclose(bins_np, np.arange(512))) inv_labels_t = base_utils.bins_to_labels(bins_t, num_symbols_per_channel=8) inv_labels_np = inv_labels_t.numpy() self.assertTrue(np.allclose(labels, inv_labels_np))