コード例 #1
0
ファイル: Sequence2Feature.py プロジェクト: wanpeng16/RVSML
def Sequence2Feature_testvec(data, net, checkpoint, in_dim, middle_dim,out_dim=512, root=None, nThreads=16, batch_size=100, train_flag=True,**kargs):
    dataset_name = data
    model = models.create(net, in_dim, middle_dim, out_dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model).cuda()
    datatrain = DataSet.create_vec(root=root,train_flag=True)

    gallery_loader = torch.utils.data.DataLoader(
        datatrain.seqdata, batch_size=batch_size, shuffle=False,
        drop_last=False, pin_memory=True, num_workers=nThreads)

    pool_feature = False

    train_feature, train_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        
    datatest = DataSet.create_vec(root=root,train_flag=False)
    query_loader = torch.utils.data.DataLoader(
        datatest.seqdata, batch_size=batch_size,
        shuffle=False, drop_last=False,
        pin_memory=True, num_workers=nThreads)
    test_feature, test_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)

    
    return train_feature, train_labels, test_feature, test_labels
    #return gallery_feature, gallery_labels, query_feature, query_labels
コード例 #2
0
def Model2Feature(data,
                  net,
                  checkpoint,
                  dim=512,
                  width=224,
                  root=None,
                  nThreads=16,
                  batch_size=100,
                  pool_feature=False,
                  **kargs):
    dataset_name = data
    model = models.create(net, dim=dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()
    data = DataSet.create(data, width=width, root=root)

    if dataset_name in ['shop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     drop_last=False,
                                                     pin_memory=True,
                                                     num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(data.query,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   pin_memory=True,
                                                   num_workers=nThreads)

        gallery_feature, gallery_labels = extract_features(
            model,
            gallery_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)
        query_feature, query_labels = extract_features(
            model,
            query_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  pin_memory=True,
                                                  num_workers=nThreads)
        features, labels = extract_features(model,
                                            data_loader,
                                            print_freq=1e5,
                                            metric=None,
                                            pool_feature=pool_feature)
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels
コード例 #3
0
def Model2Feature(data,
                  model,
                  nThreads=16,
                  batch_size=32,
                  pool_feature=False,
                  **kargs):
    dataset_name = data

    data = DataSet.create(data)

    if dataset_name in ['inshop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     drop_last=False)

        query_loader = torch.utils.data.DataLoader(data.query,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False)

        gallery_feature, gallery_labels = extract_features(
            model,
            gallery_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)
        query_feature, query_labels = extract_features(
            model,
            query_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(data.test,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  num_workers=nThreads)
        features, labels = extract_features(model,
                                            data_loader,
                                            print_freq=1e5,
                                            metric=None,
                                            pool_feature=pool_feature)
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels
コード例 #4
0
def Model2Feature(data,
                  net,
                  checkpoint,
                  dim=512,
                  width=224,
                  root=None,
                  nThreads=16,
                  batch_size=100,
                  pool_feature=False,
                  **kargs):
    dataset_name = data
    model = models.create(net, dim=dim, pretrained=False)
    try:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    except:
        print('load checkpoint failed, the state in the '
              'checkpoint is not matched with the model, '
              'try to reload checkpoint with unstrict mode')
        model.load_state_dict(checkpoint['state_dict'], strict=False)
    model = torch.nn.DataParallel(model).cuda()
    data = DataSet.create(data, width=width, root=root)

    train_loader = torch.utils.data.DataLoader(data.train,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               drop_last=False,
                                               pin_memory=True,
                                               num_workers=nThreads)
    test_loader = torch.utils.data.DataLoader(data.gallery,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              drop_last=False,
                                              pin_memory=True,
                                              num_workers=nThreads)

    train_feature, train_labels \
        = extract_features(model, train_loader, print_freq=1e4,
                           metric=None, pool_feature=pool_feature)
    test_feature, test_labels \
        = extract_features(model, test_loader, print_freq=1e4,
                           metric=None, pool_feature=pool_feature)

    return train_feature, train_labels, test_feature, test_labels
コード例 #5
0
ファイル: Sequence2Feature.py プロジェクト: wanpeng16/RVSML
def Sequence2Feature(data, net, checkpoint, in_dim, middle_dim,out_dim=512, root=None, nThreads=16, batch_size=100, train_flag=True,**kargs):
    dataset_name = data
    model = models.create(net, in_dim, middle_dim, out_dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model).cuda()
    datatrain = DataSet.create_seq(root=root,train_flag=True)
    
    if dataset_name in ['shop', 'jd_test','seq','seqfull']:
        gallery_loader = torch.utils.data.DataLoader(
            datatrain.seqdata, batch_size=batch_size, shuffle=False,
            drop_last=False, pin_memory=True, num_workers=nThreads)

        pool_feature = False

        train_feature, train_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        if train_flag == True:
            test_feature = torch.zeros(2,1)
            test_labels = torch.zeros(2,1)
            #test_labels.append(1)
        else:
            datatest = DataSet.create_seq(root=root,train_flag=False)
            query_loader = torch.utils.data.DataLoader(
                datatest.seqdata, batch_size=batch_size,
                shuffle=False, drop_last=False,
                pin_memory=True, num_workers=nThreads)
            test_feature, test_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(
            data.test, batch_size=batch_size,
            shuffle=False, drop_last=False, pin_memory=True,
            num_workers=nThreads)
        features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    #if train_flag:
    #    return train_feature, train_labels
    #else:
    return train_feature, train_labels, test_feature, test_labels
コード例 #6
0
def Model2Feature(data,
                  net,
                  checkpoint,
                  root=None,
                  nThreads=16,
                  batch_size=100,
                  pool_feature=False,
                  **kargs):
    dataset_name = data
    model = models.create(net, pretrained=False, normalized=True)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint

    model.load_state_dict(resume['state_dict'])
    model.eval()
    model = torch.nn.DataParallel(model).cuda()
    data = DataSet.create(name=data, root=root, set_name='test')

    if dataset_name in ['shop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     drop_last=False,
                                                     pin_memory=True,
                                                     num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(data.query,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   pin_memory=True,
                                                   num_workers=nThreads)

        gallery_feature, gallery_labels = extract_features(
            model,
            gallery_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)
        query_feature, query_labels = extract_features(
            model,
            query_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)

    else:  #here

        print('using else')
        data_loader = torch.utils.data.DataLoader(data.test,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  drop_last=True,
                                                  pin_memory=True,
                                                  num_workers=nThreads)

        features, labels = extract_features(model,
                                            data_loader,
                                            pool_feature=pool_feature)

        #全等?
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels
コード例 #7
0
                               transform=transform_train, index=index_train)
        testfolder = CIFAR100(root=traindir, train=False,
                              download=True, transform=transform_test, index=index)
    else:
        trainfolder = ImageFolder(traindir, transform_train, index=index_train)
        testfolder = ImageFolder(testdir, transform_test, index=index)

    train_loader = torch.utils.data.DataLoader(
        trainfolder, batch_size=128, shuffle=False, drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        testfolder, batch_size=128, shuffle=False, drop_last=False)
    print('Test %d\t' % task_id)

    model = torch.load(models[task_id])

    train_embeddings_cl, train_labels_cl = extract_features(
        model, train_loader)
    val_embeddings_cl, val_labels_cl = extract_features(
        model, test_loader)

    # Test for each task
    for i in index_train:
        ind_cl = np.where(i == train_labels_cl)[0]
        embeddings_tmp = train_embeddings_cl[ind_cl]
        class_label.append(i)
        class_mean.append(np.mean(embeddings_tmp, axis=0))

    if task_id > 0 and args.mapping_test:
        model_old = torch.load(models[task_id-1])
        train_embeddings_cl_old, train_labels_cl_old = extract_features(
            model_old, train_loader)
コード例 #8
0
ファイル: test.py プロジェクト: rajatagrawal-dev/SDC-IL
        testfolder = ImageFolder(testdir, transform_test, index=index)

    train_loader = torch.utils.data.DataLoader(trainfolder,
                                               batch_size=128,
                                               shuffle=False,
                                               drop_last=False)
    test_loader = torch.utils.data.DataLoader(testfolder,
                                              batch_size=128,
                                              shuffle=False,
                                              drop_last=False)
    print('Test %d\t' % task_id)

    model = torch.load(models[task_id])

    train_embeddings_cl, train_labels_cl = extract_features(model,
                                                            train_loader,
                                                            print_freq=32,
                                                            metric=None)
    val_embeddings_cl, val_labels_cl = extract_features(model,
                                                        test_loader,
                                                        print_freq=32,
                                                        metric=None)

    class_data = []
    for i, data in enumerate(train_loader, 0):
        inputs, pids = data
        if class_data == []:
            class_data = inputs.numpy()
        else:
            class_data = np.vstack((class_data, inputs.numpy()))

    #print("class data shape is: ", class_data.shape)
コード例 #9
0
ファイル: test.py プロジェクト: yuanmengzhixing/Deep_metric
name = temp[-1][:-10]
if args.test == 1:
    data = DataSet.create(args.data, train=False)
    data_loader = torch.utils.data.DataLoader(data.test,
                                              batch_size=128,
                                              shuffle=False,
                                              drop_last=False)
else:
    data = DataSet.create(args.data, test=False)
    data_loader = torch.utils.data.DataLoader(data.train,
                                              batch_size=128,
                                              shuffle=False,
                                              drop_last=False)

features, labels = extract_features(model,
                                    data_loader,
                                    print_freq=1e5,
                                    metric=None)

num_class = len(set(labels))

sim_mat = -pairwise_distance(features)
if args.data == 'product':

    result = Recall_at_ks_products(sim_mat,
                                   query_ids=labels,
                                   gallery_ids=labels)
else:
    result = Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels)
    result = ['%.4f' % r for r in result]
    temp = '  '
    result = temp.join(result)
コード例 #10
0
                               transform=transform_train, index=index_train)
        testfolder = CIFAR100(root=traindir, train=False,
                              download=True, transform=transform_test, index=index)
    else:
        trainfolder = ImageFolder(traindir, transform_train, index=index_train)
        testfolder = ImageFolder(testdir, transform_test, index=index)

    train_loader = torch.utils.data.DataLoader(
        trainfolder, batch_size=128, shuffle=False, drop_last=False)
    test_loader = torch.utils.data.DataLoader(
        testfolder, batch_size=128, shuffle=False, drop_last=False)
    print('Test %d\t' % task_id)

    model = torch.load(models[task_id])

    train_embeddings_cl, train_labels_cl = extract_features(
        model, train_loader, print_freq=32, metric=None)
    val_embeddings_cl, val_labels_cl = extract_features(
        model, test_loader, print_freq=32, metric=None)

    # Test for each task
    for i in index_train:
        ind_cl = np.where(i == train_labels_cl)[0]
        embeddings_tmp = train_embeddings_cl[ind_cl]
        class_label.append(i)
        class_mean.append(np.mean(embeddings_tmp, axis=0))

    if task_id > 0 and args.mapping_test:
        model_old = torch.load(models[task_id-1])
        train_embeddings_cl_old, train_labels_cl_old = extract_features(
            model_old, train_loader, print_freq=32, metric=None)
コード例 #11
0
    resume = load_checkpoint(args.resume)
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model).cuda()

    data = DataSet.create(args.data, width=args.width, root=args.data_root)

    gallery_loader = torch.utils.data.DataLoader(
        data.train, batch_size=args.batch_size, shuffle=False,
        drop_last=False, pin_memory=True, num_workers=args.nThreads)

    query_loader = torch.utils.data.DataLoader(
        data.test, batch_size=args.batch_size,
        shuffle=False, drop_last=False,
        pin_memory=True, num_workers=args.nThreads)

    gallery_feature, gallery_labels = extract_features(
        model, gallery_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature)
    query_feature, query_labels = extract_features(
        model, query_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature)

    sim_mat = pairwise_similarity(query_feature, gallery_feature)

if args.whales==True:
    whales_preds = make_whales_predictions(sim_mat, gallery_labels)
    make_whales_sub_file(whales_preds)
else:
    recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels,
                            gallery_ids=gallery_labels, data=args.data)

    result = '  '.join(['%.4f' % k for k in recall_ks])

