def _compress_ar(self, y_hat, params, height, width, kernel_size, padding): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() encoder = BufferedRansEncoder() symbols_list = [] indexes_list = [] # Warning, this is slow... # TODO: profile the calls to the bindings... masked_weight = self.context_prediction.weight * self.context_prediction.mask for h in range(height): for w in range(width): y_crop = y_hat[:, :, h:h + kernel_size, w:w + kernel_size] ctx_p = F.conv2d( y_crop, masked_weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h:h + 1, w:w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p), dim=1)) gaussian_params = gaussian_params.squeeze(3).squeeze(2) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_crop = y_crop[:, :, padding, padding] y_q = self.gaussian_conditional.quantize( y_crop, "symbols", means_hat) y_hat[:, :, h + padding, w + padding] = y_q + means_hat symbols_list.extend(y_q.squeeze().tolist()) indexes_list.extend(indexes.squeeze().tolist()) encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets) string = encoder.flush() return string
def compress(self, x): y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # yapf: enable # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape( -1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... for h in range(y_height): for w in range(y_width): y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] ctx_params = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" ctx_p = ctx_params[i:i + 1, :, padding:padding + 1, padding:padding + 1] p = params[i:i + 1, :, h:h + 1, w:w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes( scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[i, :, padding, padding] encoder.encode_with_indexes( y_q[i, :, padding, padding].int().tolist(), indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets) string = encoder.flush() y_strings.append(string) # yapf: disable return {'strings': [y_strings, z_strings], 'shape': z.size()[-2:]}
def compress(self, x): if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU)." ) y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... symbols_list = [] indexes_list = [] for h in range(y_height): for w in range(y_width): y_crop = y_hat[ i : i + 1, :, h : h + kernel_size, w : w + kernel_size ] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[i : i + 1, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ i, :, padding, padding ] symbols_list.extend(y_q[i, :, padding, padding].int().tolist()) indexes_list.extend(indexes[i, :].squeeze().int().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}