Beispiel #1
0
    def __init__(self, cfg):
        super().__init__()
        self.register_buffer(
            "pixel_mean",
            torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
        self._cfg = cfg
        # backbone
        self.backbone = build_backbone(cfg)

        # head
        pool_type = cfg.MODEL.HEADS.POOL_LAYER
        if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
        elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
        elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
        elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
        elif pool_type == "identity": pool_layer = nn.Identity()
        else:
            raise KeyError(
                f"{pool_type} is invalid, please choose from "
                f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'."
            )

        in_feat = cfg.MODEL.HEADS.IN_FEAT
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
Beispiel #2
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer(
            "pixel_mean",
            torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # fmt: off
        # backbone
        bn_norm = cfg.MODEL.BACKBONE.NORM
        num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
        with_se = cfg.MODEL.BACKBONE.WITH_SE
        # fmt :on

        backbone = build_backbone(cfg)
        self.backbone = nn.Sequential(backbone.conv1, backbone.bn1,
                                      backbone.relu, backbone.maxpool,
                                      backbone.layer1, backbone.layer2,
                                      backbone.layer3[0])
        res_conv4 = nn.Sequential(*backbone.layer3[1:])
        res_g_conv5 = backbone.layer4

        res_p_conv5 = nn.Sequential(
            Bottleneck(1024,
                       512,
                       bn_norm,
                       num_splits,
                       False,
                       with_se,
                       downsample=nn.Sequential(
                           nn.Conv2d(1024, 2048, 1, bias=False),
                           get_norm(bn_norm, 2048, num_splits))),
            Bottleneck(2048, 512, bn_norm, num_splits, False, with_se),
            Bottleneck(2048, 512, bn_norm, num_splits, False, with_se))
        res_p_conv5.load_state_dict(backbone.layer4.state_dict())

        # branch1
        self.b1 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_g_conv5))
        self.b1_head = build_reid_heads(cfg)

        # branch2
        self.b2 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
        self.b2_head = build_reid_heads(cfg)
        self.b21_head = build_reid_heads(cfg)
        self.b22_head = build_reid_heads(cfg)

        # branch3
        self.b3 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
        self.b3_head = build_reid_heads(cfg)
        self.b31_head = build_reid_heads(cfg)
        self.b32_head = build_reid_heads(cfg)
        self.b33_head = build_reid_heads(cfg)
Beispiel #3
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.backbone = build_backbone(cfg)

        # head
        self.heads = build_heads(cfg)
Beispiel #4
0
    def __init__(self, cfg):
        super().__init__()
        self.backbone = build_backbone(cfg)
        self.heads = build_heads(cfg)

        self.loss_kwargs =  {
                    # loss name
                    'loss_names': cfg.MODEL.LOSSES.NAME,

                    # loss hyperparameters
                    'ce': {
                        'eps': cfg.MODEL.LOSSES.CE.EPSILON,
                        'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
                        'scale': cfg.MODEL.LOSSES.CE.SCALE
                    },
                    'tri': {
                        'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
                        'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
                        'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
                        'scale': cfg.MODEL.LOSSES.TRI.SCALE
                    },
                    'circle': {
                        'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
                        'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
                        'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
                    },
                    'cosface': {
                        'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
                        'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
                        'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
                    },
                    'center': {
                        'num_classes': cfg.MODEL.LOSSES.CENTER.NUM_CLASSES,
                        'feat_dim': cfg.MODEL.LOSSES.CENTER.FEAT_DIM,
                        'scale': cfg.MODEL.LOSSES.CENTER.SCALE
                    }
                }

        loss_names = self.loss_kwargs['loss_names']
        if 'CenterLoss' in loss_names:
            if self.loss_kwargs['center']['num_classes'] == 0:
                self.loss_kwargs['center']['num_classes'] = cfg.MODEL.HEADS.NUM_CLASSES
            self.center_loss = CenterLoss(
                num_class=self.loss_kwargs['center']['num_classes'],
                num_feature=self.loss_kwargs['center']['feat_dim']
            )

        pixel_mean = cfg.MODEL.PIXEL_MEAN
        pixel_std = cfg.MODEL.PIXEL_STD
        self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
        self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
    def test_fusebn(self):
        cfg = get_cfg()
        cfg.defrost()
        cfg.MODEL.BACKBONE.NAME = 'build_repvgg_backbone'
        cfg.MODEL.BACKBONE.DEPTH = 'B1g2'
        cfg.MODEL.BACKBONE.PRETRAIN = False
        model = build_backbone(cfg)
        model.eval()

        test_inp = torch.randn((1, 3, 256, 128))

        y = model(test_inp)

        model.deploy(mode=True)
        from ipdb import set_trace
        set_trace()
        fused_y = model(test_inp)

        print("final error :", torch.max(torch.abs(fused_y - y)).item())
