コード例 #1
0
def main(exp, batch_size, lr, gpu):
    setproctitle(exp)
    output = '../output/{}'.format(exp)
    try:
        os.makedirs(output)
    except:
        pass
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    loss_fn = CrossEntropyLoss()

    net = OmniglotNet(10, loss_fn)
    # NOTE: load weights from pre-trained model
    net.load_state_dict(
        torch.load('../output/maml-mnist-10way-5shot/train_iter_4500.pth'))
    net.cuda()
    opt = Adam(net.parameters(), lr=lr)

    train_dataset = MNIST('../data/mnist/mnist_png',
                          transform=transforms.ToTensor())
    test_dataset = MNIST('../data/mnist/mnist_png',
                         split='test',
                         transform=transforms.ToTensor())
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1,
                              pin_memory=True)
    val_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True)
    num_epochs = 10
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    for epoch in range(num_epochs):
        # train for 1 epoch
        t_loss, t_acc, v_loss, v_acc = train(train_loader, val_loader, net,
                                             loss_fn, opt, epoch)
        train_loss += t_loss
        train_acc += t_acc
        val_loss += v_loss
        val_acc += v_acc
        # eval on the val set
        np.save('{}/train_loss.npy'.format(output), np.array(train_loss))
        np.save('{}/train_acc.npy'.format(output), np.array(train_acc))
        np.save('{}/val_loss.npy'.format(output), np.array(val_loss))
        np.save('{}/val_acc.npy'.format(output), np.array(val_acc))
コード例 #2
0
ファイル: maml.py プロジェクト: tailintalent/mela
    def test(self):
        num_in_channels = 1 if self.dataset == 'mnist' else 3
        test_net = OmniglotNet(self.num_classes, self.loss_fn, num_in_channels)
        mtr_loss, mtr_acc, mval_loss, mval_acc = 0.0, 0.0, 0.0, 0.0
        # Select ten tasks randomly from the test set to evaluate on
        for _ in range(10):
            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            if self.is_cuda:
                test_net.cuda()
            test_opt = SGD(test_net.parameters(), lr=self.inner_step_size)
            task = self.get_task('../data/{}'.format(self.dataset),
                                 self.num_classes,
                                 self.num_inst,
                                 split='test')
            # Train on the train examples, using the same number of updates as in training
            train_loader = get_data_loader(task,
                                           self.inner_batch_size,
                                           split='train')
            for i in range(self.num_inner_updates):
                in_, target = train_loader.__iter__().next()
                loss, _ = forward_pass(test_net,
                                       in_,
                                       target,
                                       is_cuda=self.is_cuda)
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()
            # Evaluate the trained model on train and val examples
            tloss, tacc = evaluate(test_net, train_loader)
            val_loader = get_data_loader(task,
                                         self.inner_batch_size,
                                         split='val')
            vloss, vacc = evaluate(test_net, val_loader)
            mtr_loss += tloss
            mtr_acc += tacc
            mval_loss += vloss
            mval_acc += vacc

        mtr_loss = mtr_loss / 10
        mtr_acc = mtr_acc / 10
        mval_loss = mval_loss / 10
        mval_acc = mval_acc / 10

        print '-------------------------'
        print 'Meta train:', mtr_loss, mtr_acc
        print 'Meta val:', mval_loss, mval_acc
        print '-------------------------'
        del test_net
        return mtr_loss, mtr_acc, mval_loss, mval_acc
コード例 #3
0
def count_correct(pred, target):
    ''' count number of correct predictions in a batch '''
    pairs = [ int(x==y) for (x, y) in zip(pred, target)]
    return sum(pairs)

def train_step(task):
    train_loader = get_data_loader(task)
    ##### Test net before training, should be random accuracy ####
    print('Before training update', evaluate(net, train_loader))
    for i in range(10):
        loss,_ = forward(net, train_loader)
        print('Loss', loss.data.cpu().numpy())
        opt.zero_grad()
        loss.backward()
        opt.step() 
        print('Iter ', i, evaluate(net, train_loader))
    ##### Test net after training, should be better than random ####
    print('After training update', evaluate(net, train_loader))

# Script
for i in range(5):
    print('Run ', i)
    net = OmniglotNet(num_classes, loss_fn=CrossEntropyLoss())
    if torch,cuda.is_available():
        net.cuda()
    opt = SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
    #opt = Adam(net.weights.values(), lr=1)
    task = OmniglotTask('/home/vashisht/data/omniglot', num_classes, num_shot)
    train_step(task)

