Exemple #1
0
def main():

    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)
    np.random.seed(1234)
    setup_seed(1234)

    print(args)

    config = [
        ('conv2d', [64, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 64 * 5 * 5])
    ]

    #device = torch.device('cuda')
    maml = Meta(args, config)

    # ------------------------------------------
    dataset_train = MiniImageNet(phase='train')
    dataset_val = MiniImageNet(phase='val')
    data_loader = FewShotDataloader

    dloader_train = data_loader(
        dataset=dataset_train,
        nKnovel=args.n_way,
        nKbase=0,
        nExemplars=args.k_spt,  # num training examples per novel category
        nTestNovel=args.n_way * args.k_spt,  # num test examples for all the novel categories
        nTestBase=0,  # num test examples for all the base categories
        batch_size=args.task_num,
        num_workers=4,
        epoch_size=args.task_num * 1000,  # num of batches per epoch
    )

    dloader_val = data_loader(
        dataset=dataset_val,
        nKnovel=args.n_way,
        nKbase=0,
        nExemplars=args.k_qry,  # num training examples per novel category
        nTestNovel=args.k_qry * args.n_way,  # num test examples for all the novel categories
        nTestBase=0,  # num test examples for all the base categories
        batch_size=1,
        num_workers=0,
        epoch_size=1000,  # num of batches per epoch
    )

    # ---------------------------------------------------

    max_val_acc = 0

    if not os.path.exists(os.path.join(args.save_path)):
        os.makedirs(os.path.join(args.save_path))

    for epoch in range(1, args.epoch+1):
        # fetch meta_batchsz num of episode each time

        train_accuracies = []

        for i, batch in enumerate(tqdm(dloader_train(epoch)), 1):

            data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]

            data_support = data_support.float()
            data_query = data_query.float()

            labels_support = labels_support.long()
            labels_query = labels_query.long()

            #x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, loss_cls = maml.forward(data_support, labels_support, data_query, labels_query)

            train_accuracies.append(accs)

            if i % 100 == 0:
                avg_accs = np.array(train_accuracies).mean(axis=0).astype(np.float16)* 100
                print( 'Train Epoch: {} \tLoss_cls: {:.4f}'.format(i, loss_cls), '\ttraining acc:', avg_accs)

        # Evaluate on the validation split

        val_accuracies = []

        for i, batch in enumerate(tqdm(dloader_val(epoch)), 1):

            data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]

            data_support = data_support.float().squeeze(0)
            data_query = data_query.float().squeeze(0)

            labels_support = labels_support.long().squeeze(0)
            labels_query = labels_query.long().squeeze(0)

            accs = maml.finetunning(data_support, labels_support, data_query, labels_query)
            val_accuracies.append(accs)

            # [b, update_step+1]
        val_acc_avg = np.array(val_accuracies).mean(axis=0).astype(np.float16)[-1]*100

        maml.load_selfvars(data_support)

        if val_acc_avg > max_val_acc:
            max_val_acc = val_acc_avg


            state = {'epoch': epoch + 1, 'model': maml.net.state_dict(),
                     'optimizer': maml.meta_optim.state_dict()}
            torch.save(state
                       , os.path.join(args.save_path, 'best_model.pth.tar'.format(epoch)))

            print( 'Validation Epoch: {}\t\t\tAccuracy: {:.2f}  % (Best)' \
                .format(epoch, val_acc_avg))
        else:
            print( 'Validation Epoch: {}\t\t\tAccuracy: {:.2f} %' \
                .format(epoch, val_acc_avg))

        if epoch % 2 == 0:
            state = {'epoch': epoch + 1, 'model': maml.net.state_dict(),
                     'optimizer': maml.meta_optim.state_dict()}
            torch.save(state
                       , os.path.join(args.save_path, 'epoch_{}.pth.tar'.format(epoch)))