Beispiel #6
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        # backbone
        self.backbone = build_backbone(cfg)

        # head
        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
            pool_layer = nn.AdaptiveAvgPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
            pool_layer = nn.AdaptiveMaxPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
            pool_layer = GeneralizedMeanPoolingP()
        else:
            pool_layer = nn.Identity()

        in_feat = cfg.MODEL.HEADS.IN_FEAT
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        self.use_clothes = cfg.MODEL.LOSSES.USE_CLOTHES

        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer(
            "pixel_mean",
            torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.backbone = build_backbone(cfg)
        # head
        self.heads = build_heads(cfg)
        # Train with clothes ids
        if self.use_clothes:
            self.clo_heads = build_heads(cfg, True)
Beispiel #8
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer(
            "pixel_mean",
            torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.backbone = build_backbone(cfg)

        # head
        self.heads = build_heads(cfg)

        self.has_extra_bn = cfg.MODEL.BACKBONE.EXTRA_BN
        if self.has_extra_bn:
            self.heads_extra_bn = get_norm(cfg.MODEL.BACKBONE.NORM,
                                           cfg.MODEL.BACKBONE.FEAT_DIM)
Beispiel #9
0
    def __init__(self, cfg):
        super().__init__()
        self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
        self._cfg = cfg
        # backbone
        self.backbone = build_backbone(cfg)

        # head
        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
            pool_layer = nn.AdaptiveAvgPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
            pool_layer = nn.AdaptiveMaxPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
            pool_layer = GeneralizedMeanPoolingP()
        else:
            pool_layer = nn.Identity()

        in_feat = cfg.MODEL.HEADS.IN_FEAT
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        self.heads = build_reid_heads(cfg, in_feat, num_classes, pool_layer)
Beispiel #10
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer(
            "pixel_mean",
            torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.backbone = build_backbone(cfg)

        # head
        self.heads = build_heads(cfg)

        if "CenterLoss" in cfg.MODEL.LOSSES.NAME:
            self.center_loss = CenterLoss(cfg)
            if self._cfg.MODEL.DEVICE == "cuda":
                self.center_loss = self.center_loss.cuda()
Beispiel #11
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        if cfg.META.DATA.NAMES == "":
            self.other_dataset = False
        else:
            self.other_dataset = True

        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.backbone = build_backbone(cfg) # resnet or mobilenet

        if self._cfg.MODEL.NORM.TYPE_BACKBONE == 'Task_norm':
            for module in self.backbone.modules():
                if isinstance(module, TaskNormI):
                    module.register_extra_weights()


        self.heads = build_reid_heads(cfg)
Beispiel #12
0
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        heads = build_heads(cfg)
        return {
            'backbone': backbone,
            'heads': heads,
            'pixel_mean': cfg.MODEL.PIXEL_MEAN,
            'pixel_std': cfg.MODEL.PIXEL_STD,
            'loss_kwargs':
                {
                    # loss name
                    'loss_names': cfg.MODEL.LOSSES.NAME,

                    # loss hyperparameters
                    'ce': {
                        'eps': cfg.MODEL.LOSSES.CE.EPSILON,
                        'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
                        'scale': cfg.MODEL.LOSSES.CE.SCALE
                    },
                    'tri': {
                        'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
                        'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
                        'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
                        'scale': cfg.MODEL.LOSSES.TRI.SCALE
                    },
                    'circle': {
                        'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
                        'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
                        'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
                    },
                    'cosface': {
                        'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
                        'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
                        'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
                    }
                }
        }
Beispiel #13
0
    def __init__(self, cfg):
        super().__init__()
        self.register_buffer(
            "pixel_mean",
            torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
        self._cfg = cfg

        # backbone
        bn_norm = cfg.MODEL.BACKBONE.NORM
        num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
        with_se = cfg.MODEL.BACKBONE.WITH_SE

        backbone = build_backbone(cfg)
        self.backbone = nn.Sequential(backbone.conv1, backbone.bn1,
                                      backbone.relu, backbone.maxpool,
                                      backbone.layer1, backbone.layer2,
                                      backbone.layer3[0])
        res_conv4 = nn.Sequential(*backbone.layer3[1:])
        res_g_conv5 = backbone.layer4

        res_p_conv5 = nn.Sequential(
            Bottleneck(1024,
                       512,
                       bn_norm,
                       num_splits,
                       False,
                       with_se,
                       downsample=nn.Sequential(
                           nn.Conv2d(1024, 2048, 1, bias=False),
                           get_norm(bn_norm, 2048, num_splits))),
            Bottleneck(2048, 512, bn_norm, num_splits, False, with_se),
            Bottleneck(2048, 512, bn_norm, num_splits, False, with_se))
        res_p_conv5.load_state_dict(backbone.layer4.state_dict())

        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
            pool_layer = nn.AdaptiveAvgPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
            pool_layer = nn.AdaptiveMaxPool2d(1)
        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
            pool_layer = GeneralizedMeanPoolingP()
        else:
            pool_layer = nn.Identity()

        # head
        in_feat = cfg.MODEL.HEADS.IN_FEAT
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        # branch1
        self.b1 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_g_conv5))
        self.b1_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)

        self.b1_head = build_reid_heads(cfg, in_feat, num_classes,
                                        nn.Identity())

        # branch2
        self.b2 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
        self.b2_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b2_head = build_reid_heads(cfg, in_feat, num_classes,
                                        nn.Identity())

        self.b21_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b21_head = build_reid_heads(cfg, in_feat, num_classes,
                                         nn.Identity())

        self.b22_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b22_head = build_reid_heads(cfg, in_feat, num_classes,
                                         nn.Identity())

        # branch3
        self.b3 = nn.Sequential(copy.deepcopy(res_conv4),
                                copy.deepcopy(res_p_conv5))
        self.b3_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b3_head = build_reid_heads(cfg, in_feat, num_classes,
                                        nn.Identity())

        self.b31_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b31_head = build_reid_heads(cfg, in_feat, num_classes,
                                         nn.Identity())

        self.b32_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b32_head = build_reid_heads(cfg, in_feat, num_classes,
                                         nn.Identity())

        self.b33_pool = self._build_pool_reduce(pool_layer, reduce_dim=in_feat)
        self.b33_head = build_reid_heads(cfg, in_feat, num_classes,
                                         nn.Identity())
Beispiel #14
0
    def from_config(cls, cfg):
        bn_norm = cfg.MODEL.BACKBONE.NORM
        with_se = cfg.MODEL.BACKBONE.WITH_SE

        all_blocks = build_backbone(cfg)

        # backbone
        backbone = nn.Sequential(all_blocks.conv1, all_blocks.bn1,
                                 all_blocks.relu, all_blocks.maxpool,
                                 all_blocks.layer1, all_blocks.layer2,
                                 all_blocks.layer3[0])
        res_conv4 = nn.Sequential(*all_blocks.layer3[1:])
        res_g_conv5 = all_blocks.layer4

        res_p_conv5 = nn.Sequential(
            Bottleneck(1024,
                       512,
                       bn_norm,
                       False,
                       with_se,
                       downsample=nn.Sequential(
                           nn.Conv2d(1024, 2048, 1, bias=False),
                           get_norm(bn_norm, 2048))),
            Bottleneck(2048, 512, bn_norm, False, with_se),
            Bottleneck(2048, 512, bn_norm, False, with_se))
        res_p_conv5.load_state_dict(all_blocks.layer4.state_dict())

        # branch
        neck1 = nn.Sequential(copy.deepcopy(res_conv4),
                              copy.deepcopy(res_g_conv5))
        b1_head = build_heads(cfg)

        # branch2
        neck2 = nn.Sequential(copy.deepcopy(res_conv4),
                              copy.deepcopy(res_p_conv5))
        b2_head = build_heads(cfg)
        b21_head = build_heads(cfg)
        b22_head = build_heads(cfg)

        # branch3
        neck3 = nn.Sequential(copy.deepcopy(res_conv4),
                              copy.deepcopy(res_p_conv5))
        b3_head = build_heads(cfg)
        b31_head = build_heads(cfg)
        b32_head = build_heads(cfg)
        b33_head = build_heads(cfg)

        return {
            'backbone': backbone,
            'neck1': neck1,
            'neck2': neck2,
            'neck3': neck3,
            'b1_head': b1_head,
            'b2_head': b2_head,
            'b21_head': b21_head,
            'b22_head': b22_head,
            'b3_head': b3_head,
            'b31_head': b31_head,
            'b32_head': b32_head,
            'b33_head': b33_head,
            'pixel_mean': cfg.MODEL.PIXEL_MEAN,
            'pixel_std': cfg.MODEL.PIXEL_STD,
            'loss_kwargs': {
                # loss name
                'loss_names': cfg.MODEL.LOSSES.NAME,

                # loss hyperparameters
                'ce': {
                    'eps': cfg.MODEL.LOSSES.CE.EPSILON,
                    'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
                    'scale': cfg.MODEL.LOSSES.CE.SCALE
                },
                'tri': {
                    'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
                    'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
                    'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
                    'scale': cfg.MODEL.LOSSES.TRI.SCALE
                },
                'circle': {
                    'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
                    'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
                    'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
                },
                'cosface': {
                    'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
                    'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
                    'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
                }
            }
        }
Beispiel #15
0
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg
        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        self.register_buffer(
            "pixel_mean",
            torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
        self.register_buffer(
            "pixel_std",
            torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))

        # backbone
        self.teacher_net = build_backbone(cfg)
        self.student_net = build_backbone(cfg)
        self.D_Net = cam_Classifier(2048, 2).apply(weights_init_kaiming)
        if 'Dis_loss_cam' in self._cfg.MODEL.LOSSES.NAME:
            if "Hazy_DukeMTMC" in self._cfg.TDATASETS.NAMES:
                camid = int(8)
            elif "Hazy_Market1501" in self._cfg.TDATASETS.NAMES:
                camid = int(6)
            self.D_Net = CamClassifier(2048, camid)
        elif 'Dis_loss' in self._cfg.MODEL.LOSSES.NAME:
            if self._cfg.MODEL.PARAM.Dis_net == "cam_Classifier":
                self.D_Net = cam_Classifier(2048,
                                            2).apply(weights_init_kaiming)
            elif self._cfg.MODEL.PARAM.Dis_net == "cam_Classifier_1024":
                self.D_Net = cam_Classifier_1024(2048,
                                                 2).apply(weights_init_kaiming)
            elif self._cfg.MODEL.PARAM.Dis_net == "cam_Classifier_1024_nobias":
                self.D_Net = cam_Classifier_1024_nobias(
                    2048, 2).apply(weights_init_kaiming)
            elif self._cfg.MODEL.PARAM.Dis_net == "cam_Classifier_fc":
                self.D_Net = cam_Classifier_fc(2048,
                                               2).apply(weights_init_kaiming)
            elif self._cfg.MODEL.PARAM.Dis_net == "cam_Classifier_fc_nobias_in_last_layer":
                self.D_Net = cam_Classifier_fc_nobias_in_last_layer(
                    2048, 2).apply(weights_init_kaiming)

        self.D_Net = self.D_Net.to(torch.device(cfg.MODEL.DEVICE))
        self.CrossEntropy_loss = nn.CrossEntropyLoss().to(
            torch.device(cfg.MODEL.DEVICE))
        self.bn = nn.BatchNorm2d(2048)
        self.bn.bias.requires_grad_(False)
        self.bn.apply(weights_init_kaiming)

        # head
        pool_type = cfg.MODEL.HEADS.POOL_LAYER
        if pool_type == 'avgpool': pool_layer = FastGlobalAvgPool2d()
        elif pool_type == 'maxpool': pool_layer = nn.AdaptiveMaxPool2d(1)
        elif pool_type == 'gempool': pool_layer = GeneralizedMeanPoolingP()
        elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
        elif pool_type == "identity": pool_layer = nn.Identity()
        else:
            raise KeyError(
                f"{pool_type} is invalid, please choose from "
                f"'avgpool', 'maxpool', 'gempool', 'avgmaxpool' and 'identity'."
            )

        in_feat = cfg.MODEL.HEADS.IN_FEAT
        num_classes = cfg.MODEL.HEADS.NUM_CLASSES
        self.teacher_heads = build_reid_heads(cfg, in_feat, num_classes,
                                              pool_layer)
        self.student_heads = build_reid_heads(cfg, in_feat, num_classes,
                                              pool_layer)