コード例 #4
0
ファイル: maml.py プロジェクト: ml-lab/pytorch-maml
class MetaLearner(object):
    def __init__(self, dataset, num_classes, num_inst, meta_batch_size,
                 meta_step_size, inner_batch_size, inner_step_size,
                 num_updates, num_inner_updates, loss_fn):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.num_inst = num_inst
        self.meta_batch_size = meta_batch_size
        self.meta_step_size = meta_step_size
        self.inner_batch_size = inner_batch_size
        self.inner_step_size = inner_step_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.loss_fn = loss_fn

        # Make the nets
        #TODO: don't actually need two nets
        num_input_channels = 1 if self.dataset == 'mnist' else 3
        self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels)
        self.net.cuda()
        self.fast_net = InnerLoop(num_classes, self.loss_fn,
                                  self.num_inner_updates, self.inner_step_size,
                                  self.inner_batch_size, self.meta_batch_size,
                                  num_input_channels)
        self.fast_net.cuda()
        self.opt = Adam(self.net.parameters())

    def get_task(self, root, n_cl, n_inst, split='train'):
        if 'mnist' in root:
            return MNISTTask(root, n_cl, n_inst, split)
        elif 'omniglot' in root:
            return OmniglotTask(root, n_cl, n_inst, split)
        else:
            print 'Unknown dataset'
            raise (Exception)

    def meta_update(self, task, ls):
        print '\n Meta update \n'
        loader = get_data_loader(task, self.inner_batch_size, split='val')
        in_, target = loader.__iter__().next()
        # We use a dummy forward / backward pass to get the correct grads into self.net
        loss, out = forward_pass(self.net, in_, target)
        # Unpack the list of grad dicts
        gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()}
        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []
        for (k, v) in self.net.named_parameters():

            def get_closure():
                key = k

                def replace_grad(grad):
                    return gradients[key]

                return replace_grad

            hooks.append(v.register_hook(get_closure()))
        # Compute grads for current step, replace with summed gradients as defined by hook
        self.opt.zero_grad()
        loss.backward()
        # Update the net parameters with the accumulated gradient according to optimizer
        self.opt.step()
        # Remove the hooks before next training phase
        for h in hooks:
            h.remove()

    def test(self):
        num_in_channels = 1 if self.dataset == 'mnist' else 3
        test_net = OmniglotNet(self.num_classes, self.loss_fn, num_in_channels)
        mtr_loss, mtr_acc, mval_loss, mval_acc = 0.0, 0.0, 0.0, 0.0
        # Select ten tasks randomly from the test set to evaluate on
        for _ in range(10):
            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            test_net.cuda()
            test_opt = SGD(test_net.parameters(), lr=self.inner_step_size)
            task = self.get_task('../data/{}'.format(self.dataset),
                                 self.num_classes,
                                 self.num_inst,
                                 split='test')
            # Train on the train examples, using the same number of updates as in training
            train_loader = get_data_loader(task,
                                           self.inner_batch_size,
                                           split='train')
            for i in range(self.num_inner_updates):
                in_, target = train_loader.__iter__().next()
                loss, _ = forward_pass(test_net, in_, target)
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()
            # Evaluate the trained model on train and val examples
            tloss, tacc = evaluate(test_net, train_loader)
            val_loader = get_data_loader(task,
                                         self.inner_batch_size,
                                         split='val')
            vloss, vacc = evaluate(test_net, val_loader)
            mtr_loss += tloss
            mtr_acc += tacc
            mval_loss += vloss
            mval_acc += vacc

        mtr_loss = mtr_loss / 10
        mtr_acc = mtr_acc / 10
        mval_loss = mval_loss / 10
        mval_acc = mval_acc / 10

        print '-------------------------'
        print 'Meta train:', mtr_loss, mtr_acc
        print 'Meta val:', mval_loss, mval_acc
        print '-------------------------'
        del test_net
        return mtr_loss, mtr_acc, mval_loss, mval_acc

    def _train(self, exp):
        ''' debugging function: learn two tasks '''
        task1 = self.get_task('../data/{}'.format(self.dataset),
                              self.num_classes, self.num_inst)
        task2 = self.get_task('../data/{}'.format(self.dataset),
                              self.num_classes, self.num_inst)
        for it in range(self.num_updates):
            grads = []
            for task in [task1, task2]:
                # Make sure fast net always starts with base weights
                self.fast_net.copy_weights(self.net)
                _, g = self.fast_net.forward(task)
                grads.append(g)
            self.meta_update(task, grads)

    def train(self, exp):
        tr_loss, tr_acc, val_loss, val_acc = [], [], [], []
        mtr_loss, mtr_acc, mval_loss, mval_acc = [], [], [], []
        for it in range(self.num_updates):
            # Evaluate on test tasks
            mt_loss, mt_acc, mv_loss, mv_acc = self.test()
            mtr_loss.append(mt_loss)
            mtr_acc.append(mt_acc)
            mval_loss.append(mv_loss)
            mval_acc.append(mv_acc)
            # Collect a meta batch update
            grads = []
            tloss, tacc, vloss, vacc = 0.0, 0.0, 0.0, 0.0
            for i in range(self.meta_batch_size):
                task = self.get_task('../data/{}'.format(self.dataset),
                                     self.num_classes, self.num_inst)
                self.fast_net.copy_weights(self.net)
                metrics, g = self.fast_net.forward(task)
                (trl, tra, vall, vala) = metrics
                grads.append(g)
                tloss += trl
                tacc += tra
                vloss += vall
                vacc += vala

            # Perform the meta update
            print 'Meta update', it
            self.meta_update(task, grads)

            # Save a model snapshot every now and then
            if it % 500 == 0:
                torch.save(self.net.state_dict(),
                           '../output/{}/train_iter_{}.pth'.format(exp, it))

            # Save stuff
            tr_loss.append(tloss / self.meta_batch_size)
            tr_acc.append(tacc / self.meta_batch_size)
            val_loss.append(vloss / self.meta_batch_size)
            val_acc.append(vacc / self.meta_batch_size)

            np.save('../output/{}/tr_loss.npy'.format(exp), np.array(tr_loss))
            np.save('../output/{}/tr_acc.npy'.format(exp), np.array(tr_acc))
            np.save('../output/{}/val_loss.npy'.format(exp),
                    np.array(val_loss))
            np.save('../output/{}/val_acc.npy'.format(exp), np.array(val_acc))

            np.save('../output/{}/meta_tr_loss.npy'.format(exp),
                    np.array(mtr_loss))
            np.save('../output/{}/meta_tr_acc.npy'.format(exp),
                    np.array(mtr_acc))
            np.save('../output/{}/meta_val_loss.npy'.format(exp),
                    np.array(mval_loss))
            np.save('../output/{}/meta_val_acc.npy'.format(exp),
                    np.array(mval_acc))