Exemplo n.º 1
0
    def compression_forward(self, x, yclass):
        """
        Forward pass through encoder, hyperprior, and decoder.

        Inputs
        x:  Input image. Format (N,C,H,W), range [0,1],
            or [-1,1] if args.normalize_image is True
            torch.Tensor

        Outputs
        intermediates: NamedTuple of intermediate values
        """
        image_dims = tuple(x.size()[1:])  # (C,H,W)

        if self.model_mode == ModelModes.EVALUATION and (self.training is
                                                         False):
            #if 1:
            n_encoder_downsamples = self.Encoder.n_downsampling_layers
            factor = 2**n_encoder_downsamples
            x = utils.pad_factor(x, x.size()[2:], factor)

        # Encoder forward pass
        y = self.Encoder(x)

        #if 1:
        if self.model_mode == ModelModes.EVALUATION and (self.training is
                                                         False):
            n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers
            factor = 2**n_hyperencoder_downsamples
            y = utils.pad_factor(y, y.size()[2:], factor)

        hyperinfo = self.Hyperprior(y, spatial_shape=x.size()[2:])

        latents_quantized = hyperinfo.decoded
        total_nbpp = hyperinfo.total_nbpp
        total_qbpp = hyperinfo.total_qbpp

        yhat_class = self.Classi(latents_quantized)
        #yhat_class = self.Classi(y)

        # Use quantized latents as input to G
        reconstruction = self.Generator(latents_quantized)

        if self.args.normalize_input_image is True:
            reconstruction = torch.tanh(reconstruction)

        # Undo padding
        #if 1:
        if self.model_mode == ModelModes.EVALUATION and (self.training is
                                                         False):
            reconstruction = reconstruction[:, :, :image_dims[1], :
                                            image_dims[2]]

        intermediates = Intermediates(x, yhat_class, yclass, reconstruction,
                                      latents_quantized, total_nbpp,
                                      total_qbpp)

        return intermediates, hyperinfo
Exemplo n.º 2
0
    def compress(self, x, silent=False):
        """
        * Pass image through encoder to obtain latents: x -> Encoder() -> y
        * Pass latents through hyperprior encoder to obtain hyperlatents:
          y -> hyperencoder() -> z
        * Encode hyperlatents via nonparametric entropy model.
        * Pass hyperlatents through mean-scale hyperprior decoder to obtain mean,
          scale over latents: z -> hyperdecoder() -> (mu, sigma).
        * Encode latents via entropy model derived from (mean, scale) hyperprior output.
        """

        assert self.model_mode == ModelModes.EVALUATION and (
            self.training is False), (
                f'Set model mode to {ModelModes.EVALUATION} for compression.')

        spatial_shape = tuple(x.size()[2:])

        if self.model_mode == ModelModes.EVALUATION and (self.training is
                                                         False):
            #if 1:
            n_encoder_downsamples = self.Encoder.n_downsampling_layers
            factor = 2**n_encoder_downsamples
            x = utils.pad_factor(x, x.size()[2:], factor)

        # Encoder forward pass
        y = self.Encoder(x)

        if self.model_mode == ModelModes.EVALUATION and (self.training is
                                                         False):
            #if 1:
            n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers
            factor = 2**n_hyperencoder_downsamples
            y = utils.pad_factor(y, y.size()[2:], factor)

        compression_output = self.Hyperprior.compress_forward(y, spatial_shape)
        attained_hbpp = 32 * len(
            compression_output.hyperlatents_encoded) / np.prod(spatial_shape)
        attained_lbpp = 32 * len(
            compression_output.latents_encoded) / np.prod(spatial_shape)
        attained_bpp = 32 * (
            (len(compression_output.hyperlatents_encoded) +
             len(compression_output.latents_encoded)) / np.prod(spatial_shape))
        if silent is False:
            self.logger.info('[ESTIMATED]')
            self.logger.info(f'BPP: {compression_output.total_bpp:.3f}')
            self.logger.info(
                f'HL BPP: {compression_output.hyperlatent_bpp:.3f}')
            self.logger.info(f'L BPP: {compression_output.latent_bpp:.3f}')

            self.logger.info('[ATTAINED]')
            self.logger.info(f'BPP: {attained_bpp:.3f}')
            self.logger.info(f'HL BPP: {attained_hbpp:.3f}')
            self.logger.info(f'L BPP: {attained_lbpp:.3f}')

        return attained_bpp, compression_output
