Beispiel #1
0
    def __init__(self, num_classes, camera_num, view_num, cfg, factory,
                 rearrange):
        super(build_transformer_local, self).__init__()
        model_path = cfg.MODEL.PRETRAIN_PATH
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.in_planes = 768

        print('using Transformer_type: {} as a backbone'.format(
            cfg.MODEL.TRANSFORMER_TYPE))

        if cfg.MODEL.SIE_CAMERA:
            camera_num = camera_num
        else:
            camera_num = 0

        if cfg.MODEL.SIE_VIEW:
            view_num = view_num
        else:
            view_num = 0

        self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](
            img_size=cfg.INPUT.SIZE_TRAIN,
            sie_xishu=cfg.MODEL.SIE_COE,
            local_feature=cfg.MODEL.JPM,
            camera=camera_num,
            view=view_num,
            stride_size=cfg.MODEL.STRIDE_SIZE,
            drop_path_rate=cfg.MODEL.DROP_PATH)

        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        block = self.base.blocks[-1]
        layer_norm = self.base.norm
        self.b1 = nn.Sequential(copy.deepcopy(block),
                                copy.deepcopy(layer_norm))
        self.b2 = nn.Sequential(copy.deepcopy(block),
                                copy.deepcopy(layer_norm))

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)
            self.classifier_1 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_1.apply(weights_init_classifier)
            self.classifier_2 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_2.apply(weights_init_classifier)
            self.classifier_3 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_3.apply(weights_init_classifier)
            self.classifier_4 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_4.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        self.bottleneck_1 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_1.bias.requires_grad_(False)
        self.bottleneck_1.apply(weights_init_kaiming)
        self.bottleneck_2 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_2.bias.requires_grad_(False)
        self.bottleneck_2.apply(weights_init_kaiming)
        self.bottleneck_3 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_3.bias.requires_grad_(False)
        self.bottleneck_3.apply(weights_init_kaiming)
        self.bottleneck_4 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_4.bias.requires_grad_(False)
        self.bottleneck_4.apply(weights_init_kaiming)

        self.shuffle_groups = cfg.MODEL.SHUFFLE_GROUP
        print('using shuffle_groups size:{}'.format(self.shuffle_groups))
        self.shift_num = cfg.MODEL.SHIFT_NUM
        print('using shift_num size:{}'.format(self.shift_num))
        self.divide_length = cfg.MODEL.DEVIDE_LENGTH
        print('using divide_length size:{}'.format(self.divide_length))
        self.rearrange = rearrange
Beispiel #2
0
    def __init__(self, num_classes, camera_num, view_num, cfg, factory):
        super(build_transformer, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.in_planes = 768

        print('using Transformer_type: {} as a backbone'.format(
            cfg.MODEL.TRANSFORMER_TYPE))

        if cfg.MODEL.SIE_CAMERA:
            camera_num = camera_num
        else:
            camera_num = 0
        if cfg.MODEL.SIE_VIEW:
            view_num = view_num
        else:
            view_num = 0

        self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](
            img_size=cfg.INPUT.SIZE_TRAIN,
            sie_xishu=cfg.MODEL.SIE_COE,
            camera=camera_num,
            view=view_num,
            stride_size=cfg.MODEL.STRIDE_SIZE,
            drop_path_rate=cfg.MODEL.DROP_PATH,
            drop_rate=cfg.MODEL.DROP_OUT,
            attn_drop_rate=cfg.MODEL.ATT_DROP_RATE)
        if cfg.MODEL.TRANSFORMER_TYPE == 'deit_small_patch16_224_TransReID':
            self.in_planes = 384
        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
