def decompress(self, strings, indexes): """Decompresses a tensor. Reconstructs the quantized tensor from bit strings produced by `compress()`. Args: strings: `tf.Tensor` containing the compressed bit strings. indexes: `tf.Tensor` specifying the scalar distribution for each output element. See class docstring for examples. Returns: A `tf.Tensor` of the same shape as `indexes` (without the optional channel dimension). """ indexes = _add_offset_indexes(indexes, self._num_noise_levels) indexes = self._normalize_indexes(indexes) flat_indexes = self._flatten_indexes(indexes) symbols_shape = tf.shape(flat_indexes) decode_shape = symbols_shape[-self.coding_rank:] handle = gen_ops.create_range_decoder(strings, self.cdf) handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes, decode_shape, self.cdf_offset.dtype) sanity = gen_ops.entropy_decode_finalize(handle) tf.debugging.assert_equal(sanity, True, message="Sanity check failed.") symbols += tf.gather(self.cdf_offset, flat_indexes) offset = self._offset_from_indexes(indexes) return tf.cast(symbols, self.bottleneck_dtype) + offset
def decompress(self, strings, broadcast_shape): """Decompresses a tensor. Reconstructs the quantized tensor from bit strings produced by `compress()`. It is necessary to provide a part of the output shape in `broadcast_shape`. Args: strings: `tf.Tensor` containing the compressed bit strings. broadcast_shape: Iterable of ints. The part of the output tensor shape between the shape of `strings` on the left and `self.prior_shape` on the right. This must match the shape of the input to `compress()`. Returns: A `tf.Tensor` of shape `strings.shape + broadcast_shape + self.prior_shape`. """ strings = tf.convert_to_tensor(strings, dtype=tf.string) broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32) decode_shape = tf.concat( [broadcast_shape, [tf.reduce_prod(self.prior_shape_tensor)]], 0) output_shape = tf.concat( [tf.shape(strings), broadcast_shape, self.prior_shape_tensor], 0) handle = gen_ops.create_range_decoder(strings, self.cdf) handle, symbols = gen_ops.entropy_decode_channel( handle, decode_shape, self.cdf_offset.dtype) sanity = gen_ops.entropy_decode_finalize(handle) tf.debugging.assert_equal(sanity, True, message="Sanity check failed.") symbols += self.cdf_offset symbols = tf.reshape(symbols, output_shape) outputs = tf.cast(symbols, self.bottleneck_dtype) offset = self.quantization_offset if offset is not None: outputs += offset return outputs
def decompress(self, strings, indexes): """Decompresses a tensor. Reconstructs the quantized tensor from bit strings produced by `compress()`. Args: strings: `tf.Tensor` containing the compressed bit strings. indexes: `tf.Tensor` specifying the scalar distribution for each output element. See class docstring for examples. Returns: A `tf.Tensor` of the same shape as `indexes` (without the optional channel dimension). """ strings = tf.convert_to_tensor(strings, dtype=tf.string) indexes = self._normalize_indexes(indexes) flat_indexes = self._flatten_indexes(indexes) last_n_elems = lambda t, n: t[-n:] if n else t[:0] decode_shape = last_n_elems(tf.shape(flat_indexes), self.coding_rank) handle = gen_ops.create_range_decoder(strings, self.cdf) handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes, decode_shape, self.cdf_offset.dtype) sanity = gen_ops.entropy_decode_finalize(handle) tf.debugging.assert_equal(sanity, True, message="Sanity check failed.") symbols += tf.gather(self.cdf_offset, flat_indexes) return tf.cast(symbols, self.dtype)