Ejemplo n.º 1
0
    def __init__(self, args, classes=21, dataset='pascal'):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        self.base_net = EESPNet(args)  #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict = {
            'pascal': 16,
            'city': 16,
            'coco': 32,
            'greenhouse': 16,
            'ishihara': 16,
            'camvid': 16
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [
            4 * base_dec_planes, 3 * base_dec_planes, 2 * base_dec_planes,
            classes
        ]
        pyr_plane_proj = min(classes // 2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3],
                                          last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0]))
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1]))
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2]))

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()
Ejemplo n.º 2
0
class ESPNetv2Segmentation(nn.Module):
    '''
    This class defines the ESPNetv2 architecture for the Semantic Segmenation
    '''

    def __init__(self, args, classes=21, dataset='pascal'):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        self.base_net = EESPNet(args) #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict={
            'pascal': 16,
            'city': 16,
            'coco': 32
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [4*base_dec_planes, 3*base_dec_planes, 2*base_dec_planes, classes]
        pyr_plane_proj = min(classes //2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3], last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0])
                                      )
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1])
                                      )
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2])
                                      )

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()

    def upsample(self, x):
        return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)

    def init_params(self):
        '''
        Function to initialze the parameters
        '''
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def get_basenet_params(self):
        modules_base = [self.base_net]
        for i in range(len(modules_base)):
            for m in modules_base[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_segment_params(self):
        modules_seg = [self.bu_dec_l1, self.bu_dec_l2, self.bu_dec_l3, self.bu_dec_l4,
                       self.merge_enc_dec_l4, self.merge_enc_dec_l3, self.merge_enc_dec_l2,
                       self.bu_br_l4, self.bu_br_l3, self.bu_br_l2]
        for i in range(len(modules_seg)):
            for m in modules_seg[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def forward(self, x):
        '''
        :param x: Receives the input RGB image
        :return: a C-dimensional vector, C=# of classes
        '''
        x_size = (x.size(2), x.size(3))
        enc_out_l1 = self.base_net.level1(x)  # 112
        if not self.base_net.input_reinforcement:
            del x
            x = None

        enc_out_l2 = self.base_net.level2_0(enc_out_l1, x)  # 56

        enc_out_l3_0 = self.base_net.level3_0(enc_out_l2, x)  # down-sample
        for i, layer in enumerate(self.base_net.level3):
            if i == 0:
                enc_out_l3 = layer(enc_out_l3_0)
            else:
                enc_out_l3 = layer(enc_out_l3)

        enc_out_l4_0 = self.base_net.level4_0(enc_out_l3, x)  # down-sample
        for i, layer in enumerate(self.base_net.level4):
            if i == 0:
                enc_out_l4 = layer(enc_out_l4_0)
            else:
                enc_out_l4 = layer(enc_out_l4)

        # bottom-up decoding
        bu_out = self.bu_dec_l1(enc_out_l4)

        # Decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l3_proj = self.merge_enc_dec_l2(enc_out_l3)
        bu_out = enc_out_l3_proj + bu_out
        bu_out = self.bu_br_l2(bu_out)
        bu_out = self.bu_dec_l2(bu_out)

        #decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l2_proj = self.merge_enc_dec_l3(enc_out_l2)
        bu_out = enc_out_l2_proj + bu_out
        bu_out = self.bu_br_l3(bu_out)
        bu_out = self.bu_dec_l3(bu_out)

        # decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l1_proj = self.merge_enc_dec_l4(enc_out_l1)
        bu_out = enc_out_l1_proj + bu_out
        bu_out = self.bu_br_l4(bu_out)
        bu_out  = self.bu_dec_l4(bu_out)

        return F.interpolate(bu_out, size=x_size, mode='bilinear', align_corners=True)
Ejemplo n.º 3
0
    def __init__(self, args, extra_layer):
        '''
        :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
        :param s: factor that scales the number of output feature maps
        '''
        super(ESPNetv2SSD512, self).__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================

        self.basenet = EESPNet(args)

        # delete the classification layer
        del self.basenet.classifier

        # retrive the basenet configuration
        base_net_config = self.basenet.config
        config = base_net_config[:4] + [base_net_config[5]]

        # add configuration for SSD version
        config += [1024, 512, 256, 128]

        # =============================================================
        #                EXTRA LAYERS for DETECTION
        # =============================================================

        self.extra_level6 = extra_layer(config[4], config[5])
        self.extra_level7 = extra_layer(config[5], config[6])
        self.extra_level8 = extra_layer(config[6], config[7])

        self.extra_level9 = nn.Sequential(
            nn.Conv2d(config[7],
                      config[8],
                      kernel_size=3,
                      stride=2,
                      bias=False,
                      padding=1), nn.ReLU(inplace=True))

        # =============================================================
        #                EXTRA LAYERS for Bottom-up decoding
        # =============================================================

        from nn_layers.efficient_pyramid_pool import EfficientPyrPool

        in_features = config[5] + config[6]
        out_features = config[5]
        red_factor = 5
        self.bu_4x4_8x8 = EfficientPyrPool(in_planes=in_features,
                                           proj_planes=out_features //
                                           red_factor,
                                           out_planes=out_features)

        in_features = config[4] + config[5]
        out_features = config[4]
        self.bu_8x8_16x16 = EfficientPyrPool(in_planes=in_features,
                                             proj_planes=out_features //
                                             red_factor,
                                             out_planes=out_features)

        in_features = config[4] + config[3]
        out_features = config[3]
        self.bu_16x16_32x32 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        in_features = config[3] + config[2]
        out_features = config[2]
        self.bu_32x32_64x64 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        self.config = config
Ejemplo n.º 4
0
class ESPNetv2SSD512(nn.Module):
    def __init__(self, args, extra_layer):
        '''
        :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
        :param s: factor that scales the number of output feature maps
        '''
        super(ESPNetv2SSD512, self).__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================

        self.basenet = EESPNet(args)

        # delete the classification layer
        del self.basenet.classifier

        # retrive the basenet configuration
        base_net_config = self.basenet.config
        config = base_net_config[:4] + [base_net_config[5]]

        # add configuration for SSD version
        config += [1024, 512, 256, 128]

        # =============================================================
        #                EXTRA LAYERS for DETECTION
        # =============================================================

        self.extra_level6 = extra_layer(config[4], config[5])
        self.extra_level7 = extra_layer(config[5], config[6])
        self.extra_level8 = extra_layer(config[6], config[7])

        self.extra_level9 = nn.Sequential(
            nn.Conv2d(config[7],
                      config[8],
                      kernel_size=3,
                      stride=2,
                      bias=False,
                      padding=1), nn.ReLU(inplace=True))

        # =============================================================
        #                EXTRA LAYERS for Bottom-up decoding
        # =============================================================

        from nn_layers.efficient_pyramid_pool import EfficientPyrPool

        in_features = config[5] + config[6]
        out_features = config[5]
        red_factor = 5
        self.bu_4x4_8x8 = EfficientPyrPool(in_planes=in_features,
                                           proj_planes=out_features //
                                           red_factor,
                                           out_planes=out_features)

        in_features = config[4] + config[5]
        out_features = config[4]
        self.bu_8x8_16x16 = EfficientPyrPool(in_planes=in_features,
                                             proj_planes=out_features //
                                             red_factor,
                                             out_planes=out_features)

        in_features = config[4] + config[3]
        out_features = config[3]
        self.bu_16x16_32x32 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        in_features = config[3] + config[2]
        out_features = config[2]
        self.bu_32x32_64x64 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        self.config = config

    def up_sample(self, x, size):
        return F.interpolate(x,
                             size=(size[2], size[3]),
                             align_corners=True,
                             mode='bilinear')

    def forward(self, x, is_train=True):
        '''
        :param x: Receives the input RGB image
        :return: a C-dimensional vector, C=# of classes
        '''
        out_256x256 = self.basenet.level1(x)  # 112
        if not self.basenet.input_reinforcement:
            del x
            x = None

        out_128x128 = self.basenet.level2_0(out_256x256, x)  # 56

        out_64x64 = self.basenet.level3_0(out_128x128, x)  # down-sample
        for i, layer in enumerate(self.basenet.level3):
            out_64x64 = layer(out_64x64)

        # Detection network
        out_32x32 = self.basenet.level4_0(out_64x64, x)  # down-sample
        for i, layer in enumerate(self.basenet.level4):
            out_32x32 = layer(out_32x32)

        out_16x16 = self.basenet.level5_0(out_32x32, x)  # down-sample
        for i, layer in enumerate(self.basenet.level5):
            out_16x16 = layer(out_16x16)

        # Detection network's extra layers
        out_8x8 = self.extra_level6(out_16x16)
        out_4x4 = self.extra_level7(out_8x8)
        out_2x2 = self.extra_level8(out_4x4)
        out_1x1 = self.extra_level9(out_2x2)

        # bottom-up decoding
        ## 3x3 and 5x5
        out_4x4_8x8 = self.up_sample(out_4x4, out_8x8.size())
        out_4x4_8x8 = torch.cat((out_4x4_8x8, out_8x8), dim=1)
        out_8x8_epp = self.bu_4x4_8x8(out_4x4_8x8)

        ## 5x5 and 10x10
        out_8x8_16x16 = self.up_sample(out_8x8_epp, out_16x16.size())
        out_8x8_16x16 = torch.cat((out_8x8_16x16, out_16x16), dim=1)
        out_16x16_epp = self.bu_8x8_16x16(out_8x8_16x16)

        ## 10x10 and 19x19
        out_16x16_32x32 = self.up_sample(out_16x16_epp, out_32x32.size())
        out_16x16_32x32 = torch.cat((out_16x16_32x32, out_32x32), dim=1)
        out_32x32_epp = self.bu_16x16_32x32(out_16x16_32x32)

        ## 19x19 and 38x38
        out_32x32_64x64 = self.up_sample(out_32x32_epp, out_64x64.size())
        out_32x32_64x64 = torch.cat((out_32x32_64x64, out_64x64), dim=1)
        out_64x64_epp = self.bu_32x32_64x64(out_32x32_64x64)

        return out_64x64_epp, out_32x32_epp, out_16x16_epp, out_8x8_epp, out_4x4, out_2x2, out_1x1
Ejemplo n.º 5
0
class ESPNetv2SSD300(nn.Module):
    def __init__(self, args, extra_layer):
        '''
        :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
        :param s: factor that scales the number of output feature maps
        '''
        super(ESPNetv2SSD300, self).__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================

        self.basenet = EESPNet(args)

        # delete the classification layer
        del self.basenet.classifier

        # retrive the basenet configuration
        base_net_config = self.basenet.config
        config = base_net_config[:4] + [base_net_config[5]]

        # add configuration for SSD version
        config += [1024, 512, 256]

        # =============================================================
        #                EXTRA LAYERS for DETECTION
        # =============================================================

        self.extra_level6 = extra_layer(config[4], config[5])
        self.extra_level7 = extra_layer(config[5], config[6])

        self.extra_level8 = nn.Sequential(
            nn.Conv2d(config[6],
                      config[6],
                      kernel_size=3,
                      stride=2,
                      bias=False,
                      padding=1), nn.BatchNorm2d(config[6]),
            nn.ReLU(inplace=True),
            nn.Conv2d(config[6],
                      config[7],
                      kernel_size=2,
                      stride=2,
                      bias=False), nn.ReLU(inplace=True))

        # =============================================================
        #                EXTRA LAYERS for Bottom-up decoding
        # =============================================================

        from nn_layers.efficient_pyramid_pool import EfficientPyrPool

        in_features = config[5] + config[6]
        out_features = config[5]
        red_factor = 5
        self.bu_3x3_5x5 = EfficientPyrPool(in_planes=in_features,
                                           proj_planes=out_features //
                                           red_factor,
                                           out_planes=out_features)

        in_features = config[4] + config[5]
        out_features = config[4]
        self.bu_5x5_10x10 = EfficientPyrPool(in_planes=in_features,
                                             proj_planes=out_features //
                                             red_factor,
                                             out_planes=out_features)

        in_features = config[4] + config[3]
        out_features = config[3]
        self.bu_10x10_19x19 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        in_features = config[3] + config[2]
        out_features = config[2]
        self.bu_19x19_38x38 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features)

        self.config = config

    def up_sample(self, x, size):
        return F.interpolate(x,
                             size=(size[2], size[3]),
                             align_corners=True,
                             mode='bilinear')

    def forward(self, x, is_train=True):
        '''
        :param x: Receives the input RGB image
        :return: a C-dimensional vector, C=# of classes
        '''
        out_150x150 = self.basenet.level1(x)  # 112
        if not self.basenet.input_reinforcement:
            del x
            x = None

        out_75x75 = self.basenet.level2_0(out_150x150, x)  # 56

        out_38x38 = self.basenet.level3_0(out_75x75, x)  # down-sample
        for i, layer in enumerate(self.basenet.level3):
            out_38x38 = layer(out_38x38)

        # Detection network
        out_19x19 = self.basenet.level4_0(out_38x38, x)  # down-sample
        for i, layer in enumerate(self.basenet.level4):
            out_19x19 = layer(out_19x19)

        out_10x10 = self.basenet.level5_0(out_19x19, x)  # down-sample
        for i, layer in enumerate(self.basenet.level5):
            out_10x10 = layer(out_10x10)

        # Detection network's extra layers
        out_5x5 = self.extra_level6(out_10x10)
        out_3x3 = self.extra_level7(out_5x5)
        out_1x1 = self.extra_level8(out_3x3)

        # bottom-up decoding
        ## 3x3 and 5x5
        out_3x3_5x5 = self.up_sample(out_3x3, out_5x5.size())
        out_3x3_5x5 = torch.cat((out_3x3_5x5, out_5x5), dim=1)
        out_5x5_epp = self.bu_3x3_5x5(out_3x3_5x5)

        ## 5x5 and 10x10
        out_5x5_10x10 = self.up_sample(out_5x5_epp, out_10x10.size())
        out_5x5_10x10 = torch.cat((out_5x5_10x10, out_10x10), dim=1)
        out_10x10_epp = self.bu_5x5_10x10(out_5x5_10x10)

        ## 10x10 and 19x19
        out_10x10_19x19 = self.up_sample(out_10x10_epp, out_19x19.size())
        out_10x10_19x19 = torch.cat((out_10x10_19x19, out_19x19), dim=1)
        out_19x19_epp = self.bu_10x10_19x19(out_10x10_19x19)

        ## 19x19 and 38x38
        out_19x19_38x38 = self.up_sample(out_19x19_epp, out_38x38.size())
        out_19x19_38x38 = torch.cat((out_19x19_38x38, out_38x38), dim=1)
        out_38x38_epp = self.bu_19x19_38x38(out_19x19_38x38)

        return out_38x38_epp, out_19x19_epp, out_10x10_epp, out_5x5_epp, out_3x3, out_1x1
Ejemplo n.º 6
0
    def __init__(self, args, extra_layer):
        '''
        :param classes: number of classes in the dataset. Default is 1000 for the ImageNet dataset
        :param s: factor that scales the number of output feature maps
        '''
        super(ESPNetv2SSD, self).__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================

        self.basenet = EESPNet(args)
        # delete the classification layer
        del self.basenet.classifier
        # delte the last layer in level 5
        #del self.basenet.level5[4]
        #del self.basenet.level5[3]

        # retrive the basenet configuration
        base_net_config = self.basenet.config
        config = base_net_config[:4] + [base_net_config[5]]

        # add configuration for SSD version
        config += [512, 256, 128]

        # =============================================================
        #                EXTRA LAYERS for DETECTION
        # =============================================================

        self.extra_level6 = extra_layer(config[4], config[5])  #

        self.extra_level7 = extra_layer(config[5], config[6])

        self.extra_level8 = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Conv2d(config[6],
                      config[7],
                      kernel_size=1,
                      stride=1,
                      bias=False), nn.ReLU(inplace=True))

        # =============================================================
        #                EXTRA LAYERS for Bottom-up decoding
        # =============================================================

        from nn_layers.efficient_pyramid_pool import EfficientPyrPool

        in_features = config[5] + config[6]
        out_features = config[5]
        red_factor = 5
        self.bu_3x3_5x5 = EfficientPyrPool(in_planes=in_features,
                                           proj_planes=out_features //
                                           red_factor,
                                           out_planes=out_features,
                                           scales=[2.0, 1.0])

        in_features = config[4] + config[5]
        out_features = config[4]
        self.bu_5x5_10x10 = EfficientPyrPool(in_planes=in_features,
                                             proj_planes=out_features //
                                             red_factor,
                                             out_planes=out_features,
                                             scales=[2.0, 1.0, 0.5])

        in_features = config[4] + config[3]
        out_features = config[3]
        self.bu_10x10_19x19 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features,
                                               scales=[2.0, 1.0, 0.5, 0.25])

        in_features = config[3] + config[2]
        out_features = config[2]
        self.bu_19x19_38x38 = EfficientPyrPool(in_planes=in_features,
                                               proj_planes=out_features //
                                               red_factor,
                                               out_planes=out_features,
                                               scales=[2.0, 1.0, 0.5, 0.25])

        self.config = config
