Exemple #1
0
    def decompress(self, strings, broadcast_shape):
        """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.
    It is necessary to provide a part of the output shape in `broadcast_shape`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      broadcast_shape: Iterable of ints. The part of the output tensor shape
        between the shape of `strings` on the left and `self.prior_shape` on the
        right. This must match the shape of the input to `compress()`.

    Returns:
      A `tf.Tensor` of shape `strings.shape + broadcast_shape +
      self.prior_shape`.
    """
        strings = tf.convert_to_tensor(strings, dtype=tf.string)
        broadcast_shape = tf.convert_to_tensor(broadcast_shape, dtype=tf.int32)
        decode_shape = tf.concat([broadcast_shape, self.prior_shape_tensor], 0)
        output_shape = tf.concat([tf.shape(strings), decode_shape], 0)
        indexes, offset = self._compute_indexes_and_offset(broadcast_shape)
        handle = gen_ops.create_range_decoder(strings, self.cdf)
        decode_indexes = tf.broadcast_to(indexes, output_shape)
        handle, symbols = gen_ops.entropy_decode_index(handle, decode_indexes,
                                                       decode_shape,
                                                       self.cdf_offset.dtype)
        sanity = gen_ops.entropy_decode_finalize(handle)
        tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
        symbols += tf.gather(self.cdf_offset, indexes)
        outputs = tf.cast(symbols, self.bottleneck_dtype)
        return outputs + offset
Exemple #2
0
    def decompress(self, strings, indexes):
        """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      indexes: `tf.Tensor` specifying the scalar distribution for each output
        element. See class docstring for examples.

    Returns:
      A `tf.Tensor` of the same shape as `indexes` (without the optional channel
      dimension).
    """
        indexes = _add_offset_indexes(indexes, self._num_noise_levels)
        indexes = self._normalize_indexes(indexes)
        flat_indexes = self._flatten_indexes(indexes)
        symbols_shape = tf.shape(flat_indexes)
        decode_shape = symbols_shape[-self.coding_rank:]
        handle = gen_ops.create_range_decoder(strings, self.cdf)
        handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes,
                                                       decode_shape,
                                                       self.cdf_offset.dtype)
        sanity = gen_ops.entropy_decode_finalize(handle)
        tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
        symbols += tf.gather(self.cdf_offset, flat_indexes)
        offset = self._offset_from_indexes(indexes)
        return tf.cast(symbols, self.bottleneck_dtype) + offset
    def decompress(self, strings, indexes):
        """Decompresses a tensor.

    Reconstructs the quantized tensor from bit strings produced by `compress()`.

    Args:
      strings: `tf.Tensor` containing the compressed bit strings.
      indexes: `tf.Tensor` specifying the scalar distribution for each output
        element. See class docstring for examples.

    Returns:
      A `tf.Tensor` of the same shape as `indexes` (without the optional channel
      dimension).
    """
        strings = tf.convert_to_tensor(strings, dtype=tf.string)
        indexes = self._normalize_indexes(indexes)
        flat_indexes = self._flatten_indexes(indexes)
        last_n_elems = lambda t, n: t[-n:] if n else t[:0]
        decode_shape = last_n_elems(tf.shape(flat_indexes), self.coding_rank)
        handle = gen_ops.create_range_decoder(strings, self.cdf)
        handle, symbols = gen_ops.entropy_decode_index(handle, flat_indexes,
                                                       decode_shape,
                                                       self.cdf_offset.dtype)
        sanity = gen_ops.entropy_decode_finalize(handle)
        tf.debugging.assert_equal(sanity, True, message="Sanity check failed.")
        symbols += tf.gather(self.cdf_offset, flat_indexes)
        return tf.cast(symbols, self.dtype)