Exemple #2
0
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        embedding_net.load_state_dict(checkpoint['embedding'])
        seg_head.load_state_dict(checkpoint['head'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return embedding_net, seg_head, optimizer, start_epoch


if __name__ == '__main__':
    setup_seed(1234)
    parser = argparse.ArgumentParser()
    parser.add_argument('--cls-head',
                        type=str,
                        default='R2D2',
                        help='R2D2 or SVM')
    parser.add_argument('--num-epoch',
                        type=int,
                        default=50,
                        help='number of training epochs')
    parser.add_argument('--save-epoch',
                        type=int,
                        default=2,
                        help='frequency of model saving')
    parser.add_argument('--train-shot',
                        type=int,
Exemple #3
0
def main():

    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)
    np.random.seed(1234)
    setup_seed(1234)

    print(args)

    config = [
        ('conv2d', [64, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 64 * 5 * 5])
    ]

    #device = torch.device('cuda')
    maml = Meta(args, config)
    maml.net.load_state_dict(torch.load(args.load_path)['model'])

    # ------------------------------------------

    dataset_test = tieredImageNet(phase='test')
    data_loader = FewShotDataloader
    dloader_test = data_loader(
        dataset=dataset_test,
        nKnovel=args.n_way,
        nKbase=0,
        nExemplars=args.k_spt, # num training examples per novel category
        nTestNovel=args.k_qry * args.n_way, # num test examples for all the novel categories
        nTestBase=0, # num test examples for all the base categories
        batch_size=1,
        num_workers=1,
        epoch_size=args.epoch, # num of batches per epoch
    )

    # ---------------------------------------------------

    test_accuracies = []

    for i, batch in enumerate(tqdm(dloader_test()), 1):

        data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]

        data_support = data_support.float().squeeze(0)
        data_query = data_query.float().squeeze(0)

        labels_support = labels_support.long().squeeze(0)
        labels_query = labels_query.long().squeeze(0)

        accs = maml.finetunning(data_support, labels_support, data_query, labels_query)
        test_accuracies.append(accs)

        # [b, update_step+1]
        accuracies = np.array(test_accuracies)
        accuracies = accuracies[:,-1]*100
        avg = np.array(accuracies).mean().astype(np.float16)
        std = np.std(np.array(accuracies))
        ci95 = 1.96 * std / np.sqrt(i + 1)

        if i % 50 == 0:
            print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} %)' \
                  .format(i, args.epoch, avg, ci95))