Ejemplo n.º 7
0
    def __init__(self, args, classes=21, dataset='pascal', dense_fuse=False, trainable_fusion=True):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        #
        # RGB
        #
        self.base_net = EESPNet(args) #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #
        # Depth
        #
        tmp_args = copy.deepcopy(args)
        tmp_args.channels = 1
        self.depth_base_net = EESPNet(tmp_args)
        del self.depth_base_net.classifier
        del self.depth_base_net.level5
        del self.depth_base_net.level5_0

        self.fusion_gate_level1 = FusionGate(nchannel=32, is_trainable=trainable_fusion)
        self.fusion_gate_level2 = FusionGate(nchannel=128, is_trainable=trainable_fusion)
        self.fusion_gate_level3 = FusionGate(nchannel=256, is_trainable=trainable_fusion)
        self.fusion_gate_level4 = FusionGate(nchannel=512, is_trainable=trainable_fusion)

        # Layer 1
#        self.depth_encoder_level1 = nn.Sequential(
#                                            CBR(nIn=1, nOut=32, kSize=3, stride=2), # Input: 3, Ouput: 16, kernel: 3
#                                            CBR(nIn=32, nOut=32, kSize=3), # Input: 3, Ouput: 16, kernel: 3
#                                      )
#
#        # Level 2
#        self.depth_encoder_level2 = nn.Sequential(
##                                            C(nIn=32, nOut=128, kSize=1), # Pixel-wise conv
##                                            CBR(nIn=128, nOut=128, kSize=3, stride=2, groups=128) # Depth-wise conv
#                                            CBR(nIn=32, nOut=128, kSize=3, stride=2),  # Downsample
#                                            CBR(nIn=128, nOut=128, kSize=3),
#                                            CBR(nIn=128, nOut=128, kSize=3) 
#                                      )
#
#        # Level 3
#        self.depth_encoder_level3 = nn.Sequential(
##                                            C(nIn=128, nOut=256, kSize=1), # Pixel-wise conv
##                                            CBR(nIn=256, nOut=256, kSize=3, groups=256)             # Depth-wise conv
#                                            CBR(nIn=128, nOut=256, kSize=3, stride=2),
#                                            CBR(nIn=256, nOut=256, kSize=3),
#                                            CBR(nIn=256, nOut=256, kSize=3),
#                                            CBR(nIn=256, nOut=256, kSize=3)
#                                             
#                                      )
#
#        # Level 4
#        self.depth_encoder_level4 = nn.Sequential(
#                                            CBR(nIn=256, nOut=512, kSize=3, stride=2), # Pixel-wise conv
#                                            CBR(nIn=512, nOut=512, kSize=3),
#                                            CBR(nIn=512, nOut=512, kSize=3),
#                                            CBR(nIn=512, nOut=512, kSize=3)
#                                      )


          # 112 L1

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict={
            'pascal': 16,
            'city': 16,
            'coco': 32,
            'greenhouse': 16,
            'ishihara': 16,
            'sun': 16,
            'camvid': 16
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [4*base_dec_planes, 3*base_dec_planes, 2*base_dec_planes, classes]
        pyr_plane_proj = min(classes //2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3], last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0])
                                      )
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1])
                                      )
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2])
                                      )

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()
        self.dense_fuse = dense_fuse