Beispiel #3
0
class build_transformer_local(nn.Module):
    def __init__(self, num_classes, camera_num, view_num, cfg, factory,
                 rearrange):
        super(build_transformer_local, self).__init__()
        model_path = cfg.MODEL.PRETRAIN_PATH
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.in_planes = 768

        print('using Transformer_type: {} as a backbone'.format(
            cfg.MODEL.TRANSFORMER_TYPE))

        if cfg.MODEL.SIE_CAMERA:
            camera_num = camera_num
        else:
            camera_num = 0

        if cfg.MODEL.SIE_VIEW:
            view_num = view_num
        else:
            view_num = 0

        self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](
            img_size=cfg.INPUT.SIZE_TRAIN,
            sie_xishu=cfg.MODEL.SIE_COE,
            local_feature=cfg.MODEL.JPM,
            camera=camera_num,
            view=view_num,
            stride_size=cfg.MODEL.STRIDE_SIZE,
            drop_path_rate=cfg.MODEL.DROP_PATH)

        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        block = self.base.blocks[-1]
        layer_norm = self.base.norm
        self.b1 = nn.Sequential(copy.deepcopy(block),
                                copy.deepcopy(layer_norm))
        self.b2 = nn.Sequential(copy.deepcopy(block),
                                copy.deepcopy(layer_norm))

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)
            self.classifier_1 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_1.apply(weights_init_classifier)
            self.classifier_2 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_2.apply(weights_init_classifier)
            self.classifier_3 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_3.apply(weights_init_classifier)
            self.classifier_4 = nn.Linear(self.in_planes,
                                          self.num_classes,
                                          bias=False)
            self.classifier_4.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        self.bottleneck_1 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_1.bias.requires_grad_(False)
        self.bottleneck_1.apply(weights_init_kaiming)
        self.bottleneck_2 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_2.bias.requires_grad_(False)
        self.bottleneck_2.apply(weights_init_kaiming)
        self.bottleneck_3 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_3.bias.requires_grad_(False)
        self.bottleneck_3.apply(weights_init_kaiming)
        self.bottleneck_4 = nn.BatchNorm1d(self.in_planes)
        self.bottleneck_4.bias.requires_grad_(False)
        self.bottleneck_4.apply(weights_init_kaiming)

        self.shuffle_groups = cfg.MODEL.SHUFFLE_GROUP
        print('using shuffle_groups size:{}'.format(self.shuffle_groups))
        self.shift_num = cfg.MODEL.SHIFT_NUM
        print('using shift_num size:{}'.format(self.shift_num))
        self.divide_length = cfg.MODEL.DEVIDE_LENGTH
        print('using divide_length size:{}'.format(self.divide_length))
        self.rearrange = rearrange

    def forward(self,
                x,
                label=None,
                cam_label=None,
                view_label=None):  # label is unused if self.cos_layer == 'no'

        features = self.base(x, cam_label=cam_label, view_label=view_label)

        # global branch
        b1_feat = self.b1(features)  # [64, 129, 768]
        global_feat = b1_feat[:, 0]

        # JPM branch
        feature_length = features.size(1) - 1
        patch_length = feature_length // self.divide_length
        token = features[:, 0:1]

        if self.rearrange:
            x = shuffle_unit(features, self.shift_num, self.shuffle_groups)
        else:
            x = features[:, 1:]
        # lf_1
        b1_local_feat = x[:, :patch_length]
        b1_local_feat = self.b2(torch.cat((token, b1_local_feat), dim=1))
        local_feat_1 = b1_local_feat[:, 0]

        # lf_2
        b2_local_feat = x[:, patch_length:patch_length * 2]
        b2_local_feat = self.b2(torch.cat((token, b2_local_feat), dim=1))
        local_feat_2 = b2_local_feat[:, 0]

        # lf_3
        b3_local_feat = x[:, patch_length * 2:patch_length * 3]
        b3_local_feat = self.b2(torch.cat((token, b3_local_feat), dim=1))
        local_feat_3 = b3_local_feat[:, 0]

        # lf_4
        b4_local_feat = x[:, patch_length * 3:patch_length * 4]
        b4_local_feat = self.b2(torch.cat((token, b4_local_feat), dim=1))
        local_feat_4 = b4_local_feat[:, 0]

        feat = self.bottleneck(global_feat)

        local_feat_1_bn = self.bottleneck_1(local_feat_1)
        local_feat_2_bn = self.bottleneck_2(local_feat_2)
        local_feat_3_bn = self.bottleneck_3(local_feat_3)
        local_feat_4_bn = self.bottleneck_4(local_feat_4)

        if self.training:
            if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                     'circle'):
                cls_score = self.classifier(feat, label)
            else:
                cls_score = self.classifier(feat)
                cls_score_1 = self.classifier_1(local_feat_1_bn)
                cls_score_2 = self.classifier_2(local_feat_2_bn)
                cls_score_3 = self.classifier_3(local_feat_3_bn)
                cls_score_4 = self.classifier_4(local_feat_4_bn)
            return [
                cls_score, cls_score_1, cls_score_2, cls_score_3, cls_score_4
            ], [
                global_feat, local_feat_1, local_feat_2, local_feat_3,
                local_feat_4
            ]  # global feature for triplet loss
        else:
            if self.neck_feat == 'after':
                return torch.cat([
                    feat, local_feat_1_bn / 4, local_feat_2_bn / 4,
                    local_feat_3_bn / 4, local_feat_4_bn / 4
                ],
                                 dim=1)
            else:
                return torch.cat([
                    global_feat, local_feat_1 / 4, local_feat_2 / 4,
                    local_feat_3 / 4, local_feat_4 / 4
                ],
                                 dim=1)

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in param_dict:
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(
            model_path))
