예제 #1
0
    def __init__(self, dataset, num_classes, num_inst, meta_batch_size,
                 meta_step_size, inner_batch_size, inner_step_size,
                 num_updates, num_inner_updates, loss_fn):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.num_inst = num_inst
        self.meta_batch_size = meta_batch_size
        self.meta_step_size = meta_step_size
        self.inner_batch_size = inner_batch_size
        self.inner_step_size = inner_step_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.loss_fn = loss_fn

        # Make the nets
        #TODO: don't actually need two nets
        # Note this code base has two nets - one for the general, one for the inner loop
        num_input_channels = 1 if self.dataset == 'mnist' else 3
        self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels)
        self.net.cuda()
        self.fast_net = InnerLoop(num_classes, self.loss_fn,
                                  self.num_inner_updates, self.inner_step_size,
                                  self.inner_batch_size, self.meta_batch_size,
                                  num_input_channels)
        self.fast_net.cuda()
        # This will be only over the net, not the fast net.
        self.opt = Adam(self.net.parameters(), lr=meta_step_size)
예제 #2
0
파일: maml.py 프로젝트: ml-lab/pytorch-maml
    def __init__(self, dataset, num_classes, num_inst, meta_batch_size,
                 meta_step_size, inner_batch_size, inner_step_size,
                 num_updates, num_inner_updates, loss_fn):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.num_inst = num_inst
        self.meta_batch_size = meta_batch_size
        self.meta_step_size = meta_step_size
        self.inner_batch_size = inner_batch_size
        self.inner_step_size = inner_step_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.loss_fn = loss_fn

        # Make the nets
        #TODO: don't actually need two nets
        num_input_channels = 1 if self.dataset == 'mnist' else 3
        self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels)
        self.net.cuda()
        self.fast_net = InnerLoop(num_classes, self.loss_fn,
                                  self.num_inner_updates, self.inner_step_size,
                                  self.inner_batch_size, self.meta_batch_size,
                                  num_input_channels)
        self.fast_net.cuda()
        self.opt = Adam(self.net.parameters())
    def Evaluate(self, individual):
        """
        Adapted inner loop of Model Agnostic Meta-Learning to be Baldwinian
        a la Fernando et al. 2018.
        Source: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py
    
        """

        inner_net = InnerLoop(self.num_classes, self.loss_fn, self.num_updates,
                              self.inner_step_size, self.inner_batch_size,
                              self.meta_batch_size, self.num_input_channels)

        #tasks = {}
        #tasks['t1'] = self.get_task("../data/{}".format(self.dataset), self.num_classes, self.num_inst)

        for t in range(self.meta_batch_size):

            # Outer-loop is completed by NES for G generations

            inner_net.copy_weights(individual['network'])
            metrics = inner_net.forward(self.tasks[str(t)])

            # Want validation accuracy for fitness (tr_loss, tr_acc, val_loss, val_acc):
            individual['fitness'] += metrics[-1]

            keys = ("_tr_loss", "_tr_acc", "_val_loss", "_val_acc")
            idx = 0

            for k in keys:

                key = "t" + str(t) + k
                individual[key] = metrics[idx]
                idx += 1

        self.Record_Performance(individual)
예제 #4
0
    def __init__(self,
                 meta_dataset,
                 fs_dataset,
                 K,
                 meta_lr,
                 inner_lr,
                 layer,
                 hidden,
                 tissue_num,
                 meta_batch_size,
                 inner_batch_size,
                 num_updates,
                 num_inner_updates,
                 tissue_index_list,
                 patience=3,
                 num_trials=10):

        super(self.__class__, self).__init__()

        self.meta_dataset = meta_dataset
        self.fs_dataset = fs_dataset

        self.meta_batch_size = meta_batch_size
        self.inner_batch_size = inner_batch_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.num_trials = num_trials
        self.hidden = hidden
        self.patience = patience
        self.feature_num = self.fs_dataset.feature.shape[1]

        self.K = K
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        self.layer = layer
        self.hidden = hidden
        self.tissue_index_list = tissue_index_list
        self.tissue_num = tissue_num

        self.observed_tissue_model = mlp(self.feature_num, layer, hidden)
        self.observed_opt = torch.optim.Adam(
            self.observed_tissue_model.parameters(),
            lr=self.meta_lr,
            betas=(0.9, 0.99),
            eps=1e-05)
        self.inner_net = InnerLoop(self.num_inner_updates, self.inner_lr,
                                   self.feature_num, layer, hidden)

        #torch.cuda.manual_seed(args.seed)
        self.observed_tissue_model.cuda()
        self.inner_net.cuda()