Ejemplo n.º 8
0
class ESPDNetSegmentation(nn.Module):
    '''
    This class defines the ESPDNet architecture for the Semantic Segmenation
    '''

    def __init__(self, args, classes=21, dataset='pascal', dense_fuse=False, trainable_fusion=True):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        #
        # RGB
        #
        self.base_net = EESPNet(args) #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #
        # Depth
        #
        tmp_args = copy.deepcopy(args)
        tmp_args.channels = 1
        self.depth_base_net = EESPNet(tmp_args)
        del self.depth_base_net.classifier
        del self.depth_base_net.level5
        del self.depth_base_net.level5_0

        self.fusion_gate_level1 = FusionGate(nchannel=32, is_trainable=trainable_fusion)
        self.fusion_gate_level2 = FusionGate(nchannel=128, is_trainable=trainable_fusion)
        self.fusion_gate_level3 = FusionGate(nchannel=256, is_trainable=trainable_fusion)
        self.fusion_gate_level4 = FusionGate(nchannel=512, is_trainable=trainable_fusion)

        # Layer 1
#        self.depth_encoder_level1 = nn.Sequential(
#                                            CBR(nIn=1, nOut=32, kSize=3, stride=2), # Input: 3, Ouput: 16, kernel: 3
#                                            CBR(nIn=32, nOut=32, kSize=3), # Input: 3, Ouput: 16, kernel: 3
#                                      )
#
#        # Level 2
#        self.depth_encoder_level2 = nn.Sequential(
##                                            C(nIn=32, nOut=128, kSize=1), # Pixel-wise conv
##                                            CBR(nIn=128, nOut=128, kSize=3, stride=2, groups=128) # Depth-wise conv
#                                            CBR(nIn=32, nOut=128, kSize=3, stride=2),  # Downsample
#                                            CBR(nIn=128, nOut=128, kSize=3),
#                                            CBR(nIn=128, nOut=128, kSize=3) 
#                                      )
#
#        # Level 3
#        self.depth_encoder_level3 = nn.Sequential(
##                                            C(nIn=128, nOut=256, kSize=1), # Pixel-wise conv
##                                            CBR(nIn=256, nOut=256, kSize=3, groups=256)             # Depth-wise conv
#                                            CBR(nIn=128, nOut=256, kSize=3, stride=2),
#                                            CBR(nIn=256, nOut=256, kSize=3),
#                                            CBR(nIn=256, nOut=256, kSize=3),
#                                            CBR(nIn=256, nOut=256, kSize=3)
#                                             
#                                      )
#
#        # Level 4
#        self.depth_encoder_level4 = nn.Sequential(
#                                            CBR(nIn=256, nOut=512, kSize=3, stride=2), # Pixel-wise conv
#                                            CBR(nIn=512, nOut=512, kSize=3),
#                                            CBR(nIn=512, nOut=512, kSize=3),
#                                            CBR(nIn=512, nOut=512, kSize=3)
#                                      )


          # 112 L1

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict={
            'pascal': 16,
            'city': 16,
            'coco': 32,
            'greenhouse': 16,
            'ishihara': 16,
            'sun': 16,
            'camvid': 16
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [4*base_dec_planes, 3*base_dec_planes, 2*base_dec_planes, classes]
        pyr_plane_proj = min(classes //2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2], proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3], last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0])
                                      )
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1])
                                      )
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2])
                                      )

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()
        self.dense_fuse = dense_fuse

    def upsample(self, x):
        return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)

    def init_params(self):
        '''
        Function to initialze the parameters
        '''
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def get_basenet_params(self):
        modules_base = [self.base_net]
        for i in range(len(modules_base)):
            for m in modules_base[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_depth_encoder_params(self):
        modules_depth = [self.depth_base_net]
        for i in range(len(modules_depth)):
            for m in modules_depth[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_segment_params(self):
        modules_seg = [self.bu_dec_l1, self.bu_dec_l2, self.bu_dec_l3, self.bu_dec_l4,
                       self.merge_enc_dec_l4, self.merge_enc_dec_l3, self.merge_enc_dec_l2,
                       self.bu_br_l4, self.bu_br_l3, self.bu_br_l2]
        for i in range(len(modules_seg)):
            for m in modules_seg[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def forward(self, x, x_d=None):
        '''
        :param x: Receives the input RGB image
        :param x_d: Receives the input Depth image
        :return: a C-dimensional vector, C=# of classes
        '''

        if x_d is not None:
            pass
#            print(x.size(), x_d.size())

        x_size = (x.size(2), x.size(3)) # Width and height

        # 
        # First conv
        #
        enc_out_l1 = self.base_net.level1(x)  # 112
        if not self.base_net.input_reinforcement:
            del x
            x = None

        if x_d is not None:
            d_enc_out_l1 = self.depth_base_net.level1(x_d) # Depth
    
            # Fusion level 1
            # enc_out_l1 += d_enc_out_l1
            enc_out_l1 = self.fusion_gate_level1(enc_out_l1, d_enc_out_l1)

        # 
        # Second layer (Strided EESP)
        #
        enc_out_l2 = self.base_net.level2_0(enc_out_l1, x)  # 56
        if x_d is not None:
            d_enc_out_l2 = self.depth_base_net.level2_0(d_enc_out_l1)
    
            # Fusion level 2
            # enc_out_l2 += d_enc_out_l2
            enc_out_l2 = self.fusion_gate_level2(enc_out_l2, d_enc_out_l2)

        # 
        # Third layer 1 (Strided EESP)
        #
        enc_out_l3_0 = self.base_net.level3_0(enc_out_l2, x)  # down-sample -> 28
        if x_d is not None: 
            d_enc_out_l3_0 = self.depth_base_net.level3_0(d_enc_out_l2)  # down-sample -> 28
        # 
        # EESP
        #
        for i, (layer, dlayer) in enumerate(zip(self.base_net.level3, self.depth_base_net.level3)):
            if i == 0:
                enc_out_l3 = layer(enc_out_l3_0)
                if x_d is not None: 
                    d_enc_out_l3 = dlayer(d_enc_out_l3_0)
                    if self.dense_fuse:
                        # enc_out_l3 += d_enc_out_l3
                        enc_out_l3 = self.fusion_gate_level3(enc_out_l3, d_enc_out_l3)
            else:
                enc_out_l3 = dlayer(enc_out_l3)
                if x_d is not None: 
                    d_enc_out_l3 = dlayer(d_enc_out_l3)
                    if self.dense_fuse:
                        # enc_out_l3 += d_enc_out_l3
                        enc_out_l3 = self.fusion_gate_level3(enc_out_l3, d_enc_out_l3)

        if x_d is not None and not self.dense_fuse:
            # Fusion level 3
            # enc_out_l3 += d_enc_out_l3
            enc_out_l3 = self.fusion_gate_level3(enc_out_l3, d_enc_out_l3)

        # 
        # Forth layer 1 (Strided EESP)
        #
        enc_out_l4_0 = self.base_net.level4_0(enc_out_l3, x)  # down-sample -> 14
        if x_d is not None: 
            d_enc_out_l4_0 = self.depth_base_net.level4_0(d_enc_out_l3)  # down-sample -> 14
        # 
        # EESP
        #
        for i, (layer, dlayer) in enumerate(zip(self.base_net.level4, self.depth_base_net.level4)):
            if i == 0:
                enc_out_l4 = layer(enc_out_l4_0)
                if x_d is not None: 
                    d_enc_out_l4 = dlayer(d_enc_out_l4_0)
                    if self.dense_fuse:
                        # enc_out_l4 += d_enc_out_l4
                        enc_out_l4 = self.fusion_gate_level4(enc_out_l4, d_enc_out_l4)
            else:
                enc_out_l4 = layer(enc_out_l4)
                if x_d is not None: 
                    d_enc_out_l4 = dlayer(d_enc_out_l4)
                    if self.dense_fuse:
                        # enc_out_l4 += d_enc_out_l4
                        enc_out_l4 = self.fusion_gate_level4(enc_out_l4, d_enc_out_l4)

        if x_d is not None and not self.dense_fuse:
            # Fusion level 4
            # enc_out_l4 += d_enc_out_l4
            enc_out_l4 = self.fusion_gate_level4(enc_out_l4, d_enc_out_l4)


        # *** 5th layer is for and classification and removed for segmentation ***

        # bottom-up decoding
        bu_out = self.bu_dec_l1(enc_out_l4)

        # Decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l3_proj = self.merge_enc_dec_l2(enc_out_l3)
        bu_out = enc_out_l3_proj + bu_out
        bu_out = self.bu_br_l2(bu_out)
        bu_out = self.bu_dec_l2(bu_out)

        #decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l2_proj = self.merge_enc_dec_l3(enc_out_l2)
        bu_out = enc_out_l2_proj + bu_out
        bu_out = self.bu_br_l3(bu_out)
        bu_out = self.bu_dec_l3(bu_out)

        # decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l1_proj = self.merge_enc_dec_l4(enc_out_l1)
        bu_out = enc_out_l1_proj + bu_out
        bu_out = self.bu_br_l4(bu_out)
        bu_out  = self.bu_dec_l4(bu_out)

        return F.interpolate(bu_out, size=x_size, mode='bilinear', align_corners=True)
    def __init__(self,
                 args,
                 classes=21,
                 dataset='pascal',
                 dense_fuse=False,
                 trainable_fusion=True,
                 aux_layer=2,
                 fix_pyr_plane_proj=False,
                 in_channels=32,
                 spatial=True):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        #
        # RGB
        #
        self.base_net = EESPNet(args)  #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #
        # Depth
        #
        tmp_args = copy.deepcopy(args)
        tmp_args.channels = 1
        self.depth_base_net = EESPNet(tmp_args)
        del self.depth_base_net.classifier
        del self.depth_base_net.level5
        del self.depth_base_net.level5_0

        self.fusion_gate_level1 = FusionGate(nchannel=32,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level2 = FusionGate(nchannel=128,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level3 = FusionGate(nchannel=256,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level4 = FusionGate(nchannel=512,
                                             is_trainable=trainable_fusion)

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict = {
            'pascal': 16,
            'city': 16,
            'coco': 32,
            'greenhouse': 16,
            'ishihara': 16,
            'sun': 16,
            'camvid': 16,
            'forest': 16
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [
            4 * base_dec_planes, 3 * base_dec_planes, 2 * base_dec_planes,
            classes
        ]
        if fix_pyr_plane_proj:
            pyr_plane_proj = base_dec_planes
        else:
            pyr_plane_proj = min(classes // 2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3],
                                          last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0]))
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1]))
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2]))

        # Auxiliary branch
        self.aux_layer = aux_layer
        if aux_layer >= 0 and aux_layer < 3:
            self.aux_decoder = EfficientPyrPool(
                in_planes=dec_planes[aux_layer],
                proj_planes=pyr_plane_proj,
                out_planes=dec_planes[3],
                last_layer_br=False)

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()
        self.dense_fuse = dense_fuse

        #
        # Traversability module
        #
        self.trav_module = LabelProbEstimator(in_channels=in_channels,
                                              spatial=spatial)

        # Register
        self.activation = {}

        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output.detach()

            return hook

        self.bu_dec_l4.merge_layer[2].register_forward_hook(
            get_activation('output_main'))
        self.aux_decoder.merge_layer[2].register_forward_hook(
            get_activation('output_aux'))