Beispiel #4
0
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif 'efficientnet' in model_name:
            self.base = EfficientNet.from_pretrained(model_name)
            print('using {} as a backbone'.format(model_name))
            self.in_planes = self.base._fc.in_features
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        #如果用efficentnet则这几句不要加
        if pretrain_choice == 'imagenet':
            #加载分布式权重时才需要加
            # self.base = nn.DataParallel(self.base)
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
Beispiel #5
0
class build_transformer(nn.Module):
    def __init__(self, num_classes, camera_num, view_num, cfg, factory):
        super(build_transformer, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.in_planes = 768

        print('using Transformer_type: {} as a backbone'.format(
            cfg.MODEL.TRANSFORMER_TYPE))

        if cfg.MODEL.SIE_CAMERA:
            camera_num = camera_num
        else:
            camera_num = 0
        if cfg.MODEL.SIE_VIEW:
            view_num = view_num
        else:
            view_num = 0

        self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](
            img_size=cfg.INPUT.SIZE_TRAIN,
            sie_xishu=cfg.MODEL.SIE_COE,
            camera=camera_num,
            view=view_num,
            stride_size=cfg.MODEL.STRIDE_SIZE,
            drop_path_rate=cfg.MODEL.DROP_PATH,
            drop_rate=cfg.MODEL.DROP_OUT,
            attn_drop_rate=cfg.MODEL.ATT_DROP_RATE)
        if cfg.MODEL.TRANSFORMER_TYPE == 'deit_small_patch16_224_TransReID':
            self.in_planes = 384
        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

    def forward(self, x, label=None, cam_label=None, view_label=None):
        global_feat = self.base(x, cam_label=cam_label, view_label=view_label)

        feat = self.bottleneck(global_feat)

        if self.training:
            if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                     'circle'):
                cls_score = self.classifier(feat, label)
            else:
                cls_score = self.classifier(feat)

            return cls_score, global_feat  # global feature for triplet loss
        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                return feat
            else:
                # print("Test with feature before BN")
                return global_feat

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in param_dict:
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(
            model_path))
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.model_name = model_name
        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck, frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride, frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'efficientnet_b4':
            print('using efficientnet_b2 as a backbone') 
            self.base = EfficientNet.from_pretrained('efficientnet-b2', advprop=False) 
            self.in_planes = self.base._fc.in_features
        elif model_name == 'efficientnet_b7':
            print('using efficientnet_b7 as a backbone')
            self.base = EfficientNet.from_pretrained('efficientnet-b7', advprop=True)
            self.in_planes = self.base._fc.in_features
        elif model_name == 'resnext101_ibn_a':
            self.in_planes = 2048
            self.base = resnext101_ibn_a(last_stride)
            print('using resnext101_ibn_a as a backbone')
            
        elif model_name == 'HRnet':
            self.in_planes = 2048
            self.base = get_cls_net(cfg, pretrained = model_path)  
            
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b4' and model_name != 'efficientnet_b7' and  model_name != 'HRnet':
       
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes, self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes, self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes, self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes, self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        if self.model_name == 'HRnet':
            self.attention_tconv = nn.Conv1d(self.in_planes, 1, 3, padding=1)                        
            self.upsample0 = nn.Sequential(
                                    nn.Conv2d(32, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample1 = nn.Sequential(
                                    nn.Conv2d(64, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample2 = nn.Sequential(
                                    nn.Conv2d(128, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample3 = nn.Sequential(
                                    nn.Conv2d(256, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
Beispiel #7
0
class Backbone(nn.Module):
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif 'efficientnet' in model_name:
            self.base = EfficientNet.from_pretrained(model_name)
            print('using {} as a backbone'.format(model_name))
            self.in_planes = self.base._fc.in_features
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        #如果用efficentnet则这几句不要加
        if pretrain_choice == 'imagenet':
            #加载分布式权重时才需要加
            # self.base = nn.DataParallel(self.base)
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

    def forward(self,
                x,
                label=None):  # label is unused if self.cos_layer == 'no'
        x = self.base(x)
        global_feat = self.gap(x)
        global_feat = global_feat.view(global_feat.shape[0],
                                       -1)  # flatten to (bs, 2048)
        feat = self.bottleneck(global_feat)

        if self.neck == 'no':
            feat = global_feat
        elif self.neck == 'bnneck':
            feat = self.bottleneck(global_feat)

        if self.training:
            if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                     'circle'):
                cls_score = self.classifier(feat, label)
            else:
                cls_score = self.classifier(feat)

            return cls_score, global_feat
        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                return feat
            else:
                # print("Test with feature before BN")
                return global_feat

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        # if isinstance(self,torch.nn.DataParallel):
        #     self = self.module
        # if isinstance(param_dict,torch.nn.DataParallel):
        #     param_dict = param_dict.module
        # self = nn.DataParallel(self)
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i].copy_(param_dict[i])
        # self.load_state_dict(param_dict)
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)

        # self = nn.DataParallel(self)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(
            model_path))
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        # resnet_ibn 与 常规的resnet区别在于在Bottleneck内使用了InstanceNorm.
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        # se_resnet_ibn 在 se_resnet_ibn 基础上对Bottleneck增加了SE-layer对通道做attention操作
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')

        # resnet_ibn_b 与 resnet_ibn_a 的区别:
        # resnet_ibn_a在Bottleneck内输出前将通道数拆分成两部分,一部分做InstanceNorm,一部分做batchNorm,然后拼接
        # resnet_ibn_b将con1后面的batchNorm替换成InstanceNorm,并在在Bottleneck内输出前先做batchNorm再做InstanceNorm。
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE

        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
class Backbone(nn.Module):
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.model_name = model_name
        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck, frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride, frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'efficientnet_b4':
            print('using efficientnet_b2 as a backbone') 
            self.base = EfficientNet.from_pretrained('efficientnet-b2', advprop=False) 
            self.in_planes = self.base._fc.in_features
        elif model_name == 'efficientnet_b7':
            print('using efficientnet_b7 as a backbone')
            self.base = EfficientNet.from_pretrained('efficientnet-b7', advprop=True)
            self.in_planes = self.base._fc.in_features
        elif model_name == 'resnext101_ibn_a':
            self.in_planes = 2048
            self.base = resnext101_ibn_a(last_stride)
            print('using resnext101_ibn_a as a backbone')
            
        elif model_name == 'HRnet':
            self.in_planes = 2048
            self.base = get_cls_net(cfg, pretrained = model_path)  
            
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b4' and model_name != 'efficientnet_b7' and  model_name != 'HRnet':
       
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes, self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes, self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes, self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes, self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        if self.model_name == 'HRnet':
            self.attention_tconv = nn.Conv1d(self.in_planes, 1, 3, padding=1)                        
            self.upsample0 = nn.Sequential(
                                    nn.Conv2d(32, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample1 = nn.Sequential(
                                    nn.Conv2d(64, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample2 = nn.Sequential(
                                    nn.Conv2d(128, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            self.upsample3 = nn.Sequential(
                                    nn.Conv2d(256, self.in_planes, kernel_size=1, stride=1, bias=False),
                                        )
            
    def forward(self, x, label=None):  # label is unused if self.cos_layer == 'no'
        if self.model_name == 'HRnet':
            y_list = self.base(x)
            
            global_feat0 = self.gap(self.upsample0(y_list[0])) 
            global_feat1 = self.gap(self.upsample1(y_list[1])) 
            global_feat2 = self.gap(self.upsample2(y_list[2])) 
            global_feat3 = self.gap(self.upsample3(y_list[3])) 
            weight_ori = torch.cat([global_feat0, global_feat1, global_feat2, global_feat3], dim=2)
            weight_ori = weight_ori.view(weight_ori.shape[0], weight_ori.shape[1], -1)
            attention_feat = F.relu(self.attention_tconv(weight_ori))
            attention_feat = torch.squeeze(attention_feat)
            weight = F.sigmoid(attention_feat)
            weight = F.normalize(weight, p=1, dim=1)

            weight = torch.unsqueeze(weight, 1)
            weight = weight.expand_as(weight_ori)
            global_feat = torch.mul(weight_ori, weight)
            global_feat = global_feat.sum(-1)
            global_feat = global_feat.view(global_feat.shape[0], -1) #flatten to (bs, 2048)
            feat = self.bottleneck(global_feat)
            if self.training:
                if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax', 'circle'):
                    cls_score = self.classifier(feat, label)
                else:
                    cls_score = self.classifier(feat)
                return cls_score, global_feat  # global feature for triplet loss
            else:
                return feat
        else:
            if self.model_name =='efficientnet_b4' or self.model_name =='efficientnet_b7':
                x = self.base.extract_features(x)
            else:
                x = self.base(x)
            global_feat = self.gap(x)
            global_feat = global_feat.view(global_feat.shape[0], -1)  # flatten to (bs, 2048)

            if self.neck == 'no':
                feat = global_feat
            elif self.neck == 'bnneck':
                feat = self.bottleneck(global_feat)

            if self.training:
                if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax', 'circle'):
                    #print(feat,label)
                    cls_score = self.classifier(feat, label)
                else:
                    cls_score = self.classifier(feat)

                return cls_score, global_feat
            else:
                if self.neck_feat == 'after':
                    # print("Test with feature after BN")
                    return feat
                else:
                    # print("Test with feature before BN")
                    return global_feat

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        '''
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in param_dict.items():
            name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
            new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
        '''
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i].copy_(param_dict[i])
            #self.state_dict()[i].copy_(new_state_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(model_path))
class Backbone(nn.Module):
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.model_name = model_name

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet152':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               layers=[3, 8, 36, 3])
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride,
                                           frozen_stages=cfg.MODEL.FROZEN)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'efficientnet_b7':
            print('using efficientnet_b7 as a backbone')
            self.base = EfficientNet.from_pretrained('efficientnet-b7',
                                                     advprop=False)
            self.in_planes = self.base._fc.in_features
        elif model_name == 'densenet169_ibn_a':
            self.in_planes = 1664
            self.base = densenet169_ibn_a()
            print('using densenet169_ibn_a as a backbone')
        elif model_name == 'resnest50':
            self.in_planes = 2048
            self.base = resnest50(last_stride)
            print('using resnest50 as a backbone')
        elif model_name == 'resnest101':
            self.in_planes = 2048
            self.base = resnest101(last_stride)
            print('using resnest101 as a backbone')
        elif model_name == 'resnest200':
            self.in_planes = 2048
            self.base = resnest200(last_stride)
            print('using resnest200 as a backbone')
        elif model_name == 'resnest269':
            self.in_planes = 2048
            self.base = resnest269(last_stride)
            print('using resnest269 as a backbone')
        elif model_name == 'resnext101_ibn_a':
            self.in_planes = 2048
            self.base = resnext101_ibn_a()
            print('using resnext101_ibn_a as a backbone')
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if cfg.MODEL.POOLING_METHOD == 'gempoolP':
            print('using GeMP pooling')
            self.gap = GeneralizedMeanPoolingP()
        elif cfg.MODEL.POOLING_METHOD == 'gempool':
            print('using GeM pooling')
            self.gap = GeM(freeze_p=False)
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE

        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b7':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))
        elif pretrain_choice == 'self':
            param_dict = torch.load(model_path, map_location='cpu')
            for i in param_dict:
                if 'classifier' in i:
                    continue
                self.state_dict()[i].copy_(param_dict[i])
            print('Loading finetune model......from {}'.format(model_path))

    def forward(self,
                x,
                label=None,
                cam_label=None):  # label is unused if self.cos_layer == 'no'
        if self.model_name == 'efficientnet_b7':
            x = self.base.extract_features(x)
        else:
            x = self.base(x)
        global_feat = nn.functional.avg_pool2d(x, x.shape[2:4])
        global_feat = global_feat.view(global_feat.shape[0],
                                       -1)  # flatten to (bs, 2048)

        if self.neck == 'no':
            feat = global_feat
        elif self.neck == 'bnneck':
            feat = self.bottleneck(global_feat)

        if self.training:
            if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                     'circle'):
                cls_score = self.classifier(feat, label)
            else:
                cls_score = self.classifier(feat)

            return cls_score, global_feat  # global feature for triplet loss
        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                return feat
            else:
                # print("Test with feature before BN")
                return global_feat

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path, map_location='cpu')
        if 'state_dict' in param_dict:
            param_dict = param_dict['state_dict']
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.model_name = model_name

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet152':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               layers=[3, 8, 36, 3])
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride,
                                           frozen_stages=cfg.MODEL.FROZEN)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'efficientnet_b7':
            print('using efficientnet_b7 as a backbone')
            self.base = EfficientNet.from_pretrained('efficientnet-b7',
                                                     advprop=False)
            self.in_planes = self.base._fc.in_features
        elif model_name == 'densenet169_ibn_a':
            self.in_planes = 1664
            self.base = densenet169_ibn_a()
            print('using densenet169_ibn_a as a backbone')
        elif model_name == 'resnest50':
            self.in_planes = 2048
            self.base = resnest50(last_stride)
            print('using resnest50 as a backbone')
        elif model_name == 'resnest101':
            self.in_planes = 2048
            self.base = resnest101(last_stride)
            print('using resnest101 as a backbone')
        elif model_name == 'resnest200':
            self.in_planes = 2048
            self.base = resnest200(last_stride)
            print('using resnest200 as a backbone')
        elif model_name == 'resnest269':
            self.in_planes = 2048
            self.base = resnest269(last_stride)
            print('using resnest269 as a backbone')
        elif model_name == 'resnext101_ibn_a':
            self.in_planes = 2048
            self.base = resnext101_ibn_a()
            print('using resnext101_ibn_a as a backbone')
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if cfg.MODEL.POOLING_METHOD == 'gempoolP':
            print('using GeMP pooling')
            self.gap = GeneralizedMeanPoolingP()
        elif cfg.MODEL.POOLING_METHOD == 'gempool':
            print('using GeM pooling')
            self.gap = GeM(freeze_p=False)
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE

        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b7':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))
        elif pretrain_choice == 'self':
            param_dict = torch.load(model_path, map_location='cpu')
            for i in param_dict:
                if 'classifier' in i:
                    continue
                self.state_dict()[i].copy_(param_dict[i])
            print('Loading finetune model......from {}'.format(model_path))
    def __init__(self, num_classes, camera_num, view_num, cfg, factory):
        super(build_transformer, self).__init__()
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        print('using Transformer_type: {} as a backbone'.format(
            cfg.MODEL.Transformer_TYPE))

        if cfg.MODEL.CAMERA_EMBEDDING:
            camera_num = camera_num
        else:
            camera_num = 0

        if cfg.MODEL.VIEWPOINT_EMBEDDING:
            view_num = view_num
        else:
            view_num = 0

        self.base = factory[cfg.MODEL.Transformer_TYPE](
            img_size=cfg.INPUT.SIZE_TRAIN,
            aie_xishu=cfg.MODEL.AIE_COE,
            local_feature=cfg.MODEL.LOCAL_F,
            camera=camera_num,
            view=view_num,
            stride_size=cfg.MODEL.STRIDE_SIZE,
            drop_path_rate=cfg.MODEL.DROP_PATH)

        if pretrain_choice == 'imagenet':
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.in_planes = self.base.embed_dim
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

        if pretrain_choice == 'self':
            param_dict = torch.load(model_path, map_location='cpu')
            for i in param_dict:
                if 'classifier' in i:
                    continue
                self.state_dict()[i].copy_(param_dict[i])
            print('Loading finetune model......from {}'.format(model_path))
Beispiel #13
0
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        self.model_name = model_name
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.layernorm = cfg.MODEL.LAYERN0RM
        if self.layernorm:
            self.ln = nn.LayerNorm([cfg.MODEL.FEAT_SIZE])

        if model_name == 'resnet50':
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'efficientnet_b8':
            self.base = EfficientNet.from_pretrained('efficientnet-b8')
        elif model_name == 'efficientnet_b0':
            self.base = EfficientNet.from_pretrained('efficientnet-b0')
        elif model_name == 'efficientnet_b1':
            self.base = EfficientNet.from_pretrained('efficientnet-b1')
        elif model_name == 'efficientnet_b2':
            self.base = EfficientNet.from_pretrained('efficientnet-b2')
        elif model_name == 'efficientnet_b3':
            self.base = EfficientNet.from_pretrained('efficientnet-b3')
        elif model_name == 'efficientnet_b4':
            self.base = EfficientNet.from_pretrained('efficientnet-b4')
        elif model_name == 'efficientnet_b5':
            self.base = EfficientNet.from_pretrained('efficientnet-b5')
        elif model_name == 'efficientnet_b6':
            self.base = EfficientNet.from_pretrained('efficientnet-b6')
        elif model_name == 'efficientnet_b7':
            self.base = EfficientNet.from_pretrained('efficientnet-b7')
        elif model_name == 'densenet121':  # feat size 1024
            self.base = densenet121(pretrained=True)
        elif model_name == 'densenet161':
            self.base = densenet161(pretrained=True)
        elif model_name == 'densenet169':
            self.base = densenet169(pretrained=True)
        elif model_name == 'nasnet':
            self.base = ft_net_NAS()  # feat size 4032
        elif model_name == 'osnet':
            self.base = osnet.build_osnet_backbone(cfg.MODEL.PRETRAIN_PATH)
        elif model_name == 'sknet':
            self.base = SKNet101()
        elif model_name == 'cbam':
            self.base = resnet101_cbam()
        elif model_name == 'non_local':
            self.base = Non_Local_101(last_stride)
        elif model_name == 'inceptionv4':  # feat size 1536
            self.base = inception()
        elif model_name == 'mgn':
            self.base = MGN_res50()
        else:
            print('unsupported backbone! but got {}'.format(model_name))
        self.in_planes = cfg.MODEL.FEAT_SIZE

        if pretrain_choice == 'imagenet' and cfg.MODEL.PRETRAIN_PATH and 'resnet101' in cfg.MODEL.NAME:
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'curricularface':
            print('using {} loss with s:{}, m: {}'.format(
                self.ID_LOSS_TYPE, cfg.SOLVER.COSINE_SCALE,
                cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CurricularFace(self.in_planes,
                                             self.num_classes,
                                             s=cfg.SOLVER.COSINE_SCALE,
                                             m=cfg.SOLVER.COSINE_MARGIN)

        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
Beispiel #14
0
class Backbone(nn.Module):
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        self.model_name = model_name
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.layernorm = cfg.MODEL.LAYERN0RM
        if self.layernorm:
            self.ln = nn.LayerNorm([cfg.MODEL.FEAT_SIZE])

        if model_name == 'resnet50':
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'efficientnet_b8':
            self.base = EfficientNet.from_pretrained('efficientnet-b8')
        elif model_name == 'efficientnet_b0':
            self.base = EfficientNet.from_pretrained('efficientnet-b0')
        elif model_name == 'efficientnet_b1':
            self.base = EfficientNet.from_pretrained('efficientnet-b1')
        elif model_name == 'efficientnet_b2':
            self.base = EfficientNet.from_pretrained('efficientnet-b2')
        elif model_name == 'efficientnet_b3':
            self.base = EfficientNet.from_pretrained('efficientnet-b3')
        elif model_name == 'efficientnet_b4':
            self.base = EfficientNet.from_pretrained('efficientnet-b4')
        elif model_name == 'efficientnet_b5':
            self.base = EfficientNet.from_pretrained('efficientnet-b5')
        elif model_name == 'efficientnet_b6':
            self.base = EfficientNet.from_pretrained('efficientnet-b6')
        elif model_name == 'efficientnet_b7':
            self.base = EfficientNet.from_pretrained('efficientnet-b7')
        elif model_name == 'densenet121':  # feat size 1024
            self.base = densenet121(pretrained=True)
        elif model_name == 'densenet161':
            self.base = densenet161(pretrained=True)
        elif model_name == 'densenet169':
            self.base = densenet169(pretrained=True)
        elif model_name == 'nasnet':
            self.base = ft_net_NAS()  # feat size 4032
        elif model_name == 'osnet':
            self.base = osnet.build_osnet_backbone(cfg.MODEL.PRETRAIN_PATH)
        elif model_name == 'sknet':
            self.base = SKNet101()
        elif model_name == 'cbam':
            self.base = resnet101_cbam()
        elif model_name == 'non_local':
            self.base = Non_Local_101(last_stride)
        elif model_name == 'inceptionv4':  # feat size 1536
            self.base = inception()
        elif model_name == 'mgn':
            self.base = MGN_res50()
        else:
            print('unsupported backbone! but got {}'.format(model_name))
        self.in_planes = cfg.MODEL.FEAT_SIZE

        if pretrain_choice == 'imagenet' and cfg.MODEL.PRETRAIN_PATH and 'resnet101' in cfg.MODEL.NAME:
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)

        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'curricularface':
            print('using {} loss with s:{}, m: {}'.format(
                self.ID_LOSS_TYPE, cfg.SOLVER.COSINE_SCALE,
                cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CurricularFace(self.in_planes,
                                             self.num_classes,
                                             s=cfg.SOLVER.COSINE_SCALE,
                                             m=cfg.SOLVER.COSINE_MARGIN)

        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

    def forward(self,
                x,
                label=None):  # label is unused if self.cos_layer == 'no'
        if self.model_name != 'mgn':
            x = self.base(x)
            global_feat = self.gap(x)
            global_feat = global_feat.view(global_feat.shape[0],
                                           -1)  # flatten to (bs, 2048)
            feat = self.bottleneck(global_feat)

            if self.neck == 'no':
                feat = global_feat
            elif self.neck == 'bnneck':
                feat = self.bottleneck(global_feat)
            if self.training:
                try:
                    if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                             'circle', 'curricularface'):
                        cls_score = self.classifier(feat, label)
                    else:
                        cls_score = self.classifier(feat)
                    if not self.layernorm:
                        return cls_score, global_feat
                    else:
                        return cls_score, self.ln(global_feat)
                except:
                    pass
            else:
                if self.neck_feat == 'after':
                    # print("Test with feature after BN")
                    if not self.layernorm:
                        return feat
                    else:
                        self.ln(feat)
                else:
                    # print("Test with feature before BN")
                    if not self.layernorm:
                        return global_feat
                    else:
                        return self.ln(global_feat)
        else:
            pass
            # todo MGN

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(
            model_path))
Beispiel #15
0
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        self.cfg = cfg
        self.model_name = model_name
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        # self.in_planes = 1280
        # model_weight_b0 = EfficientNet.from_pretrained('efficientnet-b0')
        # model_weight_b0.to('cuda')
        # mm = nn.Sequential(*model_weight_b0.named_children())
        # self.base = model_weight_b0.extract_features
        #
        # from IPython import embed
        # embed()
        #print('using efficientnet-b0 as a backbone')

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet50_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet50_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'resnet50_ibn_b':
            self.in_planes = 2048
            self.base = resnet50_ibn_b(last_stride)
            print('using resnet50_ibn_b as a backbone')
        elif model_name == 'resnest50':
            self.in_planes = 2048
            self.base = resnest50(last_stride)
            print('using resnest50 as a backbone')
        elif model_name == 'resnest50_ibn':
            self.in_planes = 2048
            self.base = resnest50_ibn(last_stride)
            print('using resnest50_ibn as a backbone')
        elif model_name == 'resnest101':
            self.in_planes = 2048
            self.base = resnest101(last_stride)
            print('using resnest101 as a backbone')
        elif model_name == 'resnest101_ibn':
            self.in_planes = 2048
            self.base = resnest101_ibn(last_stride)
            print('using resnest101_ibn as a backbone')
        elif model_name == 'efficientnet_b7':
            # self.in_planes = 1280
            #
            # model_weight_b0 = EfficientNet.from_pretrained('efficientnet-b0')
            # model_weight_b0.to('cuda')
            # self.base = model_weight_b0.extract_features
            self.base = EfficientNet.from_pretrained('efficientnet-b0')
            self.in_planes = self.base._fc.in_features
            print('using efficientnet-b0 as a backbone')
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b7':
            # if model_name == 'efficientnet_b7':
            #     state_dict = torch.load(model_path)
            #     # self.base.load_state_dict(state_dict)
            #     if 'state_dict' in state_dict:
            #         param_dict = state_dict['state_dict']
            #     for i in param_dict:
            #         if 'fc' in i:
            #             continue
            #         self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
            # else:
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)
        if cfg.MODEL.IF_USE_PCB:
            self.pcb = PCB(cfg,
                           256,
                           num_classes,
                           0.5,
                           self.in_planes,
                           cut_at_pooling=False)
        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
Beispiel #16
0
class Backbone(nn.Module):
    def __init__(self, num_classes, cfg):
        super(Backbone, self).__init__()
        last_stride = cfg.MODEL.LAST_STRIDE
        model_path = cfg.MODEL.PRETRAIN_PATH
        model_name = cfg.MODEL.NAME
        self.cfg = cfg
        self.model_name = model_name
        pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        # self.in_planes = 1280
        # model_weight_b0 = EfficientNet.from_pretrained('efficientnet-b0')
        # model_weight_b0.to('cuda')
        # mm = nn.Sequential(*model_weight_b0.named_children())
        # self.base = model_weight_b0.extract_features
        #
        # from IPython import embed
        # embed()
        #print('using efficientnet-b0 as a backbone')

        if model_name == 'resnet50':
            self.in_planes = 2048
            self.base = ResNet(last_stride=last_stride,
                               block=Bottleneck,
                               frozen_stages=cfg.MODEL.FROZEN,
                               layers=[3, 4, 6, 3])
            print('using resnet50 as a backbone')
        elif model_name == 'resnet50_ibn_a':
            self.in_planes = 2048
            self.base = resnet50_ibn_a(last_stride)
            print('using resnet50_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_a':
            self.in_planes = 2048
            self.base = resnet101_ibn_a(last_stride,
                                        frozen_stages=cfg.MODEL.FROZEN)
            print('using resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet101_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet101_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'se_resnet50_ibn_a':
            self.in_planes = 2048
            self.base = se_resnet50_ibn_a(last_stride)
            print('using se_resnet101_ibn_a as a backbone')
        elif model_name == 'resnet101_ibn_b':
            self.in_planes = 2048
            self.base = resnet101_ibn_b(last_stride)
            print('using resnet101_ibn_b as a backbone')
        elif model_name == 'resnet50_ibn_b':
            self.in_planes = 2048
            self.base = resnet50_ibn_b(last_stride)
            print('using resnet50_ibn_b as a backbone')
        elif model_name == 'resnest50':
            self.in_planes = 2048
            self.base = resnest50(last_stride)
            print('using resnest50 as a backbone')
        elif model_name == 'resnest50_ibn':
            self.in_planes = 2048
            self.base = resnest50_ibn(last_stride)
            print('using resnest50_ibn as a backbone')
        elif model_name == 'resnest101':
            self.in_planes = 2048
            self.base = resnest101(last_stride)
            print('using resnest101 as a backbone')
        elif model_name == 'resnest101_ibn':
            self.in_planes = 2048
            self.base = resnest101_ibn(last_stride)
            print('using resnest101_ibn as a backbone')
        elif model_name == 'efficientnet_b7':
            # self.in_planes = 1280
            #
            # model_weight_b0 = EfficientNet.from_pretrained('efficientnet-b0')
            # model_weight_b0.to('cuda')
            # self.base = model_weight_b0.extract_features
            self.base = EfficientNet.from_pretrained('efficientnet-b0')
            self.in_planes = self.base._fc.in_features
            print('using efficientnet-b0 as a backbone')
        else:
            print('unsupported backbone! but got {}'.format(model_name))

        if pretrain_choice == 'imagenet' and model_name != 'efficientnet_b7':
            # if model_name == 'efficientnet_b7':
            #     state_dict = torch.load(model_path)
            #     # self.base.load_state_dict(state_dict)
            #     if 'state_dict' in state_dict:
            #         param_dict = state_dict['state_dict']
            #     for i in param_dict:
            #         if 'fc' in i:
            #             continue
            #         self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
            # else:
            self.base.load_param(model_path)
            print('Loading pretrained ImageNet model......from {}'.format(
                model_path))

        if cfg.MODEL.POOLING_METHOD == 'GeM':
            print('using GeM pooling')
            self.gap = GeM()
        else:
            self.gap = nn.AdaptiveAvgPool2d(1)
        if cfg.MODEL.IF_USE_PCB:
            self.pcb = PCB(cfg,
                           256,
                           num_classes,
                           0.5,
                           self.in_planes,
                           cut_at_pooling=False)
        self.num_classes = num_classes
        self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE
        if self.ID_LOSS_TYPE == 'arcface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Arcface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'cosface':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = Cosface(self.in_planes,
                                      self.num_classes,
                                      s=cfg.SOLVER.COSINE_SCALE,
                                      m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'amsoftmax':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = AMSoftmax(self.in_planes,
                                        self.num_classes,
                                        s=cfg.SOLVER.COSINE_SCALE,
                                        m=cfg.SOLVER.COSINE_MARGIN)
        elif self.ID_LOSS_TYPE == 'circle':
            print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,
                                                     cfg.SOLVER.COSINE_SCALE,
                                                     cfg.SOLVER.COSINE_MARGIN))
            self.classifier = CircleLoss(self.in_planes,
                                         self.num_classes,
                                         s=cfg.SOLVER.COSINE_SCALE,
                                         m=cfg.SOLVER.COSINE_MARGIN)
        else:
            self.classifier = nn.Linear(self.in_planes,
                                        self.num_classes,
                                        bias=False)
            self.classifier.apply(weights_init_classifier)

        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        # from IPython import embed
        # embed()
    def forward(self,
                x,
                label=None):  # label is unused if self.cos_layer == 'no'
        # device = 'cuda'
        # x.to(device)
        # from IPython import  embed
        # embed()
        #print("x.shape",x.shape)
        if 'efficientnet_b7' == self.model_name:
            x = self.base.extract_features(x)
        else:

            x = self.base(x)
        if self.cfg.MODEL.IF_USE_PCB:
            pcb_out = self.pcb(x)

        # print("x.shape",x.shape)
        global_feat = self.gap(x)
        # print("global_feat.shape",global_feat.shape)
        # print("pcb_out.shape",pcb_out.shape)
        global_feat = global_feat.view(global_feat.shape[0],
                                       -1)  # flatten to (bs, 2048)
        feat = self.bottleneck(global_feat)

        if self.neck == 'no':
            feat = global_feat
        elif self.neck == 'bnneck':
            feat = self.bottleneck(global_feat)

        if self.training:
            if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax',
                                     'circle'):
                cls_score = self.classifier(feat, label)
            else:
                cls_score = self.classifier(feat)
            if self.cfg.MODEL.IF_USE_PCB:
                return cls_score, global_feat, pcb_out
            else:
                return cls_score, global_feat
        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                if self.cfg.MODEL.IF_USE_PCB:
                    return feat, pcb_out
                else:
                    return feat
            else:
                # print("Test with feature before BN")
                if self.cfg.MODEL.IF_USE_PCB:
                    return global_feat, pcb_out
                else:
                    return global_feat

    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(
            model_path))