def checkpoint_loader(model, path, eval_only=False):
    checkpoint = load_checkpoint(path)
    pretrained_dict = checkpoint['state_dict']
    if isinstance(model, nn.DataParallel):
        Parallel = 1
        model = model.module.cpu()
    else:
        Parallel = 0

    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    if eval_only:
        keys_to_del = []
        for key in pretrained_dict.keys():
            if 'classifier' in key:
                keys_to_del.append(key)
        for key in keys_to_del:
            del pretrained_dict[key]
        pass
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)

    start_epoch = checkpoint['epoch']
    best_top1 = checkpoint['best_top1']

    if Parallel:
        model = nn.DataParallel(model).cuda()

    return model, start_epoch, best_top1
Esempio n. 2
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.benchmark = True

    valset = Feeder(args.val_feat_path,
                    args.val_knn_graph_path,
                    args.val_label_path,
                    args.seed,
                    args.k_at_hop,
                    args.active_connection,
                    train=False)
    valloader = DataLoader(valset,
                           batch_size=args.batch_size,
                           num_workers=args.workers,
                           shuffle=False,
                           pin_memory=True)

    ckpt = load_checkpoint(args.checkpoint)
    net = model.gcn()
    net.load_state_dict(ckpt['state_dict'])
    net = net.cuda()

    knn_graph = valset.knn_graph
    knn_graph_dict = list()
    for neighbors in knn_graph:
        knn_graph_dict.append(dict())
        for n in neighbors[1:]:
            knn_graph_dict[-1][n] = []

    criterion = nn.CrossEntropyLoss().cuda()
    edges, scores = validate(valloader, net, criterion)

    np.save('edges', edges)
    np.save('scores', scores)
    #edges=np.load('edges.npy')
    #scores = np.load('scores.npy')

    clusters = graph_propagation(edges,
                                 scores,
                                 max_sz=900,
                                 step=0.6,
                                 pool='avg')
    final_pred = clusters2labels(clusters, len(valset))
    labels = valset.labels

    print('------------------------------------')
    print('Number of nodes: ', len(labels))
    print('Precision   Recall   F-Sore   NMI')
    p, r, f = bcubed(final_pred, labels)
    nmi = normalized_mutual_info_score(final_pred, labels)
    print(('{:.4f}    ' * 4).format(p, r, f, nmi))

    labels, final_pred = single_remove(labels, final_pred)
    print('------------------------------------')
    print('After removing singleton culsters, number of nodes: ', len(labels))
    print('Precision   Recall   F-Sore   NMI')
    p, r, f = bcubed(final_pred, labels)
    nmi = normalized_mutual_info_score(final_pred, labels)
    print(('{:.4f}    ' * 4).format(p, r, f, nmi))
Esempio n. 3
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (144, 56) if args.arch == 'inception' else \
                                  (240, 240)
    dataset, num_classes, train_loader, val_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers)

    # Create model

    img_branch = models.create(args.arch,
                               cut_layer=args.cut_layer,
                               num_classes=num_classes)

    args.resume = "/mnt/lustre/renjiawei/DAIN_py/logs/Resnet50-single_view-split1/model_best.pth.tar"

    # Load from checkpoint
    start_epoch = best_top1 = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        # img_high_level.load_state_dict(checkpoint['state_dict_img'])
        # diff_high_level.load_state_dict(checkpoint['state_dict_diff'])
        img_branch.load_state_dict(checkpoint['state_dict_img'])
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(
            start_epoch, best_top1))

    img_branch = nn.DataParallel(img_branch).cuda()
    # img_branch = nn.DataParallel(img_branch)
    img_branch.train(False)

    x = torch.randn(64, 1, 224, 224, requires_grad=True)

    torch_out = torch.onnx._export(
        img_branch,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        "super_resolution.onnx",
        # where to save the model (can be a file or file-like object)
        export_params=True
    )  # store the trained parameter weights inside the model file
Esempio n. 4
0
def test_cycle_gan(**kwargs):
    opt._parse(kwargs)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    # Write standard output into file
    sys.stdout = Logger(os.path.join(opt.save_dir, 'log_test.txt'))

    print('========user config========')
    pprint(opt._state_dict())
    print('===========end=============')
    if opt.use_gpu:
        print('currently using GPU')
        torch.cuda.manual_seed_all(opt.seed)
    else:
        print('currently using cpu')

    pin_memory = True if opt.use_gpu else False
    print('initializing dataset {}'.format(opt.dataset_mode))
    dataset = UnalignedDataset(opt)
    testloader = DataLoader(dataset,
                            opt.batchSize,
                            True,
                            num_workers=opt.workers,
                            pin_memory=pin_memory)

    summaryWriter = SummaryWriter(os.path.join(opt.save_dir,
                                               'tensorboard_log'))

    print('initializing model ... ')
    netG_A, netG_B, netD_A, netD_B = load_checkpoint(opt)
    start_epoch = opt.start_epoch
    if opt.use_gpu:
        netG_A = torch.nn.DataParallel(netG_A).cuda()
        netG_B = torch.nn.DataParallel(netG_B).cuda()
        netD_A = torch.nn.DataParallel(netD_A).cuda()
        netD_B = torch.nn.DataParallel(netD_B).cuda()

    # get tester
    cycleganTester = Tester(opt, netG_A, netG_B, netD_A, netD_B, summaryWriter)

    for epoch in range(start_epoch, opt.max_epoch):
        # test over whole dataset
        cycleganTester.test(epoch, testloader)
def main(args):
    batch_time = AverageMeter()
    end = time.time()

    checkpoint = load_checkpoint(args.resume)  #loaded
    print('pool_features:', args.pool_feature)
    epoch = checkpoint['epoch']

    gallery_feature, gallery_labels, query_feature, query_labels = \
    Model2Feature(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint
    , batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)    #output

    sim_mat = pairwise_similarity(query_feature, gallery_feature)  #成对相似性
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    print('labels', query_labels)
    print('feature:', gallery_feature)

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)

    result = '  '.join(['%.4f' % k for k in recall_ks])  #   result=recall_ks
    print('Epoch-%d' % epoch, result)
    batch_time.update(time.time() - end)

    print('Epoch-%d\t' % epoch,
          'Time {batch_time.avg:.3f}\t'.format(batch_time=batch_time))

    import matplotlib.pyplot as plt
    import torchvision
    import numpy as np

    similarity = torch.mm(gallery_feature, gallery_feature.t())
    similarity.size()

    #draw Feature Map
    img = torchvision.utils.make_grid(similarity).numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()
def eval(ckp_path=None, model=None):
    args = Config()
    if (ckp_path != None):
        checkpoint = load_checkpoint(ckp_path, args)
    else:
        checkpoint = model
        checkpoint.eval()
    # print(args.pool_feature)

    gallery_feature, gallery_labels, query_feature, query_labels = \
        Model2Feature(data=args.data,model = checkpoint, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)
    if args.gallery_eq_query is True:
        sim_mat = sim_mat - torch.eye(sim_mat.size(0))

    recall_ks = Recall_at_ks(sim_mat,
                             query_ids=query_labels,
                             gallery_ids=gallery_labels,
                             data=args.data)
    if (ckp_path == None):
        checkpoint.train()
    return recall_ks
