예제 #1
0
파일: FPT.py 프로젝트: zhwzhong/FPT-1
    def __init__(self, feature_dim, with_norm='none', upsample_method='bilinear'):
        super(FPT, self).__init__()
        self.feature_dim = feature_dim
        assert upsample_method in ['nearest', 'bilinear']
        def interpolate(input):
            return F.interpolate(input, scale_factor=2, mode=upsample_method, align_corners=False if upsample_method == 'bilinear' else None)
        self.fpn_upsample = interpolate
        assert with_norm in ['group_norm', 'batch_norm', 'none']
        if with_norm == 'batch_norm':
            norm = nn.BatchNorm2d
        elif with_norm == 'group_norm':
            def group_norm(num_channels):
                return nn.GroupNorm(32, num_channels)
            norm = group_norm
        self.st_p5 = SelfTrans(n_head = 1, n_mix = 2, d_model = feature_dim, d_k= feature_dim, d_v= feature_dim)
        self.st_p4 = SelfTrans(n_head = 1, n_mix = 2, d_model = feature_dim, d_k= feature_dim, d_v= feature_dim)
        self.st_p3 = SelfTrans(n_head = 1, n_mix = 2, d_model = feature_dim, d_k= feature_dim, d_v= feature_dim)
        self.st_p2 = SelfTrans(n_head = 1, n_mix = 2, d_model = feature_dim, d_k= feature_dim, d_v= feature_dim)
        
        self.gt_p4_p5 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        self.gt_p3_p4 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        self.gt_p3_p5 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        self.gt_p2_p3 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        self.gt_p2_p4 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        self.gt_p2_p5 = GroundTrans(in_channels=feature_dim, inter_channels=None, mode='dot', dimension=2, bn_layer=True)
        
        self.rt_p5_p4 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        self.rt_p5_p3 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        self.rt_p5_p2 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        self.rt_p4_p3 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        self.rt_p4_p2 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        self.rt_p3_p2 = RenderTrans(channels_high=feature_dim, channels_low=feature_dim, upsample=False)
        drop_block = DropBlock2D(block_size=3, drop_prob=0.2)
        
        if with_norm != 'none':
            self.fpn_p5_1x1 = nn.Sequential(*[nn.Conv2d(2048, feature_dim, 1, bias=False), norm(feature_dim)])
            self.fpn_p4_1x1 = nn.Sequential(*[nn.Conv2d(1024, feature_dim, 1, bias=False), norm(feature_dim)])
            self.fpn_p3_1x1 = nn.Sequential(*[nn.Conv2d(512, feature_dim, 1, bias=False), norm(feature_dim)])
            self.fpn_p2_1x1 = nn.Sequential(*[nn.Conv2d(256, feature_dim, 1, bias=False), norm(feature_dim)])
            
            self.fpt_p5 = nn.Sequential(*[nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1, bias=False), norm(feature_dim)])
            self.fpt_p4 = nn.Sequential(*[nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1, bias=False), norm(feature_dim)])
            self.fpt_p3 = nn.Sequential(*[nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1, bias=False), norm(feature_dim)])
            self.fpt_p2 = nn.Sequential(*[nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1, bias=False), norm(feature_dim)])
        else:
            self.fpn_p5_1x1 = nn.Conv2d(2048, feature_dim, 1)
            self.fpn_p4_1x1 = nn.Conv2d(1024, feature_dim, 1)
            self.fpn_p3_1x1 = nn.Conv2d(512, feature_dim, 1)
            self.fpn_p2_1x1 = nn.Conv2d(256, feature_dim, 1)
            
            self.fpt_p5 = nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1)
            self.fpt_p4 = nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1)
            self.fpt_p3 = nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1)
            self.fpt_p2 = nn.Conv2d(feature_dim*5, feature_dim, 3, padding=1)

        self.initialize()
예제 #2
0
파일: FPN.py 프로젝트: zcl912/FPT
    def __init__(self, dim_in_top, dim_in_lateral):
        super().__init__()
        self.dim_in_top = dim_in_top
        self.dim_in_lateral = dim_in_lateral
        self.dim_out = dim_in_top
        if cfg.FPN.USE_GN:
            self.conv_lateral = nn.Sequential(
                nn.Conv2d(dim_in_lateral, self.dim_out, 3, 1, 1, bias=False),
                nn.GroupNorm(net_utils.get_group_gn(self.dim_out),
                             self.dim_out,
                             eps=cfg.GROUP_NORM.EPSILON),
                nn.Conv2d(dim_in_lateral, self.dim_out, 3, 1, 1, bias=False),
                nn.ReLU(inplace=True))
        else:
            self.conv_lateral = nn.Sequential(
                nn.Conv2d(dim_in_lateral, self.dim_out, 3, 1, 1, bias=False),
                nn.Conv2d(dim_in_lateral, self.dim_out, 3, 1, 1, bias=False),
                nn.ReLU(inplace=True))

        self._init_weights()
        self.st = SelfTrans(n_head=1,
                            n_mix=4,
                            d_model=cfg.FPN.DIM,
                            d_k=cfg.FPN.DIM,
                            d_v=cfg.FPN.DIM)
        self.gt = GroundTrans(in_channels=cfg.FPN.DIM,
                              inter_channels=None,
                              mode='dot',
                              dimension=2,
                              bn_layer=True)