Exemple #4
0
def main():

    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)
    np.random.seed(1234)
    setup_seed(1234)

    print(args)

    config = [('conv2d', [64, 3, 3, 3, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [64, 64, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [64]), ('max_pool2d', [2, 2, 0]),
              ('conv2d', [64, 64, 3, 3, 1, 0]), ('relu', [True]), ('bn', [64]),
              ('max_pool2d', [2, 2, 0]), ('conv2d', [64, 64, 3, 3, 1, 0]),
              ('relu', [True]), ('bn', [64]), ('max_pool2d', [2, 1, 0]),
              ('flatten', []), ('linear', [args.n_way, 64 * 5 * 5])]

    #device = torch.device('cuda')
    maml = Meta(args, config)

    max_val_acc = 0

    if not os.path.exists(os.path.join(args.save_path)):
        os.makedirs(os.path.join(args.save_path))

    for epoch in range(1, args.epoch + 1):
        # fetch meta_batchsz num of episode each time

        train_accuracies = []

        train_dataset = Office31Dataset(num_classes=args.n_way,
                                        num_support=args.k_spt,
                                        num_query=args.k_qry,
                                        num_epoch=32,
                                        phase='train')
        dloader_train = DataLoader(train_dataset,
                                   shuffle=True,
                                   num_workers=32,
                                   batch_size=args.task_num)

        for i, batch in enumerate(tqdm(dloader_train)):

            data_support, _, labels_support, data_query, _, labels_query = [
                x.cuda() for x in batch
            ]

            data_support = data_support.float()
            data_query = data_query.float()

            labels_support = labels_support.long()
            labels_query = labels_query.long()

            #x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs, loss_cls = maml.forward(data_support, labels_support,
                                          data_query, labels_query)

            train_accuracies.append(accs)

            if i % 100 == 0:
                avg_accs = np.array(train_accuracies).mean(axis=0).astype(
                    np.float16) * 100
                print('Train Epoch: {} \tLoss_cls: {:.4f}'.format(i, loss_cls),
                      '\ttraining acc:', avg_accs)

        # Evaluate on the validation split

        val_accuracies = []

        val_dataset = Office31Dataset(num_classes=args.n_way,
                                      num_support=args.k_spt,
                                      num_query=args.k_qry,
                                      num_epoch=500,
                                      phase='val')
        dloader_val = DataLoader(val_dataset,
                                 shuffle=False,
                                 num_workers=32,
                                 batch_size=1)

        for i, batch in enumerate(tqdm(dloader_val)):

            data_support, _, labels_support, data_query, _, labels_query = [
                x.cuda() for x in batch
            ]

            data_support = data_support.float().squeeze(0)
            data_query = data_query.float().squeeze(0)

            labels_support = labels_support.long().squeeze(0)
            labels_query = labels_query.long().squeeze(0)

            accs = maml.finetunning(data_support, labels_support, data_query,
                                    labels_query)
            val_accuracies.append(accs)

            # [b, update_step+1]
        val_acc_avg = np.array(val_accuracies).mean(axis=0).astype(
            np.float16)[-1] * 100

        maml.load_selfvars(data_support)

        if val_acc_avg > max_val_acc:
            max_val_acc = val_acc_avg

            state = {
                'epoch': epoch + 1,
                'model': maml.net.state_dict(),
                'optimizer': maml.meta_optim.state_dict()
            }
            torch.save(
                state,
                os.path.join(args.save_path,
                             'best_model.pth.tar'.format(epoch)))

            print( 'Validation Epoch: {}\t\t\tAccuracy: {:.2f}  % (Best)' \
                .format(epoch, val_acc_avg))
        else:
            print( 'Validation Epoch: {}\t\t\tAccuracy: {:.2f} %' \
                .format(epoch, val_acc_avg))

        if epoch % 2 == 0:
            state = {
                'epoch': epoch + 1,
                'model': maml.net.state_dict(),
                'optimizer': maml.meta_optim.state_dict()
            }
            torch.save(
                state,
                os.path.join(args.save_path, 'epoch_{}.pth.tar'.format(epoch)))
Exemple #5
0
def main():

    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)
    np.random.seed(1234)
    setup_seed(1234)

    print(args)

    config = [
        ('conv2d', [64, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [64, 64, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [64]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 64 * 5 * 5])
    ]

    #device = torch.device('cuda')
    maml = Meta(args, config)
    maml.net.load_state_dict(torch.load(args.load_path)['model'])

    # ---------------------------------------------------

    test_dataset = Office31Dataset(num_classes=args.n_way, num_support=args.k_spt,
                                    num_query=args.k_qry, num_epoch=args.epoch, phase='test')
    dloader_test = DataLoader(test_dataset, shuffle=True, num_workers=32, batch_size= 1)

    test_accuracies = []

    for i, batch in enumerate(tqdm(dloader_test)):

        data_support, _, labels_support, data_query, _, labels_query = [x.cuda() for x in batch]

        data_support = data_support.float().squeeze(0)
        data_query = data_query.float().squeeze(0)

        labels_support = labels_support.long().squeeze(0)
        labels_query = labels_query.long().squeeze(0)

        accs = maml.finetunning(data_support, labels_support, data_query, labels_query)
        test_accuracies.append(accs)

        # [b, update_step+1]
        accuracies = np.array(test_accuracies)
        accuracies = 100*accuracies[:,-1]
        avg = np.array(accuracies).mean().astype(np.float16)
        std = np.std(np.array(accuracies))
        ci95 = 1.96 * std / np.sqrt(i + 1)

        if i % 50 == 0:
            print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} %)' \
                  .format(i, args.epoch, avg, ci95))