Esempio n. 7
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Create data loaders
    if args.height is None or args.width is None:
        args.height, args.width = (144, 56) if args.arch == 'inception' else \
                                  (240, 240)
    dataset, num_classes, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.workers, args.combine_trainval)

    # Create model

    img_branch = models.create(args.arch,
                               cut_layer=args.cut_layer,
                               num_classes=num_classes,
                               num_features=args.features)
    diff_branch = models.create(args.arch,
                                cut_layer=args.cut_layer,
                                num_classes=num_classes,
                                num_features=args.features)

    # Load from checkpoint
    start_epoch = best_top1 = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        img_branch.load_state_dict(checkpoint['state_dict_img'])
        diff_branch.load_state_dict(checkpoint['state_dict_diff'])
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(
            start_epoch, best_top1))

    img_branch = nn.DataParallel(img_branch).cuda()
    diff_branch = nn.DataParallel(diff_branch).cuda()
    # img_branch = nn.DataParallel(img_branch)
    # diff_branch = nn.DataParallel(diff_branch)

    # Criterion
    criterion = nn.CrossEntropyLoss().cuda()
    # criterion = nn.CrossEntropyLoss()

    # Evaluator
    evaluator = Evaluator(img_branch, diff_branch, criterion)
    if args.evaluate:
        # print("Validation:")
        # top1, _ = evaluator.evaluate(val_loader)
        # print("Validation acc: {:.1%}".format(top1))
        print("Test:")
        top1, (gt, pred) = evaluator.evaluate(test_loader)
        print("Test acc: {:.1%}".format(top1))
        from confusion_matrix import plot_confusion_matrix
        plot_confusion_matrix(gt, pred, dataset.classes, args.logs_dir)
        return

    img_param_groups = [
        {
            'params': img_branch.module.low_level_modules.parameters(),
            'lr_mult': 0.1
        },
        {
            'params': img_branch.module.high_level_modules.parameters(),
            'lr_mult': 0.1
        },
        {
            'params': img_branch.module.classifier.parameters(),
            'lr_mult': 1
        },
    ]

    diff_param_groups = [
        {
            'params': diff_branch.module.low_level_modules.parameters(),
            'lr_mult': 0.1
        },
        {
            'params': diff_branch.module.high_level_modules.parameters(),
            'lr_mult': 0.1
        },
        {
            'params': diff_branch.module.classifier.parameters(),
            'lr_mult': 1
        },
    ]

    img_optimizer = torch.optim.SGD(img_param_groups,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    diff_optimizer = torch.optim.SGD(diff_param_groups,
                                     lr=args.lr,
                                     momentum=args.momentum,
                                     weight_decay=args.weight_decay,
                                     nesterov=True)

    # Trainer
    trainer = Trainer(img_branch, diff_branch, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        step_size = args.step_size
        lr = args.lr * (0.1**(epoch // step_size))
        for g in img_optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)
        for g in diff_optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, img_optimizer, diff_optimizer)
        if epoch < args.start_save:
            continue
        top1, _ = evaluator.evaluate(val_loader)

        is_best = top1 > best_top1
        best_top1 = max(top1, best_top1)
        save_checkpoint(
            {
                'state_dict_img': img_branch.module.state_dict(),
                'state_dict_diff': diff_branch.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1,
            },
            is_best,
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  top1: {:5.1%}  best: {:5.1%}{}\n'.
              format(epoch, top1, best_top1, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    img_branch.module.load_state_dict(checkpoint['state_dict_img'])
    diff_branch.module.load_state_dict(checkpoint['state_dict_diff'])
    top1, (gt, pred) = evaluator.evaluate(test_loader)
    from confusion_matrix import plot_confusion_matrix
    plot_confusion_matrix(gt, pred, dataset.classes, args.logs_dir)
    print('\n * Test Accuarcy: {:5.1%}\n'.format(top1))
Esempio n. 8
0
def main(args):
    # s_ = time.time()

    save_dir = args.save_dir
    mkdir_if_missing(save_dir)

    sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt'))
    display(args)
    start = 0

    model = models.create(args.net, pretrained=True, dim=args.dim)

    # for vgg and densenet
    if args.resume is None:
        model_dict = model.state_dict()

    else:
        # resume model
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)

    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # freeze BN
    if args.freeze_BN is True:
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40 * '#', 'BatchNorm NOT frozen')

    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.module.classifier.parameters()))

    new_params = [
        p for p in model.module.parameters() if id(p) in new_param_ids
    ]

    base_params = [
        p for p in model.module.parameters() if id(p) not in new_param_ids
    ]

    param_groups = [{
        'params': base_params,
        'lr_mult': 0.0
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    criterion = losses.create(args.loss,
                              margin=args.margin,
                              alpha=args.alpha,
                              base=args.loss_base).cuda()

    # Decor_loss = losses.create('decor').cuda()
    data = DataSet.create(args.data,
                          ratio=args.ratio,
                          width=args.width,
                          origin_width=args.origin_width,
                          root=args.data_root)

    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train,
                                          num_instances=args.num_instances),
        drop_last=True,
        pin_memory=True,
        num_workers=args.nThreads)

    # save the train information

    for epoch in range(start, args.epochs):

        train(epoch=epoch,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              train_loader=train_loader,
              args=args)

        if epoch == 1:
            optimizer.param_groups[0]['lr_mul'] = 0.1

        if (epoch + 1) % args.save_step == 0 or epoch == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            },
                            is_best=False,
                            fpath=osp.join(
                                args.save_dir,
                                'ckp_ep' + str(epoch + 1) + '.pth.tar'))
Esempio n. 9
0
    parser.add_argument('--train_flag', type=bool, default=True)
    parser.add_argument('--classnum', type=int, default=20)
    parser.add_argument('--L', type=int, default=8)
    parser.add_argument(
        '--pool_feature',
        type=ast.literal_eval,
        default=True,
        required=False,
        help='if True extract feature from the last pool layer')

    args = parser.parse_args()

    args.resume = osp.join(args.save_dir, 'ckp_ep500.pth.tar')

    checkpoint = load_checkpoint(args.resume)
    #print(args.pool_feature)
    epoch = checkpoint['epoch']
    #if args.train_flag:

    #ipdb.set_trace()
    train_feature, train_labels, test_feature, test_labels = \
    Sequence2Feature_test(data=args.data, root=args.data_root, net=args.net, checkpoint=checkpoint,
                   in_dim=args.in_dim, middle_dim=args.middle_dim,out_dim=args.out_dim,batch_size=args.batch_size, nThreads=args.nThreads, train_flag=False)
    #print(train_feature)
    train_feature = train_feature.numpy()

    test_feature = test_feature.numpy()

    savedata_mat = "TransFeatures.mat"
    savedatapath = os.path.join(args.data_root, savedata_mat)