예제 #5
0
    def __init__(self, log, tb_writer, args):
        super(self.__class__, self).__init__()
        self.log = log
        self.tb_writer = tb_writer
        self.args = args
        self.loss_fn = MSELoss()

        self.net = OmniglotNet(self.loss_fn, args).to(device)

        self.fast_net = InnerLoop(self.loss_fn, args).to(device)

        self.opt = Adam(self.net.parameters(), lr=args.meta_lr)
        self.sampler = BatchSampler(args)
        self.memory = ReplayBuffer()
예제 #6
0
def Evaluate(individual):

    """
    Adapted inner loop of Model Agnostic Meta-Learning to be Baldwinian
    a la Fernando et al. 2018.
    Source: https://github.com/katerakelly/pytorch-maml/blob/master/src/maml.py
    
    """

    #tasks['t1'] = self.get_task("../data/{}".format(self.dataset), self.num_classes, self.num_inst)

    inner_net = InnerLoop(5, CrossEntropyLoss(), 3, 0.01, 100, 10, 3)

    for t in range(10):

        task = Get_Task("../data/{}".format('omniglot'), 5, 10)

        # Outer-loop is completed by NES for G generations

        inner_net.copy_weights(individual['network'])
        metrics = inner_net.forward(task)

        # Want validation accuracy for fitness (tr_loss, tr_acc, val_loss, val_acc):  
        print(metrics)
예제 #7
0
파일: mpm.py 프로젝트: lnhust/MPM
    def __init__(self,
                 num_way,
                 num_shot,
                 classifier_param,
                 loss_fn,
                 inner_step_size,
                 num_inner_updates,
                 has_bias=False,
                 alpha=[5, 1, 1],
                 num_layer=1):
        super(self.__class__, self).__init__()
        self.num_way = num_way
        self.num_shot = num_shot
        self.num_inner_updates = num_inner_updates
        self.inner_step_size = inner_step_size
        self.alpha = alpha

        # Make the nets
        #TODO: don't actually need two nets
        net_param_opt = {
            'userelu': False,
            'in_planes': 3,
            'out_planes': [64, 64, 64, 64],
            'num_stages': 4,
            'num_dim': 64 * 5 * 5,
            'num_Kall': 64
        }
        num_dim = net_param_opt['num_dim']

        self.model = {}
        self.model['embedding_net'] = ConvNet(net_param_opt)
        self.model['embedding_net'].cuda()

        self.model['classifier_net'] = classifier_param

        self.model['test_net'] = InnerLoop(num_dim, 5, loss_fn,
                                           self.num_inner_updates,
                                           self.inner_step_size, 1,
                                           classifier_param, has_bias,
                                           self.alpha, num_layer)
        self.model['test_net'].cuda()
