Example #1
0
    def decompress(self, strings, indexes):
        """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      indexes: `tf.Tensor` specifying the scalar distribution for each output
        element. See class docstring for examples.

    Returns:
      A `tf.Tensor` of the same shape as `indexes` (without the optional channel
      dimension).
    """
        indexes = _add_offset_indexes(indexes, self._num_noise_levels)
        indexes = self._normalize_indexes(indexes)
        flat_indexes = self._flatten_indexes(indexes)
        symbols_shape = tf.shape(flat_indexes)
        decode_shape = symbols_shape[-self.coding_rank:]
        handle = gen_ops.create_range_decoder(strings, self.cdf)
        handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes,
                                                       decode_shape,
                                                       self.cdf_offset.dtype)
        sanity = gen_ops.entropy_decode_finalize(handle)
        tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
        symbols += tf.gather(self.cdf_offset, flat_indexes)
        offset = self._offset_from_indexes(indexes)
        return tf.cast(symbols, self.bottleneck_dtype) + offset
  def decompress(self, strings, broadcast_shape):
    """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.
    It is necessary to provide a part of the output shape in `broadcast_shape`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      broadcast_shape: Iterable of ints. The part of the output tensor shape
        between the shape of `strings` on the left and `self.prior_shape` on the
        right. This must match the shape of the input to `compress()`.

    Returns:
      A `tf.Tensor` of shape `strings.shape + broadcast_shape +
      self.prior_shape`.
    """
    strings = tf.convert_to_tensor(strings, dtype=tf.string)
    broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32)
    decode_shape = tf.concat(
        [broadcast_shape, [tf.reduce_prod(self.prior_shape_tensor)]], 0)
    output_shape = tf.concat(
        [tf.shape(strings), broadcast_shape, self.prior_shape_tensor], 0)
    handle = gen_ops.create_range_decoder(strings, self.cdf)
    handle, symbols = gen_ops.entropy_decode_channel(
        handle, decode_shape, self.cdf_offset.dtype)
    sanity = gen_ops.entropy_decode_finalize(handle)
    tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
    symbols += self.cdf_offset
    symbols = tf.reshape(symbols, output_shape)
    outputs = tf.cast(symbols, self.bottleneck_dtype)
    offset = self.quantization_offset
    if offset is not None:
      outputs += offset
    return outputs
    def decompress(self, strings, indexes):
        """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      indexes: `tf.Tensor` specifying the scalar distribution for each output
        element. See class docstring for examples.

    Returns:
      A `tf.Tensor` of the same shape as `indexes` (without the optional channel
      dimension).
    """
        strings = tf.convert_to_tensor(strings, dtype=tf.string)
        indexes = self._normalize_indexes(indexes)
        flat_indexes = self._flatten_indexes(indexes)
        last_n_elems = lambda t, n: t[-n:] if n else t[:0]
        decode_shape = last_n_elems(tf.shape(flat_indexes), self.coding_rank)
        handle = gen_ops.create_range_decoder(strings, self.cdf)
        handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes,
                                                       decode_shape,
                                                       self.cdf_offset.dtype)
        sanity = gen_ops.entropy_decode_finalize(handle)
        tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
        symbols += tf.gather(self.cdf_offset, flat_indexes)
        return tf.cast(symbols, self.dtype)