Exemplo n.º 3
0
def vec_ans_index_decoder(encoded,
                          indices,
                          cdf,
                          cdf_length,
                          cdf_offset,
                          precision,
                          coding_shape,
                          overflow_width=OVERFLOW_WIDTH,
                          **kwargs):
    """
    Reverse op of `vec_ans_index_encoder`. Decodes ans-encoded bitstring into a decoded 
    message tensor.
    Arguments (`indices`, `cdf`, `cdf_length`, `cdf_offset`, `precision`) must be 
    identical to the inputs to `vec_ans_index_encoder` used to generate the encoded tensor.
    """

    original_shape = indices.shape
    B, n_channels, *_ = original_shape
    message = vrans.unflatten(encoded, coding_shape)
    indices = indices.astype(np.int32)
    cdf_index = indices

    max_overflow = (1 << overflow_width) - 1
    overflow_cdf_size = (1 << overflow_width) + 1
    overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, :]

    enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf)
    dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(
        overflow_cdf,
        np.ones_like(overflow_cdf) * len(overflow_cdf))
    overflow_codec = base_codec(enc_statfun_overflow, dec_statfun_overflow,
                                overflow_width)

    assert bool(np.all(cdf_index >= 0)) and bool(
        np.all(cdf_index < cdf.shape[0])), ("Invalid index.")

    max_value = cdf_length[cdf_index] - 2

    assert bool(np.all(max_value >= 0)) and bool(
        np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.")

    if B == 1:
        # Vectorize on patches - there's probably a way to interlace patches with
        # batch elements for B > 1 ...

        if ((original_shape[2] % PATCH_SIZE[0] == 0) and
            (original_shape[3] % PATCH_SIZE[1] == 0)) is False:
            indices = utils.pad_factor(torch.Tensor(indices),
                                       original_shape[2:],
                                       factor=PATCH_SIZE).cpu().numpy().astype(
                                           np.int32)
        padded_shape = indices.shape
        assert (indices.shape[2] % PATCH_SIZE[0]
                == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)
        cdf_index, unfolded_shape = compression_utils.decompose(
            indices, n_channels)
        coding_shape = cdf_index.shape[1:]

    symbols = []
    _, overflow_pop = substack(codec=overflow_codec, view_fun=overflow_view)

    for i in range(len(cdf_index)):
        cdf_index_i = cdf_index[i]
        cdf_i = cdf[cdf_index_i]
        cdf_length_i = cdf_length[cdf_index_i]

        enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i)
        dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i)
        symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun,
                                             precision)

        message, value = symbol_pop(message)

        max_value_i = cdf_length_i - 2
        of_mask = value == max_value_i

        if np.any(of_mask):

            message, val = overflow_pop(message, overflow_width, of_mask)
            val = cast2u64(val)
            widths = val

            cond_mask = val == max_overflow
            while np.any(cond_mask):
                message, val = overflow_pop(message, overflow_width, of_mask)
                val = cast2u64(val)
                widths = np.where(cond_mask, widths + val, widths)
                cond_mask = val == max_overflow

            overflow = np.zeros_like(val)
            cond_mask = widths != 0

            while np.any(cond_mask):
                counter = 0
                message, val = overflow_pop(message, overflow_width, of_mask)
                val = cast2u64(val)
                assert np.all(val <= max_overflow)

                op = overflow | (val << (counter * overflow_width))
                overflow = np.where(cond_mask, op, overflow)
                widths = np.where(cond_mask, widths - 1, widths)
                cond_mask = widths != 0
                counter += 1

            overflow_broadcast = value
            overflow_broadcast[of_mask] = overflow
            overflow = overflow_broadcast
            value = np.where(of_mask, overflow >> 1, value)
            cond_mask = np.logical_and(of_mask, overflow & 1)
            value = np.where(cond_mask, -value - 1, value)
            cond_mask = np.logical_and(of_mask, np.logical_not(overflow & 1))
            value = np.where(cond_mask, value + max_value_i, value)

        symbol = value + cdf_offset[cdf_index_i]
        symbols.append(symbol)

    if B == 1:
        decoded = compression_utils.reconstitute(np.stack(symbols, axis=0),
                                                 padded_shape, unfolded_shape)

        if tuple(decoded.shape) != tuple(original_shape):
            decoded = decoded[:, :, :original_shape[2], :original_shape[3]]
    else:
        decoded = np.stack(symbols, axis=0)
    return decoded
