class Hyperprior(CompressionModel): def __init__(self, planes: int = 192, mid_planes: int = 192): super().__init__(entropy_bottleneck_channels=mid_planes) self.hyper_encoder = HyperEncoder(planes, mid_planes, planes) self.hyper_decoder_mean = HyperDecoder(planes, mid_planes, planes) self.hyper_decoder_scale = HyperDecoderWithQReLU( planes, mid_planes, planes) self.gaussian_conditional = GaussianConditional(None) def forward(self, y): z = self.hyper_encoder(y) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) _, y_likelihoods = self.gaussian_conditional(y, scales, means) y_hat = quantize_ste(y - means) + means return y_hat, {"y": y_likelihoods, "z": z_likelihoods} def compress(self, y): z = self.hyper_encoder(y) z_string = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress( z_string, z.size()[-2:]) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_string = self.gaussian_conditional.compress( y, indexes, means) y_hat = self.gaussian_conditional.quantize( y, "dequantize", means) return y_hat, { "strings": [y_string, z_string], "shape": z.size()[-2:] } def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales = self.hyper_decoder_scale(z_hat) means = self.hyper_decoder_mean(z_hat) indexes = self.gaussian_conditional.build_indexes(scales) y_hat = self.gaussian_conditional.decompress( strings[0], indexes, z_hat.dtype, means) return y_hat
class ScaleHyperprior(CompressionModel): r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations (ICLR), 2018. Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) x_hat = self.g_s(y_hat) return { "x_hat": x_hat, "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, } def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers( self.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) update_registered_buffers( self.gaussian_conditional, "gaussian_conditional", ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], state_dict, ) super().load_state_dict(state_dict) @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.6.weight"].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) def compress(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_strings = self.entropy_bottleneck.compress(z) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress(y, indexes) return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_hat = self.gaussian_conditional.decompress(strings[0], indexes) x_hat = self.g_s(y_hat).clamp_(0, 1) return {"x_hat": x_hat}
class ScaleHyperprior(CompressionModel): r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, N. Johnston: `"Variational Image Compression with a Scale Hyperprior" <https://arxiv.org/abs/1802.01436>`_ Int. Conf. on Learning Representations (ICLR), 2018. Args: N (int): Number of channels M (int): Number of channels in the expansion layers (last layer of the encoder and last layer of the hyperprior decoder) """ def __init__(self, N, M, **kwargs): super().__init__(entropy_bottleneck_channels=N, **kwargs) self.g_a = nn.Sequential( conv(3, N), GDN(N), conv(N, N), GDN(N), conv(N, N), GDN(N), conv(N, M), ) self.g_s = nn.Sequential( deconv(M, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, N), GDN(N, inverse=True), deconv(N, 3), ) self.h_a = nn.Sequential( conv(M, N, stride=1, kernel_size=3), nn.ReLU(inplace=True), conv(N, N), nn.ReLU(inplace=True), conv(N, N), ) self.h_s = nn.Sequential( deconv(N, N), nn.ReLU(inplace=True), deconv(N, N), nn.ReLU(inplace=True), conv(N, M, stride=1, kernel_size=3), nn.ReLU(inplace=True), ) self.gaussian_conditional = GaussianConditional(None) self.N = int(N) self.M = int(M) def forward(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_hat, z_likelihoods = self.entropy_bottleneck(z) #量化+定义速率失真损失 scales_hat = self.h_s(z_hat) y_hat, y_likelihoods = self.gaussian_conditional( y, scales_hat) #编码导出y_hat时,依然需要解z_hat然后产出y_hat x_hat = self.g_s(y_hat) return { 'x_hat': x_hat, 'likelihoods': { 'y': y_likelihoods, 'z': z_likelihoods }, } def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers(self.entropy_bottleneck, 'entropy_bottleneck', ['_quantized_cdf', '_offset', '_cdf_length'], state_dict) update_registered_buffers( self.gaussian_conditional, 'gaussian_conditional', ['_quantized_cdf', '_offset', '_cdf_length', 'scale_table'], state_dict) super().load_state_dict(state_dict) @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" N = state_dict['g_a.0.weight'].size(0) M = state_dict['g_a.6.weight'].size(0) net = cls(N, M) net.load_state_dict(state_dict) return net def update(self, scale_table=None, force=False): if scale_table is None: scale_table = get_scale_table() self.gaussian_conditional.update_scale_table(scale_table, force=force) super().update(force=force) #重点 def compress(self, x): y = self.g_a(x) z = self.h_a(torch.abs(y)) z_strings = self.entropy_bottleneck.compress(z) #z直接量化+估计速率失真 ++ 熵编码 z_hat = self.entropy_bottleneck.decompress( z_strings, z.size()[-2:]) #z解码后结果(压缩时仍需要) scales_hat = self.h_s(z_hat) #z解码后通过h_s的结果(压缩时仍需要) indexes = self.gaussian_conditional.build_indexes(scales_hat) y_strings = self.gaussian_conditional.compress( y, indexes) #y ++ 熵编码 其中indexes已被z_hat影响过 return {'strings': [y_strings, z_strings], 'shape': z.size()[-2:]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 2 z_hat = self.entropy_bottleneck.decompress(strings[1], shape) scales_hat = self.h_s(z_hat) indexes = self.gaussian_conditional.build_indexes( scales_hat) #同压缩,获得indexes,其中indexes已被z_hat影响过 y_hat = self.gaussian_conditional.decompress( strings[0], indexes) #y ++ 熵解码 其中indexes已被z_hat影响过 x_hat = self.g_s(y_hat).clamp_(0, 1) #通过g_s网络,获得估计图像 return {'x_hat': x_hat}