def __init__(self): super().__init__() self.enc = Encoder.create_module_from_const() self.dec = Decoder.create_module_from_const() self.importance_map_layer = ImportanceMapMult( use_map=True, info_channels=config.quantizer_num_of_channels ) self.quantizer = Quantizer( num_centers=config.quantizer_num_of_centers, centers_initial_range=config.quantizer_center_init_range, centers_regularization_factor=0.1, sigma=0.1, init_centers_uniformly=True, ) self.prob_classif = ProbClassifier( classifier_in_3d_channels=1, classifier_out_3d_channels=config.quantizer_num_of_centers, receptive_field=config.quantizer_kernel_w_h, ) self.noramlize = ChangeImageStatsToKitti( direction=ChangeState.NORMALIZE) self.denoramlize = ChangeImageStatsToKitti( direction=ChangeState.DENORMALIZE)
def test_norm_flow_3channels(self): INPUT_CHANNELS = 3 normalize = ChangeImageStatsToKitti(ChangeState.NORMALIZE) denormalize = ChangeImageStatsToKitti(ChangeState.DENORMALIZE) x = torch.randn([2, INPUT_CHANNELS, 50, 50]) self.assertAlmostEqual(torch.mean(x).data, 0.0, delta=0.2) self.assertAlmostEqual(torch.var(x).data, 1.0, delta=0.3) denorm = denormalize(x) self.assertTrue(tuple(denorm.shape) == (2, INPUT_CHANNELS, 50, 50)) self.assertTrue( torch.allclose( torch.mean(denorm, dim=(0, 2, 3)), torch.tensor( [93.70454143384742, 98.28243432206516, 94.84678088809876], dtype=torch.float32, ), atol=30, rtol=0, )) self.assertTrue( torch.allclose( torch.var(denorm, dim=(0, 2, 3)), torch.tensor( [5411.79935676, 5758.60456747, 5890.31451232], dtype=torch.float32, ), atol=0, rtol=0.5, )) normalized = normalize(denorm) self.assertTrue( torch.allclose( torch.mean(normalized, dim=(0, 2, 3)), torch.tensor( [0, 0, 0], dtype=torch.float32, ), atol=0.3, rtol=0, )) self.assertTrue( torch.allclose( torch.var(normalized, dim=(0, 2, 3)), torch.tensor( [1, 1, 1], dtype=torch.float32, ), atol=0.3, rtol=0, )) self.assertTrue(tuple(normalized.shape) == (2, INPUT_CHANNELS, 50, 50))
def __init__( self, conv2d_1: Dict, uberresblocks: Dict, post_uberblock_resblock: Dict, prelast_conv2d: Dict, last_conv2d: Dict, ): super().__init__() # first deconv layers self.pre_uberblock_model = Conv2dReluBatch2d(**conv2d_1) # 5 uber blocks layers = [] for i in range(uberresblocks["num_of_uberresblocks"]): layers.append(UberResBlock(**uberresblocks["uberresblock"])) # resblock after the last uber-block layers.append(ResBlock(**post_uberblock_resblock)) self.pre_sum_model = nn.Sequential(*layers) self.post_sum_model = nn.Sequential( Conv2dReluBatch2d(**prelast_conv2d), nn.ConvTranspose2d(**last_conv2d), ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE), )
def __init__( self, conv2d_1: Dict, conv2d_2: Dict, uberresblocks: Dict, prelast_resblock: Dict, last_conv2d: Dict, ): super().__init__() pre_res_layers = [ ChangeImageStatsToKitti(direction=ChangeState.NORMALIZE) ] # first conv layers pre_res_layers.extend( [Conv2dReluBatch2d(**conv2d_1), Conv2dReluBatch2d(**conv2d_2)]) # 5 uber blocks res_layers = [] for i in range(uberresblocks["num_of_uberresblocks"]): res_layers.append(UberResBlock(**uberresblocks["uberresblock"])) # resblock after the last uber-block res_layers.append(ResBlock(**prelast_resblock)) self.pre_res_model = nn.Sequential(*pre_res_layers) self.res_model = nn.Sequential(*res_layers) self.post_sum_model = nn.Conv2d(**last_conv2d)
def __init__(self, m_feat, layer_ids, layer_wgts): super().__init__() self.m_feat = m_feat self.loss_features = [self.m_feat[i] for i in layer_ids] self.hooks = hook_outputs(self.loss_features, detach=False) self.wgts = layer_wgts self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids)) ] + [f'gram_{i}' for i in range(len(layer_ids))] self.noramlize = ChangeImageStatsToKitti(direction=ChangeState.NORMALIZE)
def main(): si_autoencoder = SideInformationAutoEncoder(config.use_si_flag) loss_manager = LossManager() si_net_loss = loss_manager.create_si_net_loss() optimizer = torch.optim.Adam(si_autoencoder.parameters(), lr=1e-4) denoramlize = ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE) x = denoramlize(torch.randn(1, 3, 192, 144)) y = denoramlize(torch.randn(1, 3, 192, 144)) B = 1 for t in range(B): # change image stats to mock kitti image ( x_reconstructed, x_dec, x_pc, importance_map_mult_weights, x_quantizer_index_of_closest_center, ) = si_autoencoder(x=x, y=y) bit_cost_loss_value = loss_manager.get_bit_cost_loss( pc_output=x_pc, quantizer_closest_center_index=x_quantizer_index_of_closest_center, importance_map_mult_weights=importance_map_mult_weights, beta_factor=config.beta, target_bit_cost=config.H_target, ) si_net_loss_value = ( si_net_loss(x_reconstructed, x) if config.use_si_flag == SiNetChannelIn.WithSideInformation else 0 ) autoencoder_loss_value = Distortions._calc_dist( x_dec, x, distortion=config.autoencoder_loss_distortion_to_minimize, cast_to_int=False, ) total_loss = ( autoencoder_loss_value * (1 - config.si_loss_weight_alpha) + si_net_loss_value * config.si_loss_weight_alpha + bit_cost_loss_value ) if t % 100 == 0: print(t, total_loss.item()) optimizer.zero_grad() total_loss.backward() optimizer.step()
def __init__(self, in_channels: SiNetChannelIn): super().__init__() internal_layers = [ DilatedResBlock( in_channels=32, out_channels=32, kernel_size=[3, 3], dilation=2 ** (i + 1), negative_slope=self.NEG_SLOPE) for i in range(self.NOF_INTERNAL_LAYERS)] pre_layers = [ nn.Conv2d( in_channels=in_channels.value, out_channels=32, kernel_size=[3, 3], padding_mode="replicate", padding=[1, 1], ), nn.LeakyReLU(negative_slope=self.NEG_SLOPE), nn.BatchNorm2d(32, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True, ), ] post_layers = [ nn.Conv2d(in_channels=32, out_channels=3, kernel_size=[1, 1],), ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE), ] # self.layers = pre_layers + post_layers + internal_layers self.pre_model = nn.Sequential(*pre_layers) self.internal_model = nn.Sequential(*internal_layers) self.post_model = nn.Sequential(*post_layers) self._weight_init()
def __init__(self, in_channels: SiNetChannelIn): super().__init__() internal_layers = [ Conv2dDSIN( in_channels=32, out_channels=32, kernel_size=[3, 3], dilation=2 ** (i + 1), negative_slope=self.NEG_SLOPE) for i in range(self.NOF_INTERNAL_LAYERS)] pre_layers = [ Conv2dDSIN( in_channels=in_channels.value, out_channels=32, kernel_size=[3, 3], dilation=1, negative_slope=self.NEG_SLOPE), ] post_layers = [ Conv2dDSIN( in_channels=32, out_channels=32, kernel_size=[3, 3], dilation=1, negative_slope=self.NEG_SLOPE), nn.Conv2d(in_channels=32, out_channels=3, kernel_size=[1, 1],), ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE), ] # self.layers = pre_layers + post_layers + internal_layers self.pre_model = nn.Sequential(*pre_layers) self.internal_model = nn.Sequential(*internal_layers) self.post_model = nn.Sequential(*post_layers) self._weight_init()