Esempio n. 10
0
def main(args):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True

    # log file

    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    dataset, num_classes, train_loader, query_loader, gallery_loader = \
        get_data(args.dataset, args.split, args.data_dir,
                 args.batch_size, args.seq_len, args.seq_srd,
                 args.workers, args.train_mode)

    # create CNN model
    cnn_model = models.create(args.a1, num_features=args.features, dropout=args.dropout)

    # create ATT model
    input_num = cnn_model.feat.in_features
    output_num = args.features
    att_model = models.create(args.a2, input_num, output_num)

    # create classifier model
    class_num = 2
    classifier_model = models.create(args.a3,  output_num, class_num)


    # CUDA acceleration model

    cnn_model = torch.nn.DataParallel(cnn_model).cuda()
    att_model = att_model.cuda()
    classifier_model = classifier_model.cuda()


    # Loss function

    criterion_oim = OIMLoss(args.features, num_classes,
                            scalar=args.oim_scalar, momentum=args.oim_momentum)
    criterion_veri = PairLoss(args.sampling_rate)
    criterion_oim.cuda()
    criterion_veri.cuda()

    # Optimizer
    base_param_ids = set(map(id, cnn_model.module.base.parameters()))
    new_params = [p for p in cnn_model.parameters() if
                  id(p) not in base_param_ids]

    param_groups1 = [
        {'params': cnn_model.module.base.parameters(), 'lr_mult': 1},
        {'params': new_params, 'lr_mult': 1}]
    param_groups2 = [
        {'params': att_model.parameters(), 'lr_mult': 1},
        {'params': classifier_model.parameters(), 'lr_mult': 1}]




    optimizer1 = torch.optim.SGD(param_groups1, lr=args.lr1,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)


    optimizer2 = torch.optim.SGD(param_groups2, lr=args.lr2,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)




    # Schedule Learning rate
    def adjust_lr1(epoch):
        lr = args.lr1 * (0.1 ** (epoch/args.lr1step))
        print(lr)
        for g in optimizer1.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr2(epoch):
        lr = args.lr2 * (0.01 ** (epoch//args.lr2step))
        print(lr)
        for g in optimizer2.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr3(epoch):
        lr = args.lr3 * (0.000001 ** (epoch //args.lr3step))
        print(lr)
        return lr


    best_top1 = 0
    start_epoch = args.start_epoch
    if args.evaluate == 1:
        print('Evaluate:')
        evaluator = ATTEvaluator(cnn_model, att_model, classifier_model, args.train_mode, criterion_veri)
        top1, mAP = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)

    elif args.test == 1:
        print('Test:')
        checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar'))
        cnn_model.load_state_dict(checkpoint1['state_dict'])
        checkpoint2 = load_checkpoint(osp.join(args.logs_dir, 'attmodel_best.pth.tar'))
        att_model.load_state_dict(checkpoint2['state_dict'])
        checkpoint3 = load_checkpoint(osp.join(args.logs_dir, 'clsmodel_best.pth.tar'))
        classifier_model.load_state_dict(checkpoint3['state_dict'])
        evaluator = ATTEvaluator(cnn_model, att_model, classifier_model, args.train_mode, criterion_veri)
        mAP, top1, top5, top10, top20 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)

    else:
        tensorboard_test_logdir = osp.join(args.logs_dir, 'test_log')
        writer = SummaryWriter(log_dir=tensorboard_test_logdir)
        if args.resume == 1:
            checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar'))
            cnn_model.load_state_dict(checkpoint1['state_dict'])
            checkpoint2 = load_checkpoint(osp.join(args.logs_dir, 'att_checkpoint.pth.tar'))
            att_model.load_state_dict(checkpoint2['state_dict'])
            checkpoint3 = load_checkpoint(osp.join(args.logs_dir, 'cls_checkpoint.pth.tar'))
            classifier_model.load_state_dict(checkpoint3['state_dict'])
            start_epoch = checkpoint1['epoch']
            best_top1 = checkpoint1['best_top1']
            print("=> Start epoch {}  best top1 {:.1%}"
                  .format(start_epoch, best_top1))
        # Trainer
        tensorboard_train_logdir = osp.join(args.logs_dir, 'train_log')
        trainer = SEQTrainer(cnn_model, att_model, classifier_model, criterion_veri, criterion_oim, args.train_mode, args.lr3, tensorboard_train_logdir)
        # Evaluator
        if args.train_mode == 'cnn':
            evaluator = CNNEvaluator(cnn_model, args.train_mode)
        elif args.train_mode == 'cnn_rnn':
            evaluator = ATTEvaluator(cnn_model, att_model, classifier_model, args.train_mode, criterion_veri)
        else:
            raise RuntimeError('Yes, Evaluator is necessary')

        for epoch in range(start_epoch, args.epochs):
            adjust_lr1(epoch)
            adjust_lr2(epoch)
            rate = adjust_lr3(epoch)
            trainer.train(epoch, train_loader, optimizer1, optimizer2, rate)

            if epoch % 1 == 0:
                mAP, top1, top5, top10, top20 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
                writer.add_scalar('test/mAP', mAP, epoch+1)
                writer.add_scalar('test/top1', top1, epoch+1)
                writer.add_scalar('test/top5', top5, epoch+1)
                writer.add_scalar('test/top10', top10, epoch+1)
                writer.add_scalar('test/top20', top20, epoch+1)
                is_best = top1 > best_top1
                if is_best:
                    best_top1 = top1

                save_cnn_checkpoint({
                    'state_dict': cnn_model.state_dict(),
                    'epoch': epoch + 1,
                    'best_top1': best_top1,
                }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar'))

                if args.train_mode == 'cnn_rnn':
                    save_att_checkpoint({
                        'state_dict': att_model.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    }, is_best, fpath=osp.join(args.logs_dir, 'att_checkpoint.pth.tar'))

                    save_cls_checkpoint({
                        'state_dict': classifier_model.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    }, is_best, fpath=osp.join(args.logs_dir, 'cls_checkpoint.pth.tar'))

        print('Test: ')
        checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar'))
        cnn_model.load_state_dict(checkpoint1['state_dict'])
        checkpoint2 = load_checkpoint(osp.join(args.logs_dir, 'attmodel_best.pth.tar'))
        att_model.load_state_dict(checkpoint2['state_dict'])
        checkpoint3 = load_checkpoint(osp.join(args.logs_dir, 'clsmodel_best.pth.tar'))
        classifier_model.load_state_dict(checkpoint3['state_dict'])
        evaluator = ATTEvaluator(cnn_model, att_model, classifier_model, args.train_mode, criterion_veri)
        mAP, top1, top5, top10, top20 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
Esempio n. 11
0
def main():
    # load the hyper-parameter
    with open('config.yml', encoding='utf-8') as f:
        CONFIG_DICT = yaml.safe_load(
            f
        )  # CONFIG_DICT is a dict that involves train_param, test_param and save_dir
    TRAIN_PARAM = CONFIG_DICT['train']
    TEST_PARAM = CONFIG_DICT['test']
    SAVA_DIR = CONFIG_DICT['save_path']
    os.environ['CUDA_VISIBLE_DEVICES'] = TRAIN_PARAM['gpu_device']
    torch.manual_seed(TRAIN_PARAM['seed'])

    if not TRAIN_PARAM['evaluate']:
        sys.stdout = Logging(osp.join(SAVA_DIR['log_dir'], 'log_train.txt'))
    else:
        sys.stdout = Logging(osp.join(SAVA_DIR['log_dir'], 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(TRAIN_PARAM))

    # GPU use Y/N
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        print("Currently using GPU {}".format([TRAIN_PARAM['gpu_device']]))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(TRAIN_PARAM['seed'])
    else:
        use_cpu = True

    print("Initializing dataset {}".format(TRAIN_PARAM['dataset']))
    # load data
    dataset = Datasets.init_dataset(name=TRAIN_PARAM['dataset'],
                                    root=TRAIN_PARAM['root'])

    pin_memory = True if use_gpu else False

    # define the tranform method
    train_transform = T.Compose([
        T.RandomSizedRectCrop(width=TRAIN_PARAM['width'],
                              height=TRAIN_PARAM['height']),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_transform = T.Compose([
        T.RectScale(width=TRAIN_PARAM['width'], height=TRAIN_PARAM['height']),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_loader = DataLoader(dataset=ImageDataset(dataset.train,
                                                   transform=train_transform),
                              sampler=RandomIdentitySampler(
                                  dataset.train,
                                  num_instances=TRAIN_PARAM['num_instances']),
                              batch_size=TRAIN_PARAM['train_batch'],
                              num_workers=TRAIN_PARAM['workers'],
                              pin_memory=pin_memory,
                              drop_last=True)

    query_loader = DataLoader(dataset=ImageDataset(dataset=dataset.query,
                                                   transform=test_transform),
                              batch_size=TEST_PARAM['test_batch'],
                              shuffle=False,
                              num_workers=TEST_PARAM['test_workers'],
                              pin_memory=pin_memory,
                              drop_last=False)

    gallery_loader = DataLoader(dataset=ImageDataset(dataset=dataset.gallery,
                                                     transform=test_transform),
                                batch_size=TEST_PARAM['test_batch'],
                                shuffle=False,
                                num_workers=TEST_PARAM['test_workers'],
                                pin_memory=pin_memory,
                                drop_last=False)

    # load model
    print("Initializing model: {}".format(TRAIN_PARAM['arch']))
    model = models.init_model(name=TRAIN_PARAM['arch'],
                              num_classes=dataset.num_train_pids,
                              loss={'xent', 'htri'})
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    # load loss_fuc
    # we judge if the softmax / triHard is in the TRAIN_PARAM['losses'], or else setting None
    criterion_xent = loss_fuc.init_losses(
        name='softmax', num_classes=dataset.num_train_pids,
        use_gpu=use_gpu) if 'softmax' in TRAIN_PARAM['losses'] else None

    criterion_trihard = loss_fuc.init_losses(
        name='trihard', margin=TRAIN_PARAM['margin']
    ) if 'trihard' in TRAIN_PARAM['losses'] else None

    # load optim
    optim = optimizer.init_optim(optim=TRAIN_PARAM['optim'],
                                 params=model.parameters(),
                                 lr=TRAIN_PARAM['lr'],
                                 weight_decay=TRAIN_PARAM['weight_decay'])
    if TRAIN_PARAM['step_size'] > 0:
        scheduler = lr_scheduler.StepLR(optimizer=optim,
                                        step_size=TRAIN_PARAM['step_size'],
                                        gamma=TRAIN_PARAM['gamma'])
    start_epoch = TRAIN_PARAM['start_epoch']

    # resume or not
    if TRAIN_PARAM['resume']:
        checkpoint = load_checkpoint(TRAIN_PARAM['resume'])
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(
            start_epoch, best_top1))

    if use_gpu:
        model = nn.DataParallel(model).cuda()
        criterion_trihard.cuda()
        criterion_xent.cuda()

    # test or not
    if TRAIN_PARAM['evaluate']:
        print("Evaluate only")
        # test(model, query_loader, gallery_loader, use_gpu)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    # instance the class Trainer
    trainer = Trainer(model=model,
                      criterion_xent=criterion_xent,
                      criterion_trihard=criterion_trihard,
                      eval=TRAIN_PARAM['triHard_only'])

    # start train
    for epoch in range(TRAIN_PARAM['start_epoch'], TRAIN_PARAM['max_epoch']):
        start_train_time = time.time()
        trainer.train(epoch=epoch,
                      optimizer=optim,
                      data_loader=train_loader,
                      use_gpu=use_gpu,
                      print_freq=TRAIN_PARAM['print_freq'])
        train_time += round(time.time() - start_train_time)

    #
    if (epoch + 1) > TEST_PARAM['start_eval'] \
      and TEST_PARAM['eval_step'] > 0 \
      and (epoch + 1) % TEST_PARAM['eval_step'] == 0 or (
      epoch + 1) == TRAIN_PARAM['max_epoch']:
        print("==> Test")
        rank1 = test(model, query_loader, gallery_loader, use_gpu)
        is_best = rank1 > best_rank1
        if is_best:
            best_rank1 = rank1
            best_epoch = epoch + 1

        if use_gpu:
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        save_checkpoint(
            {
                'state_dict': state_dict,
                'rank1': rank1,
                'epoch': epoch,
            }, is_best,
            osp.join(TRAIN_PARAM['checkpoint_dir'],
                     'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
def main(args):
    print(args.p_lambda)
    save_dir = args.save_dir
    mkdir_if_missing(save_dir)

    print("DRO:", args.DRO)

    # sys.stdout: output from console
    # sys.stderr: exceptions from python
    sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt')) #sys.stdout --> 'log.txt'
    sys.stderr = logging.Logger(os.path.join(save_dir, 'error.txt')) #sys.stderr --> 'error.txt'

    display(args)
    start = 0

    model = models.create(args.net, pretrained=True, dim=args.dim)


    save_checkpoint({
        'state_dict': model.state_dict(),
        'epoch': 0,
    }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep'+ str(start) + '.pth.tar'))
    # for vgg and densenet

    if args.resume is None:
        model_dict = model.state_dict()
    else:
        # resume model
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)


    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # freeze BN
    if args.freeze_BN is True:
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval) # m represents default layers.
    else:
        print(40*'#', 'BatchNorm NOT frozen')


    optimizer = torch.optim.Adam(model.module.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)

    print("--------------------------:", args.p_lambda)
    criterion = DRO.create(args.DRO, loss = args.loss, margin=args.margin, alpha=args.alpha,
                           beta = args.beta,
                           p_lambda = args.p_lambda, p_lambda_neg = args.p_lambda_neg, K = args.K,
                           select_TOPK_all = args.select_TOPK_all, p_choice = args.p_choice,
                           truncate_p = args.truncate_p).cuda()

    # Decor_loss = losses.create('decode').cuda()
    print("Train, RAE:", args.mode)
    data = DataSet.create(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root, RAE=args.mode)

    train_loader = torch.utils.data.DataLoader(
        data.train, batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances),
        drop_last=True, pin_memory=True, num_workers=args.nThreads)


    # save the train information

    for epoch in range(start, args.epochs):


        train(epoch=epoch, model=model, criterion=criterion,
              optimizer=optimizer, train_loader=train_loader, args=args)

        if epoch == 1:
            optimizer.param_groups[0]['lr_mul'] = 0.1
        
        if (epoch+1) % args.save_step == 0 or epoch==0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch+1),
            }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar'))
Esempio n. 13
0
                    default=512,
                    help='Dimension of Embedding Feather')
parser.add_argument('-batch_size', type=int, default=64)
parser.add_argument('--nThreads',
                    '-j',
                    default=16,
                    type=int,
                    metavar='N',
                    help='number of data loading threads (default: 2)')

args = parser.parse_args()

PATH = args.r
model = models.create('vgg', dim=args.dim, pretrained=False)

resume = load_checkpoint(PATH)
epoch = resume['epoch']
model.load_state_dict(resume['state_dict'])

# model = torch.load(args.r)
model.classifier = torch.nn.Sequential()
model = torch.nn.DataParallel(model).cuda()

data = DataSet.create(args.data)

if args.data == 'shop':
    gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                 batch_size=64,
                                                 shuffle=False,
                                                 drop_last=False)
    query_loader = torch.utils.data.DataLoader(data.query,
Esempio n. 14
0
def load_best_checkpoint(cnn_model, siamese_model):
    checkpoint0 = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar'))
    cnn_model.load_state_dict(checkpoint0['state_dict'])

    checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'siamesemodel_best.pth.tar'))
    siamese_model.load_state_dict(checkpoint1['state_dict'])
Esempio n. 15
0
def main(args):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # log file
    if args.evaluate == 1:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_test.txt'))
    else:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    # from reid.data import get_data ,
    dataset, num_classes, train_loader, query_loader, gallery_loader = \
        get_data(args.dataset, args.split, args.data_dir,
                 args.batch_size, args.seq_len, args.seq_srd,
                 args.workers, args.train_mode)

    # create CNN model
    cnn_model = models.create(args.a1, num_features=args.features, dropout=args.dropout)

    # create ATT model
    input_num = cnn_model.feat.in_features  # 2048
    output_num = args.features  # 128
    att_model = models.create(args.a2, input_num, output_num)

    # create classifier model
    class_num = 2
    classifier_model = models.create(args.a3,  output_num, class_num)

    # CUDA acceleration model

    cnn_model = torch.nn.DataParallel(cnn_model).to(device)
    att_model = att_model.to(device)
    classifier_model = classifier_model.to(device)

    criterion_oim = OIMLoss(args.features, num_classes,
                            scalar=args.oim_scalar, momentum=args.oim_momentum)
    criterion_veri = PairLoss(args.sampling_rate)
    criterion_oim.to(device)
    criterion_veri.to(device)

    # Optimizer
    base_param_ids = set(map(id, cnn_model.module.base.parameters()))
    new_params = [p for p in cnn_model.parameters() if
                  id(p) not in base_param_ids]

    param_groups1 = [
        {'params': cnn_model.module.base.parameters(), 'lr_mult': 1},
        {'params': new_params, 'lr_mult': 1}]
    param_groups2 = [
        {'params': att_model.parameters(), 'lr_mult': 1},
        {'params': classifier_model.parameters(), 'lr_mult': 1}]

    optimizer1 = torch.optim.SGD(param_groups1, lr=args.lr1,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)

    optimizer2 = torch.optim.SGD(param_groups2, lr=args.lr2,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)
    # optimizer1 = torch.optim.Adam(param_groups1, lr=args.lr1, weight_decay=args.weight_decay)
    #
    # optimizer2 = torch.optim.Adam(param_groups2, lr=args.lr2, weight_decay=args.weight_decay)

    # Schedule Learning rate
    def adjust_lr1(epoch):
        lr = args.lr1 * (0.1 ** (epoch/args.lr1step))
        print(lr)
        for g in optimizer1.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr2(epoch):
        lr = args.lr2 * (0.01 ** (epoch//args.lr2step))
        print(lr)
        for g in optimizer2.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr3(epoch):
        lr = args.lr3 * (0.000001 ** (epoch //args.lr3step))
        print(lr)
        return lr

    # Trainer
    trainer = SEQTrainer(cnn_model, att_model, classifier_model, criterion_veri, criterion_oim, args.train_mode, args.lr3)

    # Evaluator
    if args.train_mode == 'cnn':
        evaluator = CNNEvaluator(cnn_model, args.train_mode)
    elif args.train_mode == 'cnn_rnn':
        evaluator = ATTEvaluator(cnn_model, att_model, classifier_model, args.train_mode)

    else:
        raise RuntimeError('Yes, Evaluator is necessary')

    best_top1 = 0
    if args.evaluate == 1:  # evaluate
        checkpoint = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar'))
        cnn_model.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(osp.join(args.logs_dir, 'attmodel_best.pth.tar'))
        att_model.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(osp.join(args.logs_dir, 'clsmodel_best.pth.tar'))
        classifier_model.load_state_dict(checkpoint['state_dict'])

        top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)

    else:
        for epoch in range(args.start_epoch, args.epochs):
            adjust_lr1(epoch)
            adjust_lr2(epoch)
            rate = adjust_lr3(epoch)
            trainer.train(epoch, train_loader, optimizer1, optimizer2, rate)

            if (epoch+1) % 3 == 0 or (epoch+1) == args.epochs:

                top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
                is_best = top1 > best_top1
                if is_best:
                    best_top1 = top1

                save_cnn_checkpoint({
                    'state_dict': cnn_model.state_dict(),
                    'epoch': epoch + 1,
                    'best_top1': best_top1,
                }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar'))

                if args.train_mode == 'cnn_rnn':
                    save_att_checkpoint({
                        'state_dict': att_model.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    }, is_best, fpath=osp.join(args.logs_dir, 'att_checkpoint.pth.tar'))

                    save_cls_checkpoint({
                        'state_dict': classifier_model.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    }, is_best, fpath=osp.join(args.logs_dir, 'cls_checkpoint.pth.tar'))
Esempio n. 16
0
def main(args):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # log file
    if args.evaluate == 1:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_test.txt'))
    else:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    # from reid.data import get_data ,
    dataset, num_classes, train_loader, query_loader, gallery_loader = \
        get_data(args.dataset, args.split, args.data_dir,
                 args.batch_size, args.seq_len, args.seq_srd,
                 args.workers, args.train_mode)

    # create CNN model
    cnn_model = models.create(args.a1,
                              num_features=args.features,
                              dropout=args.dropout)

    # create ATT model
    input_num = cnn_model.feat.in_features  # 2048
    output_num = args.features  # 128
    att_model = models.create(args.a2, input_num, output_num)

    # create classifier model
    class_num = 2
    classifier_model = models.create(args.a3, output_num, class_num)

    # CUDA acceleration model

    cnn_model = torch.nn.DataParallel(cnn_model).to(device)
    att_model = att_model.to(device)
    classifier_model = classifier_model.to(device)

    criterion_oim = OIMLoss(args.features,
                            num_classes,
                            scalar=args.oim_scalar,
                            momentum=args.oim_momentum)
    criterion_veri = PairLoss(args.sampling_rate)
    criterion_oim.to(device)
    criterion_veri.to(device)

    # Optimizer
    base_param_ids = set(map(id, cnn_model.module.base.parameters()))
    new_params = [
        p for p in cnn_model.parameters() if id(p) not in base_param_ids
    ]

    param_groups1 = [{
        'params': cnn_model.module.base.parameters(),
        'lr_mult': 1
    }, {
        'params': new_params,
        'lr_mult': 1
    }]
    param_groups2 = [{
        'params': att_model.parameters(),
        'lr_mult': 1
    }, {
        'params': classifier_model.parameters(),
        'lr_mult': 1
    }]

    optimizer1 = torch.optim.SGD(param_groups1,
                                 lr=args.lr1,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)

    optimizer2 = torch.optim.SGD(param_groups2,
                                 lr=args.lr2,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)

    # optimizer1 = torch.optim.Adam(param_groups1, lr=args.lr1, weight_decay=args.weight_decay)
    #
    # optimizer2 = torch.optim.Adam(param_groups2, lr=args.lr2, weight_decay=args.weight_decay)

    # Schedule Learning rate
    def adjust_lr1(epoch):
        lr = args.lr1 * (0.1**(epoch / args.lr1step))
        print(lr)
        for g in optimizer1.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr2(epoch):
        lr = args.lr2 * (0.01**(epoch // args.lr2step))
        print(lr)
        for g in optimizer2.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr3(epoch):
        lr = args.lr3 * (0.000001**(epoch // args.lr3step))
        print(lr)
        return lr

    # Trainer
    trainer = SEQTrainer(cnn_model, att_model, classifier_model,
                         criterion_veri, criterion_oim, args.train_mode,
                         args.lr3)

    # Evaluator
    if args.train_mode == 'cnn':
        evaluator = CNNEvaluator(cnn_model, args.train_mode)
    elif args.train_mode == 'cnn_rnn':
        evaluator = ATTEvaluator(cnn_model, att_model, classifier_model,
                                 args.train_mode)

    else:
        raise RuntimeError('Yes, Evaluator is necessary')

    best_top1 = 0
    if args.evaluate == 1:  # evaluate
        checkpoint = load_checkpoint(
            osp.join(args.logs_dir, 'cnnmodel_best.pth.tar'))
        cnn_model.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(
            osp.join(args.logs_dir, 'attmodel_best.pth.tar'))
        att_model.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(
            osp.join(args.logs_dir, 'clsmodel_best.pth.tar'))
        classifier_model.load_state_dict(checkpoint['state_dict'])

        top1 = evaluator.evaluate(query_loader, gallery_loader,
                                  dataset.queryinfo, dataset.galleryinfo)

    else:
        for epoch in range(args.start_epoch, args.epochs):
            adjust_lr1(epoch)
            adjust_lr2(epoch)
            rate = adjust_lr3(epoch)
            trainer.train(epoch, train_loader, optimizer1, optimizer2, rate)

            if (epoch + 1) % 3 == 0 or (epoch + 1) == args.epochs:

                top1 = evaluator.evaluate(query_loader, gallery_loader,
                                          dataset.queryinfo,
                                          dataset.galleryinfo)
                is_best = top1 > best_top1
                if is_best:
                    best_top1 = top1

                save_cnn_checkpoint(
                    {
                        'state_dict': cnn_model.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    },
                    is_best,
                    fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar'))

                if args.train_mode == 'cnn_rnn':
                    save_att_checkpoint(
                        {
                            'state_dict': att_model.state_dict(),
                            'epoch': epoch + 1,
                            'best_top1': best_top1,
                        },
                        is_best,
                        fpath=osp.join(args.logs_dir,
                                       'att_checkpoint.pth.tar'))

                    save_cls_checkpoint(
                        {
                            'state_dict': classifier_model.state_dict(),
                            'epoch': epoch + 1,
                            'best_top1': best_top1,
                        },
                        is_best,
                        fpath=osp.join(args.logs_dir,
                                       'cls_checkpoint.pth.tar'))
def main(args):

    # s_ = time.time()
    save_dir = args.save_dir
    mkdir_if_missing(save_dir)
    #sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt'))
    writer = SummaryWriter('log/' + args.log_name)
    display(args)
    start = 0

    model = models.create(args.net, pretrain=True, dim=args.dim)

    # for vgg and densenet
    if args.resume is None:
        model_dict = model.state_dict()

    else:
        # resume model
        print('load model from {}'.format(args.resume))
        model = load_checkpoint(args.resume, args)
        start = 80

    model = torch.nn.DataParallel(model)
    model = model.cuda()

    #freeze vgg layers

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    criterion = losses.create(args.loss,
                              margin=args.margin,
                              alpha=args.alpha,
                              beta=args.beta).cuda()

    data = DataSet.create(args.data)

    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train,
                                          num_instances=args.num_instances),
        drop_last=True,
        pin_memory=True)

    # save the train information

    for epoch in range(start, args.epochs):

        # if epoch == 5:
        #     optimizer = torch.optim.Adam(model.parameters(), lr=args.lr/100)
        #     print(args.lr/100)

        train(
            writer,
            epoch=epoch,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            train_loader=train_loader,
            args=args,
        )

        if epoch == 800:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr / 10,
                                         weight_decay=args.weight_decay)

        if (epoch + 1) % args.save_step == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            },
                            is_best=False,
                            fpath=osp.join(
                                args.save_dir,
                                'ckp_ep' + str(epoch + 1) + '.pth.tar'))
Esempio n. 18
0
def main(args):
    # s_ = time.time()
    print(torch.cuda.get_device_properties(device=0).total_memory)
    torch.cuda.empty_cache()
    print(args)
    save_dir = args.save_dir
    mkdir_if_missing(save_dir)
    num_txt = len(glob.glob(save_dir + "/*.txt"))
    sys.stdout = logging.Logger(
        os.path.join(save_dir, "log_" + str(num_txt) + ".txt"))
    display(args)
    start = 0

    model = models.create(args.net,
                          pretrained=args.pretrained,
                          dim=args.dim,
                          self_supervision_rot=args.self_supervision_rot)
    all_pretrained = glob.glob(save_dir + "/*.pth.tar")

    if (args.resume is None) or (len(all_pretrained) == 0):
        model_dict = model.state_dict()

    else:
        # resume model
        all_pretrained_epochs = sorted(
            [int(x.split("/")[-1][6:-8]) for x in all_pretrained])
        args.resume = os.path.join(
            save_dir, "ckp_ep" + str(all_pretrained_epochs[-1]) + ".pth.tar")
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)

    model = torch.nn.DataParallel(model)
    model = model.cuda()
    fake_centers_dir = os.path.join(args.save_dir, "fake_center.npy")

    if np.sum(["train_1.txt" in x
               for x in glob.glob(args.save_dir + "/**/*")]) == 0:
        if args.rot_only:
            create_fake_labels(None, None, args)

        else:
            data = dataset.Dataset(args.data,
                                   ratio=args.ratio,
                                   width=args.width,
                                   origin_width=args.origin_width,
                                   root=args.data_root,
                                   self_supervision_rot=0,
                                   mode="test",
                                   rot_bt=args.rot_bt,
                                   corruption=args.corruption,
                                   args=args)

            fake_train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=100,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=args.nThreads)

            train_feature, train_labels = extract_features(
                model,
                fake_train_loader,
                print_freq=1e5,
                metric=None,
                pool_feature=args.pool_feature,
                org_feature=True)

            create_fake_labels(train_feature, train_labels, args)

            del train_feature

            fake_centers = "k-means++"

            torch.cuda.empty_cache()

    elif os.path.exists(fake_centers_dir):
        fake_centers = np.load(fake_centers_dir)
    else:
        fake_centers = "k-means++"

    time.sleep(60)

    model.train()

    # freeze BN
    if (args.freeze_BN is True) and (args.pretrained):
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40 * '#', 'BatchNorm NOT frozen')

    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.module.classifier.parameters()))
    new_rot_param_ids = set()
    if args.self_supervision_rot:
        new_rot_param_ids = set(
            map(id, model.module.classifier_rot.parameters()))
        print(new_rot_param_ids)

    new_params = [
        p for p in model.module.parameters() if id(p) in new_param_ids
    ]

    new_rot_params = [
        p for p in model.module.parameters() if id(p) in new_rot_param_ids
    ]

    base_params = [
        p for p in model.module.parameters()
        if (id(p) not in new_param_ids) and (id(p) not in new_rot_param_ids)
    ]

    param_groups = [{
        'params': base_params
    }, {
        'params': new_params
    }, {
        'params': new_rot_params,
        'lr': args.rot_lr
    }]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = losses.create(args.loss,
                              margin=args.margin,
                              alpha=args.alpha,
                              beta=args.beta,
                              base=args.loss_base).cuda()

    data = dataset.Dataset(args.data,
                           ratio=args.ratio,
                           width=args.width,
                           origin_width=args.origin_width,
                           root=args.save_dir,
                           self_supervision_rot=args.self_supervision_rot,
                           rot_bt=args.rot_bt,
                           corruption=1,
                           args=args)
    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train,
                                          num_instances=args.num_instances),
        drop_last=True,
        pin_memory=True,
        num_workers=args.nThreads)

    # save the train information

    for epoch in range(start, args.epochs):

        train(epoch=epoch,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              train_loader=train_loader,
              args=args)

        if (epoch + 1) % args.save_step == 0 or epoch == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            },
                            is_best=False,
                            fpath=osp.join(
                                args.save_dir,
                                'ckp_ep' + str(epoch + 1) + '.pth.tar'))

        if ((epoch + 1) % args.up_step == 0) and (not args.rot_only):
            # rewrite train_1.txt file
            data = dataset.Dataset(args.data,
                                   ratio=args.ratio,
                                   width=args.width,
                                   origin_width=args.origin_width,
                                   root=args.data_root,
                                   self_supervision_rot=0,
                                   mode="test",
                                   rot_bt=args.rot_bt,
                                   corruption=args.corruption,
                                   args=args)
            fake_train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=args.batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=args.nThreads)
            train_feature, train_labels = extract_features(
                model,
                fake_train_loader,
                print_freq=1e5,
                metric=None,
                pool_feature=args.pool_feature,
                org_feature=(args.dim % 64 != 0))
            fake_centers = create_fake_labels(train_feature,
                                              train_labels,
                                              args,
                                              init_centers=fake_centers)
            del train_feature
            torch.cuda.empty_cache()
            time.sleep(60)
            np.save(fake_centers_dir, fake_centers)
            # reload data
            data = dataset.Dataset(
                args.data,
                ratio=args.ratio,
                width=args.width,
                origin_width=args.origin_width,
                root=args.save_dir,
                self_supervision_rot=args.self_supervision_rot,
                rot_bt=args.rot_bt,
                corruption=1,
                args=args)

            train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=args.batch_size,
                sampler=FastRandomIdentitySampler(
                    data.train, num_instances=args.num_instances),
                drop_last=True,
                pin_memory=True,
                num_workers=args.nThreads)

            # test on testing data
            # extract_recalls(data=args.data, data_root=args.data_root, width=args.width, net=args.net, checkpoint=None,
            #         dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature,
            #         gallery_eq_query=args.gallery_eq_query, model=model)
            model.train()
            if (args.freeze_BN is True) and (args.pretrained):
                print(40 * '#', '\n BatchNorm frozen')
                model.apply(set_bn_eval)