class ESPDNetUEwithTraversability(nn.Module):
    '''
    This class defines the ESPDNet architecture for the Semantic Segmenation with uncertainty estimation

    aux_layer : A number of the layer from which the auxiliary branch is made.
      0: bu_dec_l1, 1: bu_dec_l2, 2: bu_dec_l3
    '''
    def __init__(self,
                 args,
                 classes=21,
                 dataset='pascal',
                 dense_fuse=False,
                 trainable_fusion=True,
                 aux_layer=2,
                 fix_pyr_plane_proj=False,
                 in_channels=32,
                 spatial=True):
        super().__init__()

        # =============================================================
        #                       BASE NETWORK
        # =============================================================
        #
        # RGB
        #
        self.base_net = EESPNet(args)  #imagenet model
        del self.base_net.classifier
        del self.base_net.level5
        del self.base_net.level5_0
        config = self.base_net.config

        #
        # Depth
        #
        tmp_args = copy.deepcopy(args)
        tmp_args.channels = 1
        self.depth_base_net = EESPNet(tmp_args)
        del self.depth_base_net.classifier
        del self.depth_base_net.level5
        del self.depth_base_net.level5_0

        self.fusion_gate_level1 = FusionGate(nchannel=32,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level2 = FusionGate(nchannel=128,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level3 = FusionGate(nchannel=256,
                                             is_trainable=trainable_fusion)
        self.fusion_gate_level4 = FusionGate(nchannel=512,
                                             is_trainable=trainable_fusion)

        #=============================================================
        #                   SEGMENTATION NETWORK
        #=============================================================
        dec_feat_dict = {
            'pascal': 16,
            'city': 16,
            'coco': 32,
            'greenhouse': 16,
            'ishihara': 16,
            'sun': 16,
            'camvid': 16,
            'forest': 16
        }
        base_dec_planes = dec_feat_dict[dataset]
        dec_planes = [
            4 * base_dec_planes, 3 * base_dec_planes, 2 * base_dec_planes,
            classes
        ]
        if fix_pyr_plane_proj:
            pyr_plane_proj = base_dec_planes
        else:
            pyr_plane_proj = min(classes // 2, base_dec_planes)

        self.bu_dec_l1 = EfficientPyrPool(in_planes=config[3],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[0])
        self.bu_dec_l2 = EfficientPyrPool(in_planes=dec_planes[0],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[1])
        self.bu_dec_l3 = EfficientPyrPool(in_planes=dec_planes[1],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[2])
        self.bu_dec_l4 = EfficientPyrPool(in_planes=dec_planes[2],
                                          proj_planes=pyr_plane_proj,
                                          out_planes=dec_planes[3],
                                          last_layer_br=False)

        self.merge_enc_dec_l2 = EfficientPWConv(config[2], dec_planes[0])
        self.merge_enc_dec_l3 = EfficientPWConv(config[1], dec_planes[1])
        self.merge_enc_dec_l4 = EfficientPWConv(config[0], dec_planes[2])

        self.bu_br_l2 = nn.Sequential(nn.BatchNorm2d(dec_planes[0]),
                                      nn.PReLU(dec_planes[0]))
        self.bu_br_l3 = nn.Sequential(nn.BatchNorm2d(dec_planes[1]),
                                      nn.PReLU(dec_planes[1]))
        self.bu_br_l4 = nn.Sequential(nn.BatchNorm2d(dec_planes[2]),
                                      nn.PReLU(dec_planes[2]))

        # Auxiliary branch
        self.aux_layer = aux_layer
        if aux_layer >= 0 and aux_layer < 3:
            self.aux_decoder = EfficientPyrPool(
                in_planes=dec_planes[aux_layer],
                proj_planes=pyr_plane_proj,
                out_planes=dec_planes[3],
                last_layer_br=False)

        #self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.init_params()
        self.dense_fuse = dense_fuse

        #
        # Traversability module
        #
        self.trav_module = LabelProbEstimator(in_channels=in_channels,
                                              spatial=spatial)

        # Register
        self.activation = {}

        def get_activation(name):
            def hook(model, input, output):
                self.activation[name] = output.detach()

            return hook

        self.bu_dec_l4.merge_layer[2].register_forward_hook(
            get_activation('output_main'))
        self.aux_decoder.merge_layer[2].register_forward_hook(
            get_activation('output_aux'))

    def upsample(self, x):
        return F.interpolate(x,
                             scale_factor=2,
                             mode='bilinear',
                             align_corners=True)

    def init_params(self):
        '''
        Function to initialze the parameters
        '''
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def get_basenet_params(self):
        modules_base = [self.base_net]
        for i in range(len(modules_base)):
            for m in modules_base[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(
                        m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_depth_encoder_params(self):
        modules_depth = [self.depth_base_net]
        for i in range(len(modules_depth)):
            for m in modules_depth[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(
                        m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_segment_params(self):
        modules_seg = [
            self.bu_dec_l1, self.bu_dec_l2, self.bu_dec_l3, self.bu_dec_l4,
            self.merge_enc_dec_l4, self.merge_enc_dec_l3,
            self.merge_enc_dec_l2, self.bu_br_l4, self.bu_br_l3, self.bu_br_l2
        ]
        for i in range(len(modules_seg)):
            for m in modules_seg[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(
                        m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_classification_layer_params(self):
        modules_seg = [
            self.bu_dec_l1, self.bu_dec_l2, self.bu_dec_l3, self.bu_dec_l4,
            self.merge_enc_dec_l4, self.merge_enc_dec_l3,
            self.merge_enc_dec_l2, self.bu_br_l4, self.bu_br_l3, self.bu_br_l2
        ]
        for i in range(len(modules_seg)):
            for m in modules_seg[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(
                        m[1], nn.BatchNorm2d) or isinstance(m[1], nn.PReLU):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def forward(self, x, x_d=None):
        '''
        :param x: Receives the input RGB image
        :param x_d: Receives the input Depth image
        :return: a C-dimensional vector, C=# of classes
        '''

        if x_d is not None:
            pass

        x_size = (x.size(2), x.size(3))  # Width and height

        #
        # First conv
        #
        enc_out_l1 = self.base_net.level1(x)  # 112
        if not self.base_net.input_reinforcement:
            del x
            x = None

        if x_d is not None:
            d_enc_out_l1 = self.depth_base_net.level1(x_d)  # Depth

            # Fusion level 1
            # enc_out_l1 += d_enc_out_l1
            enc_out_l1 = self.fusion_gate_level1(enc_out_l1, d_enc_out_l1)

        #
        # Second layer (Strided EESP)
        #
        enc_out_l2 = self.base_net.level2_0(enc_out_l1, x)  # 56
        if x_d is not None:
            d_enc_out_l2 = self.depth_base_net.level2_0(d_enc_out_l1)

            # Fusion level 2
            # enc_out_l2 += d_enc_out_l2
            enc_out_l2 = self.fusion_gate_level2(enc_out_l2, d_enc_out_l2)

        #
        # Third layer 1 (Strided EESP)
        #
        enc_out_l3_0 = self.base_net.level3_0(enc_out_l2,
                                              x)  # down-sample -> 28
        if x_d is not None:
            d_enc_out_l3_0 = self.depth_base_net.level3_0(
                d_enc_out_l2)  # down-sample -> 28
        #
        # EESP
        #
        for i, (layer, dlayer) in enumerate(
                zip(self.base_net.level3, self.depth_base_net.level3)):
            if i == 0:
                enc_out_l3 = layer(enc_out_l3_0)
                if x_d is not None:
                    d_enc_out_l3 = dlayer(d_enc_out_l3_0)
                    if self.dense_fuse:
                        # enc_out_l3 += d_enc_out_l3
                        enc_out_l3 = self.fusion_gate_level3(
                            enc_out_l3, d_enc_out_l3)
            else:
                enc_out_l3 = dlayer(enc_out_l3)
                if x_d is not None:
                    d_enc_out_l3 = dlayer(d_enc_out_l3)
                    if self.dense_fuse:
                        # enc_out_l3 += d_enc_out_l3
                        enc_out_l3 = self.fusion_gate_level3(
                            enc_out_l3, d_enc_out_l3)

        if x_d is not None and not self.dense_fuse:
            # Fusion level 3
            # enc_out_l3 += d_enc_out_l3
            enc_out_l3 = self.fusion_gate_level3(enc_out_l3, d_enc_out_l3)

        #
        # Forth layer 1 (Strided EESP)
        #
        enc_out_l4_0 = self.base_net.level4_0(enc_out_l3,
                                              x)  # down-sample -> 14
        if x_d is not None:
            d_enc_out_l4_0 = self.depth_base_net.level4_0(
                d_enc_out_l3)  # down-sample -> 14
        #
        # EESP
        #
        for i, (layer, dlayer) in enumerate(
                zip(self.base_net.level4, self.depth_base_net.level4)):
            if i == 0:
                enc_out_l4 = layer(enc_out_l4_0)
                if x_d is not None:
                    d_enc_out_l4 = dlayer(d_enc_out_l4_0)
                    if self.dense_fuse:
                        # enc_out_l4 += d_enc_out_l4
                        enc_out_l4 = self.fusion_gate_level4(
                            enc_out_l4, d_enc_out_l4)
            else:
                enc_out_l4 = layer(enc_out_l4)
                if x_d is not None:
                    d_enc_out_l4 = dlayer(d_enc_out_l4)
                    if self.dense_fuse:
                        # enc_out_l4 += d_enc_out_l4
                        enc_out_l4 = self.fusion_gate_level4(
                            enc_out_l4, d_enc_out_l4)

        if x_d is not None and not self.dense_fuse:
            # Fusion level 4
            # enc_out_l4 += d_enc_out_l4
            enc_out_l4 = self.fusion_gate_level4(enc_out_l4, d_enc_out_l4)

        # *** 5th layer is for and classification and removed for segmentation ***

        # bottom-up decoding
        bu_out = self.bu_dec_l1(enc_out_l4)
        if self.aux_layer == 0:
            aux_out = self.aux_decoder(bu_out)

        # Decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l3_proj = self.merge_enc_dec_l2(enc_out_l3)
        bu_out = enc_out_l3_proj + bu_out
        bu_out = self.bu_br_l2(bu_out)
        bu_out = self.bu_dec_l2(bu_out)
        if self.aux_layer == 1:
            aux_out = self.aux_decoder(bu_out)

        #decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l2_proj = self.merge_enc_dec_l3(enc_out_l2)
        bu_out = enc_out_l2_proj + bu_out
        bu_out = self.bu_br_l3(bu_out)
        bu_out = self.bu_dec_l3(bu_out)
        if self.aux_layer == 2:
            aux_out = self.aux_decoder(bu_out)

        # decoding block
        bu_out = self.upsample(bu_out)
        enc_out_l1_proj = self.merge_enc_dec_l4(enc_out_l1)
        bu_out = enc_out_l1_proj + bu_out
        bu_out = self.bu_br_l4(bu_out)
        bu_out = self.bu_dec_l4(bu_out)

        #
        # Traversability module
        #

        main_feature = F.interpolate(self.activation['output_main'],
                                     size=x_size,
                                     mode='bilinear')
        aux_feature = F.interpolate(self.activation['output_aux'],
                                    size=x_size,
                                    mode='bilinear')

        feature = torch.cat((main_feature, aux_feature), dim=1)

        prob_output = self.trav_module(feature)

        # Return the outputs as a tuple
        return (F.interpolate(bu_out,
                              size=x_size,
                              mode='bilinear',
                              align_corners=True),
                F.interpolate(aux_out,
                              size=x_size,
                              mode='bilinear',
                              align_corners=True), prob_output)