print('Epoch-%d' % epoch, result)
コード例 #12
0
def Model2Feature(data,
                  net,
                  checkpoint,
                  dim=512,
                  width=224,
                  root=None,
                  Retrieval_visualization=False,
                  nThreads=16,
                  batch_size=100,
                  pool_feature=False,
                  **kargs):
    dataset_name = data
    model = models.create(net, dim=dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)

    resume = checkpoint
    # model.load_state_dict(resume['state_dict'])

    net_dict = model.state_dict()
    weights = resume['state_dict']
    pretrained_dict = {k: v for k, v in weights.items() if k in net_dict}
    net_dict.update(pretrained_dict)
    model.load_state_dict(net_dict)

    model = torch.nn.DataParallel(model).cuda()
    data = DataSet.create(data, width=width, root=root)

    if dataset_name in ['shop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     drop_last=False,
                                                     pin_memory=True,
                                                     num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(data.query,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   pin_memory=True,
                                                   num_workers=nThreads)

        gallery_feature, gallery_labels, img_name = extract_features(
            model,
            gallery_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)
        query_feature, query_labels, img_name = extract_features(
            model,
            query_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  pin_memory=True,
                                                  num_workers=nThreads)

        ## if use the retrieval visualization, the dataset should be shuffled
        if Retrieval_visualization:
            data_loader_shuffled = torch.utils.data.DataLoader(
                data.gallery,
                batch_size=batch_size,
                shuffle=True,
                drop_last=False,
                pin_memory=True,
                num_workers=nThreads)

        else:
            data_loader_shuffled = torch.utils.data.DataLoader(
                data.gallery,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=nThreads)

        features, labels, img_name = extract_features(
            model,
            data_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)
        features_shuffled, labels_shuffled, img_name_shuffled = extract_features(
            model,
            data_loader_shuffled,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature)

        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
        gallery_feature, gallery_labels = features_shuffled, labels_shuffled

    return gallery_feature, gallery_labels, query_feature, query_labels, img_name, img_name_shuffled
コード例 #13
0
ファイル: pool_test.py プロジェクト: yuanpengcheng/VGG_dml
model = torch.nn.DataParallel(model).cuda()

data = DataSet.create(args.data)

if args.data == 'shop':
    gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                 batch_size=64,
                                                 shuffle=False,
                                                 drop_last=False)
    query_loader = torch.utils.data.DataLoader(data.query,
                                               batch_size=64,
                                               shuffle=False,
                                               drop_last=False)

    gallery_feature, gallery_labels = extract_features(model,
                                                       gallery_loader,
                                                       print_freq=1e5,
                                                       metric=None)
    query_feature, query_labels = extract_features(model,
                                                   query_loader,
                                                   print_freq=1e5,
                                                   metric=None)

    sim_mat = pairwise_similarity(x=query_feature, y=gallery_feature)
    result = Recall_at_ks_shop(sim_mat,
                               query_ids=query_labels,
                               gallery_ids=gallery_labels)