예제 #8
0
파일: maml.py 프로젝트: ml-lab/pytorch-maml
class MetaLearner(object):
    def __init__(self, dataset, num_classes, num_inst, meta_batch_size,
                 meta_step_size, inner_batch_size, inner_step_size,
                 num_updates, num_inner_updates, loss_fn):
        super(self.__class__, self).__init__()
        self.dataset = dataset
        self.num_classes = num_classes
        self.num_inst = num_inst
        self.meta_batch_size = meta_batch_size
        self.meta_step_size = meta_step_size
        self.inner_batch_size = inner_batch_size
        self.inner_step_size = inner_step_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.loss_fn = loss_fn

        # Make the nets
        #TODO: don't actually need two nets
        num_input_channels = 1 if self.dataset == 'mnist' else 3
        self.net = OmniglotNet(num_classes, self.loss_fn, num_input_channels)
        self.net.cuda()
        self.fast_net = InnerLoop(num_classes, self.loss_fn,
                                  self.num_inner_updates, self.inner_step_size,
                                  self.inner_batch_size, self.meta_batch_size,
                                  num_input_channels)
        self.fast_net.cuda()
        self.opt = Adam(self.net.parameters())

    def get_task(self, root, n_cl, n_inst, split='train'):
        if 'mnist' in root:
            return MNISTTask(root, n_cl, n_inst, split)
        elif 'omniglot' in root:
            return OmniglotTask(root, n_cl, n_inst, split)
        else:
            print 'Unknown dataset'
            raise (Exception)

    def meta_update(self, task, ls):
        print '\n Meta update \n'
        loader = get_data_loader(task, self.inner_batch_size, split='val')
        in_, target = loader.__iter__().next()
        # We use a dummy forward / backward pass to get the correct grads into self.net
        loss, out = forward_pass(self.net, in_, target)
        # Unpack the list of grad dicts
        gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()}
        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []
        for (k, v) in self.net.named_parameters():

            def get_closure():
                key = k

                def replace_grad(grad):
                    return gradients[key]

                return replace_grad

            hooks.append(v.register_hook(get_closure()))
        # Compute grads for current step, replace with summed gradients as defined by hook
        self.opt.zero_grad()
        loss.backward()
        # Update the net parameters with the accumulated gradient according to optimizer
        self.opt.step()
        # Remove the hooks before next training phase
        for h in hooks:
            h.remove()

    def test(self):
        num_in_channels = 1 if self.dataset == 'mnist' else 3
        test_net = OmniglotNet(self.num_classes, self.loss_fn, num_in_channels)
        mtr_loss, mtr_acc, mval_loss, mval_acc = 0.0, 0.0, 0.0, 0.0
        # Select ten tasks randomly from the test set to evaluate on
        for _ in range(10):
            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            test_net.cuda()
            test_opt = SGD(test_net.parameters(), lr=self.inner_step_size)
            task = self.get_task('../data/{}'.format(self.dataset),
                                 self.num_classes,
                                 self.num_inst,
                                 split='test')
            # Train on the train examples, using the same number of updates as in training
            train_loader = get_data_loader(task,
                                           self.inner_batch_size,
                                           split='train')
            for i in range(self.num_inner_updates):
                in_, target = train_loader.__iter__().next()
                loss, _ = forward_pass(test_net, in_, target)
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()
            # Evaluate the trained model on train and val examples
            tloss, tacc = evaluate(test_net, train_loader)
            val_loader = get_data_loader(task,
                                         self.inner_batch_size,
                                         split='val')
            vloss, vacc = evaluate(test_net, val_loader)
            mtr_loss += tloss
            mtr_acc += tacc
            mval_loss += vloss
            mval_acc += vacc

        mtr_loss = mtr_loss / 10
        mtr_acc = mtr_acc / 10
        mval_loss = mval_loss / 10
        mval_acc = mval_acc / 10

        print '-------------------------'
        print 'Meta train:', mtr_loss, mtr_acc
        print 'Meta val:', mval_loss, mval_acc
        print '-------------------------'
        del test_net
        return mtr_loss, mtr_acc, mval_loss, mval_acc

    def _train(self, exp):
        ''' debugging function: learn two tasks '''
        task1 = self.get_task('../data/{}'.format(self.dataset),
                              self.num_classes, self.num_inst)
        task2 = self.get_task('../data/{}'.format(self.dataset),
                              self.num_classes, self.num_inst)
        for it in range(self.num_updates):
            grads = []
            for task in [task1, task2]:
                # Make sure fast net always starts with base weights
                self.fast_net.copy_weights(self.net)
                _, g = self.fast_net.forward(task)
                grads.append(g)
            self.meta_update(task, grads)

    def train(self, exp):
        tr_loss, tr_acc, val_loss, val_acc = [], [], [], []
        mtr_loss, mtr_acc, mval_loss, mval_acc = [], [], [], []
        for it in range(self.num_updates):
            # Evaluate on test tasks
            mt_loss, mt_acc, mv_loss, mv_acc = self.test()
            mtr_loss.append(mt_loss)
            mtr_acc.append(mt_acc)
            mval_loss.append(mv_loss)
            mval_acc.append(mv_acc)
            # Collect a meta batch update
            grads = []
            tloss, tacc, vloss, vacc = 0.0, 0.0, 0.0, 0.0
            for i in range(self.meta_batch_size):
                task = self.get_task('../data/{}'.format(self.dataset),
                                     self.num_classes, self.num_inst)
                self.fast_net.copy_weights(self.net)
                metrics, g = self.fast_net.forward(task)
                (trl, tra, vall, vala) = metrics
                grads.append(g)
                tloss += trl
                tacc += tra
                vloss += vall
                vacc += vala

            # Perform the meta update
            print 'Meta update', it
            self.meta_update(task, grads)

            # Save a model snapshot every now and then
            if it % 500 == 0:
                torch.save(self.net.state_dict(),
                           '../output/{}/train_iter_{}.pth'.format(exp, it))

            # Save stuff
            tr_loss.append(tloss / self.meta_batch_size)
            tr_acc.append(tacc / self.meta_batch_size)
            val_loss.append(vloss / self.meta_batch_size)
            val_acc.append(vacc / self.meta_batch_size)

            np.save('../output/{}/tr_loss.npy'.format(exp), np.array(tr_loss))
            np.save('../output/{}/tr_acc.npy'.format(exp), np.array(tr_acc))
            np.save('../output/{}/val_loss.npy'.format(exp),
                    np.array(val_loss))
            np.save('../output/{}/val_acc.npy'.format(exp), np.array(val_acc))

            np.save('../output/{}/meta_tr_loss.npy'.format(exp),
                    np.array(mtr_loss))
            np.save('../output/{}/meta_tr_acc.npy'.format(exp),
                    np.array(mtr_acc))
            np.save('../output/{}/meta_val_loss.npy'.format(exp),
                    np.array(mval_loss))
            np.save('../output/{}/meta_val_acc.npy'.format(exp),
                    np.array(mval_acc))