Esempio n. 19
0
        'log_interval': args.log_interval,
        'model_outfile': args.model_outfile,
        'lr_reduce_factor': args.lr_reduce_factor,
        'patience': args.patience,
        'tensorboard': args.tensorboard,
        'run_label': args.run_label,
        'logger': logger
    }
    trainer = TrainerFactory.get_trainer(args.dataset, model, embedding,
                                         train_loader, trainer_config,
                                         train_evaluator, test_evaluator,
                                         dev_evaluator)

    if not args.skip_training:
        total_params = sum(param.numel() for param in model.parameters()
                           if param.requires_grad)
        logger.info('Total number of trainable parameters: %s', total_params)
        trainer.train(args.epochs)

    _, _, state_dict, _, _ = load_checkpoint(args.model_outfile)

    for k, tensor in state_dict.items():
        state_dict[k] = tensor.to(device)

    model.load_state_dict(state_dict)
    if dev_loader:
        evaluate_dataset('dev', dataset_cls, model, embedding, dev_loader,
                         args.batch_size, args.device)
    evaluate_dataset('test', dataset_cls, model, embedding, test_loader,
                     args.batch_size, args.device, args.keep_results)
Esempio n. 20
0
def main(args):
    if not torch.cuda.is_available():
        args.cpu_only = True
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if not args.cpu_only:
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True

    # Logs directory
    mkdir_if_missing(args.logs_dir)
    if not args.eval_only:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    # Data
    train_dataset, test_dataset, num_classes = get_datasets(
        args.dataset, args.data_dir)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=args.workers,
                             shuffle=False,
                             pin_memory=True)

    # Model
    model = WideResNet(args.depth,
                       args.width,
                       num_classes,
                       dropout_rate=args.dropout)
    criterion = nn.CrossEntropyLoss()

    start_epoch, best_prec1 = 0, 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['model'])
        start_epoch = checkpoint['epoch'] + 1
        best_prec1 = checkpoint['best_prec1']
        print("=> Load from {}, start epoch {}, best prec1 {:.2%}".format(
            args.resume, start_epoch, best_prec1))

    if not args.cpu_only:
        model = DataParallel(model).cuda()
        criterion = criterion.cuda()

    # Optimizer
    if args.optim_method == 'sgd':
        optimizer = SGD(model.parameters(),
                        lr=args.lr,
                        nesterov=True,
                        momentum=0.9,
                        weight_decay=args.weight_decay)
    else:
        optimizer = Adam(model.parameters(), lr=args.lr)

    # Evaluation only
    if args.eval_only:
        evaluate(start_epoch - 1, test_loader, model, criterion, args.cpu_only)
        return

    # Training
    epoch_steps = json.loads(args.epoch_steps)[::-1]
    for epoch in range(start_epoch, args.epochs):
        # Adjust learning rate
        power = 0
        for i, step in enumerate(epoch_steps):
            if epoch >= step:
                power = len(epoch_steps) - i
        lr = args.lr * (args.lr_decay_ratio**power)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Training
        train(epoch, train_loader, model, criterion, optimizer, args.cpu_only)
        prec1 = evaluate(epoch, test_loader, model, criterion, args.cpu_only)
        is_best = prec1 > best_prec1
        best_prec1 = max(best_prec1, prec1)

        # Save checkpoint
        checkpoint = {'epoch': epoch, 'best_prec1': best_prec1}
        if args.cpu_only:
            checkpoint['model'] = model.state_dict()
        else:
            checkpoint['model'] = model.module.state_dict()
        save_checkpoint(checkpoint, is_best,
                        osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {}  Prec1: {:.2%}  Best: {:.2%}{}\n'.format(
            epoch, prec1, best_prec1, ' *' if is_best else ''))