예제 #3
0
파일: FPN.py 프로젝트: zcl912/FPT
    def __init__(self,
                 conv_body_func,
                 fpn_level_info,
                 P2only=False,
                 fpt_rendering=False):
        super().__init__()
        self.fpn_level_info = fpn_level_info
        self.P2only = P2only
        self.fpt_rendering = fpt_rendering
        self.st = SelfTrans(n_head=1,
                            n_mix=4,
                            d_model=cfg.FPN.DIM,
                            d_k=cfg.FPN.DIM,
                            d_v=cfg.FPN.DIM)
        self.rt = RenderTrans(channels_high=cfg.FPN.DIM,
                              channels_low=cfg.FPN.DIM,
                              upsample=False)
        self.dim_out = fpn_dim = cfg.FPN.DIM
        min_level, max_level = get_min_max_levels()
        self.num_backbone_stages = len(fpn_level_info.blobs) - (min_level - 2)
        fpn_dim_lateral = fpn_level_info.dims
        self.spatial_scale = []

        self.conv_top = nn.Conv2d(fpn_dim_lateral[0], fpn_dim, 1, 1, 0)
        if cfg.FPN.USE_GN:
            self.conv_top = nn.Sequential(
                nn.Conv2d(fpn_dim_lateral[0], fpn_dim, 1, 1, 0, bias=False),
                nn.GroupNorm(net_utils.get_group_gn(fpn_dim),
                             fpn_dim,
                             eps=cfg.GROUP_NORM.EPSILON))
        else:
            self.conv_top = nn.Conv2d(fpn_dim_lateral[0], fpn_dim, 1, 1, 0)

        self.ground_lateral_modules = nn.ModuleList()
        self.posthoc_modules = nn.ModuleList()

        for i in range(self.num_backbone_stages - 1):
            self.ground_lateral_modules.append(
                ground_lateral_module(fpn_dim, fpn_dim_lateral[i + 1]))

        for i in range(self.num_backbone_stages):
            if cfg.FPN.USE_GN:
                self.posthoc_modules.append(
                    nn.Sequential(
                        nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=False),
                        nn.GroupNorm(net_utils.get_group_gn(fpn_dim),
                                     fpn_dim,
                                     eps=cfg.GROUP_NORM.EPSILON),
                        nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=False),
                        nn.ReLU(inplace=True)))
            else:
                self.posthoc_modules.append(
                    nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=False),
                    nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=False),
                    nn.ReLU(inplace=True))

            self.spatial_scale.append(fpn_level_info.spatial_scales[i])

        if self.fpt_rendering:
            self.fpt_rendering_conv1_modules = nn.ModuleList()
            self.fpt_rendering_conv2_modules = nn.ModuleList()

            for i in range(self.num_backbone_stages - 1):
                if cfg.FPN.USE_GN:
                    self.fpt_rendering_conv1_modules.append(
                        nn.Sequential(
                            nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1, bias=True),
                            nn.GroupNorm(net_utils.get_group_gn(fpn_dim),
                                         fpn_dim,
                                         eps=cfg.GROUP_NORM.EPSILON),
                            nn.ReLU(inplace=True)))
                    self.fpt_rendering_conv2_modules.append(
                        nn.Sequential(
                            nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1, bias=True),
                            nn.GroupNorm(net_utils.get_group_gn(fpn_dim),
                                         fpn_dim,
                                         eps=cfg.GROUP_NORM.EPSILON),
                            nn.ReLU(inplace=True)))
                else:
                    self.fpt_rendering_conv1_modules.append(
                        nn.Conv2d(fpn_dim, fpn_dim, 3, 2, 1))
                    self.fpt_rendering_conv2_modules.append(
                        nn.Conv2d(fpn_dim, fpn_dim, 3, 1, 1))

        if not cfg.FPN.EXTRA_CONV_LEVELS and max_level == 6:
            self.maxpool_p6 = nn.MaxPool2d(kernel_size=1, stride=2, padding=0)
            self.spatial_scale.insert(0, self.spatial_scale[0] * 0.5)

        if cfg.FPN.EXTRA_CONV_LEVELS and max_level > 5:
            self.extra_pyramid_modules = nn.ModuleList()
            dim_in = fpn_level_info.dims[0]
            for i in range(6, max_level + 1):
                self.extra_pyramid_modules(nn.Conv2d(dim_in, fpn_dim, 3, 2, 1))
                dim_in = fpn_dim
                self.spatial_scale.insert(0, self.spatial_scale[0] * 0.5)

        if self.P2only:
            self.spatial_scale = self.spatial_scale[-1]

        self._init_weights()

        self.conv_body = conv_body_func()  # e.g resnet