예제 #1
0
파일: maml.py 프로젝트: tailintalent/mela
    def test(self):
        num_in_channels = 1 if self.dataset == 'mnist' else 3
        test_net = OmniglotNet(self.num_classes, self.loss_fn, num_in_channels)
        mtr_loss, mtr_acc, mval_loss, mval_acc = 0.0, 0.0, 0.0, 0.0
        # Select ten tasks randomly from the test set to evaluate on
        for _ in range(10):
            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            if self.is_cuda:
                test_net.cuda()
            test_opt = SGD(test_net.parameters(), lr=self.inner_step_size)
            task = self.get_task('../data/{}'.format(self.dataset),
                                 self.num_classes,
                                 self.num_inst,
                                 split='test')
            # Train on the train examples, using the same number of updates as in training
            train_loader = get_data_loader(task,
                                           self.inner_batch_size,
                                           split='train')
            for i in range(self.num_inner_updates):
                in_, target = train_loader.__iter__().next()
                loss, _ = forward_pass(test_net,
                                       in_,
                                       target,
                                       is_cuda=self.is_cuda)
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()
            # Evaluate the trained model on train and val examples
            tloss, tacc = evaluate(test_net, train_loader)
            val_loader = get_data_loader(task,
                                         self.inner_batch_size,
                                         split='val')
            vloss, vacc = evaluate(test_net, val_loader)
            mtr_loss += tloss
            mtr_acc += tacc
            mval_loss += vloss
            mval_acc += vacc

        mtr_loss = mtr_loss / 10
        mtr_acc = mtr_acc / 10
        mval_loss = mval_loss / 10
        mval_acc = mval_acc / 10

        print '-------------------------'
        print 'Meta train:', mtr_loss, mtr_acc
        print 'Meta val:', mval_loss, mval_acc
        print '-------------------------'
        del test_net
        return mtr_loss, mtr_acc, mval_loss, mval_acc
예제 #2
0
    def test(self, i_task, episode_i_):
        predictions_ = []
        for i_agent in range(self.args.n_agent):
            test_net = OmniglotNet(self.loss_fn, self.args).to(device)

            # Make a test net with same parameters as our current net
            test_net.copy_weights(self.net)
            test_opt = SGD(test_net.parameters(), lr=self.args.fast_lr)

            episode_i = self.memory.storage[i_task - 1]

            # Train on the train examples, using the same number of updates as in training
            for i in range(self.args.fast_num_update):
                in_ = episode_i.observations[:, :, i_agent]
                target = episode_i.rewards[:, :, i_agent]
                loss, _ = forward_pass(test_net, in_, target)
                print("loss {} at {}".format(loss, i_task))
                test_opt.zero_grad()
                loss.backward()
                test_opt.step()

            # Evaluate the trained model on train and val examples
            tloss, _ = evaluate(test_net, episode_i, i_agent)
            vloss, prediction_ = evaluate(test_net, episode_i_, i_agent)
            mtr_loss = tloss / 10.
            mval_loss = vloss / 10.

            print('-------------------------')
            print('Meta train:', mtr_loss)
            print('Meta val:', mval_loss)
            print('-------------------------')
            del test_net

            predictions_.append(prediction_)

        visualize(episode_i, episode_i_, predictions_, i_task, self.args)