def create_model(num_masks: int, embedding_size: int, learned_masks: bool, disjoint_masks: bool, use_gpu: bool, pretrained: bool = True) -> CS_Tripletnet: embed_model = Resnet_18.resnet18(pretrained=pretrained, embedding_size=embedding_size) csn_model = ConditionalSimNet(embed_model, n_conditions=num_masks, embedding_size=embedding_size, learnedmask=learned_masks, prein=disjoint_masks) tripletnet: CS_Tripletnet = CS_Tripletnet(csn_model, num_concepts=num_masks, use_cuda=use_gpu) if use_gpu: tripletnet.cuda() return tripletnet
def main(): global args, best_acc args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.visdom: global plotter plotter = VisdomLinePlotter(env_name=args.name) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) global conditions if args.conditions is not None: conditions = args.conditions else: conditions = [0, 1, 2, 3] kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} print('Loading Train Dataset') train_loader = torch.utils.data.DataLoader(TripletImageLoader( 'data', 'ut-zap50k-images', 'filenames.json', conditions, 'train', n_triplets=args.num_traintriplets, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, **kwargs) print('Loading Test Dataset') test_loader = torch.utils.data.DataLoader(TripletImageLoader( 'data', 'ut-zap50k-images', 'filenames.json', conditions, 'test', n_triplets=160000, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=64, shuffle=True, **kwargs) print('Loading Val Dataset') val_loader = torch.utils.data.DataLoader(TripletImageLoader( 'data', 'ut-zap50k-images', 'filenames.json', conditions, 'val', n_triplets=80000, transform=transforms.Compose([ transforms.Resize(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=64, shuffle=True, **kwargs) model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed) csn_model = ConditionalSimNet(model, n_conditions=args.num_concepts, embedding_size=args.dim_embed, learnedmask=args.learned, prein=args.prein) global mask_var mask_var = csn_model.masks.weight tnet = CS_Tripletnet(csn_model, args.num_concepts) if args.cuda: tnet.cuda() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] tnet.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True criterion = torch.nn.MarginRankingLoss(margin=args.margin) parameters = filter(lambda p: p.requires_grad, tnet.parameters()) optimizer = optim.Adam(parameters, lr=args.lr) n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) print(' + Number of params: {}'.format(n_parameters)) if args.test: checkpoint = torch.load('runs/%s/' % ('new_context_4/') + 'model_best.pth.tar') tnet.load_state_dict(checkpoint['state_dict']) test_acc = test(test_loader, tnet, criterion, 1) sys.exit() for epoch in range(args.start_epoch, args.epochs + 1): # update learning rate adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, tnet, criterion, optimizer, epoch) # evaluate on validation set acc = test(val_loader, tnet, criterion, epoch) # remember best acc and save checkpoint is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': tnet.state_dict(), 'best_prec1': best_acc, }, is_best) checkpoint = torch.load('runs/%s/' % (args.name) + 'model_best.pth.tar') tnet.load_state_dict(checkpoint['state_dict']) test_acc = test(test_loader, tnet, criterion, 1)
def main(): global args args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) fn = os.path.join(args.datadir, 'polyvore_outfits', 'polyvore_item_metadata.json') meta_data = json.load(open(fn, 'r')) text_feature_dim = 6000 kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {} test_loader = torch.utils.data.DataLoader(TripletImageLoader( args, 'test', meta_data, transform=transforms.Compose([ transforms.Scale(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, **kwargs) model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed) csn_model = TypeSpecificNet(args, model, len(test_loader.dataset.typespaces)) criterion = torch.nn.MarginRankingLoss(margin=args.margin) tnet = Tripletnet(args, csn_model, text_feature_dim, criterion) if args.cuda: tnet.cuda() train_loader = torch.utils.data.DataLoader(TripletImageLoader( args, 'train', meta_data, text_dim=text_feature_dim, transform=transforms.Compose([ transforms.Scale(112), transforms.CenterCrop(112), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = torch.utils.data.DataLoader(TripletImageLoader( args, 'valid', meta_data, transform=transforms.Compose([ transforms.Scale(112), transforms.CenterCrop(112), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, **kwargs) best_acc = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, encoding='latin1') args.start_epoch = checkpoint['epoch'] best_acc = checkpoint['best_prec1'] tnet.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True if args.test: test_acc = test(test_loader, tnet) sys.exit() parameters = filter(lambda p: p.requires_grad, tnet.parameters()) optimizer = optim.Adam(parameters, lr=args.lr) n_parameters = sum([p.data.nelement() for p in tnet.parameters()]) print(' + Number of params: {}'.format(n_parameters)) for epoch in range(args.start_epoch, args.epochs + 1): # update learning rate adjust_learning_rate(optimizer, epoch) # train for one epoch train(train_loader, tnet, criterion, optimizer, epoch) # evaluate on validation set acc = test(val_loader, tnet) # remember best acc and save checkpoint is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': tnet.state_dict(), 'best_prec1': best_acc, }, is_best) checkpoint = torch.load('runs/%s/' % (args.name) + 'model_best.pth.tar') tnet.load_state_dict(checkpoint['state_dict']) test_acc = test(test_loader, tnet)
dest='visdom', action='store_true', help='Use visdom to track and plot') parser.add_argument('--conditions', nargs='*', type=int, help='Set of similarity notions') parser.set_defaults(test=False) parser.set_defaults(learned=False) parser.set_defaults(prein=False) parser.set_defaults(visdom=False) args = parser.parse_args() conditions = [0, 1, 2, 3] model = Resnet_18.resnet18(pretrained=False, embedding_size=args.dim_embed) csn_model = ConditionalSimNet(model, n_conditions=len(conditions), embedding_size=args.dim_embed, learnedmask=args.learned, prein=args.prein) class lp_net(nn.Module): def __init__(self, embeddingnet): super(lp_net, self).__init__() self.embeddingnet = embeddingnet def forward(self, x, c): embedded, masknorm_norm, embed_norm, tot_embed_norm = self.embeddingnet( x, c)
text_model = SentenceTransformer('distilbert-base-nli-mean-tokens') def text_embedding(sentences): sentence_embeddings = text_model.encode(sentences) return sentence_embeddings text_feature_dim = 768 dim_embed = 64 device = 'cuda' if torch.cuda.is_available() else 'cpu' image_model = Resnet_18.resnet18(pretrained=True, embedding_size=dim_embed) image_model = image_model.to(device) text_branch = EmbedBranch(text_feature_dim, dim_embed) text_branch = text_branch.to(device) def load_data(graph, node, node_list): """Load citation network dataset (cora only for now)""" print('Loading dataset...') # idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), dtype=np.dtype(str)) features = makeFeature(node, graph, image_model, text_branch) # features = sp.csr_matrix(pre_features) # build graph