Exemplo n.º 4
0
def vec_ans_index_buffered_encoder(symbols,
                                   indices,
                                   cdf,
                                   cdf_length,
                                   cdf_offset,
                                   precision,
                                   coding_shape,
                                   overflow_width=OVERFLOW_WIDTH,
                                   **kwargs):
    """
    Vectorized version of `ans_index_encoder`. Incurs constant bit overhead, 
    but is faster.

    ANS-encodes unbounded integer data using an indexed probability table.
    """

    instructions = []

    symbols_shape = symbols.shape
    B, n_channels = symbols_shape[:2]
    symbols = symbols.astype(np.int32)
    indices = indices.astype(np.int32)
    cdf_index = indices

    max_overflow = (1 << overflow_width) - 1
    overflow_cdf_size = (1 << overflow_width) + 1
    overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, None,
                                                                 None, :]

    enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf)
    dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(
        overflow_cdf,
        np.ones_like(overflow_cdf) * len(overflow_cdf))
    overflow_push, overflow_pop = base_codec(enc_statfun_overflow,
                                             dec_statfun_overflow,
                                             overflow_width)

    assert bool(np.all(cdf_index >= 0)) and bool(
        np.all(cdf_index < cdf.shape[0])), ("Invalid index.")

    max_value = cdf_length[cdf_index] - 2

    assert bool(np.all(max_value >= 0)) and bool(
        np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.")

    # Map values with tracked probabilities to range [0, ..., max_value]
    values = symbols - cdf_offset[cdf_index]

    # If outside of this range, map value to non-negative integer overflow.
    overflow = np.zeros_like(values)
    of_mask_lower = values < 0
    overflow = np.where(of_mask_lower, -2 * values - 1, overflow)
    of_mask_upper = values >= max_value
    overflow = np.where(of_mask_upper, 2 * (values - max_value), overflow)
    values = np.where(np.logical_or(of_mask_lower, of_mask_upper), max_value,
                      values)

    assert bool(np.all(values >= 0)), (
        "Invalid shifted value for current symbol - values must be non-negative."
    )

    assert bool(np.all(values < cdf_length[cdf_index] - 1)), (
        "Invalid shifted value for current symbol - outside cdf index bounds.")

    if B == 1:
        # Vectorize on patches - there's probably a way to interlace patches with
        # batch elements for B > 1 ...
        if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and
            (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False:
            values = utils.pad_factor(torch.Tensor(values),
                                      symbols_shape[2:],
                                      factor=PATCH_SIZE).cpu().numpy().astype(
                                          np.int32)
            indices = utils.pad_factor(torch.Tensor(indices),
                                       symbols_shape[2:],
                                       factor=PATCH_SIZE).cpu().numpy().astype(
                                           np.int32)
            overflow = utils.pad_factor(
                torch.Tensor(overflow), symbols_shape[2:],
                factor=PATCH_SIZE).cpu().numpy().astype(np.int32)

        assert (values.shape[2] % PATCH_SIZE[0]
                == 0) and (values.shape[3] % PATCH_SIZE[1] == 0)
        assert (indices.shape[2] % PATCH_SIZE[0]
                == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)

        values, _ = compression_utils.decompose(values, n_channels)
        overflow, _ = compression_utils.decompose(overflow, n_channels)
        cdf_index, unfolded_shape = compression_utils.decompose(
            indices, n_channels)
        coding_shape = values.shape[1:]
        assert coding_shape == cdf_index.shape[1:]

    # LIFO - last item in buffer is first item decompressed
    for i in range(len(cdf_index)):  # loop over batch dimension
        # Bin of discrete CDF that value belongs to
        value_i = values[i]
        cdf_index_i = cdf_index[i]
        cdf_i = cdf[cdf_index_i]
        cdf_length_i = cdf_length[cdf_index_i]
        max_value_i = cdf_length_i - 2

        enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i)
        dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i)
        symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun,
                                             precision)

        start, freq = enc_statfun(value_i)
        instructions.append((start, freq, False, precision, 0))
        """
        Encode overflows here
        """
        # No-op
        empty_start = np.zeros_like(value_i).astype(np.uint)
        empty_freq = np.ones_like(value_i).astype(np.uint)

        overflow_i = overflow[i]
        of_mask = value_i == max_value_i

        if np.any(of_mask):

            widths = np.zeros_like(value_i)
            cond_mask = (overflow_i >> (widths * overflow_width)) != 0

            while np.any(cond_mask):
                widths = np.where(cond_mask, widths + 1, widths)
                cond_mask = (overflow_i >> (widths * overflow_width)) != 0

            val = widths
            cond_mask = val >= max_overflow
            while np.any(cond_mask):
                print('Warning: Undefined behaviour.')
                val_push = cast2u64(max_overflow)
                overflow_start, overflow_freq = enc_statfun_overflow(val_push)
                start = overflow_start[of_mask]
                freq = overflow_start[of_mask]
                instructions.append(
                    (start, freq, True, int(overflow_width), of_mask))
                # val[cond_mask] -= max_overflow
                val = np.where(cond_mask, val - max_overflow, val)
                cond_mask = val >= max_overflow

            val_push = cast2u64(val)
            overflow_start, overflow_freq = enc_statfun_overflow(val_push)
            start = overflow_start[of_mask]
            freq = overflow_freq[of_mask]
            instructions.append(
                (start, freq, True, int(overflow_width), of_mask))

            cond_mask = widths != 0
            while np.any(cond_mask):
                counter = 0
                encoding = (overflow_i >>
                            (counter * overflow_width)) & max_overflow
                val = np.where(cond_mask, encoding, val)
                val_push = cast2u64(val)
                overflow_start, overflow_freq = enc_statfun_overflow(val_push)
                start = overflow_start[of_mask]
                freq = overflow_freq[of_mask]
                instructions.append(
                    (start, freq, True, int(overflow_width), of_mask))
                widths = np.where(cond_mask, widths - 1, widths)
                cond_mask = widths != 0
                counter += 1

    return instructions, coding_shape