elif args.data == 'jd':
    if args.test == 1:
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=64,
コード例 #14
0
ファイル: train.py プロジェクト: Tramtadama/Deep_metric
def main(args):
    # s_ = time.time()

    save_dir = args.save_dir
    mkdir_if_missing(save_dir)

    sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt'))
    display(args)
    start = 0

    model = models.create(args.net, pretrained=True, dim=args.dim)

    # for vgg and densenet
    if args.resume is None:
        model_dict = model.state_dict()

    else:
        # resume model
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)

    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # freeze BN
    if args.freeze_BN is True:
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40*'#', 'BatchNorm NOT frozen')
        
    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.module.classifier.parameters()))

    new_params = [p for p in model.module.parameters() if
                  id(p) in new_param_ids]

    base_params = [p for p in model.module.parameters() if
                   id(p) not in new_param_ids]

    param_groups = [
                {'params': base_params, 'lr_mult': 0.0},
                {'params': new_params, 'lr_mult': 1.0}]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups, lr=args.lr,
                                 weight_decay=args.weight_decay)

    # criterion = losses.create(args.loss, margin=args.margin, alpha=args.alpha, base=args.loss_base).cuda()
    criterion = torch.nn.TripletMarginLoss(margin=1)

    # Decor_loss = losses.create('decor').cuda()

    features_data = DataSet.create(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root)
    features_loader = torch.utils.data.DataLoader(
        features_data.train, batch_size=args.batch_size, shuffle=False,
        drop_last=False, pin_memory=True, num_workers=args.nThreads)
    # save the train information

    for epoch in range(start, args.epochs):

        if epoch%10 == 0:
            features, _,= extract_features(
                model, features_loader, print_freq=1e5, metric=None, pool_feature=False)

            data = DataSet.create(args.data, features=features, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root)

            train_loader = torch.utils.data.DataLoader(
                data.train, batch_size=args.batch_size,
                sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances),
                drop_last=True, pin_memory=True, num_workers=args.nThreads)

        train(epoch=epoch, model=model, criterion=criterion,
              optimizer=optimizer, train_loader=train_loader, args=args)

        if epoch == 1:
            optimizer.param_groups[0]['lr_mul'] = 0.1
        
        if (epoch+1) % args.save_step == 0 or epoch==0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch+1),
            }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar'))
