示例#1
0
    def __init__(self):
        super().__init__()
        self.enc = Encoder.create_module_from_const()
        self.dec = Decoder.create_module_from_const()
        self.importance_map_layer = ImportanceMapMult(
            use_map=True, info_channels=config.quantizer_num_of_channels
        )

        self.quantizer = Quantizer(
            num_centers=config.quantizer_num_of_centers,
            centers_initial_range=config.quantizer_center_init_range,
            centers_regularization_factor=0.1,
            sigma=0.1,
            init_centers_uniformly=True,
        )

        self.prob_classif = ProbClassifier(
            classifier_in_3d_channels=1,
            classifier_out_3d_channels=config.quantizer_num_of_centers,
            receptive_field=config.quantizer_kernel_w_h,
        )

        self.noramlize = ChangeImageStatsToKitti(
            direction=ChangeState.NORMALIZE)
        self.denoramlize = ChangeImageStatsToKitti(
            direction=ChangeState.DENORMALIZE)
示例#2
0
    def test_norm_flow_3channels(self):
        INPUT_CHANNELS = 3
        normalize = ChangeImageStatsToKitti(ChangeState.NORMALIZE)
        denormalize = ChangeImageStatsToKitti(ChangeState.DENORMALIZE)

        x = torch.randn([2, INPUT_CHANNELS, 50, 50])

        self.assertAlmostEqual(torch.mean(x).data, 0.0, delta=0.2)
        self.assertAlmostEqual(torch.var(x).data, 1.0, delta=0.3)

        denorm = denormalize(x)
        self.assertTrue(tuple(denorm.shape) == (2, INPUT_CHANNELS, 50, 50))
        self.assertTrue(
            torch.allclose(
                torch.mean(denorm, dim=(0, 2, 3)),
                torch.tensor(
                    [93.70454143384742, 98.28243432206516, 94.84678088809876],
                    dtype=torch.float32,
                ),
                atol=30,
                rtol=0,
            ))

        self.assertTrue(
            torch.allclose(
                torch.var(denorm, dim=(0, 2, 3)),
                torch.tensor(
                    [5411.79935676, 5758.60456747, 5890.31451232],
                    dtype=torch.float32,
                ),
                atol=0,
                rtol=0.5,
            ))

        normalized = normalize(denorm)

        self.assertTrue(
            torch.allclose(
                torch.mean(normalized, dim=(0, 2, 3)),
                torch.tensor(
                    [0, 0, 0],
                    dtype=torch.float32,
                ),
                atol=0.3,
                rtol=0,
            ))

        self.assertTrue(
            torch.allclose(
                torch.var(normalized, dim=(0, 2, 3)),
                torch.tensor(
                    [1, 1, 1],
                    dtype=torch.float32,
                ),
                atol=0.3,
                rtol=0,
            ))

        self.assertTrue(tuple(normalized.shape) == (2, INPUT_CHANNELS, 50, 50))
示例#3
0
    def __init__(
        self,
        conv2d_1: Dict,
        uberresblocks: Dict,
        post_uberblock_resblock: Dict,
        prelast_conv2d: Dict,
        last_conv2d: Dict,
    ):
        super().__init__()

        # first deconv layers
        self.pre_uberblock_model = Conv2dReluBatch2d(**conv2d_1)

        # 5 uber blocks
        layers = []
        for i in range(uberresblocks["num_of_uberresblocks"]):
            layers.append(UberResBlock(**uberresblocks["uberresblock"]))
        # resblock after the last uber-block
        layers.append(ResBlock(**post_uberblock_resblock))
        self.pre_sum_model = nn.Sequential(*layers)

        self.post_sum_model = nn.Sequential(
            Conv2dReluBatch2d(**prelast_conv2d),
            nn.ConvTranspose2d(**last_conv2d),
            ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE),
        )
示例#4
0
    def __init__(
        self,
        conv2d_1: Dict,
        conv2d_2: Dict,
        uberresblocks: Dict,
        prelast_resblock: Dict,
        last_conv2d: Dict,
    ):
        super().__init__()
        pre_res_layers = [
            ChangeImageStatsToKitti(direction=ChangeState.NORMALIZE)
        ]
        # first conv layers
        pre_res_layers.extend(
            [Conv2dReluBatch2d(**conv2d_1),
             Conv2dReluBatch2d(**conv2d_2)])

        # 5 uber blocks
        res_layers = []
        for i in range(uberresblocks["num_of_uberresblocks"]):
            res_layers.append(UberResBlock(**uberresblocks["uberresblock"]))

        # resblock after the last uber-block
        res_layers.append(ResBlock(**prelast_resblock))

        self.pre_res_model = nn.Sequential(*pre_res_layers)
        self.res_model = nn.Sequential(*res_layers)
        self.post_sum_model = nn.Conv2d(**last_conv2d)
