def __init__(self, planes: int = 192, mid_planes: int = 192): super().__init__(entropy_bottleneck_channels=mid_planes) self.hyper_encoder = HyperEncoder(planes, mid_planes, planes) self.hyper_decoder_mean = HyperDecoder(planes, mid_planes, planes) self.hyper_decoder_scale = HyperDecoderWithQReLU( planes, mid_planes, planes) self.gaussian_conditional = GaussianConditional(None)
def __init__(self, N=192, M=192, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, M, kernel_size=5, stride=2), ) self.g_s = nn.Sequential( deconv(M, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M)
def test_forward_inference(self): gaussian_conditional = GaussianConditional(None) gaussian_conditional.eval() x = torch.rand(1, 128, 32, 32) scales = torch.rand(1, 128, 32, 32) y, y_likelihoods = gaussian_conditional(x, scales) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert (y == torch.round(x)).all()
class Hyperprior(CompressionModel): def __init__(self, planes: int = 192, mid_planes: int = 192): super().__init__(entropy_bottleneck_channels=mid_planes) self.hyper_encoder = HyperEncoder(planes, mid_planes, planes) self.hyper_decoder_mean = HyperDecoder(planes, mid_planes, planes) self.hyper_decoder_scale = HyperDecoderWithQReLU( planes, mid_planes, planes) self.gaussian_conditional = GaussianConditional(None) def forward(self, y): z = self.hyper_encoder(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) _, y_likelihoods = self.gaussian_conditional(y, scales, means) y_hat = quantize_ste(y - means) + means return y_hat, {"y": y_likelihoods, "z": z_likelihoods} def compress(self, y): z = self.hyper_encoder(y) z_string = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress( z_string, z.size()[-2:]) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_string = self.gaussian_conditional.compress( y, indexes, means) y_hat = self.gaussian_conditional.quantize( y, "dequantize", means) return y_hat, { "strings": [y_string, z_string], "shape": z.size()[-2:] } def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, z_hat.dtype, means) return y_hat
def __init__(self, N, M, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M)
def test_forward_training_mean(self): gaussian_conditional = GaussianConditional(None) x = torch.rand(1, 128, 32, 32) scales = torch.rand(1, 128, 32, 32) means = torch.rand(1, 128, 32, 32) y, y_likelihoods = gaussian_conditional(x, scales, means) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert ((y - x) <= 0.5).all() assert ((y - x) >= -0.5).all() assert (y != torch.round(x)).any()
def test_forward_training(self): gaussian_conditional = GaussianConditional(None) x = torch.rand(1, 128, 32, 32) scales = torch.rand(1, 128, 32, 32) y, y_likelihoods = gaussian_conditional(x, scales) assert isinstance(gaussian_conditional, EntropyModel) assert y.shape == x.shape assert y_likelihoods.shape == x.shape assert ((y - x) <= 0.5).all() assert ((y - x) >= -0.5).all() assert (y != torch.round(x)).any()
def test_scripting(self): gaussian_conditional = GaussianConditional(None) x = torch.rand(1, 128, 32, 32) scales = torch.rand(1, 128, 32, 32) means = torch.rand(1, 128, 32, 32) torch.manual_seed(32) y0 = gaussian_conditional(x, scales, means) m = torch.jit.script(gaussian_conditional) torch.manual_seed(32) y1 = m(x, scales, means) assert torch.allclose(y0[0], y1[0]) assert torch.allclose(y0[1], y1[1])
def test_invalid_scale_table(self): with pytest.raises(ValueError): GaussianConditional(1) with pytest.raises(ValueError): GaussianConditional([]) with pytest.raises(ValueError): GaussianConditional(()) with pytest.raises(ValueError): GaussianConditional(torch.rand(10)) with pytest.raises(ValueError): GaussianConditional([2, 1]) with pytest.raises(ValueError): GaussianConditional([0, 1, 2]) with pytest.raises(ValueError): GaussianConditional([], scale_bound=None) with pytest.raises(ValueError): GaussianConditional([], scale_bound=-0.1)
class JointAutoregressiveHierarchicalPriors(CompressionModel): r"""Joint Autoregressive Hierarchical Priors model from D. Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N=192, M=192, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, M, kernel_size=5, stride=2), ) self.g_s = nn.Sequential( deconv(M, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction = MaskedConv2d( M, 2 * M, kernel_size=5, padding=2, stride=1 ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) params = self.h_s(z_hat) y_hat = self.gaussian_conditional.quantize( y, "noise" if self.training else "dequantize" ) ctx_params = self.context_prediction(y_hat) gaussian_params = self.entropy_parameters( torch.cat((params, ctx_params), dim=1) ) scales_hat, means_hat = gaussian_params.chunk(2, 1) _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x): if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU)." ) y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) y_strings = [] for i in range(y.size(0)): string = self._compress_ar( y_hat[i : i + 1], params, y_height, y_width, kernel_size, padding ) y_strings.append(string) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def _compress_ar(self, y_hat, params, height, width, kernel_size, padding): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() encoder = BufferedRansEncoder() symbols_list = [] indexes_list = [] # Warning, this is slow... # TODO: profile the calls to the bindings... for h in range(height): for w in range(width): y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) gaussian_params = gaussian_params.squeeze(3).squeeze(2) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_crop = y_crop[:, :, padding, padding] y_q = self.gaussian_conditional.quantize(y_crop, "symbols", means_hat) y_hat[:, :, h + padding, w + padding] = y_q + means_hat symbols_list.extend(y_q.squeeze().tolist()) indexes_list.extend(indexes.squeeze().tolist()) encoder.encode_with_indexes( symbols_list, indexes_list, cdf, cdf_lengths, offsets ) string = encoder.flush() return string def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 if next(self.parameters()).device != torch.device("cpu"): warnings.warn( "Inference on GPU is not recommended for the autoregressive " "models (the entropy coder is run sequentially on CPU)." ) # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) y_hat = torch.zeros( (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device, ) for i, y_string in enumerate(strings[0]): self._decompress_ar( y_string, y_hat[i : i + 1], params, y_height, y_width, kernel_size, padding, ) y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding)) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat} def _decompress_ar( self, y_string, y_hat, params, height, width, kernel_size, padding ): cdf = self.gaussian_conditional.quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional.cdf_length.tolist() offsets = self.gaussian_conditional.offset.tolist() decoder = RansDecoder() decoder.set_stream(y_string) # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for h in range(height): for w in range(width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[:, :, h : h + kernel_size, w : w + kernel_size] ctx_p = F.conv2d( y_crop, self.context_prediction.weight, bias=self.context_prediction.bias, ) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" p = params[:, :, h : h + 1, w : w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes.squeeze().tolist(), cdf, cdf_lengths, offsets ) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional.dequantize(rv, means_hat) hp = h + padding wp = w + padding y_hat[:, :, hp : hp + 1, wp : wp + 1] = rv def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers( self.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) update_registered_buffers( self.gaussian_conditional, "gaussian_conditional", ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], state_dict, ) super().load_state_dict(state_dict)
class ScaleHyperprior(CompressionModel): r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations (ICLR), 2018. Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers( self.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) update_registered_buffers( self.gaussian_conditional, "gaussian_conditional", ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], state_dict, ) super().load_state_dict(state_dict) @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) def compress(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress(strings[0], indexes) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
class JointAutoregressiveHierarchicalPriors(CompressionModel): r"""Joint Autoregressive Hierarchical Priors model from D. Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical Priors for Learned Image Compression" <https://arxiv.org/abs/1809.02736>`_, Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N=192, M=192, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, N, kernel_size=5, stride=2), GDN(N), conv(N, M, kernel_size=5, stride=2), ) self.g_s = nn.Sequential( deconv(M, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, N, kernel_size=5, stride=2), GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) params = self.h_s(z_hat) y_hat = self.gaussian_conditional._quantize( # pylint: disable=protected-access y, 'noise' if self.training else 'dequantize') ctx_params = self.context_prediction(y_hat) gaussian_params = self.entropy_parameters( torch.cat((params, ctx_params), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) x_hat = self.g_s(y_hat) return { 'x_hat': x_hat, 'likelihoods': { 'y': y_likelihoods, 'z': z_likelihoods }, } @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict['g_a.0.weight'].size(0) M = state_dict['g_a.6.weight'].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def compress(self, x): y = self.g_a(x) z = self.h_a(y) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s y_hat = F.pad(y, (padding, padding, padding, padding)) # yapf: enable # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape( -1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # pylint: enable=protected-access y_strings = [] for i in range(y.size(0)): encoder = BufferedRansEncoder() # Warning, this is slow... # TODO: profile the calls to the bindings... for h in range(y_height): for w in range(y_width): y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] ctx_params = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" ctx_p = ctx_params[i:i + 1, :, padding:padding + 1, padding:padding + 1] p = params[i:i + 1, :, h:h + 1, w:w + 1] gaussian_params = self.entropy_parameters( torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes( scales_hat) y_q = torch.round(y_crop - means_hat) y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[i, :, padding, padding] encoder.encode_with_indexes( y_q[i, :, padding, padding].int().tolist(), indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets) string = encoder.flush() y_strings.append(string) # yapf: disable return {'strings': [y_strings, z_strings], 'shape': z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 # FIXME: we don't respect the default entropy coder and directly call the # range ANS decoder z_hat = self.entropy_bottleneck.decompress(strings[1], shape) params = self.h_s(z_hat) s = 4 # scaling factor between z and y kernel_size = 5 # context prediction kernel size padding = (kernel_size - 1) // 2 y_height = z_hat.size(2) * s y_width = z_hat.size(3) * s # initialize y_hat to zeros, and pad it so we can directly work with # sub-tensors of size (N, C, kernel size, kernel_size) # yapf: disable y_hat = torch.zeros((z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), device=z_hat.device) decoder = RansDecoder() # pylint: disable=protected-access cdf = self.gaussian_conditional._quantized_cdf.tolist() cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() # Warning: this is slow due to the auto-regressive nature of the # decoding... See more recent publication where they use an # auto-regressive module on chunks of channels for faster decoding... for i, y_string in enumerate(strings[0]): decoder.set_stream(y_string) for h in range(y_height): for w in range(y_width): # only perform the 5x5 convolution on a cropped tensor # centered in (h, w) y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] # ctx_params = self.context_prediction(torch.round(y_crop)) ctx_params = self.context_prediction(y_crop) # 1x1 conv for the entropy parameters prediction network, so # we only keep the elements in the "center" ctx_p = ctx_params[i:i + 1, :, padding:padding + 1, padding:padding + 1] p = params[i:i + 1, :, h:h + 1, w:w + 1] gaussian_params = self.entropy_parameters(torch.cat((p, ctx_p), dim=1)) scales_hat, means_hat = gaussian_params.chunk(2, 1) indexes = self.gaussian_conditional.build_indexes(scales_hat) rv = decoder.decode_stream( indexes[i, :].squeeze().int().tolist(), cdf, cdf_lengths, offsets) rv = torch.Tensor(rv).reshape(1, -1, 1, 1) rv = self.gaussian_conditional._dequantize(rv, means_hat) y_hat[i, :, h + padding:h + padding + 1, w + padding:w + padding + 1] = rv y_hat = y_hat[:, :, padding:-padding, padding:-padding] # pylint: enable=protected-access # yapf: enable x_hat = self.g_s(y_hat).clamp_(0, 1) return {'x_hat': x_hat} def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers(self.entropy_bottleneck, 'entropy_bottleneck', ['_quantized_cdf', '_offset', '_cdf_length'], state_dict) update_registered_buffers( self.gaussian_conditional, 'gaussian_conditional', ['_quantized_cdf', '_offset', '_cdf_length', 'scale_table'], state_dict) super().load_state_dict(state_dict)
class ScaleHyperprior(CompressionModel): r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations (ICLR), 2018. Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) #量化+定义速率失真损失 scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional( y, scales_hat) #编码导出y_hat时,依然需要解z_hat然后产出y_hat x_hat = self.g_s(y_hat) return { 'x_hat': x_hat, 'likelihoods': { 'y': y_likelihoods, 'z': z_likelihoods }, } def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers(self.entropy_bottleneck, 'entropy_bottleneck', ['_quantized_cdf', '_offset', '_cdf_length'], state_dict) update_registered_buffers( self.gaussian_conditional, 'gaussian_conditional', ['_quantized_cdf', '_offset', '_cdf_length', 'scale_table'], state_dict) super().load_state_dict(state_dict) @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict['g_a.0.weight'].size(0) M = state_dict['g_a.6.weight'].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) #重点 def compress(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_strings = self.entropy_bottleneck.compress(z) #z直接量化+估计速率失真 ++ 熵编码 z_hat = self.entropy_bottleneck.decompress( z_strings, z.size()[-2:]) #z解码后结果(压缩时仍需要) scales_hat = self.h_s(z_hat) #z解码后通过h_s的结果(压缩时仍需要) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress( y, indexes) #y ++ 熵编码 其中indexes已被z_hat影响过 return {'strings': [y_strings, z_strings], 'shape': z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes( scales_hat) #同压缩,获得indexes,其中indexes已被z_hat影响过 y_hat = self.gaussian_conditional.decompress( strings[0], indexes) #y ++ 熵解码 其中indexes已被z_hat影响过 x_hat = self.g_s(y_hat).clamp_(0, 1) #通过g_s网络,获得估计图像 return {'x_hat': x_hat}
def __init__(self,N=128,M=192,K=5,**kwargs): #'cuda:0' or 'cpu' super().__init__(entropy_bottleneck_channels=N, **kwargs) # super(DSIC, self).__init__() # self.entropy_bottleneck1 = CompressionModel(entropy_bottleneck_channels=N) # self.entropy_bottleneck2 = CompressionModel(entropy_bottleneck_channels=N) self.gaussian1 = GaussianMixtureConditional(K = K) self.gaussian2 = GaussianMixtureConditional(K = K) self.N = int(N) self.M = int(M) self.K = int(K) #定义组件 self.encoder1 = Encoder1(N,M) self.encoder2 = Encoder2(N,M) self.decoder1 = Decoder1(N,M) self.decoder2 = Decoder2(N,M) # pic2 需要的组件 # #hyper # self._h_a1 = encode_hyper(N=N,M=M) # self._h_a2 = encode_hyper(N=N,M=M) # self._h_s1 = gmm_hyper_y1(N=N,M=M,K=K) # self._h_s2 = gmm_hyper_y2(N=N,M=M,K=K) ###################################################################### self.h_a1 = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s1 = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters1 = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction1 = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional1 = GaussianConditional(None) self.h_a2 = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s2 = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters2 = nn.Sequential( nn.Conv2d(M * 18 // 3, M * 10 // 3, 1), # (M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction2 = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional2 = GaussianConditional(None)
class HSIC(CompressionModel): def __init__(self,N=128,M=192,K=5,**kwargs): #'cuda:0' or 'cpu' super().__init__(entropy_bottleneck_channels=N, **kwargs) # super(DSIC, self).__init__() # self.entropy_bottleneck1 = CompressionModel(entropy_bottleneck_channels=N) # self.entropy_bottleneck2 = CompressionModel(entropy_bottleneck_channels=N) self.gaussian1 = GaussianMixtureConditional(K = K) self.gaussian2 = GaussianMixtureConditional(K = K) self.N = int(N) self.M = int(M) self.K = int(K) #定义组件 self.encoder1 = Encoder1(N,M) self.encoder2 = Encoder2(N,M) self.decoder1 = Decoder1(N,M) self.decoder2 = Decoder2(N,M) # pic2 需要的组件 # #hyper # self._h_a1 = encode_hyper(N=N,M=M) # self._h_a2 = encode_hyper(N=N,M=M) # self._h_s1 = gmm_hyper_y1(N=N,M=M,K=K) # self._h_s2 = gmm_hyper_y2(N=N,M=M,K=K) ###################################################################### self.h_a1 = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s1 = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters1 = nn.Sequential( nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction1 = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional1 = GaussianConditional(None) self.h_a2 = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(N, N, stride=2, kernel_size=5), ) self.h_s2 = nn.Sequential( deconv(N, M, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), deconv(M, M * 3 // 2, stride=2, kernel_size=5), nn.LeakyReLU(inplace=True), conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), ) self.entropy_parameters2 = nn.Sequential( nn.Conv2d(M * 18 // 3, M * 10 // 3, 1), # (M * 12 // 3, M * 10 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), nn.LeakyReLU(inplace=True), nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), ) self.context_prediction2 = MaskedConv2d(M, 2 * M, kernel_size=5, padding=2, stride=1) self.gaussian_conditional2 = GaussianConditional(None) def forward(self,x1,x2,h_matrix): #定义结构 y1,g1_1,g1_2,g1_3 = self.encoder1(x1) z1 = self.h_a1(y1) #print(z1.device) z1_hat,z1_likelihoods = self.entropy_bottleneck1(z1) #change: params1 = self.h_s1(z1_hat) y1_hat = self.gaussian_conditional1._quantize( # pylint: disable=protected-access y1, 'noise' if self.training else 'dequantize') ctx_params1 = self.context_prediction1(y1_hat) #用两次!! 2M gaussian_params1 = self.entropy_parameters1( torch.cat((params1, ctx_params1), dim=1)) scales_hat1, means_hat1 = gaussian_params1.chunk(2, 1) _, y1_likelihoods = self.gaussian_conditional1(y1, scales_hat1, means=means_hat1) # gmm1 = self._h_s1(z1_hat) #三要素 # y1_hat, y1_likelihoods = self.gaussian1(y1, gmm1[0],gmm1[1],gmm1[2]) # sigma x1_hat,g1_4,g1_5,g1_6 = self.decoder1(y1_hat) ############################################# #encoder x1_warp = kornia.warp_perspective(x1, h_matrix, (x1.size()[-2],x1.size()[-1])) y2 = self.encoder2(x1_warp,x2) ##end encoder # hyper for pic2 z2 = self.h_a2(y2) z2_hat, z2_likelihoods = self.entropy_bottleneck2(z2) #change params2 = self.h_s2(z2_hat) y2_hat = self.gaussian_conditional2._quantize( # pylint: disable=protected-access y2, 'noise' if self.training else 'dequantize') ctx_params2 = self.context_prediction2(y2_hat) gaussian_params2 = self.entropy_parameters2( torch.cat((params2, ctx_params2, ctx_params1), dim=1)) scales_hat2, means_hat2 = gaussian_params2.chunk(2, 1) _, y2_likelihoods = self.gaussian_conditional1(y2, scales_hat2, means=means_hat2) # gmm2 = self._h_s2(z2_hat, y1_hat) # 三要素 # y2_hat, y2_likelihoods = self.gaussian2(y2, gmm2[0], gmm2[1], gmm2[2]) # 这里也是临时,待改gmm # end hyper for pic2 ##decoder x1_hat_warp = kornia.warp_perspective(x1_hat, h_matrix, (x1_hat.size()[-2],x1_hat.size()[-1])) x2_hat = self.decoder2(y2_hat,x1_hat_warp) #end decoder # print(x1.size()) return { 'x1_hat': x1_hat, 'x2_hat': x2_hat, 'likelihoods':{ 'y1': y1_likelihoods, 'y2': y2_likelihoods, 'z1': z1_likelihoods, 'z2': z2_likelihoods, } }