예제 #1
0
def main():
    global args
    args = parser.parse_args()
    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_ID)

    exp_name = args.name

    kwargs = {'num_workers': 4}

    # create model, use Learner to wrap it
    model = Learner(ConvNet())
    model = model.cuda()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['iter']
            prec = checkpoint['prec']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.mode == 'train':
        train(model, exp_name, kwargs)
    else:
        evaluation(model, args.data_dir, args.batch_size, kwargs)
예제 #2
0
def test_resnet():
    try:
        from torchvision.models.resnet import resnet18
        from metann import Learner
        net = resnet18()
        net = Learner(net)
        print(net.functional(net.parameters(), True, torch.randn(3, 3, 224, 224)))
    except ImportError:
        Warning('torchvision not included, cannot be tested')
        return
    finally:
        return
예제 #3
0
파일: models.py 프로젝트: yhqjohn/maml
class CNN(nn.Module):
    def __init__(self, config):
        super(CNN, self).__init__()
        self.module = Learner(get_cnn(config))

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            return self.module(x)
        else:
            return self.module.functional(vars, bn_training, x)
예제 #4
0
def test_learner():
    net = Learner(
        nn.Sequential(
            nn.Conv2d(3, 3, 3),
            nn.Conv2d(3, 3, 3),
            Flatten(),
            nn.Linear(3, 4),
        )).to(device)
    x = torch.randn(3, 3, 5, 5).to(device)
    y = torch.randint(0, 4, (3, )).to(device)
    criterion = nn.CrossEntropyLoss()
    params = list(net.parameters())
    for i in range(500):
        outs = net.functional(params, True, x)
        loss = criterion(outs, y)
        grads = torch.autograd.grad(loss, params)
        with torch.no_grad():
            params = [(a - 0.01 * b).requires_grad_()
                      for a, b in zip(params, grads)]
    print(loss)
    assert loss <= 0.05
예제 #5
0
파일: models.py 프로젝트: yhqjohn/maml
 def __init__(self, config):
     super(CNN, self).__init__()
     self.module = Learner(get_cnn(config))
예제 #6
0
def run(rank, size, args):
    """ Distributed Synchronous SGD Example """

    device = torch.device(args.device)

    config = [
        ('conv2d', [3, 32, 3]),
        ('relu', [True]),
        ('bn2d', [32]),
        ('max_pool2d', [2, 2]),
        ('conv2d', [32, 32, 3]),
        ('relu', [True]),
        ('bn2d', [32]),
        ('max_pool2d', [2, 2]),
        ('conv2d', [32, 32, 3]),
        ('relu', [True]),
        ('bn2d', [32]),
        ('max_pool2d', [2, 2]),
        ('conv2d', [32, 32, 3]),
        ('relu', [True]),
        ('bn2d', [32]),
        ('max_pool2d', [2, 1]),
        ('flatten', ),
        ('linear', [32 * 5 * 5, 5]),
    ]

    train_dataset = l2l.vision.datasets.MiniImagenet(root='./data',
                                                     mode='train')
    # valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation')
    test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test')
    # train_loader = task_loader(train_dataset, args.n_way, args.k_shot, args.k_query, 10000,
    #                            batch_size=args.task_num//args.world_size)
    # test_loader = task_loader(test_dataset, args.n_way, args.k_shot, args.k_query, 1024,
    #                            batch_size=args.task_num//args.world_size)

    net = get_cnn(config)  #要改
    model = Meta(update_lr=args.update_lr,
                 meta_lr=args.meta_lr,
                 update_step=args.update_step,
                 update_step_test=args.update_step_test,
                 learner=Learner(net)).to(device)
    average_model(model)
    optimizer = model.meta_optim

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)
    print('Total trainable tensors:', num)

    # num_batches = ceil(len(train_set) / float(args.batch_size))

    # for epoch in range(args.epoch):
    for epoch in range(args.epoch // 10000):
        epoch_loss = 0.0
        average_model(model)
        train_loader = task_loader(train_dataset,
                                   args.n_way,
                                   args.k_shot,
                                   args.k_query,
                                   10000,
                                   batch_size=args.task_num // args.world_size)
        for step, data in enumerate(train_loader):
            # data = tuple(map(lambda x: slc(to_device(relabel(x), device)), data))
            data = [[x.to(device) for x in collate(a) + collate(b)]
                    for a, b in data]
            optimizer.zero_grad()
            if step * args.task_num % 120 == 0:
                with model.logging:
                    loss = model(data)
                accs = model.accs()
                print('\rRank ', dist.get_rank(), 'step:', step,
                      '\ttraining acc:', accs)
            else:
                loss = model(data)
            loss.backward()
            average_gradients(model)
            optimizer.step()

            # if epoch % 5 == 0:  # evaluation
            if step * args.task_num % 2000 == 0:
                accs_all_test = []
                test_loader = task_loader(test_dataset,
                                          args.n_way,
                                          args.k_shot,
                                          args.k_query,
                                          1024,
                                          batch_size=args.task_num //
                                          args.world_size)
                model.eval()
                for data_test in test_loader:
                    data_test = [[
                        x.to(device) for x in collate(a) + collate(b)
                    ] for a, b in data_test]
                    with model.logging:
                        # data_test = tuple(map(lambda x: slc(to_device(relabel(x), device)), data_test))
                        loss = model(data_test)
                        loss.backward()
                        # accs = model.accs()
                        accs_all_test.append(model.log['corrects'])
                        optimizer.zero_grad()

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Rank ', dist.get_rank(), ', epoch ', epoch, ': ',
                      'Test acc:', accs)
                optimizer.zero_grad()
                del data_test
                model.train()