コード例 #1
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__
        self.backbone = backbone(config.model)

        if hasattr(self.config.model, 'backbone_lr_ratio'):
            backbone_lr_raio = self.config.model.backbone_lr_ratio
            if backbone_lr_raio == 0:
                freeze_layer(self.backbone)

        self.upsample_type = self.config.model.upsample_type
        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
#        self.midnet_type = self.config.model.midnet_type
        self.midnet_pool_sizes = self.config.model.midnet_pool_sizes
        self.midnet_scale = self.config.model.midnet_scale

        self.midnet_in_channels = self.backbone.get_feature_map_channel(
            self.upsample_layer)
        self.midnet_out_channels = self.config.model.midnet_out_channels
        self.midnet_out_size = self.backbone.get_feature_map_size(
            self.upsample_layer, self.input_shape[0:2])

        self.midnet = transform_psp(self.midnet_pool_sizes,
                                    self.midnet_scale,
                                    self.midnet_in_channels,
                                    self.midnet_out_channels,
                                    self.midnet_out_size)
        
        self.dict_number=self.config.model.dict_number
        assert self.dict_number>=self.class_number,'dict number %d should greadter than class number %d'%(self.dict_number,self.class_number)
        
        # psp net will output channels with 2*self.midnet_out_channels
        if self.upsample_type == 'duc':
            r = 2**self.upsample_layer
            self.seg_decoder = upsample_duc(
                2*self.midnet_out_channels, self.dict_number, r)
        elif self.upsample_type == 'bilinear':
            self.seg_decoder = upsample_bilinear(
                2*self.midnet_out_channels, self.dict_number, self.input_shape[0:2])
        else:
            assert False, 'unknown upsample type %s' % self.upsample_type
            
        
        self.dict_length=self.config.model.dict_length
        assert self.dict_length>=self.class_number,'dict length %d should greadter than class number %d'%(self.dict_length,self.class_number)
        self.dict_net=transform_dict(self.dict_number,self.dict_length)
        self.dict_conv=TN.Conv2d(in_channels=self.dict_length,
                                 out_channels=self.class_number,
                                 kernel_size=1)
        self.dict_sig=TN.Sigmoid()
コード例 #2
0
ファイル: psp_global.py プロジェクト: nowrin0102/torchseg
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__
        self.backbone = backbone(config.model)

        if hasattr(self.config.model, 'backbone_lr_ratio'):
            backbone_lr_raio = self.config.model.backbone_lr_ratio
            if backbone_lr_raio == 0:
                freeze_layer(self.backbone)

        self.upsample_type = self.config.model.upsample_type
        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
        #        self.midnet_type = self.config.model.midnet_type
        self.midnet_pool_sizes = self.config.model.midnet_pool_sizes
        self.midnet_scale = self.config.model.midnet_scale

        self.midnet_in_channels = self.backbone.get_feature_map_channel(
            self.upsample_layer)
        self.midnet_out_channels = self.config.model.midnet_out_channels
        self.midnet_out_size = self.backbone.get_feature_map_size(
            self.upsample_layer, self.input_shape[0:2])

        self.midnet = transform_psp(self.midnet_pool_sizes, self.midnet_scale,
                                    self.midnet_in_channels,
                                    self.midnet_out_channels,
                                    self.midnet_out_size)

        # psp net will output channels with 2*self.midnet_out_channels
        if self.upsample_type == 'duc':
            r = 2**self.upsample_layer
            self.seg_decoder = upsample_duc(2 * self.midnet_out_channels,
                                            self.class_number, r)
        elif self.upsample_type == 'bilinear':
            self.seg_decoder = upsample_bilinear(2 * self.midnet_out_channels,
                                                 self.class_number,
                                                 self.input_shape[0:2])
        else:
            assert False, 'unknown upsample type %s' % self.upsample_type

        self.gnet_dilation_sizes = self.config.model.gnet_dilation_sizes
        self.global_decoder = transform_global(self.gnet_dilation_sizes,
                                               self.class_number)