Exemplo n.º 5
0
def vec_ans_index_decoder(encoded,
                          indices,
                          cdf,
                          cdf_length,
                          cdf_offset,
                          precision,
                          coding_shape,
                          overflow_width=OVERFLOW_WIDTH,
                          **kwargs):
    """
    Reverse op of `vec_ans_index_encoder`. Decodes ans-encoded bitstring into a decoded 
    message tensor.
    Arguments (`indices`, `cdf`, `cdf_length`, `cdf_offset`, `precision`) must be 
    identical to the inputs to `vec_ans_index_encoder` used to generate the encoded tensor.
    """

    original_shape = indices.shape
    B, n_channels, *_ = original_shape
    message = vrans.unflatten(encoded, coding_shape)
    indices = indices.astype(np.int32)
    cdf_index = indices

    assert bool(np.all(cdf_index >= 0)) and bool(
        np.all(cdf_index < cdf.shape[0])), ("Invalid index.")

    max_value = cdf_length[cdf_index] - 2

    assert bool(np.all(max_value >= 0)) and bool(
        np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.")

    if B == 1:
        # Vectorize on patches - there's probably a way to interlace patches with
        # batch elements for B > 1 ...

        if ((original_shape[2] % PATCH_SIZE[0] == 0) and
            (original_shape[3] % PATCH_SIZE[1] == 0)) is False:
            indices = utils.pad_factor(torch.Tensor(indices),
                                       original_shape[2:],
                                       factor=PATCH_SIZE).cpu().numpy().astype(
                                           np.int32)
        padded_shape = indices.shape
        assert (indices.shape[2] % PATCH_SIZE[0]
                == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)
        cdf_index, unfolded_shape = compression_utils.decompose(
            indices, n_channels)
        coding_shape = cdf_index.shape[1:]

    symbols = []
    for i in range(len(cdf_index)):
        cdf_index_i = cdf_index[i]
        cdf_i = cdf[cdf_index_i]
        cdf_length_i = cdf_length[cdf_index_i]

        enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i)
        dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i)
        symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun,
                                             precision)

        message, value = symbol_pop(message)
        symbol = value + cdf_offset[cdf_index_i]
        symbols.append(symbol)

    if B == 1:
        decoded = compression_utils.reconstitute(np.stack(symbols, axis=0),
                                                 padded_shape, unfolded_shape)

        if tuple(decoded.shape) != tuple(original_shape):
            decoded = decoded[:, :, :original_shape[2], :original_shape[3]]
    else:
        decoded = np.stack(symbols, axis=0)
    return decoded
