示例#1
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    fn = os.path.join(args.datadir, 'polyvore_outfits',
                      'polyvore_item_metadata.json')
    meta_data = json.load(open(fn, 'r'))
    text_feature_dim = 6000
    kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
    test_loader = torch.utils.data.DataLoader(TripletImageLoader(
        args,
        'test',
        meta_data,
        transform=transforms.Compose([
            transforms.Scale(112),
            transforms.CenterCrop(112),
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
    csn_model = TypeSpecificNet(args, model,
                                len(test_loader.dataset.typespaces))

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    tnet = Tripletnet(args, csn_model, text_feature_dim, criterion)
    if args.cuda:
        tnet.cuda()

    train_loader = torch.utils.data.DataLoader(TripletImageLoader(
        args,
        'train',
        meta_data,
        text_dim=text_feature_dim,
        transform=transforms.Compose([
            transforms.Scale(112),
            transforms.CenterCrop(112),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(TripletImageLoader(
        args,
        'valid',
        meta_data,
        transform=transforms.Compose([
            transforms.Scale(112),
            transforms.CenterCrop(112),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    best_acc = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, encoding='latin1')
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    if args.test:
        test_acc = test(test_loader, tnet)
        sys.exit()

    parameters = filter(lambda p: p.requires_grad, tnet.parameters())
    optimizer = optim.Adam(parameters, lr=args.lr)
    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    for epoch in range(args.start_epoch, args.epochs + 1):
        # update learning rate
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(val_loader, tnet)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)

    checkpoint = torch.load('runs/%s/' % (args.name) + 'model_best.pth.tar')
    tnet.load_state_dict(checkpoint['state_dict'])
    test_acc = test(test_loader, tnet)
示例#2
0
def main():

    global args, best_acc, writer
    args = parser.parse_args()
    writer = SummaryWriter(comment='_' + args.name + '_triplet_network')
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    # global plotter
    # plotter = VisdomLinePlotter(env_name=args.name)

    kwargs = {
        'num_workers': 1 if args.name == 'stl10' else 4,
        'pin_memory': True
    } if args.cuda else {}  # change num_workers from 1 to 4

    train_triplet_loader, test_triplet_loader, train_loader, test_loader = \
        get_TripletDataset(args.name, args.batch_size, **kwargs)

    cmd = "model=%s()" % args.net
    local_dict = locals()
    exec(cmd, globals(), local_dict)
    model = local_dict['model']
    print(args.use_fc)
    if not args.use_fc:
        tnet = Tripletnet(model)
    else:
        tnet = Tripletnet(Classifier(model))
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    optimizer = optim.SGD(tnet.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))
    time_string = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
    log_directory = "runs/%s/" % (time_string + '_' + args.name)

    with Context(os.path.join(log_directory, args.log), parallel=True):
        for epoch in range(1, args.epochs + 1):
            # train for one epoch
            train(train_triplet_loader, tnet, criterion, optimizer, epoch)
            # evaluate on validation set
            acc = test(test_triplet_loader, tnet, criterion, epoch)

            # remember best acc and save checkpoint
            is_best = acc > best_acc
            best_acc = max(acc, best_acc)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': tnet.state_dict(),
                    'best_prec1': best_acc,
                }, is_best)

        checkpoint_file = 'runs/%s/' % (args.name) + 'model_best.pth.tar'
        assert os.path.isfile(checkpoint_file), 'Nothing to load...'
        checkpoint_cl = torch.load(checkpoint_file)
        cmd = "model_cl=%s()" % args.net
        exec(cmd, globals(), local_dict)
        model_cl = local_dict['model_cl']
        if not args.use_fc:
            tnet = Tripletnet(model_cl)
        else:
            tnet = Tripletnet(Classifier(model_cl))
        tnet.load_state_dict(checkpoint_cl['state_dict'])
        classifier(tnet.embeddingnet
                   if not args.use_fc else tnet.embeddingnet.embedding,
                   train_loader,
                   test_loader,
                   writer,
                   logdir=log_directory)

    writer.close()
def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    #global plotter
    #plotter = VisdomLinePlotter(env_name=args.name)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    train_loader = torch.utils.data.DataLoader(TripletImageLoader(
        '.',
        './filenames_filename.txt',
        './triplets_train_name.txt',
        transform=transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    valid_loader = torch.utils.data.DataLoader(TripletImageLoader(
        '.',
        './filenames_filename.txt',
        './triplets_valid_name.txt',
        transform=transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.test_batch_size,
                                               shuffle=True,
                                               **kwargs)

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
            self.conv1_drop = nn.Dropout2d()
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(16820, 128)
            self.fc2 = nn.Linear(128, 20)

        def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 16820)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            return self.fc2(x)

    model = Net()
    tnet = Tripletnet(model)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)

    #optimizer = optim.SGD(tnet.parameters(), lr=args.lr, momentum=args.momentum)
    def make_optimizer(model, opt, lr, weight_decay, momentum, nesterov=True):
        if opt == 'SGD':
            optimizer = getattr(torch.optim, opt)(model.parameters(),
                                                  lr=lr,
                                                  weight_decay=weight_decay,
                                                  momentum=momentum,
                                                  nesterov=nesterov)
        elif opt == 'AMSGRAD':
            optimizer = getattr(torch.optim, 'Adam')(model.parameters(),
                                                     lr=lr,
                                                     weight_decay=weight_decay,
                                                     amsgrad=True)
        elif opt == 'Ranger':
            optimizer = Ranger(params=filter(lambda p: p.requires_grad,
                                             model.parameters()),
                               lr=lr)
        elif opt == 'RMS':
            optimizer = torch.optim.RMSprop(model.parameters(),
                                            lr=lr,
                                            alpha=0.99,
                                            eps=1e-08,
                                            weight_decay=weight_decay,
                                            momentum=momentum,
                                            centered=False)
        return optimizer

    optimizer = make_optimizer(tnet,
                               opt=args.opt,
                               lr=args.lr,
                               weight_decay=args.weight_decay,
                               momentum=args.momentum,
                               nesterov=True)
    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    for epoch in range(1, args.epochs + 1):
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(valid_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)
示例#4
0
def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = args.cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    if args.visdom:
        global plotter
        plotter = VisdomLinePlotter(env_name=args.name)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(MNIST_t(
        './data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(MNIST_t(
        './data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, 10)

        def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            return self.fc2(x)

    model = Net()
    tnet = Tripletnet(model)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    optimizer = optim.SGD(tnet.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    for epoch in range(1, args.epochs + 1):
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(test_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)
示例#5
0
def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    global plotter
    plotter = VisdomLinePlotter(env_name=args.name)

    # Normalize on RGB Value
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if args.arch.startswith('inception'):
        size = (299, 299)
    else:
        size = (224, 256)

    kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}
    train_loader = torch.utils.data.DataLoader(TripletImageLoader(
        '../video_segmentation/multi_mask',
        train=True,
        transform=transforms.Compose([
            transforms.Resize((size[0], size[0])),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(TripletImageLoader(
        '../video_segmentation/multi_mask',
        train=False,
        transform=transforms.Compose([
            transforms.Resize((size[0], size[0])),
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    print("=> creating model '{}'".format(args.arch))
    model = models.setup(args)

    tnet = Tripletnet(model, args)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    optimizer = optim.SGD(tnet.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)
    shaduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))
    i = 0

    checkpoint_epoch = 13
    print('load checkpoint ' + str(checkpoint_epoch))
    tnet.load_state_dict(
        torch.load('./out/TripletNet/checkpoint.pth.tar')['state_dict'])
    best_acc = torch.load('./out/model_best.pth.tar')['best_prec1']

    for epoch in range(1, args.epochs + 1):
        if (i) % args.decay_epoch == 0:
            shaduler.step()

        if epoch <= checkpoint_epoch:
            continue

        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(test_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)
        i += 1
def main():
    global args, best_acc
    args = parser.parse_args()
    data_path = args.data
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    global plotter 
    plotter = VisdomLinePlotter(env_name=args.name)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    num_classes = 5
    num_triplets = num_classes*64
    train_data_set = CUB_t(data_path, n_train_triplets=num_triplets, train=True,
                           transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.1307,), (0.3081,))
                           ]),
                           num_classes=num_classes)
    train_loader = torch.utils.data.DataLoader(
        train_data_set,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        CUB_t(data_path, n_test_triplets=num_classes*16, train=False,
                 transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]),
             num_classes=num_classes),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
    
    kNN_loader = CUB_t_kNN(data_path, train=False,
                 n_test = args.kNN_test_size, n_train = args.kNN_train_size,
                 transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))


    # image length 
    im_len = 64
    # size of first fully connected layer
    h1_len = (im_len-4)/2
    h2_len = (h1_len-4)/2
    fc1_len = h2_len*h2_len*20

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d()
            self.fc1 = nn.Linear(fc1_len, 50)
            self.fc2 = nn.Linear(50, 10)

        def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, fc1_len)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            return self.fc2(x)

    model = Net()
    tnet = Tripletnet(model)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                    .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin = args.margin)
    optimizer = optim.SGD(tnet.parameters(), lr=args.lr, momentum=args.momentum)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    sampler = OurSampler(num_classes, num_triplets/args.batch_size)
    
    for epoch in range(1, args.epochs + 1):
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch, sampler)
        # evaluate on validation set
        acc = test(test_loader, tnet, criterion, epoch)
        acc_kNN = test_kNN(kNN_loader, tnet, epoch, args.kNN_k)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': tnet.state_dict(),
            'best_prec1': best_acc,
        }, is_best)

        # reset sampler and regenerate triplets every few epochs
        if epoch % args.triplet_freq == 0:
            # TODO: regenerate triplets
            train_data_set.regenerate_triplet_list(num_triplets, sampler, num_triplets*hard_frac)
            # then reset sampler
            sampler.Reset()
dataset_test = Data_loader(matches_test, descript_test, transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                              num_workers=workers,
                                              batch_size=batch_size,
                                              shuffle=False)
print('Len dataloader test:' + str(len(dataloader_test)))

# Load model
model = Get_model()
model = model.to(device)
tnet = Tripletnet(model)

# Define criterion and optimizer
criterion1 = torch.nn.MarginRankingLoss(margin=0.2)
criterion2 = torch.nn.MSELoss()
optimizer = optim.SGD(tnet.parameters(), lr=0.001, momentum=0.5)


# accuracy calculations
def accuracy(dista, distb):
    dista = dista.cpu().detach().numpy().reshape(dista.shape[0], -1)
    distb = distb.cpu().detach().numpy().reshape(dista.shape[0], -1)
    y = np.zeros(dista.shape)
    y[dista < distb] = 1
    return sum(y) / dista.shape[0], y


# Train function
def train(tnet, dataloader, criterion1, criterion2, optimizer, epoch):
    tnet.train()
示例#8
0
def main():
    print('pid:', os.getpid())
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if (not args.cuda):
        print('no cuda!')
    else:
        print('we\'ve got cuda!')
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    global plotter
    plotter = VisdomLinePlotter(env_name=args.name)
    if args.pred:
        np.random.seed(args.seed)  # Numpy module.
        random.seed(args.seed)  # Python random module.
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
    ######################
    base_path = args.base_path
    embed_size = args.emb_size
    ######################
    kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}

    m = 'train'
    #m = 'test'
    #trainortest = 'test'
    trainortest = 'small_train'
    if args.binary_classify:
        coll = bin_collate_wrapper
    else:
        coll = collate_wrapper
    if not args.pred:
        print('loading training data...')
        train_loader = torch.utils.data.DataLoader(TripletEmbedLoader(
            args, base_path, m + '_embed_index.csv', trainortest + '.json',
            'train', m + '_embeddings.pt'),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   collate_fn=coll,
                                                   **kwargs)
    print('loading testing data...')
    if args.pred:
        shuff = False
    else:
        shuff = True
    test_loader = torch.utils.data.DataLoader(TripletEmbedLoader(
        args, base_path, 'test_embed_index.csv', 'test.json', 'train',
        'test_embeddings.pt'),
                                              batch_size=args.batch_size,
                                              shuffle=shuff,
                                              collate_fn=coll,
                                              **kwargs)

    class Net(nn.Module):
        def __init__(self, embed_size):
            super(Net, self).__init__()
            if args.binary_classify:
                self.nfc1 = nn.Linear(embed_size * 3, 480)
            else:
                self.nfc1 = nn.Linear(embed_size * 2, 480)
            self.nfc2 = nn.Linear(480, 320)
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, 1)
            if args.binary_classify:
                self.out = nn.Sigmoid()

        def forward(self, x):
            x = F.relu(self.nfc1(x))
            x = F.dropout(x, p=0.5, training=self.training)
            x = F.relu(self.nfc2(x))
            x = F.dropout(x, p=0.75, training=self.training)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p=0.75, training=self.training)
            #if args.binary_classify:
            #x = self.fc2(x)
            #return self.out(x)
            return self.fc2(x)

    if (args.cnn):
        model = CNNNet()
    else:
        model = Net(embed_size)
    #model = CNNNet()
    if (args.cuda):
        model.cuda()
    if (args.binary_classify):
        tnet = model
    else:
        tnet = Tripletnet(model, args)
    print('net built.')
    if args.cuda:
        tnet.cuda()
    print('tnet.cuda()')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if not args.pred:
        cudnn.benchmark = True
    if args.binary_classify:
        #criterion = torch.nn.BCELoss()
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    #optimizer = optim.SGD(tnet.parameters(), lr=args.lr, momentum=args.momentum)
    optimizer = optim.Adam(tnet.parameters(), lr=args.lr)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    if args.pred:
        print('testing...')
        acc = test(test_loader, tnet, criterion, 0)
        #exit(1)
        print('predicting...')
        predict(test_loader, tnet)
        exit(1)

    print('start training!')
    for epoch in range(1, args.epochs + 1):
        # train for one epoch
        #start_time = time.time()
        if args.binary_classify:
            bin_train(train_loader, tnet, criterion, optimizer, epoch)
            acc = bin_test(test_loader, tnet, criterion, epoch)
        else:
            train(train_loader, tnet, criterion, optimizer, epoch)
            acc = test(test_loader, tnet, criterion, epoch)
        #print("------- train: %s seconds ---" % (time.time()-start_time))
        # evaluate on validation set

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)
示例#9
0
def main():
    global args, best_acc
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    # global plotter
    # plotter = VisdomLinePlotter(env_name=args.name)

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    root_dir = "../caltech-data/"
    triplet_data_dir = os.path.join(root_dir, "triplet_data")
    train_triplet_path_file = os.path.join(triplet_data_dir,
                                           "triplet_paths_train.txt")
    train_triplet_idx_file = os.path.join(triplet_data_dir,
                                          "triplet_index_train.txt")
    val_triplet_path_file = os.path.join(triplet_data_dir,
                                         "triplet_paths_val.txt")
    val_triplet_idx_file = os.path.join(triplet_data_dir,
                                        "triplet_index_val.txt")

    train_loader = torch.utils.data.DataLoader(TripletImageLoader(
        filenames_filename=train_triplet_path_file,
        triplets_file_name=train_triplet_idx_file,
        transform=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(TripletImageLoader(
        filenames_filename=val_triplet_path_file,
        triplets_file_name=val_triplet_idx_file,
        transform=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    embedingnet = Vgg_Net()
    tnet = Tripletnet(embedingnet)
    print(tnet)
    if args.cuda:
        tnet.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            tnet.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = torch.nn.MarginRankingLoss(margin=args.margin)
    optimizer = optim.SGD(tnet.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)

    n_parameters = sum([p.data.nelement() for p in tnet.parameters()])
    print('  + Number of params: {}'.format(n_parameters))

    for epoch in range(1, args.epochs + 1):
        # train for one epoch
        train(train_loader, tnet, criterion, optimizer, epoch)
        # evaluate on validation set
        acc = test(test_loader, tnet, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': tnet.state_dict(),
                'best_prec1': best_acc,
            }, is_best)