示例#5
0
 def __init__(self, m_feat, layer_ids, layer_wgts):
     super().__init__()
     self.m_feat = m_feat
     self.loss_features = [self.m_feat[i] for i in layer_ids]
     self.hooks = hook_outputs(self.loss_features, detach=False)
     self.wgts = layer_wgts
     self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
           ] + [f'gram_{i}' for i in range(len(layer_ids))]
     self.noramlize = ChangeImageStatsToKitti(direction=ChangeState.NORMALIZE)
示例#6
0
文件: main.py 项目: barakbeilin/tDSIN
def main():
    si_autoencoder = SideInformationAutoEncoder(config.use_si_flag)
    loss_manager = LossManager()
    si_net_loss = loss_manager.create_si_net_loss()
    optimizer = torch.optim.Adam(si_autoencoder.parameters(), lr=1e-4)
    denoramlize = ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE)
    x = denoramlize(torch.randn(1, 3, 192, 144))
    y = denoramlize(torch.randn(1, 3, 192, 144))

    B = 1
    for t in range(B):

        # change image stats to mock kitti image

        (
            x_reconstructed,
            x_dec,
            x_pc,
            importance_map_mult_weights,
            x_quantizer_index_of_closest_center,
        ) = si_autoencoder(x=x, y=y)

        bit_cost_loss_value = loss_manager.get_bit_cost_loss(
            pc_output=x_pc,
            quantizer_closest_center_index=x_quantizer_index_of_closest_center,
            importance_map_mult_weights=importance_map_mult_weights,
            beta_factor=config.beta,
            target_bit_cost=config.H_target,
        )
        si_net_loss_value = (
            si_net_loss(x_reconstructed, x)
            if config.use_si_flag == SiNetChannelIn.WithSideInformation
            else 0
        )
        autoencoder_loss_value = Distortions._calc_dist(
            x_dec,
            x,
            distortion=config.autoencoder_loss_distortion_to_minimize,
            cast_to_int=False,
        )
        total_loss = (
            autoencoder_loss_value * (1 - config.si_loss_weight_alpha)
            + si_net_loss_value * config.si_loss_weight_alpha
            + bit_cost_loss_value
        )

        if t % 100 == 0:
            print(t, total_loss.item())

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
示例#7
0
    def __init__(self, in_channels: SiNetChannelIn):
        super().__init__()

        internal_layers = [
            DilatedResBlock(
                in_channels=32,
                out_channels=32,
                kernel_size=[3, 3],
                dilation=2 ** (i + 1),
                negative_slope=self.NEG_SLOPE)
            for i in range(self.NOF_INTERNAL_LAYERS)]

        pre_layers = [
            nn.Conv2d(
                in_channels=in_channels.value,
                out_channels=32,
                kernel_size=[3, 3],
                padding_mode="replicate",
                padding=[1, 1],
            ),
            nn.LeakyReLU(negative_slope=self.NEG_SLOPE),
            nn.BatchNorm2d(32,
                           eps=1e-03,
                           momentum=0.1,
                           affine=True,
                           track_running_stats=True,
                           ),
        ]

        post_layers = [
            nn.Conv2d(in_channels=32, out_channels=3, kernel_size=[1, 1],),
            ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE),
        ]

        # self.layers = pre_layers + post_layers + internal_layers

        self.pre_model = nn.Sequential(*pre_layers)
        self.internal_model = nn.Sequential(*internal_layers)
        self.post_model = nn.Sequential(*post_layers)
        self._weight_init()
示例#8
0
    def __init__(self, in_channels: SiNetChannelIn):
        super().__init__()

        internal_layers = [
            Conv2dDSIN(
                in_channels=32,
                out_channels=32,
                kernel_size=[3, 3],
                dilation=2 ** (i + 1),
                negative_slope=self.NEG_SLOPE)
            for i in range(self.NOF_INTERNAL_LAYERS)]

        pre_layers = [ 
            Conv2dDSIN(
                in_channels=in_channels.value,
                out_channels=32,
                kernel_size=[3, 3],
                dilation=1,
                negative_slope=self.NEG_SLOPE),
         
           ]

        post_layers = [
            Conv2dDSIN(
                in_channels=32,
                out_channels=32,
                kernel_size=[3, 3],
                dilation=1,
                negative_slope=self.NEG_SLOPE),
            nn.Conv2d(in_channels=32, out_channels=3, kernel_size=[1, 1],),
            ChangeImageStatsToKitti(direction=ChangeState.DENORMALIZE),
        ]

        # self.layers = pre_layers + post_layers + internal_layers

        self.pre_model = nn.Sequential(*pre_layers)
        self.internal_model = nn.Sequential(*internal_layers)
        self.post_model = nn.Sequential(*post_layers)
        self._weight_init()