コード例 #15
0
def Model2Feature(data,
                  net,
                  checkpoint,
                  dim=512,
                  width=224,
                  root=None,
                  nThreads=16,
                  batch_size=100,
                  pool_feature=False,
                  model=None,
                  org_feature=False,
                  args=None):
    dataset_name = data
    if model is None:
        model = models.create(net, dim=dim, pretrained=False)
        resume = checkpoint
        model.load_state_dict(resume['state_dict'], strict=False)
        model = torch.nn.DataParallel(model).cuda()
    data = dataset.Dataset(data,
                           width=width,
                           root=root,
                           mode="test",
                           self_supervision_rot=0,
                           args=args)

    if dataset_name in ['shop', 'jd_test', 'cifar']:
        gallery_loader = torch.utils.data.DataLoader(data.gallery,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     drop_last=False,
                                                     pin_memory=True,
                                                     num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(data.query,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False,
                                                   pin_memory=True,
                                                   num_workers=nThreads)

        gallery_feature, gallery_labels = extract_features(
            model,
            gallery_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature,
            org_feature=org_feature)
        query_feature, query_labels = extract_features(
            model,
            query_loader,
            print_freq=1e5,
            metric=None,
            pool_feature=pool_feature,
            org_feature=org_feature)
        if org_feature:
            norm = query_feature.norm(dim=1, p=2, keepdim=True)
            query_feature = query_feature.div(norm.expand_as(query_feature))
            print("feature normalized 1")
            norm = gallery_feature.norm(dim=1, p=2, keepdim=True)
            gallery_feature = gallery_feature.div(
                norm.expand_as(gallery_feature))
            print("feature normalized 2")
    else:
        data_loader = torch.utils.data.DataLoader(data.gallery,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  pin_memory=True,
                                                  num_workers=nThreads)
        features, labels = extract_features(model,
                                            data_loader,
                                            print_freq=1e5,
                                            metric=None,
                                            pool_feature=pool_feature,
                                            org_feature=org_feature)
        if org_feature:
            norm = features.norm(dim=1, p=2, keepdim=True)
            features = features.div(norm.expand_as(features))
            print("feature normalized")
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels
コード例 #16
0
def main(args):
    # s_ = time.time()
    print(torch.cuda.get_device_properties(device=0).total_memory)
    torch.cuda.empty_cache()
    print(args)
    save_dir = args.save_dir
    mkdir_if_missing(save_dir)
    num_txt = len(glob.glob(save_dir + "/*.txt"))
    sys.stdout = logging.Logger(
        os.path.join(save_dir, "log_" + str(num_txt) + ".txt"))
    display(args)
    start = 0

    model = models.create(args.net,
                          pretrained=args.pretrained,
                          dim=args.dim,
                          self_supervision_rot=args.self_supervision_rot)
    all_pretrained = glob.glob(save_dir + "/*.pth.tar")

    if (args.resume is None) or (len(all_pretrained) == 0):
        model_dict = model.state_dict()

    else:
        # resume model
        all_pretrained_epochs = sorted(
            [int(x.split("/")[-1][6:-8]) for x in all_pretrained])
        args.resume = os.path.join(
            save_dir, "ckp_ep" + str(all_pretrained_epochs[-1]) + ".pth.tar")
        print('load model from {}'.format(args.resume))
        chk_pt = load_checkpoint(args.resume)
        weight = chk_pt['state_dict']
        start = chk_pt['epoch']
        model.load_state_dict(weight)

    model = torch.nn.DataParallel(model)
    model = model.cuda()
    fake_centers_dir = os.path.join(args.save_dir, "fake_center.npy")

    if np.sum(["train_1.txt" in x
               for x in glob.glob(args.save_dir + "/**/*")]) == 0:
        if args.rot_only:
            create_fake_labels(None, None, args)

        else:
            data = dataset.Dataset(args.data,
                                   ratio=args.ratio,
                                   width=args.width,
                                   origin_width=args.origin_width,
                                   root=args.data_root,
                                   self_supervision_rot=0,
                                   mode="test",
                                   rot_bt=args.rot_bt,
                                   corruption=args.corruption,
                                   args=args)

            fake_train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=100,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=args.nThreads)

            train_feature, train_labels = extract_features(
                model,
                fake_train_loader,
                print_freq=1e5,
                metric=None,
                pool_feature=args.pool_feature,
                org_feature=True)

            create_fake_labels(train_feature, train_labels, args)

            del train_feature

            fake_centers = "k-means++"

            torch.cuda.empty_cache()

    elif os.path.exists(fake_centers_dir):
        fake_centers = np.load(fake_centers_dir)
    else:
        fake_centers = "k-means++"

    time.sleep(60)

    model.train()

    # freeze BN
    if (args.freeze_BN is True) and (args.pretrained):
        print(40 * '#', '\n BatchNorm frozen')
        model.apply(set_bn_eval)
    else:
        print(40 * '#', 'BatchNorm NOT frozen')

    # Fine-tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.module.classifier.parameters()))
    new_rot_param_ids = set()
    if args.self_supervision_rot:
        new_rot_param_ids = set(
            map(id, model.module.classifier_rot.parameters()))
        print(new_rot_param_ids)

    new_params = [
        p for p in model.module.parameters() if id(p) in new_param_ids
    ]

    new_rot_params = [
        p for p in model.module.parameters() if id(p) in new_rot_param_ids
    ]

    base_params = [
        p for p in model.module.parameters()
        if (id(p) not in new_param_ids) and (id(p) not in new_rot_param_ids)
    ]

    param_groups = [{
        'params': base_params
    }, {
        'params': new_params
    }, {
        'params': new_rot_params,
        'lr': args.rot_lr
    }]

    print('initial model is save at %s' % save_dir)

    optimizer = torch.optim.Adam(param_groups,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    criterion = losses.create(args.loss,
                              margin=args.margin,
                              alpha=args.alpha,
                              beta=args.beta,
                              base=args.loss_base).cuda()

    data = dataset.Dataset(args.data,
                           ratio=args.ratio,
                           width=args.width,
                           origin_width=args.origin_width,
                           root=args.save_dir,
                           self_supervision_rot=args.self_supervision_rot,
                           rot_bt=args.rot_bt,
                           corruption=1,
                           args=args)
    train_loader = torch.utils.data.DataLoader(
        data.train,
        batch_size=args.batch_size,
        sampler=FastRandomIdentitySampler(data.train,
                                          num_instances=args.num_instances),
        drop_last=True,
        pin_memory=True,
        num_workers=args.nThreads)

    # save the train information

    for epoch in range(start, args.epochs):

        train(epoch=epoch,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              train_loader=train_loader,
              args=args)

        if (epoch + 1) % args.save_step == 0 or epoch == 0:
            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()

            save_checkpoint({
                'state_dict': state_dict,
                'epoch': (epoch + 1),
            },
                            is_best=False,
                            fpath=osp.join(
                                args.save_dir,
                                'ckp_ep' + str(epoch + 1) + '.pth.tar'))

        if ((epoch + 1) % args.up_step == 0) and (not args.rot_only):
            # rewrite train_1.txt file
            data = dataset.Dataset(args.data,
                                   ratio=args.ratio,
                                   width=args.width,
                                   origin_width=args.origin_width,
                                   root=args.data_root,
                                   self_supervision_rot=0,
                                   mode="test",
                                   rot_bt=args.rot_bt,
                                   corruption=args.corruption,
                                   args=args)
            fake_train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=args.batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=True,
                num_workers=args.nThreads)
            train_feature, train_labels = extract_features(
                model,
                fake_train_loader,
                print_freq=1e5,
                metric=None,
                pool_feature=args.pool_feature,
                org_feature=(args.dim % 64 != 0))
            fake_centers = create_fake_labels(train_feature,
                                              train_labels,
                                              args,
                                              init_centers=fake_centers)
            del train_feature
            torch.cuda.empty_cache()
            time.sleep(60)
            np.save(fake_centers_dir, fake_centers)
            # reload data
            data = dataset.Dataset(
                args.data,
                ratio=args.ratio,
                width=args.width,
                origin_width=args.origin_width,
                root=args.save_dir,
                self_supervision_rot=args.self_supervision_rot,
                rot_bt=args.rot_bt,
                corruption=1,
                args=args)

            train_loader = torch.utils.data.DataLoader(
                data.train,
                batch_size=args.batch_size,
                sampler=FastRandomIdentitySampler(
                    data.train, num_instances=args.num_instances),
                drop_last=True,
                pin_memory=True,
                num_workers=args.nThreads)

            # test on testing data
            # extract_recalls(data=args.data, data_root=args.data_root, width=args.width, net=args.net, checkpoint=None,
            #         dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature,
            #         gallery_eq_query=args.gallery_eq_query, model=model)
            model.train()
            if (args.freeze_BN is True) and (args.pretrained):
                print(40 * '#', '\n BatchNorm frozen')
                model.apply(set_bn_eval)
