Esempio n. 1
0
    normalize,
])

end = time.time()

if dataset == 'sysu':
    # testing set
    query_img, query_label, query_cam = process_query_sysu(data_path,
                                                           mode=args.mode)
    gall_img, gall_label, gall_cam = process_gallery_sysu(data_path,
                                                          mode=args.mode,
                                                          trial=0)

elif dataset == 'regdb':
    # training set
    trainset = RegDBData(data_path, args.trial, transform=transform_train)
    # generate the idx of each person identity
    color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                    trainset.train_thermal_label)

    # testing set
    query_img, query_label = process_test_regdb(data_path,
                                                trial=args.trial,
                                                modal='visible')
    gall_img, gall_label = process_test_regdb(data_path,
                                              trial=args.trial,
                                              modal='thermal')

    gallset = TestData(gall_img,
                       gall_label,
                       transform=transform_test,
Esempio n. 2
0
File: test.py Progetto: Sun915/MCALN
            .format(cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19],
                    mAP_pool, mINP_pool))

elif dataset == 'regdb':

    for trial in range(10):
        test_trial = trial + 1
        model_path = checkpoint_path + args.resume
        #model_path = checkpoint_path + 'regdb_awg_p4_n8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial)
        if os.path.isfile(model_path):
            print('==> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(model_path)
            net.load_state_dict(checkpoint['net'])

        # training set
        trainset = RegDBData(data_path, test_trial, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label = process_test_regdb(data_path,
                                                    trial=test_trial,
                                                    modal='visible')
        gall_img, gall_label = process_test_regdb(data_path,
                                                  trial=test_trial,
                                                  modal='thermal')

        gallset = TestData(gall_img,
                           gall_label,
                           transform=transform_test,
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch Cross-Modality Training')
    parser.add_argument('--dataset',
                        default='sysu',
                        help='dataset name: regdb or sysu]')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--optim', default='sgd', type=str, help='optimizer')
    parser.add_argument(
        '--arch',
        default='resnet50',
        type=str,
        help=
        'network baseline:resnet18 or resnet50 or se_resnet50 or pcb_rpp or cbam or pcb_pyramid'
    )
    parser.add_argument(
        '--resume',
        '-r',
        default=
        'sysu_id_epoch10_baseline_gloid_debug_pyramid_attention_drop_0.0_lr_1.0e-02_dim_512_resnet50_best.t',
        type=str,
        help='resume from checkpoint')
    parser.add_argument('--test-only', action='store_true', help='test only')
    parser.add_argument('--model_path',
                        default='save_model/',
                        type=str,
                        help='model save path')
    parser.add_argument('--save_epoch',
                        default=20,
                        type=int,
                        metavar='s',
                        help='save model every 10 epochs')
    parser.add_argument('--log_path',
                        default='log/',
                        type=str,
                        help='log save path')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--low-dim',
                        default=512,
                        type=int,
                        metavar='D',
                        help='feature dimension')
    parser.add_argument('--img_w',
                        default=144,
                        type=int,
                        metavar='imgw',
                        help='img width')
    parser.add_argument('--img_h',
                        default=288,
                        type=int,
                        metavar='imgh',
                        help='img height')
    parser.add_argument('--batch-size',
                        default=32,
                        type=int,
                        metavar='B',
                        help='training batch size')
    parser.add_argument('--test-batch',
                        default=5,
                        type=int,
                        metavar='tb',
                        help='testing batch size')
    parser.add_argument(
        '--method',
        default='id',
        type=str,
        metavar='m',
        help=
        'method type:id or triplet or sphere or id_triplet or id_triplet_center or id_inter_loss'
    )
    parser.add_argument('--drop',
                        default=0.0,
                        type=float,
                        metavar='drop',
                        help='dropout ratio')
    parser.add_argument('--trial',
                        default=1,
                        type=int,
                        metavar='t',
                        help='trial (only for RegDB dataset)')
    parser.add_argument('--gpu',
                        default='0,2',
                        type=str,
                        help='gpu device ids for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--mode',
                        default='all',
                        type=str,
                        help='all or indoor')
    ##add by zc
    parser.add_argument('--with_se',
                        default=False,
                        type=bool,
                        help='whether add SEModule to the last conv layer')
    parser.add_argument('--with_rpp',
                        default=False,
                        type=bool,
                        help='whether add rpp with PCB to the last conv layer')
    parser.add_argument('--use_cbam',
                        default=False,
                        type=bool,
                        help='whether add CBAM to the last conv layer')
    parser.add_argument('--reduction',
                        default=16,
                        type=int,
                        help='SEModule reduction ratio')
    parser.add_argument('--with_labelsmooth',
                        default=False,
                        type=bool,
                        help='whether add label smooth to loss function')
    parser.add_argument('--with_model_neck',
                        default='no',
                        type=str,
                        help='whether add bnneck to loss function')
    # Balanced weight of center loss
    Center_weight_loss = 0.01
    ##end with zc
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '2'  # args.gpu
    np.random.seed(0)

    dataset = args.dataset
    if dataset == 'sysu':
        data_path = "/home/zhangc/projects/Datasets/test2/"
        log_path = args.log_path + 'sysu_log/'
        test_mode = [1, 2]  # thermal to visible
    elif dataset == 'regdb':
        data_path = "/home/zhangc/projects/Datasets/RegDB/"
        log_path = args.log_path + 'regdb_log/'
        test_mode = [2, 1]  # visible to thermal

    suffix_id = '_DEBUG_'
    lamda_ap = 0
    lamda_cen = 1

    checkpoint_path = args.model_path

    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    if args.method == 'id':
        suffix = dataset + '_id_epoch10_' + suffix_id + '_debug_pyramid_attention'  # wr0_serial_train_newepoch
    elif args.method == 'sphere':
        suffix = dataset + '_sphere'
    elif args.method == 'triplet':
        suffix = dataset + '_triplet_debug_sampler2_epoch20'
    elif args.method == 'id_triplet':
        suffix = dataset + '_id_triplet_wr'
    elif args.method == 'id_triplet_center':
        suffix = dataset + '_id_triplet_center'
    elif args.method == 'id_inter_loss':
        suffix = dataset + '_inter_loss_id_apdist_1_1_gloid_pyramid_attention_sampler2'  # debug_cbamp_debug_cbamp_
    if args.with_labelsmooth:
        suffix = suffix + '_ls'

    # suffix = suffix + '_cenloss_{}'.format(Center_weight_loss)
    suffix = suffix + '_drop_{}'.format(args.drop)
    suffix = suffix + '_lr_{:1.1e}'.format(args.lr)
    suffix = suffix + '_dim_{}'.format(args.low_dim)
    if not args.optim == 'sgd':
        suffix = suffix + '_' + args.optim
    suffix = suffix + '_' + args.arch
    if args.with_rpp:
        suffix = suffix + '_' + 'use_rpp'

    if args.with_se:
        suffix = suffix + '_' + 'use_senet'

    if args.use_cbam:
        suffix = suffix + '_' + 'use_cbam'

    if dataset == 'regdb':
        suffix = suffix + '_trial_{}'.format(args.trial)

    test_log_file = open(log_path + suffix + '.txt', "w")
    sys.stdout = Logger(log_path + suffix + '_os.txt')
    summary_writer = SummaryWriter(
        os.path.join(
            log_path,
            'tensorboard_log_' + suffix_id + '_debug_pyramid_attention'))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0
    feature_dim = args.low_dim

    if args.arch == 'pcb_pyramid':
        num_part = 10
        suffix = suffix + '_' + 'pcb_pyramid'
    # elif args.arch=='pcb_cro':
    #     num_part=11
    else:
        num_part = 6  # pcb_rpp method
    print('==> Loading data..')
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Pad(10),
        transforms.RandomCrop((args.img_h, args.img_w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((args.img_h, args.img_w)),
        transforms.ToTensor(),
        normalize,
    ])

    end = time.time()
    if dataset == 'sysu':
        # training set
        # trainset = SYSUData_tri(data_path, transform=transform_train)
        # trainset = SYSUData(data_path, transform=transform_train)
        # generate the idx of each person identity
        # color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label)

        # testing set
        query_img, query_label, query_cam = process_query_sysu(data_path,
                                                               mode=args.mode)
        gall_img, gall_label, gall_cam = process_gallery_sysu(data_path,
                                                              mode=args.mode,
                                                              trial=0)

    elif dataset == 'regdb':
        # training set
        trainset = RegDBData(data_path, args.trial, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label = process_test_regdb(data_path,
                                                    trial=args.trial,
                                                    modal='visible')
        gall_img, gall_label = process_test_regdb(data_path,
                                                  trial=args.trial,
                                                  modal='thermal')

    gallset = TestData(gall_img,
                       gall_label,
                       transform=transform_test,
                       img_size=(args.img_w, args.img_h))
    queryset = TestData(query_img,
                        query_label,
                        transform=transform_test,
                        img_size=(args.img_w, args.img_h))

    # testing data loader
    gall_loader = data.DataLoader(gallset,
                                  batch_size=args.test_batch,
                                  shuffle=False,
                                  num_workers=args.workers)
    query_loader = data.DataLoader(queryset,
                                   batch_size=args.test_batch,
                                   shuffle=False,
                                   num_workers=args.workers)

    n_class = 395  #len(np.unique(trainset.train_color_label))
    nquery = len(query_label)
    ngall = len(gall_label)

    print('Dataset {} statistics:'.format(dataset))
    print('  ------------------------------')
    print('  subset   | # ids | # images')
    print('  ------------------------------')
    # print('  visible  | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label)))
    # print('  thermal  | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label)))
    print('  ------------------------------')
    print('  query    | {:5d} | {:8d}'.format(len(np.unique(query_label)),
                                              nquery))
    print('  gallery  | {:5d} | {:8d}'.format(len(np.unique(gall_label)),
                                              ngall))
    print('  ------------------------------')
    print('Data Loading Time:\t {:.3f}'.format(time.time() - end))

    print('==> Building model..')
    net = embed_net(args.low_dim,
                    n_class,
                    drop=args.drop,
                    arch=args.arch,
                    neck=args.with_model_neck,
                    with_se=args.with_se,
                    with_rpp=args.with_rpp,
                    use_cbam=args.use_cbam,
                    reduction=args.reduction)
    net.to(device)
    cudnn.benchmark = True

    if len(args.resume) > 0:
        model_path = checkpoint_path + args.resume
        if os.path.isfile(model_path):
            print('==> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(model_path)
            # start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['net'])
            print('==> loaded checkpoint {} (epoch {})'.format(
                args.resume, checkpoint['epoch']))
            start_epoch = int(checkpoint['epoch'])
        else:
            print('==> no checkpoint found at {}'.format(args.resume))

    if args.method == 'id':
        if args.with_labelsmooth:
            criterion = CrossEntropyLabelSmooth(n_class)
            print("---------using ID loss with label smoothing-------")
        else:
            criterion = nn.CrossEntropyLoss()
            # criterion = FocalLoss(gamma=2)
            print("---------using ID loss only-------")
            triplet_criterion = TripletLoss(margin=0.3)
            triplet_criterion.to(device)
            center_criterion = CenterLoss(num_classes=n_class,
                                          feat_dim=args.low_dim,
                                          size_average=True).to(device)
            center_criterion_cro = CenterLoss_cro(num_classes=n_class,
                                                  feat_dim=args.low_dim,
                                                  size_average=True).to(device)
        criterion.to(device)

    elif args.method == 'sphere':
        sphere_criterion = OhemSphereLoss(args.low_dim, n_class)
        sphere_criterion.to(device)
        print("---------using sphere loss -------")

    elif args.method == 'triplet':
        triplet_criterion = TripletLoss(margin=0.3)
        triplet_criterion.to(device)
        print("---------using triplet loss-------")

    ###########################
    ignored_params = list(map(id, net.feature.parameters())) + list(
        map(id, net.classifier.parameters())
    )  # + list(map(id, net.pcb_classifier.parameters()))\
    # + list(map(id, net.visible_net.visible.avgpool.parameters()))+ list(map(id, net.thermal_net.thermal.avgpool.parameters()))
    base_params = filter(lambda p: id(p) not in ignored_params,
                         net.parameters())
    if args.optim == 'sgd':

        optimizer = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)

    elif args.optim == 'adam':
        optimizer = optim.Adam([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                               weight_decay=5e-4)

    ####################################################################################################
    def pcb_train(net):
        ignored_params = list(
            map(id, net.classifier.parameters())
        )  # +list(map(id, net.feature.parameters())) #+ list(map(id, net.attention.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.arch == 'pcb_rpp' or 'pcb_pyramid':
            optimizer_pcb = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_pcb

    def pcb_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.glo_classifier.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_pcb_cbam = optim.SGD(
            [{
                'params': base_params,
                'lr': args.lr
            }, {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.glo_classifier.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.classifier.parameters(),
                'lr': args.lr * 10
            }],
            weight_decay=5e-4,
            momentum=0.9,
            nesterov=True)

        return optimizer_pcb_cbam

    def id_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_id_cbam = optim.SGD([
            {
                'params': base_params,
                'lr': args.lr
            },
            {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            },
            {
                'params': net.feature.parameters(),
                'lr': args.lr
            },
            {
                'params': net.classifier.parameters(),
                'lr': args.lr
            },
            {
                'params': center_criterion.parameters(),
                'lr': args.lr * 10
            },
        ],
                                      weight_decay=5e-4,
                                      momentum=0.9,
                                      nesterov=True)

        return optimizer_id_cbam

    def inter_loss_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.glo_classifier.parameters())
        )  # +list(map(id, net.feature.parameters()))+list(map(id, net.glo_feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.method == 'id_inter_loss':
            optimizer_inter = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.glo_classifier.parameters(),
                        'lr': args.lr * 10
                    },
                    # {'params': center_criterion.parameters(), 'lr': args.lr*10},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_inter

    def full_train(net):

        ignored_params = list(map(id, net.feature.parameters())) + list(map(id, net.pcb_classifier.parameters())) \
                         + list(map(id, net.avgpool.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        optimizer_full = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.pcb_classifier.parameters(),
            'lr': args.lr
        }, {
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                   weight_decay=5e-4,
                                   momentum=0.9,
                                   nesterov=True)
        return optimizer_full

    def rpp_train(net):
        optimizer_rpp = optim.SGD([{
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                  weight_decay=5e-4,
                                  momentum=0.9,
                                  nesterov=True)

        return optimizer_rpp

    #################################################################################################
    features_blobs_vis = []
    features_blobs_ther = []

    def hook_feature_vis(module, input, output):
        features_blobs_vis.append(output.data.cpu().numpy())

    def hook_feature_ther(module, input, output):
        features_blobs_ther.append(output.data.cpu().numpy())

    def attention_map_show(net, img, feature, label, batch_idx):
        # grad_cam_v = GradCam(model=net, target_layer_names=["layer4"], use_cuda=True)
        # ther_net=net.thermal_net.thermal
        # vis_net = net.visible_net.visible
        # vis_net.eval()
        # ther_net.eval()
        net._modules.get("visible_net").register_forward_hook(hook_feature_vis)
        net._modules.get("thermal_net").register_forward_hook(
            hook_feature_ther)

        # vis_params = list(vis_net.parameters())
        # vis_weight_softmax = np.squeeze(vis_params[-2].cpu().data.numpy())
        # ther_params = list(ther_net.parameters())
        # ther_weight_softmax = np.squeeze(ther_params[-2].cpu().data.numpy())

        parm = {}
        for name, parameters in net.named_parameters():
            # print(name, ':', parameters.size())
            parm[name] = parameters.cpu().detach().numpy()

        params = list(net.parameters())
        # vis_weight_softmax = np.squeeze(vis_params[-2].cpu().data.numpy())
        weight_softmax = np.matmul(parm['classifier.classifier.1.weight'],
                                   parm['feature.feat_block.0.weight'])
        #-7 -15
        label_list = [6, 130, 282]
        label_t = label.data.cpu().numpy()
        # for i,tmp_label in label_t:
        #     if tmp_label in label_list:

        vis_CAMs = returncCAM(feature, weight_softmax, label[0].item())
        show_img = np.transpose(img.tolist(), [1, 2, 0])  #1,2,0
        # show_img=img[:,:,[1,2,0]]
        plt.imshow(show_img)
        plt.axis('off')
        # plt.savefig('./image/attention/exp_IRori_' + str(label) + '_' + str(batch_idx) + '.eps')
        # plt.savefig('./image/attention/exp_IRori_' + str(label) + '_' + str(batch_idx) + '.png')
        plt.show()
        plt.figure()
        target_index = None
        # mask_v = grad_cam_v(torch.unsqueeze(img,0), target_index)
        # mask_t = net.grad_cam_t(img2, target_index)
        show_cam(show_img, vis_CAMs, label, batch_idx)

    def returncCAM(feature, weight, class_idx):
        size_upsample = (288, 144)
        nc, h, w = feature.shape
        output_cam = []
        tmpw = weight[class_idx]
        print('label:', class_idx)
        tmpf = feature.reshape((nc, h * w)).data.cpu().numpy()
        cam = tmpw.dot(tmpf)
        cam = cam.reshape(h, w)
        cam_img = (cam - cam.min()) / (cam.max() - cam.min())  # normalize
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(cv2.resize(cam_img, size_upsample))

        return output_cam

    def show_cam(img, CAMs, label, batch_idx):
        height, width, _ = img.shape
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)),
                                    cv2.COLORMAP_JET)
        # heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        img = np.float32(img)

        cam = heatmap * 0.9 + img * 0.4
        # cam = cam / np.max(cam)
        # cv2.imshow("output", img)
        b, g, r = cv2.split(img)
        img2 = cv2.merge([r, g, b])
        img3 = cv2.merge([g, b, r])
        # plt.subplot(131);
        # plt.imshow(img)
        # plt.axis('off')
        # plt.subplot(132);
        # plt.imshow(heatmap)
        # plt.axis('off')
        # plt.subplot(133);
        # plt.imshow(cam)
        # plt.axis('off')
        # plt.show()

        plt.imshow(cam)
        plt.axis('off')
        plt.savefig('./image/attention/baseline_RGB_' + str(label.numpy()) +
                    '_' + str(batch_idx) + '.eps')
        plt.savefig('./image/attention/baseline_RGB_' + str(label.numpy()) +
                    '_' + str(batch_idx) + '.png')
        plt.show()

        # cv2.imwrite("img" + ".jpg", np.uint8(img2))
        # cv2.imwrite("heatmap" + ".jpg", np.uint8(heatmap))
        # cv2.imwrite("cam" + ".jpg", cam)

    ##############################################################################################

    def test(net, epoch):
        # switch to evaluation mode
        net.eval()
        print('Extracting Gallery Feature...')
        start = time.time()
        ptr = 0
        gall_feat = np.zeros((ngall, args.low_dim))  # args.low_dim
        # label_g = []
        #
        with torch.no_grad():
            for batch_idx, (input, label) in enumerate(gall_loader):
                batch_num = input.size(0)
                img = input[0]

                input = Variable(input.cuda())
                feat_pool, feat = net(input, input, test_mode[0])
                attention_map_show(net, img, feat_pool[0], label, batch_idx)
                # label_g.append(label)
                gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
                ptr = ptr + batch_num
        print('Extracting Time:\t {:.3f}'.format(time.time() - start))
        # # # #
        # # switch to evaluation mode
        # net.eval()
        # print('Extracting Query Feature...')
        # start = time.time()
        # ptr = 0
        #
        # query_feat = np.zeros((nquery, args.low_dim))
        # #
        # params = list(net.parameters())
        # with torch.no_grad():
        #     for batch_idx, (input, label) in enumerate(query_loader):
        #         batch_num = input.size(0)
        #         input = Variable(input.cuda())
        #         img = input[0]
        #         feat_pool, feat = net(input, input, test_mode[1])
        #
        #         attention_map_show(net, img, feat_pool[0], label,batch_idx)
        #         query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
        #         ptr = ptr + batch_num
        # print('Extracting Time:\t {:.3f}'.format(time.time() - start))
        #
        img = input[0]

        start = time.time()

        # compute the similarity
        # distmat = np.matmul(query_feat, np.transpose(gall_feat))

        # evaluation
        # if dataset == 'regdb':
        #     cmc, mAP = eval_regdb(-distmat, query_label, gall_label)
        # elif dataset == 'sysu':
        #     cmc, mAP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam,max_rank = 50)
        # print('Evaluation Time:\t {:.3f}'.format(time.time() - start))

        # uni_que = np.unique(query_label)
        # n_q, n_g = 0, 0
        # output = []
        # labels = []
        # out_q = []
        # out_g = []
        # label_q = []
        # label_g = []
        # num_label = 30
        # for i in range(num_label):
        #     label_id = uni_que[i]
        #     index_q = [i for i, a in enumerate(query_label) if a == label_id]
        #     index_g = [i for i, a in enumerate(gall_label) if a == label_id]
        #     out_q.extend(query_feat[index_q, :])
        #     out_g.extend(gall_feat[index_g, :])
        #     # tmp_q = query_label[index_q]
        #     # tmp_g = gall_label[index_g]
        #     label_q.extend(query_label[index_q])
        #     label_g.extend(gall_label[index_g])
        #     # n_q += len(tmp_q)
        #     # n_g += len(tmp_g)
        #     # labels.extend(np.concatenate((tmp_q, tmp_g), 0))
        # output = np.concatenate((out_q, out_g), 0)
        # labels = np.concatenate((label_q, label_g), 0)
        # n_q = len(out_q)
        # n_g = len(out_g)
        # figure = draw_cluster(output, labels, n_q, n_g, num_label)
        # #
        # suffix_id = 'DEBUG_Baseline+glo_id+center_'+str(num_label)
        # plt.savefig(
        #     '/home/zhangc/projects/cross_modality_V/BDTR_modal_loss_cbam_debug/image/cluster/' + '_train_' + suffix_id + '.eps')
        # plt.show()
        # plt.pause(1)
        # plt.close()
        # return cmc, mAP

    #######################################################################################################
    # training

    # def train_model(net, optimizer, final_epoch, flag='normal'):
    #     # training
    #     print('==> Start Training...')
    #     best_acc = 0
    #     for epoch in range(start_epoch, final_epoch - start_epoch + 1):
    #
    #         # train(net, epoch, optimizer, flag,ini_id,ini_modal)
    #
    #         if epoch % 2 == 0:  # epoch > 0 and
    #             print('Test Epoch: {}'.format(epoch))
    #             print('Test Epoch: {}'.format(epoch), file=test_log_file)
    #             # testing
    #             cmc, mAP = test(net, epoch)
    #             print('cmc:',cmc)
    #             print('FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'.format(
    #                 cmc[0], cmc[4], cmc[9], cmc[19], mAP))
    #             print('FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'.format(
    #                 cmc[0], cmc[4], cmc[9], cmc[19], mAP), file=test_log_file)
    #             test_log_file.flush()
    #
    #             # save model
    #             if cmc[0] > best_acc:  # not the real best for sysu-mm01
    #                 best_acc = cmc[0]
    #                 state = {
    #                     'net': net.state_dict(),
    #                     'cmc': cmc,
    #                     'mAP': mAP,
    #                     'epoch': epoch,
    #                 }
    #                 torch.save(state, checkpoint_path + suffix + '_best.t')
    #
    #             # save model every 20 epochs
    #             if epoch > 10 and epoch % args.save_epoch == 0:
    #                 state = {
    #                     'net': net.state_dict(),
    #                     'cmc': cmc,
    #                     'mAP': mAP,
    #                     'epoch': epoch,
    #                 }
    #                 torch.save(state, checkpoint_path + suffix + '_epoch_{}.t'.format(epoch))
    #
    #         # torch.save(state, checkpoint_path + suffix + '_latest.t')
    #     return net

    def train_model(net, optimizer, final_epoch, flag='normal'):
        print('==> Start Training...')
        best_acc = 0
        for epoch in range(start_epoch, final_epoch - start_epoch + 1):
            test(net, epoch)
        return net

    #########################################################################

    ini_id = 10
    ini_modal = 10
    #######################################################################################################

    if args.arch == 'pcb_rpp' or args.arch == 'pcb_pyramid':
        # print('-------epoch for pcb_traing--------')
        # optimizer_pcb = pcb_train(net)
        # model=train_model(net, optimizer_pcb, 60, flag='pcb')

        if args.with_rpp:
            print('-------epoch for rpp_traing--------')
            model = model.convert_to_rpp().cuda()
            optimizer_rpp = rpp_train(model)
            model = train_model(model, optimizer_rpp, 40, flag='rpp')

            print('-------epoch for full_traing--------')
            optimizer_full = full_train(model)
            train_model(model, optimizer_full, 80, flag='full')
        elif args.method == 'id_inter_loss':

            if args.use_cbam:
                print('-------epoch for pcb_cbam training--------')
                optimizer_pcb_cbam = pcb_cbam_train(net)
                train_model(net, optimizer_pcb_cbam, 90, flag='pcb_cbam')
            else:
                print('-------epoch for id_inter_loss_training--------')
                optimizer_inter = inter_loss_train(net)
                train_model(net, optimizer_inter, 100, flag='id_inter_loss')
            # optimizer_pcb = pcb_train(net)
            # train_model(net, optimizer_pcb, 60, flag='pcb')

        else:

            print('-------epoch for pcb_traing--------')
            optimizer_pcb = pcb_train(net)
            train_model(net, optimizer_pcb, 90, flag='pcb')

    else:
        if args.use_cbam:
            print('-------epoch for resnet_cbam training--------')
            optimizer_id_cbam = id_cbam_train(net)
            train_model(net, optimizer_id_cbam, 90, flag='pcb_cbam')
        else:
            train_model(net, optimizer, 500, flag='normal')
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch Cross-Modality Training')
    parser.add_argument('--dataset',
                        default='sysu',
                        help='dataset name: regdb or sysu]')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--optim', default='sgd', type=str, help='optimizer')
    parser.add_argument(
        '--arch',
        default='resnet50',
        type=str,
        help=
        'network baseline:resnet18 or resnet50 or se_resnet50 or pcb_rpp or cbam or pcb_pyramid'
    )
    parser.add_argument(
        '--resume',
        '-r',
        default=
        'sysu_id_epoch10_PPAM_gloid_expcenter_1m0000__debug_pyramid_attention_drop_0.0_lr_1.0e-02_dim_512_resnet50_use_cbam_best.t',
        help='resume from checkpoint')
    parser.add_argument('--test-only', action='store_true', help='test only')
    parser.add_argument('--model_path',
                        default='save_model/',
                        type=str,
                        help='model save path')
    parser.add_argument('--save_epoch',
                        default=20,
                        type=int,
                        metavar='s',
                        help='save model every 10 epochs')
    parser.add_argument('--log_path',
                        default='log/',
                        type=str,
                        help='log save path')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--low-dim',
                        default=512,
                        type=int,
                        metavar='D',
                        help='feature dimension')
    parser.add_argument('--img_w',
                        default=144,
                        type=int,
                        metavar='imgw',
                        help='img width')
    parser.add_argument('--img_h',
                        default=288,
                        type=int,
                        metavar='imgh',
                        help='img height')
    parser.add_argument('--batch-size',
                        default=32,
                        type=int,
                        metavar='B',
                        help='training batch size')
    parser.add_argument('--test-batch',
                        default=1,
                        type=int,
                        metavar='tb',
                        help='testing batch size')
    parser.add_argument(
        '--method',
        default='id',
        type=str,
        metavar='m',
        help=
        'method type:id or triplet or sphere or id_triplet or id_triplet_center or id_inter_loss'
    )
    parser.add_argument('--drop',
                        default=0.0,
                        type=float,
                        metavar='drop',
                        help='dropout ratio')
    parser.add_argument('--trial',
                        default=1,
                        type=int,
                        metavar='t',
                        help='trial (only for RegDB dataset)')
    parser.add_argument('--gpu',
                        default='0,2',
                        type=str,
                        help='gpu device ids for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--mode',
                        default='all',
                        type=str,
                        help='all or indoor')
    ##add by zc
    parser.add_argument('--with_se',
                        default=False,
                        type=bool,
                        help='whether add SEModule to the last conv layer')
    parser.add_argument('--with_rpp',
                        default=False,
                        type=bool,
                        help='whether add rpp with PCB to the last conv layer')
    parser.add_argument('--use_cbam',
                        default=True,
                        type=bool,
                        help='whether add CBAM to the last conv layer')
    parser.add_argument('--reduction',
                        default=16,
                        type=int,
                        help='SEModule reduction ratio')
    parser.add_argument('--with_labelsmooth',
                        default=False,
                        type=bool,
                        help='whether add label smooth to loss function')
    parser.add_argument('--with_model_neck',
                        default='no',
                        type=str,
                        help='whether add bnneck to loss function')
    # Balanced weight of center loss
    Center_weight_loss = 0.01
    ##end with zc
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = '2'  # args.gpu
    np.random.seed(0)

    dataset = args.dataset
    if dataset == 'sysu':
        data_path = "/home/zhangc/projects/Datasets/SYSU-MM01/"  #"/home/zhangc/projects/Datasets/test/"
        log_path = args.log_path + 'sysu_log/'
        test_mode = [1, 2]  # thermal to visible
    elif dataset == 'regdb':
        data_path = "/home/zhangc/projects/Datasets/RegDB/"
        log_path = args.log_path + 'regdb_log/'
        test_mode = [2, 1]  # visible to thermal

    suffix_id = '_DEBUG_'
    lamda_ap = 0
    lamda_cen = 1

    checkpoint_path = args.model_path

    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    if args.method == 'id':
        suffix = dataset + '_id_epoch10_' + suffix_id + '_debug_pyramid_attention'  # wr0_serial_train_newepoch
    elif args.method == 'sphere':
        suffix = dataset + '_sphere'
    elif args.method == 'triplet':
        suffix = dataset + '_triplet_debug_sampler2_epoch20'
    elif args.method == 'id_triplet':
        suffix = dataset + '_id_triplet_wr'
    elif args.method == 'id_triplet_center':
        suffix = dataset + '_id_triplet_center'
    elif args.method == 'id_inter_loss':
        suffix = dataset + '_inter_loss_id_apdist_1_1_gloid_pyramid_attention_sampler2'  # debug_cbamp_debug_cbamp_
    if args.with_labelsmooth:
        suffix = suffix + '_ls'

    # suffix = suffix + '_cenloss_{}'.format(Center_weight_loss)
    suffix = suffix + '_drop_{}'.format(args.drop)
    suffix = suffix + '_lr_{:1.1e}'.format(args.lr)
    suffix = suffix + '_dim_{}'.format(args.low_dim)
    if not args.optim == 'sgd':
        suffix = suffix + '_' + args.optim
    suffix = suffix + '_' + args.arch
    if args.with_rpp:
        suffix = suffix + '_' + 'use_rpp'

    if args.with_se:
        suffix = suffix + '_' + 'use_senet'

    if args.use_cbam:
        suffix = suffix + '_' + 'use_cbam'

    if dataset == 'regdb':
        suffix = suffix + '_trial_{}'.format(args.trial)

    test_log_file = open(log_path + '.txt', "w")
    cmc_log_file = open('all' + '_cmc.txt', "w")
    sys.stdout = Logger(log_path + suffix + '_os.txt')
    summary_writer = SummaryWriter(
        os.path.join(
            log_path,
            'tensorboard_log_' + suffix_id + '_debug_pyramid_attention'))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0
    feature_dim = args.low_dim

    if args.arch == 'pcb_pyramid':
        num_part = 10
        suffix = suffix + '_' + 'pcb_pyramid'
    # elif args.arch=='pcb_cro':
    #     num_part=11
    else:
        num_part = 6  # pcb_rpp method
    print('==> Loading data..')
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Pad(10),
        transforms.RandomCrop((args.img_h, args.img_w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((args.img_h, args.img_w)),
        transforms.ToTensor(),
        normalize,
    ])

    end = time.time()
    if dataset == 'sysu':
        # training set
        # trainset = SYSUData_tri(data_path, transform=transform_train)
        trainset = SYSUData(data_path, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label, query_cam = process_query_sysu(data_path,
                                                               mode=args.mode)
        gall_img, gall_label, gall_cam = process_gallery_sysu(data_path,
                                                              mode=args.mode,
                                                              trial=0)

    elif dataset == 'regdb':
        # training set
        trainset = RegDBData(data_path, args.trial, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label = process_test_regdb(data_path,
                                                    trial=args.trial,
                                                    modal='visible')
        gall_img, gall_label = process_test_regdb(data_path,
                                                  trial=args.trial,
                                                  modal='thermal')

    gallset = TestData_test(gall_img,
                            gall_label,
                            transform=transform_test,
                            img_size=(args.img_w, args.img_h))
    queryset = TestData_test(query_img,
                             query_label,
                             transform=transform_test,
                             img_size=(args.img_w, args.img_h))

    # testing data loader
    gall_loader = data.DataLoader(gallset,
                                  batch_size=args.test_batch,
                                  shuffle=False,
                                  num_workers=args.workers)
    query_loader = data.DataLoader(queryset,
                                   batch_size=args.test_batch,
                                   shuffle=False,
                                   num_workers=args.workers)

    n_class = 395  #len(np.unique(trainset.train_color_label))
    nquery = len(query_label)
    ngall = len(gall_label)

    print('Dataset {} statistics:'.format(dataset))
    print('  ------------------------------')
    print('  subset   | # ids | # images')
    print('  ------------------------------')
    print('  visible  | {:5d} | {:8d}'.format(n_class,
                                              len(trainset.train_color_label)))
    print('  thermal  | {:5d} | {:8d}'.format(
        n_class, len(trainset.train_thermal_label)))
    print('  ------------------------------')
    print('  query    | {:5d} | {:8d}'.format(len(np.unique(query_label)),
                                              nquery))
    print('  gallery  | {:5d} | {:8d}'.format(len(np.unique(gall_label)),
                                              ngall))
    print('  ------------------------------')
    print('Data Loading Time:\t {:.3f}'.format(time.time() - end))

    print('==> Building model..')
    net = embed_net(args.low_dim,
                    n_class,
                    drop=args.drop,
                    arch=args.arch,
                    neck=args.with_model_neck,
                    with_se=args.with_se,
                    with_rpp=args.with_rpp,
                    use_cbam=args.use_cbam,
                    reduction=args.reduction)
    net.to(device)
    cudnn.benchmark = True

    if len(args.resume) > 0:
        model_path = checkpoint_path + args.resume
        if os.path.isfile(model_path):
            print('==> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(model_path)
            # start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['net'])
            print('==> loaded checkpoint {} (epoch {})'.format(
                args.resume, checkpoint['epoch']))
            start_epoch = int(checkpoint['epoch'])
        else:
            print('==> no checkpoint found at {}'.format(args.resume))

    if args.method == 'id':
        if args.with_labelsmooth:
            criterion = CrossEntropyLabelSmooth(n_class)
            print("---------using ID loss with label smoothing-------")
        else:
            criterion = nn.CrossEntropyLoss()
            # criterion = FocalLoss(gamma=2)
            print("---------using ID loss only-------")
            triplet_criterion = TripletLoss(margin=0.3)
            triplet_criterion.to(device)
            center_criterion = CenterLoss(num_classes=n_class,
                                          feat_dim=args.low_dim,
                                          size_average=True).to(device)
            center_criterion_cro = CenterLoss_cro(num_classes=n_class,
                                                  feat_dim=args.low_dim,
                                                  size_average=True).to(device)
        criterion.to(device)

    elif args.method == 'sphere':
        sphere_criterion = OhemSphereLoss(args.low_dim, n_class)
        sphere_criterion.to(device)
        print("---------using sphere loss -------")

    elif args.method == 'triplet':
        triplet_criterion = TripletLoss(margin=0.3)
        triplet_criterion.to(device)
        print("---------using triplet loss-------")

    ###########################
    ignored_params = list(map(id, net.feature.parameters())) + list(
        map(id, net.classifier.parameters())
    )  # + list(map(id, net.pcb_classifier.parameters()))\
    # + list(map(id, net.visible_net.visible.avgpool.parameters()))+ list(map(id, net.thermal_net.thermal.avgpool.parameters()))
    base_params = filter(lambda p: id(p) not in ignored_params,
                         net.parameters())
    if args.optim == 'sgd':

        optimizer = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)

    elif args.optim == 'adam':
        optimizer = optim.Adam([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                               weight_decay=5e-4)

    ####################################################################################################
    def pcb_train(net):
        ignored_params = list(
            map(id, net.classifier.parameters())
        )  # +list(map(id, net.feature.parameters())) #+ list(map(id, net.attention.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.arch == 'pcb_rpp' or 'pcb_pyramid':
            optimizer_pcb = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_pcb

    def pcb_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.glo_classifier.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_pcb_cbam = optim.SGD(
            [{
                'params': base_params,
                'lr': args.lr
            }, {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.glo_classifier.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.classifier.parameters(),
                'lr': args.lr * 10
            }],
            weight_decay=5e-4,
            momentum=0.9,
            nesterov=True)

        return optimizer_pcb_cbam

    def id_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_id_cbam = optim.SGD([
            {
                'params': base_params,
                'lr': args.lr
            },
            {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            },
            {
                'params': net.feature.parameters(),
                'lr': args.lr
            },
            {
                'params': net.classifier.parameters(),
                'lr': args.lr
            },
            {
                'params': center_criterion.parameters(),
                'lr': args.lr * 10
            },
        ],
                                      weight_decay=5e-4,
                                      momentum=0.9,
                                      nesterov=True)

        return optimizer_id_cbam

    def inter_loss_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.glo_classifier.parameters())
        )  # +list(map(id, net.feature.parameters()))+list(map(id, net.glo_feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.method == 'id_inter_loss':
            optimizer_inter = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.glo_classifier.parameters(),
                        'lr': args.lr * 10
                    },
                    # {'params': center_criterion.parameters(), 'lr': args.lr*10},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_inter

    def full_train(net):

        ignored_params = list(map(id, net.feature.parameters())) + list(map(id, net.pcb_classifier.parameters())) \
                         + list(map(id, net.avgpool.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        optimizer_full = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.pcb_classifier.parameters(),
            'lr': args.lr
        }, {
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                   weight_decay=5e-4,
                                   momentum=0.9,
                                   nesterov=True)
        return optimizer_full

    def rpp_train(net):
        optimizer_rpp = optim.SGD([{
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                  weight_decay=5e-4,
                                  momentum=0.9,
                                  nesterov=True)

        return optimizer_rpp

    #################################################################################################
    def attention_map_show(net, img, feature, classifier, mode):
        grad_cam_v = GradCam(model=net,
                             target_layer_names=["layer4"],
                             use_cuda=True)

        show_img = np.transpose(img.tolist(), [1, 2, 0])
        plt.imshow(show_img)
        plt.figure()
        target_index = None
        mask_v = grad_cam_v(torch.unsqueeze(img, 0), target_index)
        # mask_t = net.grad_cam_t(img2, target_index)
        show_cam_on_image(img, mask_v, mode=mode)

    ##############################################################################################

    def test(net, epoch):
        # switch to evaluation mode
        net.eval()
        print('Extracting Gallery Feature...')
        start = time.time()
        ptr = 0
        gall_feat = np.zeros((ngall, args.low_dim))  # args.low_dim
        gall_file = []
        label_g = []
        with torch.no_grad():
            for batch_idx, (input, label, img_file) in enumerate(gall_loader):
                batch_num = input.size(0)
                input = Variable(input.cuda())
                feat_pool, feat = net(input, input, test_mode[0])
                label_g.append(label)
                gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
                gall_file.append(img_file)
                ptr = ptr + batch_num
        print('Extracting Time:\t {:.3f}'.format(time.time() - start))

        # switch to evaluation mode
        net.eval()
        print('Extracting Query Feature...')
        start = time.time()
        ptr = 0
        query_feat = np.zeros((nquery, args.low_dim))
        query_file = []
        with torch.no_grad():
            for batch_idx, (input, label, img_file) in enumerate(query_loader):
                batch_num = input.size(0)
                input = Variable(input.cuda())

                feat_pool, feat = net(input, input, test_mode[1])
                query_feat[ptr:ptr +
                           batch_num, :] = feat.detach().cpu().numpy()
                query_file.append(img_file)
                ptr = ptr + batch_num
        print('Extracting Time:\t {:.3f}'.format(time.time() - start))
        img = input[0]
        # attention_map_show(net, img, feat[0], net.classifier, 'query')
        start = time.time()
        # compute the similarity
        distmat = np.matmul(query_feat, np.transpose(gall_feat))

        # evaluation
        if dataset == 'regdb':
            cmc, mAP = eval_regdb(-distmat, query_label, gall_label)
        elif dataset == 'sysu':
            # cmc, mAP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam)
            cmc, mAP = eval_sysu_test(-distmat, query_label, gall_label,
                                      query_cam, gall_cam, np.array(query_img),
                                      np.array(gall_img))
        print('Evaluation Time:\t {:.3f}'.format(time.time() - start))

        # uni_que = np.unique(query_label)
        # out_q=[]
        # out_g=[]
        # label_q=[]
        # label_g=[]
        # num_label=20
        # for i in range(num_label):
        #     label_id = uni_que[i]
        #     index_q = [i for i, a in enumerate(query_label) if a == label_id]
        #     index_g = [i for i, a in enumerate(gall_label) if a == label_id]
        #     out_q.extend(query_feat[index_q[:10], :])
        #     out_g.extend(gall_feat[index_g[:10], :])
        #     # tmp_q = query_label[index_q]
        #     # tmp_g = gall_label[index_g]
        #     label_q.extend(query_label[index_q[:10]])
        #     label_g.extend(gall_label[index_g[:10]])
        #     # n_q += len(tmp_q)
        #     # n_g += len(tmp_g)
        #     # labels.extend(np.concatenate((tmp_q, tmp_g), 0))
        # output=np.concatenate((out_q, out_g), 0)
        # labels=np.concatenate((label_q, label_g), 0)
        # n_q = len(out_q)
        # n_g = len(out_g)
        # figure = draw_cluster(output, labels, n_q,n_g,num_label)
        # #
        # suffix_id = 'DEBUG_PPAM+glo_id +expcen_'+str(num_label)
        # plt.savefig(
        #     '/home/zhangc/projects/cross_modality_V/BDTR_modal_loss_cbam_debug/image/cluster/' + '_train_' + suffix_id + '.eps')
        # plt.show()
        # plt.pause(1)
        # plt.close()
        return cmc, mAP

    #######################################################################################################
    # training

    def train_model(net, optimizer, final_epoch, flag='normal'):
        # training
        print('==> Start Training...')
        best_acc = 0
        for epoch in range(start_epoch, final_epoch - start_epoch + 1):

            # train(net, epoch, optimizer, flag,ini_id,ini_modal)

            if epoch % 2 == 0:  # epoch > 0 and
                print('Test Epoch: {}'.format(epoch))
                print('Test Epoch: {}'.format(epoch), file=test_log_file)
                # testing
                cmc, mAP = test(net, epoch)
                print(
                    'FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'
                    .format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
                print(
                    'FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'
                    .format(cmc[0], cmc[4], cmc[9], cmc[19], mAP),
                    file=test_log_file)
                print(cmc[:50], file=cmc_log_file)
                np.savetxt("./cmc_all_0_mar.txt", cmc, fmt='%f', delimiter=',')
                test_log_file.flush()
                cmc_log_file.flush()
                # save model
                if cmc[0] > best_acc:  # not the real best for sysu-mm01
                    best_acc = cmc[0]
                    state = {
                        'net': net.state_dict(),
                        'cmc': cmc,
                        'mAP': mAP,
                        'epoch': epoch,
                    }
                    torch.save(state, checkpoint_path + suffix + '_best.t')

                # save model every 20 epochs
                if epoch > 10 and epoch % args.save_epoch == 0:
                    state = {
                        'net': net.state_dict(),
                        'cmc': cmc,
                        'mAP': mAP,
                        'epoch': epoch,
                    }
                    torch.save(
                        state,
                        checkpoint_path + suffix + '_epoch_{}.t'.format(epoch))

            # torch.save(state, checkpoint_path + suffix + '_latest.t')
        return net

    #########################################################################

    ini_id = 10
    ini_modal = 10
    # def train_model(net, optimizer, final_epoch, flag='normal'):
    #     print('==> Start Training...')
    #     best_acc = 0
    #     for epoch in range(start_epoch, final_epoch - start_epoch + 1):
    #         test(net, epoch)
    #     return net
    #######################################################################################################

    if args.arch == 'pcb_rpp' or args.arch == 'pcb_pyramid':
        # print('-------epoch for pcb_traing--------')
        # optimizer_pcb = pcb_train(net)
        # model=train_model(net, optimizer_pcb, 60, flag='pcb')

        if args.with_rpp:
            print('-------epoch for rpp_traing--------')
            model = model.convert_to_rpp().cuda()
            optimizer_rpp = rpp_train(model)
            model = train_model(model, optimizer_rpp, 40, flag='rpp')

            print('-------epoch for full_traing--------')
            optimizer_full = full_train(model)
            train_model(model, optimizer_full, 80, flag='full')
        elif args.method == 'id_inter_loss':

            if args.use_cbam:
                print('-------epoch for pcb_cbam training--------')
                optimizer_pcb_cbam = pcb_cbam_train(net)
                train_model(net, optimizer_pcb_cbam, 90, flag='pcb_cbam')
            else:
                print('-------epoch for id_inter_loss_training--------')
                optimizer_inter = inter_loss_train(net)
                train_model(net, optimizer_inter, 100, flag='id_inter_loss')
            # optimizer_pcb = pcb_train(net)
            # train_model(net, optimizer_pcb, 60, flag='pcb')

        else:

            print('-------epoch for pcb_traing--------')
            optimizer_pcb = pcb_train(net)
            train_model(net, optimizer_pcb, 90, flag='pcb')

    else:
        if args.use_cbam:
            print('-------epoch for resnet_cbam training--------')
            optimizer_id_cbam = id_cbam_train(net)
            train_model(net, optimizer_id_cbam, 90, flag='pcb_cbam')
        else:
            train_model(net, optimizer, 500, flag='normal')
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch Cross-Modality Training')
    parser.add_argument('--dataset',
                        default='sysu',
                        help='dataset name: regdb or sysu]')
    parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
    parser.add_argument('--optim', default='sgd', type=str, help='optimizer')
    parser.add_argument(
        '--arch',
        default='resnet50',
        type=str,
        help=
        'network baseline:resnet18 or resnet50 or se_resnet50 or pcb_rpp or cbam or pcb_pyramid'
    )
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        type=str,
                        help='resume from checkpoint')
    parser.add_argument('--test-only', action='store_true', help='test only')
    parser.add_argument('--model_path',
                        default='save_model/',
                        type=str,
                        help='model save path')
    parser.add_argument('--save_epoch',
                        default=20,
                        type=int,
                        metavar='s',
                        help='save model every 10 epochs')
    parser.add_argument('--log_path',
                        default='log/',
                        type=str,
                        help='log save path')
    parser.add_argument('--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--low-dim',
                        default=512,
                        type=int,
                        metavar='D',
                        help='feature dimension')
    parser.add_argument('--img_w',
                        default=144,
                        type=int,
                        metavar='imgw',
                        help='img width')
    parser.add_argument('--img_h',
                        default=288,
                        type=int,
                        metavar='imgh',
                        help='img height')
    parser.add_argument('--batch-size',
                        default=32,
                        type=int,
                        metavar='B',
                        help='training batch size')
    parser.add_argument('--test-batch',
                        default=64,
                        type=int,
                        metavar='tb',
                        help='testing batch size')
    parser.add_argument(
        '--method',
        default='id',
        type=str,
        metavar='m',
        help=
        'method type:id or triplet or sphere or id_triplet or id_triplet_center or id_inter_loss'
    )
    parser.add_argument('--drop',
                        default=0.0,
                        type=float,
                        metavar='drop',
                        help='dropout ratio')
    parser.add_argument('--trial',
                        default=1,
                        type=int,
                        metavar='t',
                        help='trial (only for RegDB dataset)')
    parser.add_argument('--gpu',
                        default='0,2',
                        type=str,
                        help='gpu device ids for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--mode',
                        default='indoor',
                        type=str,
                        help='all or indoor')
    ##add by zc
    parser.add_argument('--with_se',
                        default=False,
                        type=bool,
                        help='whether add SEModule to the last conv layer')
    parser.add_argument('--with_rpp',
                        default=False,
                        type=bool,
                        help='whether add rpp with PCB to the last conv layer')
    parser.add_argument('--use_cbam',
                        default=True,
                        type=bool,
                        help='whether add CBAM to the last conv layer')
    parser.add_argument('--reduction',
                        default=16,
                        type=int,
                        help='SEModule reduction ratio')
    parser.add_argument('--with_labelsmooth',
                        default=False,
                        type=bool,
                        help='whether add label smooth to loss function')
    parser.add_argument('--with_model_neck',
                        default='no',
                        type=str,
                        help='whether add bnneck to loss function')
    # Balanced weight of center loss
    Center_weight_loss = 0.01
    ##end with zc
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # args.gpu
    np.random.seed(0)

    dataset = args.dataset
    if dataset == 'sysu':
        data_path = "/home/zhangc/projects/Datasets/SYSU-MM01/"
        log_path = args.log_path + 'sysu_log/'
        test_mode = [1, 2]  #[1,2]  # thermal to visible
    elif dataset == 'regdb':
        data_path = "/home/zhangc/projects/Datasets/RegDB/"
        log_path = args.log_path + 'regdb_log/'
        test_mode = [2, 1]  # visible to thermal

    suffix_id = 'Debug_1m1_indoor_Avg_Atte'
    lamda_ap = 0
    lamda_cen = 1

    checkpoint_path = args.model_path

    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path)

    if args.method == 'id':
        suffix = dataset + '_id_epoch10_' + suffix_id + '_debug_pyramid_attention'  #wr0_serial_train_newepoch
    elif args.method == 'sphere':
        suffix = dataset + '_sphere'
    elif args.method == 'triplet':
        suffix = dataset + '_triplet_debug_sampler2_epoch20'
    elif args.method == 'id_triplet':
        suffix = dataset + '_id_triplet_wr'
    elif args.method == 'id_triplet_center':
        suffix = dataset + '_id_triplet_center'
    elif args.method == 'id_inter_loss':
        suffix = dataset + '_inter_loss_id_apdist_1_1_gloid_pyramid_attention_sampler2'  #debug_cbamp_debug_cbamp_
    if args.with_labelsmooth:
        suffix = suffix + '_ls'

    # suffix = suffix + '_cenloss_{}'.format(Center_weight_loss)
    suffix = suffix + '_drop_{}'.format(args.drop)
    suffix = suffix + '_lr_{:1.1e}'.format(args.lr)
    suffix = suffix + '_dim_{}'.format(args.low_dim)
    if not args.optim == 'sgd':
        suffix = suffix + '_' + args.optim
    suffix = suffix + '_' + args.arch
    if args.with_rpp:
        suffix = suffix + '_' + 'use_rpp'

    if args.with_se:
        suffix = suffix + '_' + 'use_senet'

    if args.use_cbam:
        suffix = suffix + '_' + 'use_cbam'

    if dataset == 'regdb':
        suffix = suffix + '_trial_{}'.format(args.trial)

    test_log_file = open(log_path + suffix + '.txt', "w")
    sys.stdout = Logger(log_path + suffix + '_os.txt')
    summary_writer = SummaryWriter(
        os.path.join(
            log_path,
            'tensorboard_log_' + suffix_id + '_debug_pyramid_attention'))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0
    feature_dim = args.low_dim

    if args.arch == 'pcb_pyramid':
        num_part = 10
        suffix = suffix + '_' + 'pcb_pyramid'
    # elif args.arch=='pcb_cro':
    #     num_part=11
    else:
        num_part = 6  #pcb_rpp method
    print('==> Loading data..')
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Pad(10),
        transforms.RandomCrop((args.img_h, args.img_w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((args.img_h, args.img_w)),
        transforms.ToTensor(),
        normalize,
    ])

    end = time.time()
    if dataset == 'sysu':
        # training set
        # trainset = SYSUData_tri(data_path, transform=transform_train)
        trainset = SYSUData(data_path, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label, query_cam = process_query_sysu(data_path,
                                                               mode=args.mode)
        gall_img, gall_label, gall_cam = process_gallery_sysu(data_path,
                                                              mode=args.mode,
                                                              trial=0)

    elif dataset == 'regdb':
        # training set
        trainset = RegDBData(data_path, args.trial, transform=transform_train)
        # generate the idx of each person identity
        color_pos, thermal_pos = GenIdx(trainset.train_color_label,
                                        trainset.train_thermal_label)

        # testing set
        query_img, query_label = process_test_regdb(data_path,
                                                    trial=args.trial,
                                                    modal='visible')
        gall_img, gall_label = process_test_regdb(data_path,
                                                  trial=args.trial,
                                                  modal='thermal')

    gallset = TestData(gall_img,
                       gall_label,
                       transform=transform_test,
                       img_size=(args.img_w, args.img_h))
    queryset = TestData(query_img,
                        query_label,
                        transform=transform_test,
                        img_size=(args.img_w, args.img_h))

    # testing data loader
    gall_loader = data.DataLoader(gallset,
                                  batch_size=args.test_batch,
                                  shuffle=False,
                                  num_workers=args.workers)
    query_loader = data.DataLoader(queryset,
                                   batch_size=args.test_batch,
                                   shuffle=False,
                                   num_workers=args.workers)

    n_class = len(np.unique(trainset.train_color_label))
    nquery = len(query_label)
    ngall = len(gall_label)

    print('Dataset {} statistics:'.format(dataset))
    print('  ------------------------------')
    print('  subset   | # ids | # images')
    print('  ------------------------------')
    print('  visible  | {:5d} | {:8d}'.format(n_class,
                                              len(trainset.train_color_label)))
    print('  thermal  | {:5d} | {:8d}'.format(
        n_class, len(trainset.train_thermal_label)))
    print('  ------------------------------')
    print('  query    | {:5d} | {:8d}'.format(len(np.unique(query_label)),
                                              nquery))
    print('  gallery  | {:5d} | {:8d}'.format(len(np.unique(gall_label)),
                                              ngall))
    print('  ------------------------------')
    print('Data Loading Time:\t {:.3f}'.format(time.time() - end))

    print('==> Building model..')
    net = embed_net(args.low_dim,
                    n_class,
                    drop=args.drop,
                    arch=args.arch,
                    neck=args.with_model_neck,
                    with_se=args.with_se,
                    with_rpp=args.with_rpp,
                    use_cbam=args.use_cbam,
                    reduction=args.reduction)
    net.to(device)
    cudnn.benchmark = True

    if len(args.resume) > 0:
        model_path = checkpoint_path + args.resume
        if os.path.isfile(model_path):
            print('==> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(model_path)
            # start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['net'])
            print('==> loaded checkpoint {} (epoch {})'.format(
                args.resume, checkpoint['epoch']))
            start_epoch = int(checkpoint['epoch'])
        else:
            print('==> no checkpoint found at {}'.format(args.resume))

    if args.method == 'id':
        if args.with_labelsmooth:
            criterion = CrossEntropyLabelSmooth(n_class)
            print("---------using ID loss with label smoothing-------")
        else:
            criterion = nn.CrossEntropyLoss()
            # criterion = FocalLoss(gamma=2)
            print("---------using ID loss only-------")
            triplet_criterion = TripletLoss(margin=0.7)
            triplet_criterion.to(device)
            center_criterion = CenterLoss(num_classes=n_class,
                                          feat_dim=args.low_dim,
                                          size_average=True).to(device)
            center_criterion_cro = CenterLoss_cro(num_classes=n_class,
                                                  feat_dim=args.low_dim,
                                                  size_average=True).to(device)
        criterion.to(device)

    elif args.method == 'id_inter_loss':
        #oim_loss = OIMLoss(args.low_dim, n_class, 1, 0.5).cuda()
        center_criterion = CenterLoss(num_classes=n_class,
                                      feat_dim=args.low_dim,
                                      size_average=True).to(device)
###########################
    ignored_params = list(map(id, net.feature.parameters())) + list(
        map(id, net.classifier.parameters())
    )  #+ list(map(id, net.pcb_classifier.parameters()))\
    # + list(map(id, net.visible_net.visible.avgpool.parameters()))+ list(map(id, net.thermal_net.thermal.avgpool.parameters()))
    base_params = filter(lambda p: id(p) not in ignored_params,
                         net.parameters())
    if args.optim == 'sgd':

        optimizer = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)

    elif args.optim == 'adam':
        optimizer = optim.Adam([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.classifier.parameters(),
            'lr': args.lr
        }],
                               weight_decay=5e-4)

####################################################################################################

    def pcb_train(net):
        ignored_params = list(
            map(id, net.classifier.parameters())
        )  #+list(map(id, net.feature.parameters())) #+ list(map(id, net.attention.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.arch == 'pcb_rpp' or 'pcb_pyramid':
            optimizer_pcb = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_pcb

    def pcb_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.glo_classifier.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_pcb_cbam = optim.SGD(
            [{
                'params': base_params,
                'lr': args.lr
            }, {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.glo_classifier.parameters(),
                'lr': args.lr * 10
            }, {
                'params': net.classifier.parameters(),
                'lr': args.lr * 10
            }],
            weight_decay=5e-4,
            momentum=0.9,
            nesterov=True)

        return optimizer_pcb_cbam

    def id_cbam_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.attention.parameters())) + list(
                map(id, net.feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        # if args.arch == 'pcb_rpp'or 'pcb_pyramid':
        optimizer_id_cbam = optim.SGD([
            {
                'params': base_params,
                'lr': args.lr
            },
            {
                'params': net.attention.parameters(),
                'lr': args.lr * 10
            },
            {
                'params': net.feature.parameters(),
                'lr': args.lr
            },
            {
                'params': net.classifier.parameters(),
                'lr': args.lr
            },
            {
                'params': center_criterion.parameters(),
                'lr': args.lr * 10
            },
        ],
                                      weight_decay=5e-4,
                                      momentum=0.9,
                                      nesterov=True)

        return optimizer_id_cbam

    def inter_loss_train(net):
        ignored_params = list(map(id, net.classifier.parameters())) + list(
            map(id, net.glo_classifier.parameters())
        )  #+list(map(id, net.feature.parameters()))+list(map(id, net.glo_feature.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        if args.method == 'id_inter_loss':
            optimizer_inter = optim.SGD(
                [
                    {
                        'params': base_params,
                        'lr': args.lr
                    },
                    # {'params': net.feature.parameters(), 'lr': args.lr},
                    {
                        'params': net.glo_classifier.parameters(),
                        'lr': args.lr * 10
                    },
                    # {'params': center_criterion.parameters(), 'lr': args.lr*10},
                    {
                        'params': net.classifier.parameters(),
                        'lr': args.lr * 10
                    }
                ],
                weight_decay=5e-4,
                momentum=0.9,
                nesterov=True)

        return optimizer_inter

    def full_train(net):

        ignored_params = list(map(id, net.feature.parameters()))  + list(map(id, net.pcb_classifier.parameters()))\
            + list(map(id, net.avgpool.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             net.parameters())
        optimizer_full = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * args.lr
        }, {
            'params': net.feature.parameters(),
            'lr': args.lr
        }, {
            'params': net.pcb_classifier.parameters(),
            'lr': args.lr
        }, {
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                   weight_decay=5e-4,
                                   momentum=0.9,
                                   nesterov=True)
        return optimizer_full

    def rpp_train(net):
        optimizer_rpp = optim.SGD([{
            'params': net.avgpool.parameters(),
            'lr': args.lr
        }],
                                  weight_decay=5e-4,
                                  momentum=0.9,
                                  nesterov=True)

        return optimizer_rpp

    '''
    def lr_scheduler(optimizer, epoch):  # new2
        warmup_epoch = 30
        warmup_lr = 1e-5
        lr_steps = [60, 90, 130]
        start_lr = 1e-2
        lr_factor = 0.1

        if epoch <= warmup_epoch:  # lr warmup
            warmup_scale = (start_lr / warmup_lr) ** (1.0 / warmup_epoch)
            lr = warmup_lr * (warmup_scale ** epoch)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            optimizer.defaults['lr'] = lr
        else:  # lr jump
            for i, el in enumerate(lr_steps):
                if epoch == el:
                    lr = start_lr * (lr_factor ** (i + 1))
                    # logger.info('====> LR is set to: {}'.format(lr))
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    optimizer.defaults['lr'] = lr
        lrs = [round(el['lr'], 6) for el in optimizer.param_groups]
        return optimizer, lrs

    
    def warmup_fn(optimizer, epoch):#new1
        if epoch<=30:
            lr=0.1*args.lr*epoch
        elif epoch>30 and epoch <=60:
            lr=args.lr
        elif epoch>60 and epoch<=90:
            lr=args.lr*0.1
        elif epoch>90 and epoch <=120:
            lr=args.lr*0.01

        optimizer.param_groups[0]['lr'] =lr
        optimizer.param_groups[1]['lr'] =lr
        optimizer.param_groups[2]['lr'] =lr

        return lr

    '''

    def adjust_learning_rate(optimizer, epoch, flag):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        if epoch <= 10:  #30:
            lr = args.lr
        elif epoch > 10 and epoch <= 20:
            lr = args.lr * 0.1
        elif epoch > 20 and epoch < 30:
            lr = args.lr * 0.01
        else:
            lr = args.lr * 0.001

        if flag == 'pcb':  #args.arch=='pcb_rpp'
            optimizer.param_groups[0]['lr'] = lr
            # optimizer.param_groups[1]['lr'] = lr
            optimizer.param_groups[1]['lr'] = lr * 10

        elif flag == 'rpp':  #args.with_rpp
            optimizer.param_groups[0]['lr'] = lr
        elif flag == 'full':  #full_train
            optimizer.param_groups[0]['lr'] = 0.1 * lr
            optimizer.param_groups[1]['lr'] = lr
            optimizer.param_groups[2]['lr'] = lr
            optimizer.param_groups[3]['lr'] = lr
        elif flag == 'normal':
            optimizer.param_groups[0]['lr'] = 0.1 * lr
            optimizer.param_groups[1]['lr'] = lr
            optimizer.param_groups[2]['lr'] = lr
        elif flag == 'id_inter_loss':
            optimizer.param_groups[0]['lr'] = lr
            # optimizer.param_groups[1]['lr'] = lr
            # optimizer.param_groups[2]['lr'] = lr
            optimizer.param_groups[1]['lr'] = lr * 10
            optimizer.param_groups[2]['lr'] = lr * 10
        elif flag == 'pcb_cbam':
            optimizer.param_groups[0]['lr'] = lr
            optimizer.param_groups[1]['lr'] = lr * 10
            optimizer.param_groups[2]['lr'] = lr * 10
            optimizer.param_groups[3]['lr'] = lr * 10

        elif flag == 'id_cbam':
            optimizer.param_groups[0]['lr'] = lr
            optimizer.param_groups[1]['lr'] = lr * 10
            optimizer.param_groups[2]['lr'] = lr
            optimizer.param_groups[3]['lr'] = lr
            optimizer.param_groups[4]['lr'] = lr * 10

        return lr

    def adjust_lr(optimizer, ep, flag='normal'):
        if ep < 40:
            lr = 1e-3 * (ep // 5 + 1)
        elif ep < 60:
            lr = 1e-2
        elif ep < 80:
            lr = 1e-3
        else:
            lr = 1e-4
        for p in optimizer.param_groups:
            p['lr'] = lr

        return lr
#################################################################################################

    def train(net, epoch, optimizer, flag, ini_id, ini_modal):
        current_lr = adjust_learning_rate(optimizer, epoch, flag)
        # current_lr = warmup_fn(optimizer, epoch)
        # _, current_lr = lr_scheduler(optimizer, epoch)
        train_loss = AverageMeter()
        data_time = AverageMeter()
        batch_time = AverageMeter()
        precisions = AverageMeter()
        correct = 0
        total = 0

        print('==> Preparing Data Loader...')
        # identity sampler
        # sampler = IdentitySampler(trainset.train_color_label, \
        #                           trainset.train_thermal_label, color_pos, thermal_pos, args.batch_size)
        sampler = RandomIdentitySampler_alignedreid(trainset.train_color_label, \
                                  trainset.train_thermal_label, color_pos, thermal_pos, args.batch_size,4)

        # sampler = triplet_Sampler_reid(trainset.train_color_label, \
        #                           trainset.train_thermal_label, color_pos, thermal_pos, args.batch_size,8)
        trainset.cIndex = sampler.index1  # color index
        trainset.tIndex = sampler.index2  # thermal index
        # trainset.anchor_pos_rgb = sampler.anchor_pos_rgb
        # trainset.anchor_pos_ir = sampler.anchor_pos_ir
        trainloader = data.DataLoader(trainset, batch_size=args.batch_size, \
                                      sampler=sampler, num_workers=args.workers, drop_last=True)

        # trainset=SYSU_triplet_dataset(data_folder=data_path)
        # trainloader=data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,drop_last = True)
        # switch to train mode
        net.train()
        model_static = net.visible_net.visible.state_dict()
        end = time.time()

        for batch_idx, (input1, input2, label1,
                        label2) in enumerate(trainloader):
            # for batch_idx, (anchor_r,anchor_t,input1, input2, label1, label2,an_label1,an_label2) in enumerate(trainloader):
            ####show attention map
            # img1=input1
            # img2=input2
            # show_img=img1[-1,:,:,:]
            # plt.imshow(show_img)
            # plt.figure()
            # target_index = None
            # mask_v = net.grad_cam_v(img1, target_index)
            # mask_t = net.grad_cam_t(img2, target_index)
            # show_cam_on_image(img1[-1,:,:,:], mask_v,epoch,mode="visible")
            # show_cam_on_image(img2[-1,:,:,:], mask_t, epoch,mode="thermal")

            input1 = Variable(input1.cuda())
            input2 = Variable(input2.cuda())

            label1 = Variable(label1.cuda())
            label2 = Variable(label2.cuda())

            labels = torch.cat((label1, label2), 0)
            labels = Variable(labels.cuda())
            data_time.update(time.time() - end)

            output, feat = net(input1, input2)

            global_step = epoch * len(trainloader) + batch_idx

            summary_writer.add_scalar('lr', current_lr, global_step)

            if args.method == 'id':

                if args.arch == 'pcb_rpp' or args.arch == 'pcb_pyramid':
                    outputs, v_global_feat, t_global_feat, feat_p = output
                    loss = criterion(outputs[0], labels)
                    prec, = accuracy(outputs[0], labels.data, topk=(20, ))
                    correct_lc = (outputs[0].max(1)[1]).eq(labels).sum().item()
                    for i in range(1, num_part):
                        loss += criterion(outputs[i], labels)
                        tmp_prec, = accuracy(outputs[i],
                                             labels.data,
                                             topk=(20, ))
                        prec += tmp_prec
                        # prec1 = torch.mean(prec[0])
                        correct_lc += (
                            outputs[i].max(1)[1]).eq(labels).sum().item()
                    prec1 = prec[0] * 1.0 / num_part
                    correct += correct_lc * 1.0 / num_part
                else:
                    # net.classifier.to(device)
                    # output=net.classifier(output,labels)

                    id_loss = criterion(output, labels)
                    glo_modal, _, _, _ = triplet_criterion(feat, feat, labels)
                    feat_glo_v = torch.split(feat, args.batch_size, 0)[0]
                    feat_glo_t = torch.split(feat, args.batch_size, 0)[1]
                    modal_glo_v, _, dist_ap_v, dist_an_v = triplet_criterion(
                        feat_glo_v, feat_glo_v, label1)
                    modal_glo_t, _, dist_ap_t, dist_an_t = triplet_criterion(
                        feat_glo_t, feat_glo_t, label2)

                    intra_modal = modal_glo_v + modal_glo_t
                    inter_modal_glo_v, _, inter_dist_ap_v, inter_dist_an_v = triplet_criterion(
                        feat_glo_v, feat_glo_t, label1)
                    inter_modal_glo_t, _, inter_dist_ap_t, inter_dist_an_t = triplet_criterion(
                        feat_glo_t, feat_glo_v, label1)
                    inter_modal = inter_modal_glo_v + inter_modal_glo_t
                    # ap_dist_loss=torch.mean(inter_dist_ap_v)+torch.mean(inter_dist_ap_t)+torch.mean(dist_ap_v)+torch.mean(dist_ap_t)

                    apn_intra = torch.max(dist_an_v, dist_an_t)
                    apn_inter = torch.min(inter_dist_ap_v, inter_dist_ap_t)
                    inter_intra_pn = triplet_criterion.dist_l2_apn(
                        feat_glo_v, feat_glo_t)

                    modal_loss = glo_modal + 0 * intra_modal + 0 * inter_modal + 0 * inter_intra_pn

                    center_loss = center_criterion(labels, feat)
                    # center_loss = center_criterion_cro(label1, label2,feat_glo_v,feat_glo_t)
                    delta_id = np.maximum(
                        1.0 * (ini_id - id_loss.cpu().detach().numpy()) /
                        (id_loss.cpu().detach().numpy() + 1e-12), 0)
                    delta_modal = np.maximum(
                        1.0 * (ini_modal - modal_loss.cpu().detach().numpy()) /
                        (modal_loss.cpu().detach().numpy() + 1e-12), 0)

                    if not delta_id and not delta_modal:
                        alpha = 1
                    else:
                        alpha = delta_id * 1.0 / (delta_id + delta_modal)

                    loss = 1 * id_loss + lamda_ap * modal_loss + lamda_cen * center_loss
                    summary_writer.add_scalar('modal_loss', modal_loss.item(),
                                              global_step)
                    summary_writer.add_scalar('id_loss', id_loss.item(),
                                              global_step)
                    summary_writer.add_scalar('center_loss',
                                              center_loss.item(), global_step)
                    summary_writer.add_scalar('total_loss', loss.item(),
                                              global_step)
                    prec, = accuracy(output.data, labels.data, topk=(20, ))
                    prec1 = prec[0]
                    correct += (output.max(1)[1]).eq(labels).sum().item()
                    ini_id = id_loss.cpu().detach().numpy()
                    ini_modal = modal_loss.cpu().detach().numpy()
                # _, predicted = outputs.max(1)
                # correct += predicted.eq(labels).sum().item()
                # revise by zc

            elif args.method == 'id_inter_loss':
                #oim_loss = OIMLoss(n_class, n_class, 30, 0.5).cuda()
                # loss = id_inter_loss(outputs,v_global_feat,t_global_feat,labels)
                outputs, v_global_feat, t_global_feat, feat_p = output
                loss, prec = oim_inter_center_loss(center_criterion, outputs,
                                                   feat, feat_p, v_global_feat,
                                                   t_global_feat, labels,
                                                   num_part, summary_writer,
                                                   global_step)

                # prec1=prec
                # correct_lc=prec*6
                prec, = accuracy(outputs[0], labels.data, topk=(20, ))
                correct_lc = (outputs[0].max(1)[1]).eq(labels).sum().item()
                for i in range(1, num_part):
                    tmp_prec, = accuracy(outputs[i], labels.data, topk=(20, ))
                    prec += tmp_prec
                    # prec1 = torch.mean(prec[0])
                    correct_lc += (
                        outputs[i].max(1)[1]).eq(labels).sum().item()
                prec1 = prec[0] * 1.0 / num_part
                correct += correct_lc * 1.0 / num_part

            elif args.method == 'sphere':

                loss = sphere_criterion(feat, labels)
                prec, = accuracy(output.data, labels.data, topk=(20, ))
                prec1 = prec[0]

            elif args.method == 'triplet':
                if args.arch == "pcb_rpp" or args.arch == 'pcb_pyramid':
                    loss, prec = triplet_criterion(feat_p[0], labels)
                    correct_lc = (feat_p[0].max(1)[1]).eq(labels).sum().item()
                    for i in range(1, num_part):
                        loss_tmp, prec_tmp = triplet_criterion(
                            feat_p[i], labels)
                        loss += loss_tmp
                        prec += prec_tmp
                        #
                        correct_lc += (
                            outputs[i].max(1)[1]).eq(labels).sum().item()
                    prec1 = prec * 1.0 / num_part
                    correct += correct_lc * 1.0 / num_part
                else:

                    loss, prec1 = triplet_criterion(feat, labels)
                    correct += prec1

            elif args.method == 'id_triplet':
                loss = 0.3 * id_criterion(
                    outputs, labels) + 0.7 * triplet_criterion(feat, labels)[0]
                prec, = accuracy(outputs.data, labels.data, topk=(20, ))
                prec1 = prec[0]

            elif args.method == 'id_triplet_center':
                loss = id_criterion(outputs, labels) + triplet_criterion(
                    feat, labels)[0] + Center_weight_loss * center_criterion(
                        feat, labels)
                prec, = accuracy(outputs.data, labels.data, topk=(20, ))
                prec1 = prec[0]

            total += labels.size(0)
            # acc_avg=(outputs.max(1)[1]==labels).float().mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), labels.size(0))  # loss.item()
            precisions.update(prec1, labels.size(0))
            # total += labels.size(0)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if batch_idx % 10 == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time: {:.3f} ({:.3f})\t'
                      'Data: {:.3f} ({:.3f})\t'
                      'lr:{}\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'rank1 {:.2%} ({:.2%})\t'
                      'acc {:.2%}\t '.format(epoch, batch_idx,
                                             len(trainloader), batch_time.val,
                                             batch_time.avg, data_time.val,
                                             data_time.avg, current_lr,
                                             train_loss.val, train_loss.avg,
                                             precisions.val, precisions.avg,
                                             1. * correct / total))

            # if epoch % 2 == 0 and batch_idx==len(trainloader):
            #     figure=draw_cluster(net.l2norm(output),labels,32,32)
            #
            #     plt.savefig('/home/zhangc/projects/cross_modality_V/BDTR_modal_loss_cbam_debug/image/cluster/'+str(epoch)+'_train_'+suffix_id+'.jpg')
            #     plt.show()
            #     plt.pause(1)
            #     plt.close()
##############################################################################################

    def test(net, epoch):
        # switch to evaluation mode
        net.eval()
        print('Extracting Gallery Feature...')
        start = time.time()
        ptr = 0
        gall_feat = np.zeros((ngall, args.low_dim))  #args.low_dim
        label_g = []
        with torch.no_grad():
            for batch_idx, (input, label) in enumerate(gall_loader):
                batch_num = input.size(0)
                input = Variable(input.cuda())
                feat_pool, feat = net(input, input, test_mode[0])
                label_g.append(label)
                gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
                ptr = ptr + batch_num
        print('Extracting Time:\t {:.3f}'.format(time.time() - start))

        # switch to evaluation mode
        net.eval()
        print('Extracting Query Feature...')
        start = time.time()
        ptr = 0
        query_feat = np.zeros((nquery, args.low_dim))
        label_q = []
        with torch.no_grad():
            for batch_idx, (input, label) in enumerate(query_loader):
                batch_num = input.size(0)
                input = Variable(input.cuda())
                feat_pool, feat = net(input, input, test_mode[1])

                query_feat[ptr:ptr +
                           batch_num, :] = feat.detach().cpu().numpy()
                label_q.append(label)
                ptr = ptr + batch_num
        print('Extracting Time:\t {:.3f}'.format(time.time() - start))

        start = time.time()
        # compute the similarity
        distmat = np.matmul(query_feat, np.transpose(gall_feat))

        # output=np.concatenate((query_feat,gall_feat),0)
        # labels=label_q+label_g
        # figure = draw_cluster(output, labels, len(label_q),len(label_g))
        # #
        # plt.savefig('/home/zhangc/projects/cross_modality_V/BDTR_modal_loss_cbam_debug/image/cluster/' + str(
        #     epoch) + '_test_'+suffix_id+'.jpg')
        # plt.show()
        # plt.pause(1)
        # plt.close()

        # evaluation
        if dataset == 'regdb':
            cmc, mAP = eval_regdb(-distmat, query_label, gall_label)
        elif dataset == 'sysu':
            cmc, mAP = eval_sysu(-distmat, query_label, gall_label, query_cam,
                                 gall_cam)
        print('Evaluation Time:\t {:.3f}'.format(time.time() - start))

        # uni_que = np.unique(query_label)
        # n_q, n_g = 0, 0
        # output = []
        # labels = []
        # for i in range(30):
        #     label_id = uni_que[randrange(0, len(uni_que))]
        #     index_q = [i for i, a in enumerate(query_label) if a == label_id]
        #     index_g = [i for i, a in enumerate(gall_label) if a == label_id]
        #     output.extend(np.concatenate((query_feat[index_q, :], gall_feat[index_g, :]), 0))
        #     tmp_q = query_label[index_q]
        #     tmp_g = gall_label[index_g]
        #     n_q += len(tmp_q)
        #     n_g += len(tmp_g)
        #     labels.extend(np.concatenate((tmp_q, tmp_g), 0))
        #
        # figure = draw_cluster(output, labels, n_q, n_g)
        # #
        # suffix_id = 'DEBUG_30'
        # plt.savefig(
        #     '/home/zhangc/projects/cross_modality_V/BDTR_modal_loss_cbam_debug/image/cluster/'  + '_train_' + suffix_id + '.jpg')
        # plt.show()
        # plt.pause(1)
        # plt.close()
        return cmc, mAP
#######################################################################################################
# training

    def train_model(net, optimizer, final_epoch, flag='normal'):
        # training
        print('==> Start Training...')
        best_acc = 0
        for epoch in range(start_epoch, final_epoch - start_epoch + 1):

            train(net, epoch, optimizer, flag, ini_id, ini_modal)

            if epoch % 2 == 0:  #epoch > 0 and
                print('Test Epoch: {}'.format(epoch))
                print('Test Epoch: {}'.format(epoch), file=test_log_file)
                # testing
                cmc, mAP = test(net, epoch)
                print(
                    'FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'
                    .format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))
                print(
                    'FC:   Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'
                    .format(cmc[0], cmc[4], cmc[9], cmc[19], mAP),
                    file=test_log_file)
                test_log_file.flush()

                # save model
                if cmc[0] > best_acc:  # not the real best for sysu-mm01
                    best_acc = cmc[0]
                    state = {
                        'net': net.state_dict(),
                        'cmc': cmc,
                        'mAP': mAP,
                        'epoch': epoch,
                    }
                    torch.save(state, checkpoint_path + suffix + '_best.t')

                # save model every 20 epochs
                if epoch > 10 and epoch % args.save_epoch == 0:
                    state = {
                        'net': net.state_dict(),
                        'cmc': cmc,
                        'mAP': mAP,
                        'epoch': epoch,
                    }
                    torch.save(
                        state,
                        checkpoint_path + suffix + '_epoch_{}.t'.format(epoch))

            # torch.save(state, checkpoint_path + suffix + '_latest.t')
        return net


#########################################################################

    ini_id = 10
    ini_modal = 10
    #######################################################################################################

    if args.arch == 'pcb_rpp' or args.arch == 'pcb_pyramid':
        # print('-------epoch for pcb_traing--------')
        # optimizer_pcb = pcb_train(net)
        # model=train_model(net, optimizer_pcb, 60, flag='pcb')

        if args.with_rpp:
            print('-------epoch for rpp_traing--------')
            model = model.convert_to_rpp().cuda()
            optimizer_rpp = rpp_train(model)
            model = train_model(model, optimizer_rpp, 40, flag='rpp')

            print('-------epoch for full_traing--------')
            optimizer_full = full_train(model)
            train_model(model, optimizer_full, 80, flag='full')
        elif args.method == 'id_inter_loss':

            if args.use_cbam:
                print('-------epoch for pcb_cbam training--------')
                optimizer_pcb_cbam = pcb_cbam_train(net)
                train_model(net, optimizer_pcb_cbam, 90, flag='pcb_cbam')
            else:
                print('-------epoch for id_inter_loss_training--------')
                optimizer_inter = inter_loss_train(net)
                train_model(net, optimizer_inter, 100, flag='id_inter_loss')
            # optimizer_pcb = pcb_train(net)
            # train_model(net, optimizer_pcb, 60, flag='pcb')

        else:

            print('-------epoch for pcb_traing--------')
            optimizer_pcb = pcb_train(net)
            train_model(net, optimizer_pcb, 90, flag='pcb')

    else:
        if args.use_cbam:
            print('-------epoch for resnet_cbam training--------')
            optimizer_id_cbam = id_cbam_train(net)
            train_model(net, optimizer_id_cbam, 90, flag='pcb_cbam')
        else:
            train_model(net, optimizer, 500, flag='normal')