Exemplo n.º 6
0
def vec_ans_index_buffered_encoder(symbols,
                                   indices,
                                   cdf,
                                   cdf_length,
                                   cdf_offset,
                                   precision,
                                   coding_shape,
                                   overflow_width=OVERFLOW_WIDTH,
                                   **kwargs):
    """
    Vectorized version of `ans_index_encoder`. Incurs constant bit overhead, 
    but is faster.

    ANS-encodes unbounded integer data using an indexed probability table.
    """

    instructions = []

    symbols_shape = symbols.shape
    B, n_channels = symbols_shape[:2]
    symbols = symbols.astype(np.int32)
    indices = indices.astype(np.int32)
    cdf_index = indices

    assert bool(np.all(cdf_index >= 0)) and bool(
        np.all(cdf_index < cdf.shape[0])), ("Invalid index.")

    max_value = cdf_length[cdf_index] - 2

    assert bool(np.all(max_value >= 0)) and bool(
        np.all(max_value < cdf.shape[1] - 1)), ("Invalid max length.")

    # Map values with tracked probabilities to range [0, ..., max_value]
    values = symbols - cdf_offset[cdf_index]

    # If outside of this range, map value to non-negative integer overflow.
    overflow = np.zeros_like(values)

    of_mask = values < 0
    overflow = np.where(of_mask, -2 * values - 1, overflow)
    values = np.where(of_mask, max_value, values)

    of_mask = values >= max_value
    overflow = np.where(of_mask, 2 * (values - max_value), overflow)
    values = np.where(of_mask, max_value, values)

    assert bool(np.all(values >= 0)), (
        "Invalid shifted value for current symbol - values must be non-negative."
    )

    assert bool(np.all(values < cdf_length[cdf_index] - 1)), (
        "Invalid shifted value for current symbol - outside cdf index bounds.")

    if B == 1:
        # Vectorize on patches - there's probably a way to interlace patches with
        # batch elements for B > 1 ...
        if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and
            (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False:
            values = utils.pad_factor(torch.Tensor(values),
                                      symbols_shape[2:],
                                      factor=PATCH_SIZE).cpu().numpy().astype(
                                          np.int32)
            indices = utils.pad_factor(torch.Tensor(indices),
                                       symbols_shape[2:],
                                       factor=PATCH_SIZE).cpu().numpy().astype(
                                           np.int32)

        assert (values.shape[2] % PATCH_SIZE[0]
                == 0) and (values.shape[3] % PATCH_SIZE[1] == 0)
        assert (indices.shape[2] % PATCH_SIZE[0]
                == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)

        values, _ = compression_utils.decompose(values, n_channels)
        cdf_index, unfolded_shape = compression_utils.decompose(
            indices, n_channels)
        coding_shape = values.shape[1:]

    # LIFO - last item in buffer is first item decompressed
    for i in range(len(cdf_index)):  # loop over batch dimension
        # Bin of discrete CDF that value belongs to
        value_i = values[i]
        cdf_index_i = cdf_index[i]
        cdf_i = cdf[cdf_index_i]
        cdf_i_length = cdf_length[cdf_index_i]

        enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i)
        dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_i_length)
        symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun,
                                             precision)

        start, freq = enc_statfun(value_i)
        instructions.append((start, freq, False))
        """
        Encode overflows here
        """
        # of_mask = values == max_value
        # widths = np.zeros_like(values)
        # widths[of_mask] = 1

        # overflow =

        # cond_mask = overflow >> (widths * overflow_width) != 0
        # while np.all(cond_mask) is False:
        #     widths[cond_mask] += 1

        # val = widths

    return instructions, coding_shape