def main(): # Hyperparameters img_size = 112 emb_size = 64 device = torch.device("cuda") # Dataloader transform = torchvision.transforms.Compose([ torchvision.transforms.Scale((img_size, img_size)), torchvision.transforms.CenterCrop(112), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dataset = TripletDataset(root_dir="../../data/images/", data_dir="../../data/", transform=transform) test_auc_dataset = CategoryDataset( root_dir="../../data/images/", data_dir="../../data/", transform=transform, use_mean_img=True, data_file="test_no_dup_with_category_3more_name.json", neg_samples=True, ) # Model tnet = CompatModel( emb_size, n_conditions=len(train_dataset.conditions) // 2, learnedmask=True, prein=False, ) tnet.load_state_dict(torch.load("./csn_model_best.pth")) tnet = tnet.to(device) tnet.eval() embeddingnet = tnet.embeddingnet # Test auc = test_compatibility_auc(test_auc_dataset, embeddingnet) print("AUC: {:.4f}".format(auc)) fitb_accuracy = test_fitb_quesitons(test_auc_dataset, embeddingnet) print("Fitb Accuracy: {:.4f}".format(fitb_accuracy))
# Dataloader train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader = prepare_dataloaders( root_dir="../../data/images", data_dir="../../data", img_size=299, batch_size=12, use_mean_img=False, neg_samples=False, collate_fn=lstm_collate_fn, ) # Model model = CompatModel(emb_size=emb_size, need_rep=True, vocabulary=len(train_dataset.vocabulary)) mode = model.to(device) # Train process def train(model, device, train_loader, val_loader, comment): optimizer = torch.optim.SGD(model.parameters(), lr=2e-1, momentum=0.9) scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5) for epoch in range(1, epochs + 1): # Train phase total_loss = 0 scheduler.step() model.train() for batch_num, input_data in enumerate(train_loader, 1): lengths, images, names, offsets, set_ids, labels, is_compat = input_data image_seqs = images.to(device) # (20+, 3, 224, 224)
root_dir="../../data/images/", data_dir="../../data/", transform=transform, use_mean_img=True, data_file="valid_no_dup_with_category_3more_name.json", neg_samples=True, ) # Model tnet = CompatModel( emb_size, n_conditions=len(train_dataset.conditions) // 2, learnedmask=True, prein=False, ) tnet = tnet.to(device) def accuracy(dista, distb): margin = 0 pred = (dista - distb - margin).cpu().data return (pred > 0).sum().item() / dista.size(0) # Hyperparameters criterion = torch.nn.MarginRankingLoss(margin=0.2) parameters = filter(lambda p: p.requires_grad, tnet.parameters()) optimizer = torch.optim.Adam(parameters, lr=5e-5) n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) logging.info(" + Number of params: {}".format(n_parameters))
# Dataloader train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader = prepare_dataloaders( root_dir="../../data/images", data_dir="../../data", img_size=299, use_mean_img=False, neg_samples=False, collate_fn=lstm_collate_fn, ) # Restore model parameters model = CompatModel(emb_size=emb_size, need_rep=False, vocabulary=len(train_dataset.vocabulary)) model.load_state_dict(torch.load('model.pth')) model.to(device) model.eval() # Compute feature or Load extracted feature if os.path.exists("test_features.pkl"): print("Found test_features.pkl...") test_features = pickle.load(open('./test_features.pkl', 'rb')) else: print("Extract cnn features...") test_features = {} for input_data in tqdm(test_loader): lengths, images, names, offsets, set_ids, labels, is_compat = input_data image_seqs = images.to(device) with torch.no_grad(): emb_seqs = model.encoder_cnn(image_seqs) batch_ids = []