def main(args):
    if not os.path.exists(args.logs_dir):
        os.mkdir(args.logs_dir)
    if not os.path.exists(args.tensorboard_dir):
        os.mkdir(args.tensorboard_dir)
    tensorboardWrite = SummaryWriter(log_dir = args.tensorboard_dir)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # log file
    if args.evaluate == 1:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_test.txt'))
    else:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    print("Initializing dataset {}".format(args.dataset))
    # from reid.data import get_data ,
    dataset, num_classes, train_loader, query_loader, gallery_loader = \
        get_data(args, args.dataset, args.split, args.data_dir,
                 args.batch_size, args.seq_len, args.seq_srd,
                 args.workers)
    print('[len] train: {}, query: {}, gallery: {}'.format(*list(map(len, [train_loader, query_loader, gallery_loader]))))

    # create CNN model
    # cnn_model = models.create(args.a1, args.flow1, args.flow2, num_features=args.features, dropout=args.dropout)
    cnn_model_flow = [models.create(args.a1, args.flow1, num_features=args.features, dropout=args.dropout)]
    if any(args.flow2):
        cnn_model_flow.append(models.create(args.a1, args.flow2, num_features=args.features, dropout=args.dropout))
    # cnn_model_flow1 = cnn_model_flow1.cuda()
    # cnn_model_flow2 = cnn_model_flow2.cuda()


    # create ATT model
    input_num = cnn_model_flow[0].feat.in_features  # 2048
    output_num = args.features  # 128
    att_model = models.create(args.a2, input_num, output_num)
    # att_model.cuda()

    # # ------peixian:tow attmodel------
    # att_model_flow1 = models.create(args.a2, input_num, output_num)
    # att_model_flow2 = models.create(args.a2, input_num, output_num)
    # # --------------------------------

    # create classifier model
    class_num = 2
    classifier_model = models.create(args.a3,  output_num, class_num)
    # classifier_model.cuda()

    # CUDA acceleration model

    # cnn_model = torch.nn.DataParallel(cnn_model).to(device)
    # # ------peixian:tow attmodel------
    # for att_model in [att_model_flow1, att_model_flow2]:
    #     att_model = att_model.to(device)
    # # --------------------------------
    att_model = att_model.cuda()
    classifier_model = classifier_model.cuda()

    # cnn_model = torch.nn.DataParallel(cnn_model).cuda()
    # cnn_model_flow1 = torch.nn.DataParallel(cnn_model_flow1,device_ids=[0,1,2])
    # cnn_model_flow2 = torch.nn.DataParallel(cnn_model_flow2,device_ids=[0,1,2])
    
    # 
    cnn_model_flow[0].cuda()
    cnn_model_flow[0] = torch.nn.DataParallel(cnn_model_flow[0],device_ids=[0])
    if len(cnn_model_flow) > 1:
        cnn_model_flow[1].cuda()
        cnn_model_flow[1] = torch.nn.DataParallel(cnn_model_flow[1],device_ids=[0])



    # att_model = torch.nn.DataParallel(att_model,device_ids=[1,2,3])
    # classifier_model = torch.nn.DataParallel(classifier_model,device_ids=[1,2,3])


    criterion_oim = OIMLoss(args.features, num_classes,
                            scalar=args.oim_scalar, momentum=args.oim_momentum)
    criterion_veri = PairLoss(args.sampling_rate)
    criterion_oim.cuda()
    criterion_veri.cuda()

    # criterion_oim.cuda()
    # criterion_veri.cuda()

    # Optimizer
    optimizer1 = []
    # cnn_model_flow = [cnn_model_flow1, cnn_model_flow2]
    for cnn_model in range(len(cnn_model_flow)):
        base_param_ids = set(map(id, cnn_model_flow[cnn_model].module.base.parameters()))
        new_params = [p for p in cnn_model_flow[cnn_model].module.parameters() if
                    id(p) not in base_param_ids]

        param_groups1 = [
            {'params': cnn_model_flow[cnn_model].module.base.parameters(), 'lr_mult': 1},
            {'params': new_params, 'lr_mult': 1}]

        optimizer1.append(torch.optim.SGD(param_groups1, lr=args.lr1,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True))
    
    param_groups2 = [
        {'params': att_model.parameters(), 'lr_mult': 1},
        {'params': classifier_model.parameters(), 'lr_mult': 1}]                        
    optimizer2 = torch.optim.SGD(param_groups2, lr=args.lr2,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    # optimizer1 = torch.optim.Adam(param_groups1, lr=args.lr1, weight_decay=args.weight_decay)
    #
    # optimizer2 = torch.optim.Adam(param_groups2, lr=args.lr2, weight_decay=args.weight_decay)

    # Schedule Learning rate
    def adjust_lr1(epoch):
        lr = args.lr1 * (0.1 ** (epoch/args.lr1step))
        print(lr)
        for o in optimizer1:
            for g in o.param_groups:
                g['lr'] = lr * g.get('lr_mult', 1)

    def adjust_lr2(epoch):
        lr = args.lr2 * (0.01 ** (epoch//args.lr2step))
        print(lr)
        for g in optimizer2.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)
        # # peixian:  two attmodel:
        # for o in optimizer2:
        #     for g in o.param_groups:
        #         g['lr'] = lr * g.get('lr_mult', 1)
        # #

    def adjust_lr3(epoch):
        lr = args.lr3 * (0.000001 ** (epoch //args.lr3step))
        print(lr)
        return lr

    # Trainer
    trainer = SEQTrainer(cnn_model_flow, att_model, classifier_model, criterion_veri, criterion_oim, args.lr3, args.flow1rate)


    # Evaluator
    evaluator = ATTEvaluator(cnn_model_flow, att_model, classifier_model, args.flow1rate)

    best_top1 = 0
    if args.evaluate == 1 or args.pretrain == 1:  # evaluate
        for cnn_model in range(len(cnn_model_flow)):
            checkpoint = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best_flow' + str(cnn_model) + '.pth.tar'))
            cnn_model_flow[cnn_model].module.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(osp.join(args.logs_dir, 'attmodel_best.pth.tar'))
        att_model.load_state_dict(checkpoint['state_dict'])

        checkpoint = load_checkpoint(osp.join(args.logs_dir, 'clsmodel_best.pth.tar'))
        classifier_model.load_state_dict(checkpoint['state_dict'])

        top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)
        # top1 = evaluator.evaluate(query_loader, gallery_loader,dataset.num_tracklet)

    if args.evaluate == 0:
        for epoch in range(args.start_epoch, args.epochs):
            adjust_lr1(epoch)
            adjust_lr2(epoch)
            rate = adjust_lr3(epoch)
            trainer.train(epoch, train_loader, optimizer1, optimizer2, rate,tensorboardWrite)

            if (epoch+1) % 1 == 0 or (epoch+1) == args.epochs:

                top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.queryinfo, dataset.galleryinfo)

                is_best = top1 > best_top1
                if is_best:
                    best_top1 = top1
                for cnn_model in range(len(cnn_model_flow)):
                    save_cnn_checkpoint({
                        'state_dict': cnn_model_flow[cnn_model].module.state_dict(),
                        'epoch': epoch + 1,
                        'best_top1': best_top1,
                    }, is_best, index=cnn_model, fpath=osp.join(args.logs_dir, 'cnn_checkpoint_flow'+str(cnn_model)+'.pth.tar'))

                save_att_checkpoint({
                    'state_dict': att_model.state_dict(),
                    'epoch': epoch + 1,
                    'best_top1': best_top1,
                }, is_best, fpath=osp.join(args.logs_dir, 'att_checkpoint.pth.tar'))

                save_cls_checkpoint({
                    'state_dict': classifier_model.state_dict(),
                    'epoch': epoch + 1,
                    'best_top1': best_top1,
                }, is_best, fpath=osp.join(args.logs_dir, 'cls_checkpoint.pth.tar'))
Esempio n. 22
0
def main(args):
    s_ = time.time()

    #  训练日志保存
    save_dir = args.save_dir
    mkdir_if_missing(save_dir)

    sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt'))
    display(args)
    start = 0

    model = models.create(args.net, pretrained=True, dim=args.dim)

    if args.r is None:
        model_dict = model.state_dict()
        # orthogonal init
        if args.init == 'orth':
            w = model_dict['classifier.0.weight']
            model_dict['classifier.0.weight'] = torch.nn.init.orthogonal_(w)
        else:
            print('initialize the FC layer kai-ming-ly')
            w = model_dict['classifier.0.weight']
            model_dict['classifier.0.weight'] = torch.nn.init.kaiming_normal_(
                w)

        # zero bias
        model_dict['classifier.0.bias'] = torch.zeros(args.dim)
        model.load_state_dict(model_dict)
    else:
        # resume model
        chk_pt = load_checkpoint(args.r)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # freeze BN
    if args.BN == 1:
        print(40 * '#', 'BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40 * '#', 'BatchNorm NOT frozen')
    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10

    new_param_ids = set(map(id, model.module.classifier.parameters()))

    new_params = [
        p for p in model.module.parameters() if id(p) in new_param_ids
    ]

    base_params = [
        p for p in model.module.parameters() if id(p) not in new_param_ids
    ]

    param_groups = [{
        'params': base_params,
        'lr_mult': 0.0
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.loss == 'center-nca':
        criterion = losses.create(args.loss, alpha=args.alpha).cuda()
    elif args.loss == 'cluster-nca':
        criterion = losses.create(args.loss, alpha=args.alpha,
                                  beta=args.beta).cuda()
    elif args.loss == 'neighbour':
        criterion = losses.create(args.loss, k=args.k,
                                  margin=args.margin).cuda()
    elif args.loss == 'nca':
        criterion = losses.create(args.loss, alpha=args.alpha, k=args.k).cuda()
    elif args.loss == 'triplet':
        criterion = losses.create(args.loss, alpha=args.alpha).cuda()
    elif args.loss == 'bin' or args.loss == 'ori_bin':
        criterion = losses.create(args.loss,
                                  margin=args.margin,
                                  alpha=args.alpha)
    else:
        criterion = losses.create(args.loss).cuda()

    # Decor_loss = losses.create('decor').cuda()
    data = DataSet.create(args.data, root=None)

    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.BatchSize,
        sampler=FastRandomIdentitySampler(data.train,
                                          num_instances=args.num_instances),
        drop_last=True,
        pin_memory=True,
        num_workers=args.nThreads)

    # save the train information
    epoch_list = list()
    loss_list = list()
    pos_list = list()
    neg_list = list()

    for epoch in range(start, args.epochs):
        epoch_list.append(epoch)

        running_loss = 0.0
        running_pos = 0.0
        running_neg = 0.0

        if epoch == 1:
            optimizer.param_groups[0]['lr_mul'] = 0.1

        if (epoch == 1000 and args.data == 'car') or \
                (epoch == 550 and args.data == 'cub') or \
                (epoch == 100 and args.data in ['shop', 'jd']):

            param_groups = [{
                'params': base_params,
                'lr_mult': 0.1
            }, {
                'params': new_params,
                'lr_mult': 1.0
            }]

            optimizer = torch.optim.Adam(param_groups,
                                         lr=0.1 * args.lr,
                                         weight_decay=args.weight_decay)

        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            # wrap them in Variable
            inputs = Variable(inputs.cuda())

            # type of labels is Variable cuda.Longtensor
            labels = Variable(labels).cuda()

            optimizer.zero_grad()

            embed_feat = model(inputs)

            loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels)

            # decor_loss = Decor_loss(embed_feat)

            # loss += args.theta * decor_loss

            if not type(loss) == torch.Tensor:
                print('One time con not back-ward')
                continue

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_neg += dist_an
            running_pos += dist_ap

            if epoch == 0 and i == 0:
                print(50 * '#')
                print('Train Begin -- HA-HA-HA-HA-AH-AH-AH-AH --')

        loss_list.append(running_loss)
        pos_list.append(running_pos / (i + 1))
        neg_list.append(running_neg / (i + 1))

        print(
            '[Epoch %03d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f'
            % (epoch + 1, running_loss / (i + 1), inter_, dist_ap, dist_an))

        if (epoch + 1) % args.save_step == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            },
                            is_best=False,
                            fpath=osp.join(
                                args.save_dir,
                                'ckp_ep' + str(epoch + 1) + '.pth.tar'))

    np.savez(os.path.join(save_dir, "result.npz"),
             epoch=epoch_list,
             loss=loss_list,
             pos=pos_list,
             neg=neg_list)
    t = time.time() - s_
    print('training takes %.2f hour' % (t / 3600))