コード例 #17
0
def main(args):
    num_class_dict = {'cub': int(100), 'car': int(98)}
    #  训练日志保存
    log_dir = os.path.join(args.checkpoints, args.log_dir)
    mkdir_if_missing(log_dir)

    sys.stdout = logging.Logger(os.path.join(log_dir, 'log.txt'))
    display(args)

    if args.r is None:
        model = models.create(args.net, Embed_dim=args.dim)
        # load part of the model
        model_dict = model.state_dict()
        # print(model_dict)
        if args.net == 'bn':
            pretrained_dict = torch.load('pretrained_models/bn_inception-239d2248.pth')
        else:
            pretrained_dict = torch.load('pretrained_models/inception_v3_google-1a9a5a14.pth')

        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

        model_dict.update(pretrained_dict)

        # orth init
        if args.init == 'orth':
            print('initialize the FC layer orthogonally')
            _, _, v = torch.svd(model_dict['Embed.linear.weight'])
            model_dict['Embed.linear.weight'] = v.t()

        # zero bias
        model_dict['Embed.linear.bias'] = torch.zeros(args.dim)

        model.load_state_dict(model_dict)
    else:
        # resume model
        model = torch.load(args.r)

    model = model.cuda()

    # compute the cluster centers for each class here

    def normalize(x):
        norm = x.norm(dim=1, p=2, keepdim=True)
        x = x.div(norm.expand_as(x))
        return x

    data = DataSet.create(args.data, root=None, test=False)

    if args.center_init == 'cluster':
        data_loader = torch.utils.data.DataLoader(
            data.train, batch_size=args.BatchSize, shuffle=False, drop_last=False)

        features, labels = extract_features(model, data_loader, print_freq=32, metric=None)
        features = [feature.resize_(1, args.dim) for feature in features]
        features = torch.cat(features)
        features = features.numpy()
        labels = np.array(labels)

        centers, center_labels = cluster_(features, labels, n_clusters=args.n_cluster)
        center_labels = [int(center_label) for center_label in center_labels]

        centers = Variable(torch.FloatTensor(centers).cuda(),  requires_grad=True)
        center_labels = Variable(torch.LongTensor(center_labels)).cuda()
        print(40*'#', '\n Clustering Done')

    else:
        center_labels = int(args.n_cluster) * list(range(num_class_dict[args.data]))
        center_labels = Variable(torch.LongTensor(center_labels).cuda())

        centers = normalize(torch.rand(num_class_dict[args.data]*args.n_cluster, args.dim))
        centers = Variable(centers.cuda(), requires_grad=True)

    torch.save(model, os.path.join(log_dir, 'model.pkl'))
    print('initial model is save at %s' % log_dir)

    # fine tune the model: the learning rate for pre-trained parameter is 1/10
    new_param_ids = set(map(id, model.Embed.parameters()))

    new_params = [p for p in model.parameters() if
                  id(p) in new_param_ids]

    base_params = [p for p in model.parameters() if
                   id(p) not in new_param_ids]
    param_groups = [
                {'params': base_params, 'lr_mult': 0.1},
                {'params': new_params, 'lr_mult': 1.0},
                {'params': centers, 'lr_mult': 1.0}]

    optimizer = torch.optim.Adam(param_groups, lr=args.lr,
                                 weight_decay=args.weight_decay)

    cluster_counter = np.zeros([num_class_dict[args.data], args.n_cluster])
    criterion = losses.create(args.loss, alpha=args.alpha, centers=centers,
                              center_labels=center_labels, cluster_counter=cluster_counter).cuda()

    # random sampling to generate mini-batch
    train_loader = torch.utils.data.DataLoader(
        data.train, batch_size=args.BatchSize, shuffle=True, drop_last=False)

    # save the train information
    epoch_list = list()
    loss_list = list()
    pos_list = list()
    neg_list = list()

    # _mask = Variable(torch.ByteTensor(np.ones([2, 4]))).cuda()
    dtype = torch.ByteTensor
    _mask = torch.ones(int(num_class_dict[args.data]), args.n_cluster).type(dtype)
    _mask = Variable(_mask).cuda()

    for epoch in range(args.start, args.epochs):
        epoch_list.append(epoch)

        running_loss = 0.0
        running_pos = 0.0
        running_neg = 0.0
        to_zero(cluster_counter)

        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            # wrap them in Variable
            inputs = Variable(inputs.cuda())

            # type of labels is Variable cuda.Longtensor
            labels = Variable(labels).cuda()
            optimizer.zero_grad()
            # centers.zero_grad()
            embed_feat = model(inputs)

            # update network weight
            loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels, _mask)
            loss.backward()
            optimizer.step()

            centers.data = normalize(centers.data)

            running_loss += loss.data[0]
            running_neg += dist_an
            running_pos += dist_ap

            if epoch == 0 and i == 0:
                print(50 * '#')
                print('Train Begin -- HA-HA-HA')
            if i % 10 == 9:
                print('[Epoch %05d Iteration %2d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f'
                      % (epoch + 1,  i+1, loss.data[0], inter_, dist_ap, dist_an))
        # cluster number counter show here
        print(cluster_counter)
        loss_list.append(running_loss)
        pos_list.append(running_pos / i)
        neg_list.append(running_neg / i)
        # update the _mask to make the cluster with only 1 or no member to be silent
        # _mask = Variable(torch.FloatTensor(cluster_counter) > 1).cuda()
        # cluster_distribution = torch.sum(_mask, 1).cpu().data.numpy().tolist()
        # print(cluster_distribution)
        # print('[Epoch %05d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f'
        #       % (epoch + 1, running_loss, inter_, dist_ap, dist_an))

        if epoch % args.save_step == 0:
            torch.save(model, os.path.join(log_dir, '%d_model.pkl' % epoch))
    np.savez(os.path.join(log_dir, "result.npz"), epoch=epoch_list, loss=loss_list, pos=pos_list, neg=neg_list)
コード例 #18
0
# coding=utf-8
from __future__ import absolute_import, print_function
import argparse
import torch
from torch.backends import cudnn
from evaluations import extract_features, pairwise_distance, pairwise_similarity
from evaluations import Recall_at_ks, Recall_at_ks_products, Recall_at_ks_shop
import models
import DataSet
import os
import numpy as np
cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
im1 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t18085/239/1572160811/242071/9e3b6d97/5ad06c21Nd73ffab7.jpg'
im2 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t17857/121/1655327696/242539/1771960e/5ad06c69N5b34d078.jpg'
from PIL import Image
im1 = Image.open(im1)
im2 = Image.open(im2)
im1.save('1.jpg')
im2.save('2.jpg')
r = '/opt/intern/users/xunwang/checkpoints/bin/jd/512-BN-alpha40/135_model.pth'
PATH = r
model = models.create('vgg', dim=512, pretrained=False)
model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load(PATH))
model = model.cuda()
data = DataSet.create('jd')
data_loader = torch.utils.data.DataLoader(data.gallery,
                                          batch_size=64,
                                          shuffle=False,