def Sequence2Feature_testvec(data, net, checkpoint, in_dim, middle_dim,out_dim=512, root=None, nThreads=16, batch_size=100, train_flag=True,**kargs): dataset_name = data model = models.create(net, in_dim, middle_dim, out_dim, pretrained=False) # resume = load_checkpoint(ckp_path) resume = checkpoint model.load_state_dict(resume['state_dict']) model = torch.nn.DataParallel(model).cuda() datatrain = DataSet.create_vec(root=root,train_flag=True) gallery_loader = torch.utils.data.DataLoader( datatrain.seqdata, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) pool_feature = False train_feature, train_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) datatest = DataSet.create_vec(root=root,train_flag=False) query_loader = torch.utils.data.DataLoader( datatest.seqdata, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) test_feature, test_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) return train_feature, train_labels, test_feature, test_labels #return gallery_feature, gallery_labels, query_feature, query_labels
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs): dataset_name = data model = models.create(net, dim=dim, pretrained=False) # resume = load_checkpoint(ckp_path) resume = checkpoint model.load_state_dict(resume['state_dict']) model = torch.nn.DataParallel(model, device_ids=[0]).cuda() data = DataSet.create(data, width=width, root=root) if dataset_name in ['shop', 'jd_test']: gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) query_loader = torch.utils.data.DataLoader(data.query, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) gallery_feature, gallery_labels = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) query_feature, query_labels = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) else: data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) gallery_feature, gallery_labels = query_feature, query_labels = features, labels return gallery_feature, gallery_labels, query_feature, query_labels
def Model2Feature(data, model, nThreads=16, batch_size=32, pool_feature=False, **kargs): dataset_name = data data = DataSet.create(data) if dataset_name in ['inshop', 'jd_test']: gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False) query_loader = torch.utils.data.DataLoader(data.query, batch_size=batch_size, shuffle=False, drop_last=False) gallery_feature, gallery_labels = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) query_feature, query_labels = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) else: data_loader = torch.utils.data.DataLoader(data.test, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=nThreads) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) gallery_feature, gallery_labels = query_feature, query_labels = features, labels return gallery_feature, gallery_labels, query_feature, query_labels
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs): dataset_name = data model = models.create(net, dim=dim, pretrained=False) try: model.load_state_dict(checkpoint['state_dict'], strict=True) except: print('load checkpoint failed, the state in the ' 'checkpoint is not matched with the model, ' 'try to reload checkpoint with unstrict mode') model.load_state_dict(checkpoint['state_dict'], strict=False) model = torch.nn.DataParallel(model).cuda() data = DataSet.create(data, width=width, root=root) train_loader = torch.utils.data.DataLoader(data.train, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) test_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) train_feature, train_labels \ = extract_features(model, train_loader, print_freq=1e4, metric=None, pool_feature=pool_feature) test_feature, test_labels \ = extract_features(model, test_loader, print_freq=1e4, metric=None, pool_feature=pool_feature) return train_feature, train_labels, test_feature, test_labels
def Sequence2Feature(data, net, checkpoint, in_dim, middle_dim,out_dim=512, root=None, nThreads=16, batch_size=100, train_flag=True,**kargs): dataset_name = data model = models.create(net, in_dim, middle_dim, out_dim, pretrained=False) # resume = load_checkpoint(ckp_path) resume = checkpoint model.load_state_dict(resume['state_dict']) model = torch.nn.DataParallel(model).cuda() datatrain = DataSet.create_seq(root=root,train_flag=True) if dataset_name in ['shop', 'jd_test','seq','seqfull']: gallery_loader = torch.utils.data.DataLoader( datatrain.seqdata, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) pool_feature = False train_feature, train_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) if train_flag == True: test_feature = torch.zeros(2,1) test_labels = torch.zeros(2,1) #test_labels.append(1) else: datatest = DataSet.create_seq(root=root,train_flag=False) query_loader = torch.utils.data.DataLoader( datatest.seqdata, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) test_feature, test_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) else: data_loader = torch.utils.data.DataLoader( data.test, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) gallery_feature, gallery_labels = query_feature, query_labels = features, labels #if train_flag: # return train_feature, train_labels #else: return train_feature, train_labels, test_feature, test_labels
def Model2Feature(data, net, checkpoint, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs): dataset_name = data model = models.create(net, pretrained=False, normalized=True) # resume = load_checkpoint(ckp_path) resume = checkpoint model.load_state_dict(resume['state_dict']) model.eval() model = torch.nn.DataParallel(model).cuda() data = DataSet.create(name=data, root=root, set_name='test') if dataset_name in ['shop', 'jd_test']: gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) query_loader = torch.utils.data.DataLoader(data.query, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) gallery_feature, gallery_labels = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) query_feature, query_labels = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) else: #here print('using else') data_loader = torch.utils.data.DataLoader(data.test, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=nThreads) features, labels = extract_features(model, data_loader, pool_feature=pool_feature) #全等? gallery_feature, gallery_labels = query_feature, query_labels = features, labels return gallery_feature, gallery_labels, query_feature, query_labels
transform=transform_train, index=index_train) testfolder = CIFAR100(root=traindir, train=False, download=True, transform=transform_test, index=index) else: trainfolder = ImageFolder(traindir, transform_train, index=index_train) testfolder = ImageFolder(testdir, transform_test, index=index) train_loader = torch.utils.data.DataLoader( trainfolder, batch_size=128, shuffle=False, drop_last=False) test_loader = torch.utils.data.DataLoader( testfolder, batch_size=128, shuffle=False, drop_last=False) print('Test %d\t' % task_id) model = torch.load(models[task_id]) train_embeddings_cl, train_labels_cl = extract_features( model, train_loader) val_embeddings_cl, val_labels_cl = extract_features( model, test_loader) # Test for each task for i in index_train: ind_cl = np.where(i == train_labels_cl)[0] embeddings_tmp = train_embeddings_cl[ind_cl] class_label.append(i) class_mean.append(np.mean(embeddings_tmp, axis=0)) if task_id > 0 and args.mapping_test: model_old = torch.load(models[task_id-1]) train_embeddings_cl_old, train_labels_cl_old = extract_features( model_old, train_loader)
testfolder = ImageFolder(testdir, transform_test, index=index) train_loader = torch.utils.data.DataLoader(trainfolder, batch_size=128, shuffle=False, drop_last=False) test_loader = torch.utils.data.DataLoader(testfolder, batch_size=128, shuffle=False, drop_last=False) print('Test %d\t' % task_id) model = torch.load(models[task_id]) train_embeddings_cl, train_labels_cl = extract_features(model, train_loader, print_freq=32, metric=None) val_embeddings_cl, val_labels_cl = extract_features(model, test_loader, print_freq=32, metric=None) class_data = [] for i, data in enumerate(train_loader, 0): inputs, pids = data if class_data == []: class_data = inputs.numpy() else: class_data = np.vstack((class_data, inputs.numpy())) #print("class data shape is: ", class_data.shape)
name = temp[-1][:-10] if args.test == 1: data = DataSet.create(args.data, train=False) data_loader = torch.utils.data.DataLoader(data.test, batch_size=128, shuffle=False, drop_last=False) else: data = DataSet.create(args.data, test=False) data_loader = torch.utils.data.DataLoader(data.train, batch_size=128, shuffle=False, drop_last=False) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None) num_class = len(set(labels)) sim_mat = -pairwise_distance(features) if args.data == 'product': result = Recall_at_ks_products(sim_mat, query_ids=labels, gallery_ids=labels) else: result = Recall_at_ks(sim_mat, query_ids=labels, gallery_ids=labels) result = ['%.4f' % r for r in result] temp = ' ' result = temp.join(result)
transform=transform_train, index=index_train) testfolder = CIFAR100(root=traindir, train=False, download=True, transform=transform_test, index=index) else: trainfolder = ImageFolder(traindir, transform_train, index=index_train) testfolder = ImageFolder(testdir, transform_test, index=index) train_loader = torch.utils.data.DataLoader( trainfolder, batch_size=128, shuffle=False, drop_last=False) test_loader = torch.utils.data.DataLoader( testfolder, batch_size=128, shuffle=False, drop_last=False) print('Test %d\t' % task_id) model = torch.load(models[task_id]) train_embeddings_cl, train_labels_cl = extract_features( model, train_loader, print_freq=32, metric=None) val_embeddings_cl, val_labels_cl = extract_features( model, test_loader, print_freq=32, metric=None) # Test for each task for i in index_train: ind_cl = np.where(i == train_labels_cl)[0] embeddings_tmp = train_embeddings_cl[ind_cl] class_label.append(i) class_mean.append(np.mean(embeddings_tmp, axis=0)) if task_id > 0 and args.mapping_test: model_old = torch.load(models[task_id-1]) train_embeddings_cl_old, train_labels_cl_old = extract_features( model_old, train_loader, print_freq=32, metric=None)
resume = load_checkpoint(args.resume) model.load_state_dict(resume['state_dict']) model = torch.nn.DataParallel(model).cuda() data = DataSet.create(args.data, width=args.width, root=args.data_root) gallery_loader = torch.utils.data.DataLoader( data.train, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.nThreads) query_loader = torch.utils.data.DataLoader( data.test, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.nThreads) gallery_feature, gallery_labels = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature) query_feature, query_labels = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature) sim_mat = pairwise_similarity(query_feature, gallery_feature) if args.whales==True: whales_preds = make_whales_predictions(sim_mat, gallery_labels) make_whales_sub_file(whales_preds) else: recall_ks = Recall_at_ks(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels, data=args.data) result = ' '.join(['%.4f' % k for k in recall_ks]) print('Epoch-%d' % epoch, result)
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, Retrieval_visualization=False, nThreads=16, batch_size=100, pool_feature=False, **kargs): dataset_name = data model = models.create(net, dim=dim, pretrained=False) # resume = load_checkpoint(ckp_path) resume = checkpoint # model.load_state_dict(resume['state_dict']) net_dict = model.state_dict() weights = resume['state_dict'] pretrained_dict = {k: v for k, v in weights.items() if k in net_dict} net_dict.update(pretrained_dict) model.load_state_dict(net_dict) model = torch.nn.DataParallel(model).cuda() data = DataSet.create(data, width=width, root=root) if dataset_name in ['shop', 'jd_test']: gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) query_loader = torch.utils.data.DataLoader(data.query, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) gallery_feature, gallery_labels, img_name = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) query_feature, query_labels, img_name = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) else: data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) ## if use the retrieval visualization, the dataset should be shuffled if Retrieval_visualization: data_loader_shuffled = torch.utils.data.DataLoader( data.gallery, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=nThreads) else: data_loader_shuffled = torch.utils.data.DataLoader( data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) features, labels, img_name = extract_features( model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature) features_shuffled, labels_shuffled, img_name_shuffled = extract_features( model, data_loader_shuffled, print_freq=1e5, metric=None, pool_feature=pool_feature) gallery_feature, gallery_labels = query_feature, query_labels = features, labels gallery_feature, gallery_labels = features_shuffled, labels_shuffled return gallery_feature, gallery_labels, query_feature, query_labels, img_name, img_name_shuffled
model = torch.nn.DataParallel(model).cuda() data = DataSet.create(args.data) if args.data == 'shop': gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=64, shuffle=False, drop_last=False) query_loader = torch.utils.data.DataLoader(data.query, batch_size=64, shuffle=False, drop_last=False) gallery_feature, gallery_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None) query_feature, query_labels = extract_features(model, query_loader, print_freq=1e5, metric=None) sim_mat = pairwise_similarity(x=query_feature, y=gallery_feature) result = Recall_at_ks_shop(sim_mat, query_ids=query_labels, gallery_ids=gallery_labels) elif args.data == 'jd': if args.test == 1: data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=64,
def main(args): # s_ = time.time() save_dir = args.save_dir mkdir_if_missing(save_dir) sys.stdout = logging.Logger(os.path.join(save_dir, 'log.txt')) display(args) start = 0 model = models.create(args.net, pretrained=True, dim=args.dim) # for vgg and densenet if args.resume is None: model_dict = model.state_dict() else: # resume model print('load model from {}'.format(args.resume)) chk_pt = load_checkpoint(args.resume) weight = chk_pt['state_dict'] start = chk_pt['epoch'] model.load_state_dict(weight) model = torch.nn.DataParallel(model) model = model.cuda() # freeze BN if args.freeze_BN is True: print(40 * '#', '\n BatchNorm frozen') model.apply(set_bn_eval) else: print(40*'#', 'BatchNorm NOT frozen') # Fine-tune the model: the learning rate for pre-trained parameter is 1/10 new_param_ids = set(map(id, model.module.classifier.parameters())) new_params = [p for p in model.module.parameters() if id(p) in new_param_ids] base_params = [p for p in model.module.parameters() if id(p) not in new_param_ids] param_groups = [ {'params': base_params, 'lr_mult': 0.0}, {'params': new_params, 'lr_mult': 1.0}] print('initial model is save at %s' % save_dir) optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) # criterion = losses.create(args.loss, margin=args.margin, alpha=args.alpha, base=args.loss_base).cuda() criterion = torch.nn.TripletMarginLoss(margin=1) # Decor_loss = losses.create('decor').cuda() features_data = DataSet.create(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root) features_loader = torch.utils.data.DataLoader( features_data.train, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.nThreads) # save the train information for epoch in range(start, args.epochs): if epoch%10 == 0: features, _,= extract_features( model, features_loader, print_freq=1e5, metric=None, pool_feature=False) data = DataSet.create(args.data, features=features, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root) train_loader = torch.utils.data.DataLoader( data.train, batch_size=args.batch_size, sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances), drop_last=True, pin_memory=True, num_workers=args.nThreads) train(epoch=epoch, model=model, criterion=criterion, optimizer=optimizer, train_loader=train_loader, args=args) if epoch == 1: optimizer.param_groups[0]['lr_mul'] = 0.1 if (epoch+1) % args.save_step == 0 or epoch==0: if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ 'state_dict': state_dict, 'epoch': (epoch+1), }, is_best=False, fpath=osp.join(args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar'))
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, model=None, org_feature=False, args=None): dataset_name = data if model is None: model = models.create(net, dim=dim, pretrained=False) resume = checkpoint model.load_state_dict(resume['state_dict'], strict=False) model = torch.nn.DataParallel(model).cuda() data = dataset.Dataset(data, width=width, root=root, mode="test", self_supervision_rot=0, args=args) if dataset_name in ['shop', 'jd_test', 'cifar']: gallery_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) query_loader = torch.utils.data.DataLoader(data.query, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) gallery_feature, gallery_labels = extract_features( model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature, org_feature=org_feature) query_feature, query_labels = extract_features( model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature, org_feature=org_feature) if org_feature: norm = query_feature.norm(dim=1, p=2, keepdim=True) query_feature = query_feature.div(norm.expand_as(query_feature)) print("feature normalized 1") norm = gallery_feature.norm(dim=1, p=2, keepdim=True) gallery_feature = gallery_feature.div( norm.expand_as(gallery_feature)) print("feature normalized 2") else: data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=nThreads) features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature, org_feature=org_feature) if org_feature: norm = features.norm(dim=1, p=2, keepdim=True) features = features.div(norm.expand_as(features)) print("feature normalized") gallery_feature, gallery_labels = query_feature, query_labels = features, labels return gallery_feature, gallery_labels, query_feature, query_labels
def main(args): # s_ = time.time() print(torch.cuda.get_device_properties(device=0).total_memory) torch.cuda.empty_cache() print(args) save_dir = args.save_dir mkdir_if_missing(save_dir) num_txt = len(glob.glob(save_dir + "/*.txt")) sys.stdout = logging.Logger( os.path.join(save_dir, "log_" + str(num_txt) + ".txt")) display(args) start = 0 model = models.create(args.net, pretrained=args.pretrained, dim=args.dim, self_supervision_rot=args.self_supervision_rot) all_pretrained = glob.glob(save_dir + "/*.pth.tar") if (args.resume is None) or (len(all_pretrained) == 0): model_dict = model.state_dict() else: # resume model all_pretrained_epochs = sorted( [int(x.split("/")[-1][6:-8]) for x in all_pretrained]) args.resume = os.path.join( save_dir, "ckp_ep" + str(all_pretrained_epochs[-1]) + ".pth.tar") print('load model from {}'.format(args.resume)) chk_pt = load_checkpoint(args.resume) weight = chk_pt['state_dict'] start = chk_pt['epoch'] model.load_state_dict(weight) model = torch.nn.DataParallel(model) model = model.cuda() fake_centers_dir = os.path.join(args.save_dir, "fake_center.npy") if np.sum(["train_1.txt" in x for x in glob.glob(args.save_dir + "/**/*")]) == 0: if args.rot_only: create_fake_labels(None, None, args) else: data = dataset.Dataset(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root, self_supervision_rot=0, mode="test", rot_bt=args.rot_bt, corruption=args.corruption, args=args) fake_train_loader = torch.utils.data.DataLoader( data.train, batch_size=100, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.nThreads) train_feature, train_labels = extract_features( model, fake_train_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature, org_feature=True) create_fake_labels(train_feature, train_labels, args) del train_feature fake_centers = "k-means++" torch.cuda.empty_cache() elif os.path.exists(fake_centers_dir): fake_centers = np.load(fake_centers_dir) else: fake_centers = "k-means++" time.sleep(60) model.train() # freeze BN if (args.freeze_BN is True) and (args.pretrained): print(40 * '#', '\n BatchNorm frozen') model.apply(set_bn_eval) else: print(40 * '#', 'BatchNorm NOT frozen') # Fine-tune the model: the learning rate for pre-trained parameter is 1/10 new_param_ids = set(map(id, model.module.classifier.parameters())) new_rot_param_ids = set() if args.self_supervision_rot: new_rot_param_ids = set( map(id, model.module.classifier_rot.parameters())) print(new_rot_param_ids) new_params = [ p for p in model.module.parameters() if id(p) in new_param_ids ] new_rot_params = [ p for p in model.module.parameters() if id(p) in new_rot_param_ids ] base_params = [ p for p in model.module.parameters() if (id(p) not in new_param_ids) and (id(p) not in new_rot_param_ids) ] param_groups = [{ 'params': base_params }, { 'params': new_params }, { 'params': new_rot_params, 'lr': args.rot_lr }] print('initial model is save at %s' % save_dir) optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) criterion = losses.create(args.loss, margin=args.margin, alpha=args.alpha, beta=args.beta, base=args.loss_base).cuda() data = dataset.Dataset(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.save_dir, self_supervision_rot=args.self_supervision_rot, rot_bt=args.rot_bt, corruption=1, args=args) train_loader = torch.utils.data.DataLoader( data.train, batch_size=args.batch_size, sampler=FastRandomIdentitySampler(data.train, num_instances=args.num_instances), drop_last=True, pin_memory=True, num_workers=args.nThreads) # save the train information for epoch in range(start, args.epochs): train(epoch=epoch, model=model, criterion=criterion, optimizer=optimizer, train_loader=train_loader, args=args) if (epoch + 1) % args.save_step == 0 or epoch == 0: if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint({ 'state_dict': state_dict, 'epoch': (epoch + 1), }, is_best=False, fpath=osp.join( args.save_dir, 'ckp_ep' + str(epoch + 1) + '.pth.tar')) if ((epoch + 1) % args.up_step == 0) and (not args.rot_only): # rewrite train_1.txt file data = dataset.Dataset(args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.data_root, self_supervision_rot=0, mode="test", rot_bt=args.rot_bt, corruption=args.corruption, args=args) fake_train_loader = torch.utils.data.DataLoader( data.train, batch_size=args.batch_size, shuffle=False, drop_last=False, pin_memory=True, num_workers=args.nThreads) train_feature, train_labels = extract_features( model, fake_train_loader, print_freq=1e5, metric=None, pool_feature=args.pool_feature, org_feature=(args.dim % 64 != 0)) fake_centers = create_fake_labels(train_feature, train_labels, args, init_centers=fake_centers) del train_feature torch.cuda.empty_cache() time.sleep(60) np.save(fake_centers_dir, fake_centers) # reload data data = dataset.Dataset( args.data, ratio=args.ratio, width=args.width, origin_width=args.origin_width, root=args.save_dir, self_supervision_rot=args.self_supervision_rot, rot_bt=args.rot_bt, corruption=1, args=args) train_loader = torch.utils.data.DataLoader( data.train, batch_size=args.batch_size, sampler=FastRandomIdentitySampler( data.train, num_instances=args.num_instances), drop_last=True, pin_memory=True, num_workers=args.nThreads) # test on testing data # extract_recalls(data=args.data, data_root=args.data_root, width=args.width, net=args.net, checkpoint=None, # dim=args.dim, batch_size=args.batch_size, nThreads=args.nThreads, pool_feature=args.pool_feature, # gallery_eq_query=args.gallery_eq_query, model=model) model.train() if (args.freeze_BN is True) and (args.pretrained): print(40 * '#', '\n BatchNorm frozen') model.apply(set_bn_eval)
def main(args): num_class_dict = {'cub': int(100), 'car': int(98)} # 训练日志保存 log_dir = os.path.join(args.checkpoints, args.log_dir) mkdir_if_missing(log_dir) sys.stdout = logging.Logger(os.path.join(log_dir, 'log.txt')) display(args) if args.r is None: model = models.create(args.net, Embed_dim=args.dim) # load part of the model model_dict = model.state_dict() # print(model_dict) if args.net == 'bn': pretrained_dict = torch.load('pretrained_models/bn_inception-239d2248.pth') else: pretrained_dict = torch.load('pretrained_models/inception_v3_google-1a9a5a14.pth') pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) # orth init if args.init == 'orth': print('initialize the FC layer orthogonally') _, _, v = torch.svd(model_dict['Embed.linear.weight']) model_dict['Embed.linear.weight'] = v.t() # zero bias model_dict['Embed.linear.bias'] = torch.zeros(args.dim) model.load_state_dict(model_dict) else: # resume model model = torch.load(args.r) model = model.cuda() # compute the cluster centers for each class here def normalize(x): norm = x.norm(dim=1, p=2, keepdim=True) x = x.div(norm.expand_as(x)) return x data = DataSet.create(args.data, root=None, test=False) if args.center_init == 'cluster': data_loader = torch.utils.data.DataLoader( data.train, batch_size=args.BatchSize, shuffle=False, drop_last=False) features, labels = extract_features(model, data_loader, print_freq=32, metric=None) features = [feature.resize_(1, args.dim) for feature in features] features = torch.cat(features) features = features.numpy() labels = np.array(labels) centers, center_labels = cluster_(features, labels, n_clusters=args.n_cluster) center_labels = [int(center_label) for center_label in center_labels] centers = Variable(torch.FloatTensor(centers).cuda(), requires_grad=True) center_labels = Variable(torch.LongTensor(center_labels)).cuda() print(40*'#', '\n Clustering Done') else: center_labels = int(args.n_cluster) * list(range(num_class_dict[args.data])) center_labels = Variable(torch.LongTensor(center_labels).cuda()) centers = normalize(torch.rand(num_class_dict[args.data]*args.n_cluster, args.dim)) centers = Variable(centers.cuda(), requires_grad=True) torch.save(model, os.path.join(log_dir, 'model.pkl')) print('initial model is save at %s' % log_dir) # fine tune the model: the learning rate for pre-trained parameter is 1/10 new_param_ids = set(map(id, model.Embed.parameters())) new_params = [p for p in model.parameters() if id(p) in new_param_ids] base_params = [p for p in model.parameters() if id(p) not in new_param_ids] param_groups = [ {'params': base_params, 'lr_mult': 0.1}, {'params': new_params, 'lr_mult': 1.0}, {'params': centers, 'lr_mult': 1.0}] optimizer = torch.optim.Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay) cluster_counter = np.zeros([num_class_dict[args.data], args.n_cluster]) criterion = losses.create(args.loss, alpha=args.alpha, centers=centers, center_labels=center_labels, cluster_counter=cluster_counter).cuda() # random sampling to generate mini-batch train_loader = torch.utils.data.DataLoader( data.train, batch_size=args.BatchSize, shuffle=True, drop_last=False) # save the train information epoch_list = list() loss_list = list() pos_list = list() neg_list = list() # _mask = Variable(torch.ByteTensor(np.ones([2, 4]))).cuda() dtype = torch.ByteTensor _mask = torch.ones(int(num_class_dict[args.data]), args.n_cluster).type(dtype) _mask = Variable(_mask).cuda() for epoch in range(args.start, args.epochs): epoch_list.append(epoch) running_loss = 0.0 running_pos = 0.0 running_neg = 0.0 to_zero(cluster_counter) for i, data in enumerate(train_loader, 0): inputs, labels = data # wrap them in Variable inputs = Variable(inputs.cuda()) # type of labels is Variable cuda.Longtensor labels = Variable(labels).cuda() optimizer.zero_grad() # centers.zero_grad() embed_feat = model(inputs) # update network weight loss, inter_, dist_ap, dist_an = criterion(embed_feat, labels, _mask) loss.backward() optimizer.step() centers.data = normalize(centers.data) running_loss += loss.data[0] running_neg += dist_an running_pos += dist_ap if epoch == 0 and i == 0: print(50 * '#') print('Train Begin -- HA-HA-HA') if i % 10 == 9: print('[Epoch %05d Iteration %2d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f' % (epoch + 1, i+1, loss.data[0], inter_, dist_ap, dist_an)) # cluster number counter show here print(cluster_counter) loss_list.append(running_loss) pos_list.append(running_pos / i) neg_list.append(running_neg / i) # update the _mask to make the cluster with only 1 or no member to be silent # _mask = Variable(torch.FloatTensor(cluster_counter) > 1).cuda() # cluster_distribution = torch.sum(_mask, 1).cpu().data.numpy().tolist() # print(cluster_distribution) # print('[Epoch %05d]\t Loss: %.3f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f' # % (epoch + 1, running_loss, inter_, dist_ap, dist_an)) if epoch % args.save_step == 0: torch.save(model, os.path.join(log_dir, '%d_model.pkl' % epoch)) np.savez(os.path.join(log_dir, "result.npz"), epoch=epoch_list, loss=loss_list, pos=pos_list, neg=neg_list)
# coding=utf-8 from __future__ import absolute_import, print_function import argparse import torch from torch.backends import cudnn from evaluations import extract_features, pairwise_distance, pairwise_similarity from evaluations import Recall_at_ks, Recall_at_ks_products, Recall_at_ks_shop import models import DataSet import os import numpy as np cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = "0" im1 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t18085/239/1572160811/242071/9e3b6d97/5ad06c21Nd73ffab7.jpg' im2 = '/opt/intern/users/xunwang/jd-comp/images/P/img/jfs/t17857/121/1655327696/242539/1771960e/5ad06c69N5b34d078.jpg' from PIL import Image im1 = Image.open(im1) im2 = Image.open(im2) im1.save('1.jpg') im2.save('2.jpg') r = '/opt/intern/users/xunwang/checkpoints/bin/jd/512-BN-alpha40/135_model.pth' PATH = r model = models.create('vgg', dim=512, pretrained=False) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(PATH)) model = model.cuda() data = DataSet.create('jd') data_loader = torch.utils.data.DataLoader(data.gallery, batch_size=64, shuffle=False,