예제 #9
0
class MetaLearner(object):
    def __init__(self,
                 meta_dataset,
                 fs_dataset,
                 K,
                 meta_lr,
                 inner_lr,
                 layer,
                 hidden,
                 tissue_num,
                 meta_batch_size,
                 inner_batch_size,
                 num_updates,
                 num_inner_updates,
                 tissue_index_list,
                 patience=3,
                 num_trials=10):

        super(self.__class__, self).__init__()

        self.meta_dataset = meta_dataset
        self.fs_dataset = fs_dataset

        self.meta_batch_size = meta_batch_size
        self.inner_batch_size = inner_batch_size
        self.num_updates = num_updates
        self.num_inner_updates = num_inner_updates
        self.num_trials = num_trials
        self.hidden = hidden
        self.patience = patience
        self.feature_num = self.fs_dataset.feature.shape[1]

        self.K = K
        self.meta_lr = meta_lr
        self.inner_lr = inner_lr
        self.layer = layer
        self.hidden = hidden
        self.tissue_index_list = tissue_index_list
        self.tissue_num = tissue_num

        self.observed_tissue_model = mlp(self.feature_num, layer, hidden)
        self.observed_opt = torch.optim.Adam(
            self.observed_tissue_model.parameters(),
            lr=self.meta_lr,
            betas=(0.9, 0.99),
            eps=1e-05)
        self.inner_net = InnerLoop(self.num_inner_updates, self.inner_lr,
                                   self.feature_num, layer, hidden)

        #torch.cuda.manual_seed(args.seed)
        self.observed_tissue_model.cuda()
        self.inner_net.cuda()

    def zero_shot_test(self, unseen_train_loader, unseen_vali_loader,
                       unseen_test_loader):

        unseen_tissue_model = mlp(self.feature_num, self.layer, self.hidden)

        # First need to copy the original meta learning model
        unseen_tissue_model.copy_weights(self.observed_tissue_model)
        unseen_tissue_model.cuda()
        unseen_tissue_model.eval()

        train_performance = evaluate_cv(unseen_tissue_model,
                                        unseen_train_loader)
        vali_performance = evaluate_cv(unseen_tissue_model, unseen_vali_loader)
        test_performance = evaluate_cv(unseen_tissue_model, unseen_test_loader)

        return train_performance, vali_performance, test_performance, np.mean(
            tissue_loss)

    def meta_update(self, test_loader, ls):

        #Walk in 'Meta update' function
        in_, target = test_loader.__iter__().next()

        # We use a dummy forward / backward pass to get the correct grads into self.net
        loss, out = forward_pass(self.observed_tissue_model, in_, target)

        # Unpack the list of grad dicts
        gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()}

        #for k, val, in gradients.items():
        #	gradients[k] = val / args.meta_batch_size
        #	print k,':',gradients[k]

        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []

        for (k, v) in self.observed_tissue_model.named_parameters():

            def get_closure():
                key = k

                def replace_grad(grad):
                    return gradients[key]

                return replace_grad

            hooks.append(v.register_hook(get_closure()))

        # Compute grads for current step, replace with summed gradients as defined by hook
        self.observed_opt.zero_grad()

        loss.backward()
        # Update the net parameters with the accumulated gradient according to optimizer
        self.observed_opt.step()
        # Remove the hooks before next training phase
        for h in hooks:
            h.remove()

    def unseen_tissue_learn(self, unseen_train_loader, unseen_test_loader):

        unseen_tissue_model = mlp(self.feature_num, self.layer, self.hidden)

        # First need to copy the original meta learning model
        unseen_tissue_model.copy_weights(self.observed_tissue_model)
        unseen_tissue_model.cuda()
        #unseen_tissue_model.train()
        unseen_tissue_model.eval()

        unseen_opt = torch.optim.SGD(unseen_tissue_model.parameters(),
                                     lr=self.inner_lr)
        #unseen_opt = torch.optim.Adam(unseen_tissue_model.parameters(), lr=self.inner_lr, betas=(0.9, 0.99), eps=1e-05)

        # Here test_feature and test_label contains only one tissue info
        #unseen_train_loader, unseen_test_loader = get_unseen_data_loader(test_feature, test_label, K, args.inner_batch_size)
        for i in range(self.num_inner_updates):

            in_, target = unseen_train_loader.__iter__().next()
            loss, _ = forward_pass(unseen_tissue_model, in_, target)
            unseen_opt.zero_grad()
            loss.backward()
            unseen_opt.step()

        # Test on the rest of cell lines in this tissue (unseen_test_loader)
        mtrain_loss, mtrain_pear_corr, mtrain_spearman_corr, _, _ = evaluate_new(
            unseen_tissue_model, unseen_train_loader, 1)
        mtest_loss, mtest_pear_corr, mtest_spearman_corr, test_prediction, test_true_label = evaluate_new(
            unseen_tissue_model, unseen_test_loader, 0)

        return mtrain_loss, mtrain_pear_corr, mtrain_spearman_corr, mtest_loss, mtest_pear_corr, mtest_spearman_corr, test_prediction, test_true_label

    def train(self):

        best_train_loss_test_corr, best_train_corr_test_corr = 0, 0
        best_train_corr_test_scorr, best_train_scorr_test_scorr = 0, 0

        best_train_corr_model = ''

        unseen_train_loader, unseen_test_loader = get_unseen_data_loader(
            self.fs_dataset.feature, self.fs_dataset.label, self.K)
        #unseen_train_loader, unseen_test_loader = get_unseen_data_loader( self.fs_dataset.feature, self.fs_dataset.label, self.fs_catdata, self.K )

        # Here the training process starts
        best_fewshot_train_corr, best_fewshot_train_loss = -2, 1000
        best_fewshot_test_corr, best_fewshot_test_loss = -2, 1000
        best_train_loss_epoch, best_train_corr_epoch = 0, 0

        best_fewshot_train_spearman_corr, best_fewshot_test_spearman_corr = -2, -2
        best_train_spearman_corr_epoch = 0

        train_loss, train_corr = np.zeros((self.num_updates, )), np.zeros(
            (self.num_updates, ))
        test_loss, test_corr = np.zeros((self.num_updates, )), np.zeros(
            (self.num_updates, ))
        train_spearman_corr, test_spearman_corr = np.zeros(
            (self.num_updates, )), np.zeros((self.num_updates, ))

        for epoch in range(self.num_updates):

            # Collect a meta batch update
            grads = []
            meta_train_loss, meta_train_corr, meta_val_loss, meta_val_corr = np.zeros(
                (self.meta_batch_size, )), np.zeros(
                    (self.meta_batch_size, )), np.zeros(
                        (self.meta_batch_size, )), np.zeros(
                            (self.meta_batch_size, ))

            self.inner_net.copy_weights(self.observed_tissue_model)
            for i in range(self.meta_batch_size):

                observed_train_loader, observed_test_loader = get_observed_data_loader(
                    self.meta_dataset.feature, self.meta_dataset.label,
                    self.tissue_index_list, self.K, self.inner_batch_size,
                    self.tissue_num)

                #self.inner_net.copy_weights( self.observed_tissue_model )

                metrics, g = self.inner_net.forward(observed_train_loader,
                                                    observed_test_loader)
                grads.append(g)

                meta_train_loss[i], meta_train_corr[i], meta_val_loss[
                    i], meta_val_corr[i] = metrics

            # Perform the meta update
            self.meta_update(observed_test_loader, grads)

            #meta_train_loss_mean, meta_train_corr_mean, meta_val_loss_mean, meta_val_corr_mean = meta_train_loss.mean(), meta_train_corr.mean(), meta_val_loss.mean(), meta_val_corr.mean()

            ## Evaluate K shot test tasks
            train_loss[epoch], train_corr[epoch], train_spearman_corr[
                epoch], test_loss[epoch], test_corr[epoch], test_spearman_corr[
                    epoch], _, _ = self.unseen_tissue_learn(
                        unseen_train_loader, unseen_test_loader)

            if test_loss[epoch] < best_fewshot_test_loss:
                best_fewshot_test_loss = test_loss[epoch]

            if test_corr[epoch] > best_fewshot_test_corr:
                best_fewshot_test_corr = test_corr[epoch]

            if train_loss[epoch] < best_fewshot_train_loss:
                best_fewshot_train_loss = train_loss[epoch]
                best_train_loss_epoch = epoch

            if train_corr[epoch] > best_fewshot_train_corr:
                best_fewshot_train_corr = train_corr[epoch]
                best_train_corr_epoch = epoch
                best_train_corr_model = self.observed_tissue_model

            if train_spearman_corr[epoch] > best_fewshot_train_spearman_corr:
                best_fewshot_train_spearman_corr = train_spearman_corr[epoch]
                best_train_spearman_corr_epoch = epoch

            if test_spearman_corr[epoch] > best_fewshot_test_spearman_corr:
                best_fewshot_test_spearman_corr = test_spearman_corr[epoch]
                best_test_spearman_epoch = epoch

            print 'Few shot', epoch, 'train_loss:', float(
                '%.3f' % train_loss[epoch]), 'train_pearson:', float(
                    '%.3f' % train_corr[epoch]), 'train_spearman:', float(
                        '%.3f' % train_spearman_corr[epoch]),
            print 'test_loss:', float(
                '%.3f' % test_loss[epoch]), 'test_pearson:', float(
                    '%.3f' % test_corr[epoch]), 'test_spearman:', float(
                        '%.3f' % test_spearman_corr[epoch])

        best_train_loss_test_corr = test_corr[best_train_loss_epoch]
        best_train_corr_test_corr = test_corr[best_train_corr_epoch]
        best_train_corr_test_scorr = test_spearman_corr[best_train_corr_epoch]
        best_train_scorr_test_scorr = test_spearman_corr[
            best_train_spearman_corr_epoch]

        print '--trial summerize--', 'best_train_loss_test_corr:', float(
            '%.3f' %
            best_train_loss_test_corr), 'best_train_corr_test_corr', float(
                '%.3f' % best_train_corr_test_corr
            ), 'best_train_corr_test_scorr', float(
                '%.3f' % best_train_corr_test_scorr
            ), 'best_train_scorr_test_corr', float('%.3f' %
                                                   best_train_scorr_test_scorr)

        return best_train_loss_test_corr, best_train_corr_test_corr, best_train_corr_test_scorr, best_train_scorr_test_scorr, best_train_corr_model
