def __init__(self, configer, loss_dict=None, flag=""):
        super(ClsModel, self).__init__()
        self.configer = configer
        self.flag = flag if len(flag) == 0 else "{}_".format(flag)
        self.backbone = slim.__dict__[configer.get('network.{}backbone'.format(self.flag))](
            pretrained=configer.get('network.{}pretrained'.format(self.flag)),
            has_classifier=False
        )
        self.reduction = None
        fc_dim_out = configer.get('network.{}fc_dim'.format(self.flag), default=None)
        fc_dim = self.backbone.num_features
        if fc_dim_out is not None:
            self.reduction = nn.Conv2d(self.backbone.num_features, fc_dim_out, 1)
            fc_dim = fc_dim_out

        self.linear_list = nn.ModuleList()
        linear_type = configer.get('network', '{}linear_type'.format(self.flag))
        for num_classes in configer.get('data.num_classes'):
            self.linear_list.append(ModuleHelper.Linear(linear_type)(fc_dim, num_classes))

        self.embed = None
        if configer.get('network.{}embed'.format(self.flag), default=True):
            feat_dim = configer.get('network', '{}feat_dim'.format(self.flag))
            self.embed = nn.Sequential(
                nn.Linear(fc_dim, feat_dim),
                nn.BatchNorm1d(feat_dim)
            )

        self.bn = nn.BatchNorm1d(fc_dim)
        nn.init.zeros_(self.bn.bias)
        self.bn.bias.requires_grad = False

        self.valid_loss_dict = LOSS_TYPE[configer.get('loss', 'loss_type')] if loss_dict is None else loss_dict
 def __init__(self, configer, flag="", target_class=1):
     super(DeployClsModel, self).__init__()
     self.configer = configer
     self.flag = flag if len(flag) == 0 else "{}_".format(flag)
     self.backbone = BackboneSelector(configer).get_backbone(
         backbone_type=configer.get('network.{}backbone'.format(self.flag)),
         rm_last_stride=configer.get('network', '{}rm_last_stride'.format(self.flag), default=False)
     )
     self.reduction = None
     fc_dim_out = configer.get('network.{}fc_dim'.format(self.flag), default=None)
     fc_dim = self.backbone.num_features
     if fc_dim_out is not None:
         self.reduction = nn.Conv2d(self.backbone.num_features, fc_dim_out, 1)
         fc_dim = fc_dim_out
     self.bn = None
     if configer.get('network.{}fc_bn'.format(self.flag), default=None):
         self.bn = nn.BatchNorm2d(fc_dim)
         
     if self.configer.get('deploy.extract_score', default=False) or self.configer.get('deploy.extract_cam', default=False):
         self.linear_lists = nn.ModuleList()
         for source in range(self.configer.get('data', 'num_data_sources')):
             linear_list = nn.ModuleList()
             linear_type = self.configer.get('network', '{}src{}_linear_type'.format(self.flag, source))
             for num_classes in self.configer.get('data.src{}_num_classes'.format(source)):
                 linear_list.append(ModuleHelper.Linear(linear_type)(fc_dim, num_classes))
             self.linear_lists.append(linear_list)        
Exemple #3
0
    def __init__(self, configer, loss_dict=None, flag=""):
        super(ClsModel, self).__init__()
        self.configer = configer
        self.flag = flag if len(flag) == 0 else "{}_".format(flag)
        self.backbone = BackboneSelector(self.configer).get_backbone(
            backbone_type=self.configer.get('network.{}backbone'.format(
                self.flag)),
            pretrained_model=self.configer.get('network.{}pretrained'.format(
                self.flag)),
            rm_last_stride=self.configer.get('network.{}rm_last_stride'.format(
                self.flag),
                                             default=False))

        self.reduction = None
        fc_dim_out = self.configer.get('network.{}fc_dim'.format(self.flag),
                                       default=None)
        fc_dim = self.backbone.num_features
        if fc_dim_out is not None:
            self.reduction = nn.Conv2d(self.backbone.num_features, fc_dim_out,
                                       1)
            fc_dim = fc_dim_out
        self.bn = None
        if self.configer.get('network.{}fc_bn'.format(self.flag),
                             default=True):
            self.bn = nn.BatchNorm1d(fc_dim)
            nn.init.zeros_(self.bn.bias)
            self.bn.bias.requires_grad = False
        self.relu = None
        if self.configer.get('network.{}fc_relu'.format(self.flag),
                             default=True):
            self.relu = nn.ReLU()

        self.linear_lists = nn.ModuleList()
        for source in range(self.configer.get('data', 'num_data_sources')):
            linear_list = nn.ModuleList()
            linear_type = self.configer.get(
                'network', '{}src{}_linear_type'.format(self.flag, source))
            for num_classes in self.configer.get(
                    'data.src{}_num_classes'.format(source)):
                linear_list.append(
                    ModuleHelper.Linear(linear_type)(fc_dim, num_classes))
            self.linear_lists.append(linear_list)

        self.global_linear = None
        if self.configer.get('data.global_num_classes',
                             default=None) is not None:
            global_linear_type = self.configer.get(
                'network', '{}global_linear_type'.format(self.flag))
            self.global_linear = ModuleHelper.Linear(global_linear_type)(
                fc_dim,
                self.configer.get('data.global_num_classes', default=None))

        self.embed_after_norm = self.configer.get('network.embed_after_norm',
                                                  default=True)
        self.embed = None
        if self.configer.get('network.{}embed'.format(self.flag),
                             default=True):
            feat_dim = self.configer.get('network',
                                         '{}feat_dim'.format(self.flag))
            embed = []
            embed.append(nn.Linear(fc_dim, feat_dim))
            if self.configer.get('network.{}embed_norm_type'.format(
                    self.flag)) == 'L2':
                embed.append(LpNormalize(p=2, dim=1))
            elif self.configer.get('network.{}embed_norm_type'.format(
                    self.flag)) == 'BN':
                embed.append(nn.BatchNorm1d(feat_dim))
            self.embed = nn.Sequential(*embed)

        self.valid_loss_dict = LOSS_TYPE[self.configer.get(
            'loss', 'loss_type')] if loss_dict is None else loss_dict