Пример #1
0
def load_trained_csn(state_path, n_conditions=4, embedding_size=64):
    """Load trained Conditional Similarity Network."""
    cnn = resnet_18.resnet18()
    csn = ConditionalSimNet(cnn, n_conditions, embedding_size)
    tnet = CS_Tripletnet(csn)
    ckpt = torch.load(state_path)
    tnet.load_state_dict(ckpt['state_dict'])
    tnet.eval()
    csn = tnet.embeddingnet
    return csn
Пример #2
0
    def create_model(num_masks: int, embedding_size: int, learned_masks: bool,
                     disjoint_masks: bool, use_gpu: bool,
                     pretrained: bool = True) -> CS_Tripletnet:
        embed_model = Resnet_18.resnet18(pretrained=pretrained, embedding_size=embedding_size)
        csn_model = ConditionalSimNet(embed_model,
                                      n_conditions=num_masks,
                                      embedding_size=embedding_size,
                                      learnedmask=learned_masks,
                                      prein=disjoint_masks)
        tripletnet: CS_Tripletnet = CS_Tripletnet(csn_model,
                                                  num_concepts=num_masks,
                                                  use_cuda=use_gpu)
        if use_gpu:
            tripletnet.cuda()

        return tripletnet
Пример #3
0
def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    if args.visdom:
        global plotter
        plotter = VisdomLinePlotter(env_name=args.name)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    global conditions
    if args.conditions is not None:
        conditions = args.conditions
    else:
        conditions = [0, 1, 2, 3]

    kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
    print('Loading Train Dataset')
    train_loader = torch.utils.data.DataLoader(TripletImageLoader(
        'data',
        'ut-zap50k-images',
        'filenames.json',
        conditions,
        'train',
        n_triplets=args.num_traintriplets,
        transform=transforms.Compose([
            transforms.Resize(112),
            transforms.CenterCrop(112),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    print('Loading Test Dataset')
    test_loader = torch.utils.data.DataLoader(TripletImageLoader(
        'data',
        'ut-zap50k-images',
        'filenames.json',
        conditions,
        'test',
        n_triplets=160000,
        transform=transforms.Compose([
            transforms.Resize(112),
            transforms.CenterCrop(112),
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=16,
                                              shuffle=True,
                                              **kwargs)
    print('Loading Val Dataset')
    val_loader = torch.utils.data.DataLoader(TripletImageLoader(
        'data',
        'ut-zap50k-images',
        'filenames.json',
        conditions,
        'val',
        n_triplets=80000,
        transform=transforms.Compose([
            transforms.Resize(112),
            transforms.CenterCrop(112),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=16,
                                             shuffle=True,
                                             **kwargs)

    model = Resnet_models.resnext50_32x4d()
    csn_model = ConditionalSimNet(model,
                                  n_conditions=args.num_concepts,
                                  embedding_size=args.dim_embed,
                                  learnedmask=args.learned,
                                  prein=args.prein)
    global mask_var
    mask_var = csn_model.masks.weight
    tnet = CS_Tripletnet(csn_model, args.num_concepts)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    parameters = filter(lambda p: p.requires_grad, tnet.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    if args.test:
        checkpoint = torch.load('runs/%s/' % ('new_context_4/') +
                                'model_best.pth.tar')
        tnet.load_state_dict(checkpoint['state_dict'])
        test_acc = test(test_loader, tnet, criterion, 1)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs + 1):
        # update learning rate
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(val_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)

    checkpoint = torch.load('runs/%s/' % (args.name) + 'model_best.pth.tar')
    tnet.load_state_dict(checkpoint['state_dict'])
    test_acc = test(test_loader, tnet, criterion, 1)
def main():
    parser = argparse.ArgumentParser(description='This is a WIP program')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=200, metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
                        help='number of start epoch (default: 1)')
    parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
                        help='learning rate (default: 5e-5)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser.add_argument('--log-interval', type=int, default=20, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--margin', type=float, default=0.2, metavar='M',
                        help='margin for triplet loss (default: 0.2)')
    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--name', default='Conditional_Similarity_Network', type=str,
                        help='name of experiment')
    parser.add_argument('--embed_loss', type=float, default=5e-3, metavar='M',
                        help='parameter for loss for embedding norm')
    parser.add_argument('--mask_loss', type=float, default=5e-4, metavar='M',
                        help='parameter for loss for mask norm')
    parser.add_argument('--num_traintriplets', type=int, default=100000, metavar='N',
                        help='how many unique training triplets (default: 100000)')
    parser.add_argument('--dim_embed', type=int, default=64, metavar='N',
                        help='how many dimensions in embedding (default: 64)')
    parser.add_argument('--test', dest='test', action='store_true',
                        help='To only run inference on test set')
    parser.add_argument('--learned', dest='learned', action='store_true',
                        help='To learn masks from random initialization')
    parser.add_argument('--prein', dest='prein', action='store_true',
                        help='To initialize masks to be disjoint')
    parser.add_argument('--conditions', nargs='*', type=int,
                        help='Set of similarity notions')
    parser.add_argument('--out', type=str, default='result',
                        help='dir to save models and log')
    parser.set_defaults(test=False)
    parser.set_defaults(learned=False)
    parser.set_defaults(prein=False)
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    if args.conditions is not None:
        conditions = args.conditions
    else:
        conditions = list(range(4))

    kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else dict()
    train_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json',
                           conditions, 'train', n_triplets=args.num_traintriplets,
                           transform=T.Compose([
                               T.Scale(112),
                               T.CenterCrop(112),
                               T.RandomHorizontalFlip(),
                               T.ToTensor(),
                               normalize,
                           ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json',
                           conditions, 'test', n_triplets=160000,
                           transform=T.Compose([
                               T.Scale(112),
                               T.CenterCrop(112),
                               T.ToTensor(),
                               normalize,
                           ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        TripletImageLoader('data', 'ut-zap50k-images', 'filenames.json',
                           conditions, 'val', n_triplets=80000,
                           transform=T.Compose([
                               T.Scale(112),
                               T.CenterCrop(112),
                               T.ToTensor(),
                               normalize,
                           ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)

    model = resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
    csn_model = ConditionalSimNet(model, n_conditions=len(conditions),
                                  embedding_size=args.dim_embed,
                                  learnedmask=args.learned, prein=args.prein)
    mask_var = csn_model.masks.weight
    triplet_net = CS_Tripletnet(csn_model)
    if args.cuda:
        triplet_net.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            triplet_net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    parameters = filter(lambda p: p.requires_grad, triplet_net.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr)
    n_param = sum([p.data.nelement() for p in triplet_net.parameters()])
    print('# of parameters: {}'.format(n_param))
    print('\n\n')

    if args.test:
        import sys
        test_loss, test_acc = test(test_loader, triplet_net, args.epochs + 1)
        print('accuracy: {}, loss: {}'.format(test_acc.avg, test_loss.avg))
        sys.exit()

    root = os.path.join(args.out, dt.now().strftime('%m%d_%H%M'))
    if not os.path.exists(root):
        os.makedirs(root)
    best_acc = .0
    log = dict()
    start_time = dt.now()
    for epoch in tqdm(range(args.start_epoch, args.epochs + 1), desc='total'):
        adjust_lr(args, optimizer, epoch)
        loss_acc_log = train(args, train_loader, triplet_net,
                             criterion, optimizer, epoch)
        log['epoch_{}_train'.format(epoch)] = loss_acc_log
        losses, accs = test(args, val_loader, triplet_net, criterion, epoch)
        log['epoch_{}_val'.format(epoch)] = {
            'loss': losses.avg, 'acc': 100. * accs.avg}
        tqdm.write('[validation]\nloss: {:.4f}\tacc: {:.2f}%\n'.format(
            losses.avg, 100. * accs.avg))

        is_best = accs.avg > best_acc
        best_acc = max(accs.avg, best_acc)
        save_ckpt(root, triplet_net.state_dict(), is_best)

    end_time = dt.now()
    print('\ntraining finished.')
    print('started at {}, ended at {}, duration is {} hours'.format(
        start_time.strftime('%m%d, %H:%M'), end_time.strftime('%m%d, %H:%M'),
        (end_time - start_time).total_seconds() / 3600.))
    save_ckpt(root, triplet_net.state_dict(), filename='model.pth')
    log_filepath = os.path.join(root, 'log.pkl')
    with open(log_filepath, 'wb') as f:
        pickle.dump(log, f)
    print('log files saved at {}'.format(log_filepath))
    transform=transform,
    use_mean_img=True,
    data_file="test_no_dup_with_category_3more_name.json",
    neg_samples=True,
)

# An encoding Net, CSN_Net, Triplet Net
model = resnet18(pretrained=True, embedding_size=emb_size)
csn_model = ConditionalSimNet(
    model,
    n_conditions=len(train_dataset.conditions) // 2,
    embedding_size=emb_size,
    learnedmask=True,
    prein=False,
)
tnet = CS_Tripletnet(csn_model)
tnet = tnet.to(device)
tnet.load_state_dict(torch.load("./csn_model_best.pth"))
tnet = tnet.to(device)
tnet.eval()
embeddingnet = tnet.embeddingnet


# Test !!!!!!
def test_compatibility_auc(test_auc_dataset, embeddingnet):
    scores = []
    targets = []
    for i in range(len(test_auc_dataset)):
        print("#{}/{}\r".format(i, len(test_auc_dataset)), end="", flush=True)

        conditions = torch.tensor(