コード例 #3
0
ファイル: merge_seg.py プロジェクト: nowrin0102/torchseg
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__
        self.backbone = backbone(config.model)

        if hasattr(self.config.model, 'backbone_freeze'):
            if self.config.model.backbone_freeze:
                print('freeze backbone weights' + '*' * 30)
                freeze_layer(self.backbone)

        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
        self.ignore_index = self.config.dataset.ignore_index
        self.edge_class_num = self.config.dataset.edge_class_num

        self.midnet_input_shape = self.backbone.get_output_shape(
            self.upsample_layer, self.input_shape)
        self.midnet_out_channels = 2 * self.midnet_input_shape[1]

        self.midnet = get_midnet(self.config, self.midnet_input_shape,
                                 self.midnet_out_channels)

        # out feature channels 512
        self.branch_edge = get_suffix_net(self.config,
                                          self.midnet_out_channels,
                                          self.edge_class_num)
        # out feature channels 512
        self.branch_seg = get_suffix_net(self.config, self.midnet_out_channels,
                                         self.class_number)
        # input=concat(512,512)
        self.feature_conv = conv_bn_relu(in_channels=512 + 512,
                                         out_channels=512,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)

        self.seg = get_suffix_net(self.config, 512, self.class_number)
コード例 #4
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__
        self.backbone = backbone(config.model)
        if hasattr(self.config.model, 'backbone_freeze'):
            if self.config.model.backbone_freeze:
                print('freeze backbone weights' + '*' * 30)
                freeze_layer(self.backbone)

        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
        self.ignore_index = self.config.dataset.ignore_index
        self.edge_class_num = self.config.dataset.edge_class_num

        self.midnet_input_shape = self.backbone.get_output_shape(
            self.upsample_layer, self.input_shape)
        self.midnet_out_channels = 2 * self.midnet_input_shape[1]

        self.midnet = get_midnet(self.config, self.midnet_input_shape,
                                 self.midnet_out_channels)

        if hasattr(self.config.model, 'edge_seg_order'):
            self.edge_seg_order = self.config.model.edge_seg_order
            print('the edge and seg order is %s' % self.edge_seg_order,
                  '*' * 30)
            assert self.edge_seg_order in [
                'same', 'first', 'later'
            ], 'unexcepted edge seg order %s' % self.edge_seg_order
        else:
            self.edge_seg_order = 'same'

        if self.edge_seg_order == 'same':
            self.seg_decoder = get_suffix_net(self.config,
                                              self.midnet_out_channels,
                                              self.class_number)
            self.edge_decoder = get_suffix_net(self.config,
                                               self.midnet_out_channels,
                                               self.edge_class_num)
        elif self.edge_seg_order == 'later':
            self.seg_decoder = get_suffix_net(self.config,
                                              self.midnet_out_channels,
                                              self.class_number)
            self.edge_decoder = get_suffix_net(self.config, 512,
                                               self.edge_class_num)
        else:
            self.edge_decoder = get_suffix_net(self.config,
                                               self.midnet_out_channels,
                                               self.edge_class_num)
            self.feature_conv = conv_bn_relu(
                in_channels=self.midnet_out_channels,
                out_channels=512,
                kernel_size=1,
                stride=1,
                padding=0)
            # the input is torch.cat[self.edge_class_num,self.class_number]
            self.seg_conv = conv_bn_relu(in_channels=512 + 512,
                                         out_channels=512,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
            self.seg_decoder = get_suffix_net(self.config, 512,
                                              self.class_number)
コード例 #5
0
ファイル: cross_merge.py プロジェクト: nowrin0102/torchseg
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__
        self.backbone = backbone(config.model)

        if hasattr(self.config.model, 'backbone_freeze'):
            if self.config.model.backbone_freeze:
                print('freeze backbone weights' + '*' * 30)
                freeze_layer(self.backbone)

        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
        self.ignore_index = self.config.dataset.ignore_index
        self.edge_class_num = self.config.dataset.edge_class_num
        self.cross_merge_times = self.config.model.cross_merge_times

        self.midnet_input_shape = self.backbone.get_output_shape(
            self.upsample_layer, self.input_shape)
        self.midnet_out_channels = 2 * self.midnet_input_shape[1]

        self.midnet = get_midnet(self.config, self.midnet_input_shape,
                                 self.midnet_out_channels)

        self.seg0 = get_suffix_net(self.config, self.midnet_out_channels,
                                   self.class_number)
        self.edge0 = get_suffix_net(self.config, self.midnet_out_channels,
                                    self.edge_class_num)

        seg_list = []
        edge_list = []
        # before concat
        seg_conv_list = []
        edge_conv_list = []

        feature_channel = 512
        concat_channel = feature_channel // 2
        #TODO use psp other than conv
        for i in range(self.cross_merge_times):
            seg_conv_list.append(
                conv_bn_relu(in_channels=feature_channel,
                             out_channels=concat_channel,
                             kernel_size=1,
                             stride=1,
                             padding=0))
            edge_conv_list.append(
                conv_bn_relu(in_channels=feature_channel,
                             out_channels=concat_channel,
                             kernel_size=1,
                             stride=1,
                             padding=0))
            seg_list.append(
                get_suffix_net(self.config, 2 * concat_channel,
                               self.class_number))
            edge_list.append(
                get_suffix_net(self.config, 2 * concat_channel,
                               self.edge_class_num))

        self.seg_list = TN.ModuleList(seg_list)
        self.edge_list = TN.ModuleList(edge_list)
        self.seg_conv_list = TN.ModuleList(seg_conv_list)
        self.edge_conv_list = TN.ModuleList(edge_conv_list)
コード例 #6
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.name = self.__class__.__name__

        use_none_layer = config.model.use_none_layer
        self.backbone = backbone(config.model, use_none_layer=use_none_layer)

        if hasattr(self.config.model, 'backbone_freeze'):
            if self.config.model.backbone_freeze:
                #                print('freeze backbone weights'+'*'*30)
                freeze_layer(self.backbone)

        self.upsample_layer = self.config.model.upsample_layer
        self.class_number = self.config.model.class_number
        self.input_shape = self.config.model.input_shape
        self.dataset_name = self.config.dataset.name
        self.ignore_index = self.config.dataset.ignore_index
        self.edge_class_num = self.config.dataset.edge_class_num

        self.midnet_input_shape = self.backbone.get_output_shape(
            self.upsample_layer, self.input_shape)
        #        self.midnet_out_channels=self.config.model.midnet_out_channels
        self.midnet_out_channels = 2 * self.midnet_input_shape[1]

        self.midnet = get_midnet(self.config, self.midnet_input_shape,
                                 self.midnet_out_channels)

        self.decoder = get_suffix_net(config, self.midnet_out_channels,
                                      self.class_number)

        layer_shapes = self.backbone.get_layer_shapes(self.input_shape)
        print('layer shapes', layer_shapes)
        edge_aux_list = []
        for i in range(self.upsample_layer - 1):
            edge_aux_list.append(
                TN.Conv2d(in_channels=layer_shapes[i + 1][1],
                          out_channels=self.edge_class_num,
                          kernel_size=1))
        self.edge_aux_list = TN.ModuleList(edge_aux_list)
        self.edge_fusion_conv = TN.Conv2d(in_channels=self.edge_class_num *
                                          (self.upsample_layer - 1),
                                          out_channels=self.edge_class_num,
                                          kernel_size=1)

        self.optimizer_params = [{
            'params':
            [p for p in self.backbone.parameters() if p.requires_grad],
            'lr_mult':
            1
        }, {
            'params': self.edge_aux_list.parameters(),
            'lr_mult': 1
        }, {
            'params': self.edge_fusion_conv.parameters(),
            'lr_mult': 1
        }, {
            'params': self.midnet.parameters(),
            'lr_mult': 10
        }, {
            'params': self.decoder.parameters(),
            'lr_mult': 10
        }]