def _decompress_ar( self, y_string, y_hat, params, height, width, kernel_size, padding ): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() decoder = RansDecoder() decoder.set_stream(y_string) # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for h in range(height): for w in range(width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes.squeeze().tolist(), cdf, cdf_lengths, offsets ) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional.dequantize(rv, means_hat) hp = h + padding wp = w + padding y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv
def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) # yapf: disable y_hat = torch.zeros((z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device) decoder = RansDecoder() # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for i, y_string in enumerate(strings[0]): decoder.set_stream(y_string) for h in range(y_height): for w in range(y_width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] # ctx_params = self.context_prediction(torch.round(y_crop)) ctx_params = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" ctx_p = ctx_params[i:i + 1, :, padding:padding + 1, padding:padding + 1] p = params[i:i + 1, :, h:h + 1, w:w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional._dequantize(rv, means_hat) y_hat[i, :, h + padding:h + padding + 1, w + padding:w + padding + 1] = rv y_hat = y_hat[:, :, padding:-padding, padding:-padding] # pylint: enable=protected-access # yapf: enable x_hat = self.g_s(y_hat).clamp_(0, 1) return {'x_hat': x_hat}