예제 #10
0
파일: maml.py 프로젝트: ajabri/pytorch-maml
 def init_proc(self):
     global net
     net = InnerLoop(self.num_classes, self.loss_fn, self.num_inner_updates, self.inner_step_size, self.inner_batch_size, self.meta_batch_size, self.num_input_channels)
     net.cuda()
예제 #11
0
class MetaLearner(object):
    def __init__(self, log, tb_writer, args):
        super(self.__class__, self).__init__()
        self.log = log
        self.tb_writer = tb_writer
        self.args = args
        self.loss_fn = MSELoss()

        self.net = OmniglotNet(self.loss_fn, args).to(device)

        self.fast_net = InnerLoop(self.loss_fn, args).to(device)

        self.opt = Adam(self.net.parameters(), lr=args.meta_lr)
        self.sampler = BatchSampler(args)
        self.memory = ReplayBuffer()

    def meta_update(self, episode_i, ls):
        in_ = episode_i.observations[:, :, 0]
        target = episode_i.rewards[:, :, 0]

        # We use a dummy forward / backward pass to get the correct grads into self.net
        loss, out = forward_pass(self.net, in_, target)

        # Unpack the list of grad dicts
        gradients = {k: sum(d[k] for d in ls) for k in ls[0].keys()}

        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []
        for (k, v) in self.net.named_parameters():

            def get_closure():
                key = k

                def replace_grad(grad):
                    return gradients[key]

                return replace_grad

            hooks.append(v.register_hook(get_closure()))

        # Compute grads for current step, replace with summed gradients as defined by hook
        self.opt.zero_grad()
        loss.backward()

        # Update the net parameters with the accumulated gradient according to optimizer
        self.opt.step()

        # Remove the hooks before next training phase
        for h in hooks:
            h.remove()

    def test(self, i_task, episode_i_):
        predictions_ = []
        for i_agent in range(self.args.n_agent):
            test_net = OmniglotNet(self.loss_fn, self.args).to(device)

            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            test_opt = SGD(test_net.parameters(), lr=self.args.fast_lr)

            episode_i = self.memory.storage[i_task - 1]

            # Train on the train examples, using the same number of updates as in training
            for i in range(self.args.fast_num_update):
                in_ = episode_i.observations[:, :, i_agent]
                target = episode_i.rewards[:, :, i_agent]
                loss, _ = forward_pass(test_net, in_, target)
                print("loss {} at {}".format(loss, i_task))
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()

            # Evaluate the trained model on train and val examples
            tloss, _ = evaluate(test_net, episode_i, i_agent)
            vloss, prediction_ = evaluate(test_net, episode_i_, i_agent)
            mtr_loss = tloss / 10.
            mval_loss = vloss / 10.

            print('-------------------------')
            print('Meta train:', mtr_loss)
            print('Meta val:', mval_loss)
            print('-------------------------')
            del test_net

            predictions_.append(prediction_)

        visualize(episode_i, episode_i_, predictions_, i_task, self.args)

    def train(self):
        for i_task in range(10000):
            # Sample episode from current task
            self.sampler.reset_task(i_task)
            episodes = self.sampler.sample()

            # Add to memory
            self.memory.add(i_task, episodes)

            # Evaluate on test tasks
            if len(self.memory) > 1:
                self.test(i_task, episodes)

            # Collect a meta batch update
            if len(self.memory) > 2:
                meta_grads = []
                for i in range(self.args.meta_batch_size):
                    if i == 0:
                        episodes_i = self.memory.storage[i_task - 1]
                        episodes_i_ = self.memory.storage[i_task]
                    else:
                        episodes_i, episodes_i_ = self.memory.sample()

                    self.fast_net.copy_weights(self.net)
                    for i_agent in range(self.args.n_agent):
                        meta_grad = self.fast_net.forward(
                            episodes_i, episodes_i_, i_agent)
                        meta_grads.append(meta_grad)

                # Perform the meta update
                self.meta_update(episodes_i, meta_grads)