Exemplo n.º 1
0
    def compute_outer_grads(
            self,
            task: Tuple[torch.Tensor],
            n_tasks: int,
            meta_train: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute gradients on query set."""

        support_input, support_target, query_input, query_target = task

        with higher.innerloop_ctx(self.head,
                                  self.in_optim,
                                  copy_initial_weights=False,
                                  track_higher_grads=meta_train) as (fhead,
                                                                     diffopt):
            with torch.no_grad():
                support_feature = self.encoder(support_input)
            # inner loop (adapt)
            self.inner_loop(fhead, diffopt, support_feature, support_target)

            # evaluate on the query set
            with torch.set_grad_enabled(meta_train):
                quert_feature = self.encoder(query_input)
                query_output = fhead(quert_feature)
                query_loss = self.loss_function(query_output, query_target)
                query_loss /= len(query_input)

            # compute gradients when in the meta-training stage
            if meta_train == True:
                (query_loss / n_tasks).backward()

        return query_output, query_loss
Exemplo n.º 2
0
    def train(self, niter=3):
        if self.jit:
            raise NotImplementedError()

        net, _ = self.get_module()
        net.train()
        x_spt, y_spt, x_qry, y_qry = self.meta_inputs
        meta_opt = optim.Adam(net.parameters(), lr=1e-3)

        for _ in range(niter):
            task_num, setsz, c_, h, w = x_spt.size()
            querysz = x_qry.size(1)

            n_inner_iter = 5
            inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

            meta_opt.zero_grad()
            for i in range(task_num):
                with higher.innerloop_ctx(
                        net, inner_opt,
                        copy_initial_weights=False) as (fnet, diffopt):
                    for _ in range(n_inner_iter):
                        spt_logits = fnet(x_spt[i])
                        spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                        diffopt.step(spt_loss)

                    qry_logits = fnet(x_qry[i])
                    qry_loss = F.cross_entropy(qry_logits, y_qry[i])
                    qry_loss.backward()

            meta_opt.step()
Exemplo n.º 3
0
    def fit(self, dataset_train):
        """ The learner's fit function over the train set of a task.

        Args:
            dataset_train : a tf.data.Dataset object. Iterates over the training 
                            examples (support set).
        Returns:
            predictor : An instance of MyPredictor that is initilialized with 
                the fine-tuned learner's weights in this case.
        """
        self.learner.train()
        for images, labels in dataset_train:
            images, labels = self.process_task(images, labels)
            with higher.innerloop_ctx(self.learner,
                                      self.optimizer,
                                      track_higher_grads=False) as (fnet,
                                                                    diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(self.n_inner_iter):
                    spt_logits = fnet(images)
                    spt_loss = F.cross_entropy(spt_logits, labels)
                    diffopt.step(spt_loss)

                predictor = MyPredictor(fnet)
            break
        return predictor
Exemplo n.º 4
0
    def evaluate(self, dataloader, updates, mini_batch_size):

        support_set = []
        for _ in range(updates):
            text, labels = self.memory.read_batch(batch_size=mini_batch_size)
            support_set.append((text, labels))

        with higher.innerloop_ctx(self.pn,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):

            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, labels in support_set:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.pn.encode_text(text)
                repr = fpn(input_dict, out_from='transformers')
                modulation = self.nm(input_dict)
                output = fpn(repr * modulation, out_from='linear')
                loss = self.loss_fn(output, labels)
                diffopt.step(loss)
                pred = models.utils.make_prediction(output.detach())
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(labels.tolist())

            acc, prec, rec, f1 = models.utils.calculate_metrics(
                task_predictions, task_labels)

            logger.info(
                'Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                'recall = {:.4f}, F1 score = {:.4f}'.format(
                    np.mean(support_loss), acc, prec, rec, f1))

            all_losses, all_predictions, all_labels = [], [], []

            for text, labels in dataloader:
                labels = torch.tensor(labels).to(self.device)
                input_dict = self.pn.encode_text(text)
                with torch.no_grad():
                    repr = fpn(input_dict, out_from='transformers')
                    modulation = self.nm(input_dict)
                    output = fpn(repr * modulation, out_from='linear')
                    loss = self.loss_fn(output, labels)
                loss = loss.item()
                pred = models.utils.make_prediction(output.detach())
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(labels.tolist())

        acc, prec, rec, f1 = models.utils.calculate_metrics(
            all_predictions, all_labels)
        logger.info(
            'Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec,
                                       f1))

        return acc, prec, rec, f1
Exemplo n.º 5
0
def train(step_idx, data, net, inner_opt_builder, meta_opt, n_inner_iter):
    """Main meta-training step."""
    x_spt, y_spt, x_qry, y_qry = data
    task_num = x_spt.size()[0]
    querysz = x_qry.size(1)

    inner_opt = inner_opt_builder.inner_opt

    qry_losses = []
    meta_opt.zero_grad()
    for i in range(task_num):
        with higher.innerloop_ctx(
            net,
            inner_opt,
            copy_initial_weights=False,
            override=inner_opt_builder.overrides,
        ) as (
            fnet,
            diffopt,
        ):
            for _ in range(n_inner_iter):
                spt_pred = fnet(x_spt[i])
                spt_loss = F.mse_loss(spt_pred, y_spt[i])
                diffopt.step(spt_loss)
            qry_pred = fnet(x_qry[i])
            qry_loss = F.mse_loss(qry_pred, y_qry[i])
            qry_losses.append(qry_loss.detach().cpu().numpy())
            qry_loss.backward()
    metrics = {"train_loss": np.mean(qry_losses)}
    wandb.log(metrics, step=step_idx)
    meta_opt.step()
Exemplo n.º 6
0
    def run_batches(self, batches, train=True, meta_train=True):
        metrics = []
        device = next(self.model.parameters()).device
        shufflers = {
            key: shuffler()
            for key, shuffler in self._shuffler_factory.items()
        }
        with torch.backends.cudnn.flags(enabled=False):
            with higher.innerloop_ctx(self.model,
                                      self.inner_optimizer,
                                      copy_initial_weights=False) as (fmodel,
                                                                      diffopt):
                for n, inputs in enumerate(batches[:-1]):
                    inputs = self.shuffle_labels(inputs, shufflers)
                    inputs = move_to_device(inputs, device)
                    output_dict = fmodel(**inputs, **self.forward_kwargs(n))
                    loss = output_dict["loss"]
                    metric = output_dict["metric"]
                    diffopt.step(loss)
                    metrics.append({"loss": loss.item(), "metric": metric})

                inputs = self.shuffle_labels(batches[-1], shufflers)
                inputs = move_to_device(inputs, device)
                output_dict = fmodel(**inputs,
                                     **self.forward_kwargs(len(batches) - 1))
                loss = output_dict["loss"]
                metric = output_dict["metric"]
                loss.backward()
                metrics.append({"loss": loss.item(), "metric": metric})

        return metrics
Exemplo n.º 7
0
def maml_val(tepoch, model, inner_criterion, outer_criterion, inner_optimizer,
             num_adapt_steps):
    model.train()
    test_losses = []
    for batch in tepoch:
        tepoch.set_description(f"Validation")

        batch['train'][0] = batch['train'][0].view(1, -1, 6, 36, 36)
        batch['test'][0] = batch['test'][0].view(1, -1, 6, 36, 36)
        batch['train'][1] = batch['train'][1].view(1, -1, 1)
        batch['test'][1] = batch['test'][1].view(1, -1, 1)

        inputs, targets = batch['train']
        test_inputs, test_targets = batch['test']

        for task_idx, (input, target, test_input, test_target) in enumerate(
                zip(inputs, targets, test_inputs, test_targets)):
            with higher.innerloop_ctx(model,
                                      inner_optimizer,
                                      copy_initial_weights=False) as (fmodel,
                                                                      diffopt):
                for step in range(num_adapt_steps):
                    inner_loss = inner_criterion(fmodel(input), target)
                    diffopt.step(inner_loss)
                test_logit = fmodel(test_input).detach()
                test_loss = outer_criterion(test_logit, test_target)
                test_losses.append(test_loss.detach())

        losses = sum(test_losses) / len(tepoch)
        tepoch.set_postfix(loss=losses)
    '''
Exemplo n.º 8
0
def test(step_idx, data, net, inner_opt_builder, n_inner_iter):
    """Main meta-training step."""
    x_spt, y_spt, x_qry, y_qry = data
    task_num = x_spt.size()[0]
    querysz = x_qry.size(1)

    inner_opt = inner_opt_builder.inner_opt

    qry_losses = []
    for i in range(task_num):
        with higher.innerloop_ctx(
            net, inner_opt, track_higher_grads=False, override=inner_opt_builder.overrides,
        ) as (
            fnet,
            diffopt,
        ):
            for _ in range(n_inner_iter):
                spt_pred = fnet(x_spt[i])
                spt_loss = F.mse_loss(spt_pred, y_spt[i])
                diffopt.step(spt_loss)
            qry_pred = fnet(x_qry[i])
            qry_loss = F.mse_loss(qry_pred, y_qry[i])
            qry_losses.append(qry_loss.detach().cpu().numpy())
    avg_qry_loss = np.mean(qry_losses)
    _low, high = st.t.interval(
        0.95, len(qry_losses) - 1, loc=avg_qry_loss, scale=st.sem(qry_losses)
    )
    test_metrics = {"test_loss": avg_qry_loss, "test_err": high - avg_qry_loss}
    wandb.log(test_metrics, step=step_idx)
    return avg_qry_loss
    def meta_forward(self, data):
        self.learner.train()
        self.optimizer.zero_grad()
        self.optimizer_meta.zero_grad()
        # copy_initial_weights should be  False because learner is regarded a parameter of model
        with higher.innerloop_ctx(self.model,
                                  self.optimizer,
                                  copy_initial_weights=False) as (fmodel,
                                                                  diffopt):
            loss_dict = fmodel(data)
            # aggregate loss
            loss = sum([v for k, v in loss_dict.items() if 'loss_cls' in k])
            diffopt.step(loss)

            meta_data = next(self._meta_data_loader_iter)
            meta_loss_dict = fmodel(meta_data)
            # aggregate meta loss
            meta_loss = sum(
                [v for k, v in meta_loss_dict.items() if 'loss_sigmoid' in k])
            meta_loss.backward()
            self.optimizer_meta.step()

        self.learner.eval()

        self.log_meta_info()
Exemplo n.º 10
0
    def get_meta_grad(self, features, edges, labels, train_iters):

        model = self.gcn
        loss = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        
        with higher.innerloop_ctx(model, optimizer) as (fmodel, diffopt):
            
            for i in range(train_iters):
                pre = fmodel(features, edges)
                idx = select_index(labels, -1, same=False)
                pre, Y = pre[idx], labels[idx]

                cost = loss(pre, Y)

                diffopt.step(cost)

            pre = fmodel(features, edges)
            idx = select_index(labels, -1, same=False)
            sudo_idx = select_index(labels, -1, same=True)
            
            cost = 0
            
            if self.lambda_ > 0 :
                cost =+ self.lambda_ * loss(pre[idx], labels[idx])
            
            if (1-self.lambda_) > 0 :
                cost =+ (1-self.lambda_) * loss(pre[sudo_idx], self.labels_self_training[sudo_idx])
                
            return torch.autograd.grad(cost, self.adj_changes, retain_graph=False)[0]
Exemplo n.º 11
0
 def testSameInitialWeightsPostPatch(self):
     """Verify fast weight alignment/equality after monkey patching."""
     ref_named_params = list(self.reference_net.get_fast_weights().items())
     ref_params = [p for (_, p) in ref_named_params]
     with higher.innerloop_ctx(self.target_net, self.opt) as (fnet, _):
         target_named_params = list(fnet.named_parameters())
         target_params = fnet.parameters()
         self.assertEqual(
             len(ref_named_params),
             len(target_named_params),
             msg=("Length mismatched between reference net parameter count "
                  "({}) and target ({}).".format(len(ref_named_params),
                                                 len(target_named_params))))
         for ref, target in zip(ref_named_params, target_named_params):
             ref_name, ref_p = ref
             target_name, target_p = target
             self.assertEqual(
                 ref_name,
                 target_name,
                 msg="Name mismatch or parameter misalignment ('{}' vs '{}')"
                 .format(ref_name, target_name))
             self.assertTrue(
                 torch.equal(ref_p, target_p),
                 msg="Parameter value inequality for {}".format(ref_name))
         zipped = zip(ref_params, target_params)
         for i, (ref_p, target_p) in enumerate(zipped):
             self.assertTrue(
                 torch.equal(ref_p, target_p),
                 msg="Parameter misalignment in position {}.".format(i))
Exemplo n.º 12
0
    def compute_outer_grads(
            self,
            task: Tuple[torch.Tensor],
            n_tasks: int,
            meta_train: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute gradients on query set."""

        support_input, support_target, query_input, query_target = task

        with higher.innerloop_ctx(self.model,
                                  self.in_optim,
                                  track_higher_grads=False) as (fmodel,
                                                                diffopt):
            # inner loop (adapt)
            self.inner_loop(fmodel, diffopt, support_input, support_target)

            # evaluate on the query set
            with torch.set_grad_enabled(meta_train):
                query_output = fmodel(query_input)
                query_loss = self.loss_function(query_output, query_target)
                query_loss /= len(query_input)

            # compute gradients when in the meta-training stage
            if meta_train == True:
                outer_grad = []
                for p, fast_p in zip(self.model.parameters(),
                                     fmodel.parameters()):
                    outer_grad.append((p.data - fast_p.data) / n_tasks)
                self.grad_list.append(outer_grad)

        return query_output, query_loss
Exemplo n.º 13
0
    def update_fn(
        self,
        images: torch.FloatTensor,
        labels: torch.FloatTensor,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
    ) -> None:
        """One optimization step."""
        self.task_weights.requires_grad_()  # make task weights require grad

        with higher.innerloop_ctx(model, optimizer) as (fmodel, doptimizer):
            task_losses = fmodel(images, labels)
            weighted_loss = torch.sum(task_losses * self.task_weights)
            doptimizer.step(weighted_loss)

            new_task_losses = fmodel(images, labels) * self.preference_weights
            normalized_losses = new_task_losses / torch.sum(new_task_losses)
            kl_divergence = torch.sum(normalized_losses * torch.log(normalized_losses * len(self.task_weights)))
            task_weight_grads = torch.autograd.grad(kl_divergence, self.task_weights)[0]

            # gradient step on task weights
            with torch.no_grad():
                self.task_weights = torch.clamp(self.task_weights - self.flags.meta_lr * task_weight_grads, min=0)
                self.task_weights = self.task_weights / torch.sum(self.task_weights)

            # compute gradients using new task weights
            new_weighted_loss = torch.sum(task_losses * self.task_weights)
            param_grads = torch.autograd.grad(new_weighted_loss, fmodel.parameters(time=0))

        optimizer.zero_grad()
        for index, param in enumerate(model.parameters()):
            param.grad = param_grads[index]
        optimizer.step()
Exemplo n.º 14
0
    def forward(self, target_repr, init):
        """
        Conceptually, DSPN simply turns the target_repr feature vector into a set.

        target_repr: Representation that the predicted set should match. FloatTensor of size (batch_size, repr_channels).
        This can come from a set processed with the same encoder as self.encoder (auto-encoder), or a different
        input completely (normal supervised learning), such as an image encoded into a feature vector.
        """
        # copy same initial set over batch
        current_set = nn.Parameter(init)
        inner_set = InnerSet(current_set)

        # info used for loss computation
        intermediate_sets = [current_set]
        # info used for debugging
        repr_losses = []
        grad_norms = []
 
        # optimise repr_loss for fixed number of steps
        with torch.enable_grad():
            opt = torch.optim.SGD(inner_set.parameters(), lr=self.lr, momentum=0.5)
            with higher.innerloop_ctx(inner_set, opt) as (fset, diffopt):
                for i in range(self.iters):
                    predicted_repr = self.encoder(fset())
                    # how well does the representation matches the target
                    repr_loss = ((predicted_repr- target_repr)**2).sum() 
                    diffopt.step(repr_loss)
                    intermediate_sets.append(fset.mask)
                    repr_losses.append(repr_loss)
                    grad_norms.append(())
        
        return intermediate_sets, repr_losses, grad_norms
Exemplo n.º 15
0
        def closure(model, optimizer, *args):
            """This function will be evaluated on all GPUs."""  # noqa: D401
            list(model.children())[-1].train(
            ) if model.frozen else model.train()
            batch_size = inputs.shape[0]

            data = torch.cat((inputs, targets), dim=0)

            # Wrap the model into a meta-object that allows for meta-learning steps via monkeypatching:
            with higher.innerloop_ctx(model,
                                      optimizer,
                                      copy_initial_weights=False) as (fmodel,
                                                                      fopt):
                for _ in range(self.args.nadapt):
                    outputs = fmodel(data)
                    poison_loss = criterion(outputs[:batch_size], labels)

                    fopt.step(poison_loss)

            prediction = (outputs[:batch_size].data.argmax(
                dim=1) == labels).sum()

            target_loss = criterion(outputs[batch_size:], intended_classes)
            target_loss.backward(retain_graph=self.retain)

            return target_loss.detach().cpu(), prediction.detach().cpu()
Exemplo n.º 16
0
    def learn(self,
              timesteps,
              update_interval=None,
              track_higher_grads=False,
              lr_scheduler=None,
              step_callback=None,
              interval_callback=None,
              reward_aggregation='episodic'):
        if update_interval is None:
            update_interval = self.update_interval
        state = self.env.reset()
        memory = Memory()
        episodic_rewards = [0.]
        interval_rewards = [0.]
        # This context wraps the policy and optimizer to track parameter updates
        # over time such that d Params(time=t) / d Params(time=t-n) can be calculated.
        # If not tracking higher gradients, a dummy context is used which does
        # nothing.
        with innerloop_ctx(self.policy,
                           self.optimizer,
                           track_higher_grads=track_higher_grads,
                           copy_initial_weights=False) as (policy, optimizer):

            for t in trange(1, int(timesteps) + 1, leave=False):
                # Running policy:
                memory.states.append(state)
                action, logprob = policy.predict(state)
                state, reward, done, info = self.env.step(action)
                episodic_rewards[-1] += reward
                interval_rewards[-1] += reward
                if done:
                    state = self.env.reset()
                    episodic_rewards.append(0.)
                memory.actions.append(action)
                memory.logprobs.append(logprob)
                memory.rewards.append(reward)
                memory.is_terminals.append(done)

                if step_callback is not None:
                    step_callback(locals())

                # update if its time
                if t % update_interval == 0:
                    interval_rewards.append(0.)
                    loss = self.update(policy, memory, self.epochs, optimizer,
                                       self.summary)
                    if interval_callback is not None:
                        interval_callback(locals())
                    if lr_scheduler is not None:
                        lr_scheduler()
                    memory.clear()

            self.meta_policy = policy if track_higher_grads else None
        self.policy.load_state_dict(policy.state_dict())

        if reward_aggregation == 'episodic':
            return episodic_rewards[:-1 if len(episodic_rewards) > 1 else None]
        else:
            return interval_rewards
Exemplo n.º 17
0
    def evaluate(self, dataloader, updates, mini_batch_size):

        self.pn.train()

        support_set = []
        for _ in range(updates):
            text, label, candidates = self.memory.read_batch(batch_size=mini_batch_size)
            support_set.append((text, label, candidates))

        with higher.innerloop_ctx(self.pn, self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpn, diffopt):

            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, label, candidates in support_set:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text,
                                                                                                         label,
                                                                                                         candidates)

                input_dict = self.pn.encode_text(list(zip(replicated_text, replicated_relations)))
                output = fpn(input_dict)
                targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device)
                loss = self.loss_fn(output, targets)

                diffopt.step(loss)
                pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(true_labels.tolist())

            acc = models.utils.calculate_accuracy(task_predictions, task_labels)

            logger.info('Support set metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(np.mean(support_loss), acc))

            all_losses, all_predictions, all_labels = [], [], []

            for text, label, candidates in dataloader:
                replicated_text, replicated_relations, ranking_label = datasets.utils.replicate_rel_data(text,
                                                                                                         label,
                                                                                                         candidates)
                with torch.no_grad():

                    input_dict = self.pn.encode_text(list(zip(replicated_text, replicated_relations)))
                    output = fpn(input_dict)
                    targets = torch.tensor(ranking_label).float().unsqueeze(1).to(self.device)
                    loss = self.loss_fn(output, targets)

                loss = loss.item()
                pred, true_labels = models.utils.make_rel_prediction(output, ranking_label)
                all_losses.append(loss)
                all_predictions.extend(pred.tolist())
                all_labels.extend(true_labels.tolist())

        acc = models.utils.calculate_accuracy(all_predictions, all_labels)
        logger.info('Test metrics: Loss = {:.4f}, accuracy = {:.4f}'.format(np.mean(all_losses), acc))

        return acc
Exemplo n.º 18
0
    def few_shot_testing(self,
                         train_dataset,
                         eval_dataset,
                         increment_counters=False,
                         split="test"):
        """
        Allow the model to train on a small amount of datapoints at a time. After every training step,
        evaluate on many samples that haven't been seen yet.

        Results are saved in learner's `metrics` attribute.

        Parameters
        ---
        train_dataset: Dataset
            Contains examples on which the model is trained before being evaluated
        eval_dataset: Dataset
            Contains examples on which the model is evaluated
        increment_counters: bool
            If True, update online metrics and current iteration counters.
        """
        self.logger.info(
            f"few shot testing on dataset {self.config.testing.eval_dataset} "
            f"with {len(train_dataset)} samples")
        train_dataloader, eval_dataloader = self.few_shot_preparation(
            train_dataset, eval_dataset, split=split)
        all_predictions, all_labels = [], []
        with higher.innerloop_ctx(self.pln,
                                  self.inner_optimizer,
                                  copy_initial_weights=False,
                                  track_higher_grads=False) as (fpln, diffopt):
            self.pln.train()
            self.rln.eval()
            # Inner loop
            for i, (text, labels, datasets) in enumerate(train_dataloader):
                labels = torch.tensor(labels).to(self.device)
                output = self.forward(text, labels, fpln)
                loss = self.loss_fn(output["logits"], labels)
                diffopt.step(loss)

                predictions = model_utils.make_prediction(
                    output["logits"].detach())
                all_predictions.extend(predictions.tolist())
                all_labels.extend(labels.tolist())
                dataset_results = self.evaluate(dataloader=eval_dataloader,
                                                prediction_network=fpln)
                self.log_few_shot(all_predictions,
                                  all_labels,
                                  datasets,
                                  dataset_results,
                                  increment_counters,
                                  text,
                                  i,
                                  split=split)
                if (i * self.config.testing.few_shot_batch_size
                    ) % self.mini_batch_size == 0 and i > 0:
                    all_predictions, all_labels = [], []
        self.few_shot_end()
Exemplo n.º 19
0
def test_maml(db, net, device, lr_finetune, epoch, log):
    # Crucially in our testing procedure here, we do *not* fine-tune
    # the model during testing for simplicity.
    # Most research papers using MAML for this task do an extra
    # stage of fine-tuning here that should be added if you are
    # adapting this code for research.
    net.train()
    n_test_iter = db.x_test.shape[0] // db.batchsz

    qry_losses = []
    qry_accs = []

    for batch_idx in range(n_test_iter):
        x_spt, y_spt, x_qry, y_qry = db.next('test')

        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        # TODO: Maybe pull this out into a separate module so it
        # doesn't have to be duplicated between `train` and `test`?
        n_inner_iter = 5
        inner_opt = torch.optim.SGD(net.parameters(), lr=lr_finetune)

        for i in range(task_num):
            with higher.innerloop_ctx(net, inner_opt,
                                      track_higher_grads=False) as (fnet,
                                                                    diffopt):
                # Optimize the likelihood of the support set by taking
                # gradient steps w.r.t. the model's parameters.
                # This adapts the model's meta-parameters to the task.
                for _ in range(n_inner_iter):
                    spt_logits = fnet(x_spt[i])
                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])
                    diffopt.step(spt_loss)

                # The query loss and acc induced by these parameters.
                qry_logits = fnet(x_qry[i]).detach()
                qry_loss = F.cross_entropy(qry_logits,
                                           y_qry[i],
                                           reduction='none')
                qry_losses.append(qry_loss.detach())
                qry_accs.append(
                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())

    qry_losses = torch.cat(qry_losses).mean().item()
    # print("accuracies are:", qry_accs)
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
    print(
        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
    )
    log.append({
        'epoch': epoch + 1,
        'loss': qry_losses,
        'acc': qry_accs,
        'mode': 'test',
        'time': time.time(),
    })
    def outer_loop(self, batch, is_train):

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input,
             test_target) in zip(train_inputs, train_targets, test_inputs,
                                 test_targets):

            with higher.innerloop_ctx(self.network,
                                      self.inner_optimizer,
                                      track_higher_grads=False) as (fmodel,
                                                                    diffopt):

                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)

                train_logit = fmodel(train_input)
                in_loss = F.cross_entropy(train_logit, train_target)

                test_logit = fmodel(test_input)
                outer_loss = F.cross_entropy(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                with torch.no_grad():
                    acc_log += get_accuracy(
                        test_logit, test_target).item() / self.batch_size

                if is_train:
                    params = list(fmodel.parameters(time=-1))
                    in_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(in_loss, params,
                                            create_graph=True))
                    outer_grad = torch.nn.utils.parameters_to_vector(
                        torch.autograd.grad(outer_loss, params))
                    implicit_grad = self.neumann_approx(
                        in_grad, outer_grad, params)
                    grad_list.append(implicit_grad)
                    loss_list.append(outer_loss.item())

        if is_train:
            self.outer_optimizer.zero_grad()
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log
Exemplo n.º 21
0
    def train_on_loader(self, loader):
        
        ## for now, one task at a time
        task_num = 1
        
        qry_losses = []
        qry_accs = []
        
        self.backbone.train()
        for episode in tqdm(loader):
            episode = episode[0] # undo collate
            self.optimizer.zero_grad()
            support_set = episode["support_set"].cuda(non_blocking=False)
            query_set = episode["query_set"].cuda(non_blocking=False)

            ss, nclasses, c, h, w = support_set.size()
            qs, nclasses, c, h, w = query_set.size()
            
            absolute_labels = episode["targets"]
            relative_labels = absolute_labels.clone()
            # TODO: use episode['targets']
            support_relative_labels = torch.arange(episode['nclasses']).view(1, -1).repeat(
                episode['support_size'], 1).cuda().view(-1)
            query_relative_labels = torch.arange(episode['nclasses']).view(1, -1).repeat(
                episode['query_size'], 1).cuda().view(-1)
            
            inner_opt = torch.optim.SGD(self.backbone.parameters(), lr=self.inner_lr)
            
            querysz = query_relative_labels.size()[0]

            self.optimizer.zero_grad()
            for i in range(task_num):
                with higher.innerloop_ctx(self.backbone, inner_opt,
                    copy_initial_weights=False) as (fnet, diffopt):
                    
                    for _ in range(self.n_inner_iter):
                        ## for now only one task at a time
                        spt_logits = fnet(support_set.view(ss * nclasses, c, h, w)).view(ss * nclasses, -1)
                        spt_loss = F.cross_entropy(spt_logits, support_relative_labels)
                        diffopt.step(spt_loss)

                    qry_logits = fnet(query_set.view(qs * nclasses, c, h, w))
                    qry_loss = F.cross_entropy(qry_logits, query_relative_labels)
                    qry_losses.append(qry_loss.detach())
                    qry_acc = (qry_logits.argmax(
                        dim=1) == query_relative_labels).sum().item() / querysz
                    qry_accs.append(qry_acc)

                    qry_loss.backward()

            self.optimizer.step()

        qry_losses = sum(qry_losses) / len(qry_losses)
        qry_accs = 100. * sum(qry_accs) / len(qry_accs)

        return {"train_loss": qry_losses.item(), "train_acc":qry_accs}
Exemplo n.º 22
0
def meta_train_mbrl_reacher(policy, ml3_loss, dmodel, env, task_loss_fn, goals,
                            n_outer_iter, n_inner_iter, time_horizon,
                            exp_folder):
    goals = torch.Tensor(goals)

    meta_opt = torch.optim.Adam(ml3_loss.parameters(),
                                lr=ml3_loss.learning_rate)

    for outer_i in range(n_outer_iter):
        # set gradient with respect to meta loss parameters to 0
        meta_opt.zero_grad()
        all_loss = 0
        for goal in goals:
            goal = torch.Tensor(goal)
            policy.reset()
            inner_opt = torch.optim.SGD(policy.parameters(),
                                        lr=policy.learning_rate)
            for _ in range(n_inner_iter):
                inner_opt.zero_grad()
                with higher.innerloop_ctx(
                        policy, inner_opt,
                        copy_initial_weights=False) as (fpolicy, diffopt):
                    # use current meta loss to update model
                    s_tr, a_tr, g_tr = fpolicy.roll_out(
                        goal, time_horizon, dmodel, env)
                    meta_input = torch.cat(
                        [s_tr[:-1].detach(), a_tr,
                         g_tr.detach()], dim=1)
                    pred_task_loss = ml3_loss(meta_input).mean()
                    diffopt.step(pred_task_loss)
                    # compute task loss
                    s, a, g = fpolicy.roll_out(goal, time_horizon, dmodel, env)
                    task_loss = task_loss_fn(a, s[:], goal).mean()

            # collect losses for logging
            all_loss += task_loss
            # backprop grad wrt to task loss
            task_loss.backward()

            if outer_i % 100 == 0:
                # roll out in real environment, to monitor training and tp collect data for dynamics model update
                states, actions, _ = fpolicy.roll_out(goal,
                                                      time_horizon,
                                                      dmodel,
                                                      env,
                                                      real_rollout=True)
                print("meta iter: {} loss: {}".format(outer_i, (torch.mean(
                    (states[-1, :2] - goal[:2])**2))))
                if outer_i % 300 == 0 and outer_i < 3001:
                    # update dynamics model under current optimal policy
                    dmodel.train(torch.Tensor(states), torch.Tensor(actions))

        # step optimizer to update meta loss network
        meta_opt.step()
        torch.save(ml3_loss.state_dict(), f'{exp_folder}/ml3_loss_reacher.pt')
Exemplo n.º 23
0
def tmp():
    teacher_optimizer = ...
    student_optimizer = ...

    with higher.innerloop_ctx(student,
                              student_optimizer) as (smodel, sdiffopt):

        student.train()
        teacher.train()

        teacher_logits = teacher(unsupervised_batch)
        student_logits = smodel(unsupervised_batch)
        distillation_loss = soft_ce(
            torch.log_softmax(student_logits, dim=1),
            torch.softmax(teacher_logits, dim=1),
        )
        print("Distillation loss:", distillation_loss.item())

        sdiffopt.step(distillation_loss)

        student_logits = smodel(student_data)
        student_logits.squeeze_(dim=1)
        student_loss = ce(student_logits, student_labels)

        print("Student loss:", student_loss.item())

        student_loss.backward()
        print("Teacher grad: {} +- {}".format(teacher.fc1.weight.grad.mean(),
                                              teacher.fc1.weight.grad.std()))
        teacher_optimizer.step()

        if step % supervised_teacher_update_freq == 0:
            supervise_teacher(teacher_data, teacher_labels)

        clear_output(wait=True)
        with torch.no_grad():
            student.eval()
            teacher.eval()
            plot_predictions(
                (
                    unsupervised_data.numpy(),
                    student(unsupervised_data).numpy().argmax(1),
                    "Student",
                ),
                (
                    unsupervised_data.numpy(),
                    teacher(unsupervised_data).numpy().argmax(1),
                    "Teacher",
                ),
            )

        with torch.no_grad():
            for old_p, new_p in zip(student.parameters(), smodel.parameters()):
                old_p.copy_(new_p)
Exemplo n.º 24
0
    def testGradientCorrectness(self,
                                _,
                                model_builder,
                                opt_builder,
                                kwargs=None):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        model = model_builder(self)
        eps = 1e-3

        tests = 10
        count = 0
        threshold = .6  # proportion of tests that should pass

        for i in range(tests):
            xs = [torch.rand(10, 4) for _ in range(2)]

            def closure():
                cmodel = copy.deepcopy(model)
                opt = opt_builder(cmodel.parameters(), lr=lr, **kwargs)
                for x in xs[:-1]:
                    opt.zero_grad()
                    cmodel(x).pow(2).sum().backward()
                    opt.step()
                loss = cmodel(xs[-1]).pow(2).sum()
                return loss

            fd_grads = finite_difference(model, closure, eps)

            opt = opt_builder(model.parameters(), lr=lr, **kwargs)
            with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
                for x in xs[:-1]:
                    loss = fmodel(x).pow(2).sum()
                    diffopt.step(loss)
                loss = fmodel(xs[-1]).pow(2).sum()
                grads = torch.autograd.grad(loss,
                                            fmodel.parameters(time=0),
                                            allow_unused=True)
                close = []
                for g, fg in zip(grads, fd_grads):
                    if g is None:
                        # trusting that the tensor shouldn't have been used...
                        close.append(True)
                    else:
                        self.assertFalse(torch.any(torch.isnan(g)),
                                         "NaNs found in gradient.")
                        close.append(torch.allclose(g, fg, 1e-1, 1e-1))
                if all(close):
                    count += 1
        self.assertTrue(
            count / tests >= threshold,
            msg="Proportion of successful finite gradient checks below {:.0f}% "
            "threshold ({:.0f}%).".format(threshold * 100,
                                          100 * count / tests))
Exemplo n.º 25
0
def test_episodic_loader_inner_loop_per_task_good_accumulator(debug_test=True):
    import torch
    import torch.optim as optim
    from automl.child_models.learner_from_opt_as_few_shot_paper import Learner
    import automl.child_models.learner_from_opt_as_few_shot_paper 
    import higher

    ## get args for test
    args = get_args_for_mini_imagenet()

    ## get base model that meta-lstm/maml use
    base_model = Learner(image_size=args.image_size, bn_eps=args.bn_eps, bn_momentum=args.bn_momentum, n_classes=args.n_classes).to(args.device)
    
    ## get meta-set
    meta_train_loader, _, _ = get_meta_set_loaders_miniImagenet(args)

    ## start episodic training
    meta_params = base_model.parameters()
    outer_opt = optim.Adam(meta_params, lr=1e-2)
    base_model.train()
    for episode, (spt_x, spt_y, qry_x, qry_y) in enumerate(meta_train_loader):
        assert(spt_x.size(1) == args.k_shot*args.n_classes)
        assert(qry_x.size(1) == args.k_eval*args.n_classes)
        ## Get Inner Optimizer (for maml)
        inner_opt = torch.optim.SGD(base_model.parameters(), lr=1e-1)
        ## Accumulate gradient of meta-loss wrt fmodel.param(t=0)
        nb_tasks = spt_x.size(0)
        meta_losses, meta_accs = [], []
        assert(nb_tasks == args.meta_batch_size_train)
        for t in range(nb_tasks):
            ## Get supprt & query set for the current task
            spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x[t], spt_y[t], qry_x[t], qry_y[t]
            ## Inner Loop Adaptation
            with higher.innerloop_ctx(base_model, inner_opt, copy_initial_weights=args.copy_initial_weights, track_higher_grads=args.track_higher_grads) as (fmodel, diffopt):
                for i_inner in range(args.nb_inner_train_steps):
                    fmodel.train()
                    # base/child model forward pass
                    spt_logits_t = fmodel(spt_x_t)
                    inner_loss = args.criterion(spt_logits_t, spt_y_t)
                    # inner-opt update
                    diffopt.step(inner_loss)
                    inner_loss = args.criterion(spt_logits_t, spt_y_t)
            ## Evaluate on query set for current task
            qry_logits_t = fmodel(qry_x_t)
            qry_loss_t = args.criterion(qry_logits_t,  qry_y_t)
            ## Accumulate gradients wrt meta-params for each task
            qry_loss_t.backward() # note this is memory efficient
            ## collect losses & accs for logging/debugging
            meta_losses.append(qry_loss_t.detach()) # remove history so it be memory efficient and able to print stats
        ## do outer step
        outer_opt.step()
        outer_opt.zero_grad()
        print(f'[episode={episode}] meta_loss = {sum(meta_losses)/len(meta_losses)}')
Exemplo n.º 26
0
    def maml_train(self, model, inner_batch, outer_batches):
        assert model.training
        ret_dic = {}
        with higher.innerloop_ctx(
                model,
                self.inner_opt,
                copy_initial_weights=False,
                device=self.device) as (fmodel,
                                        diffopt), torch.backends.cudnn.flags(
                                            enabled=False):
            for _step in range(self.inner_steps):
                inner_ret_dic = fmodel(inner_batch)
                inner_loss = inner_ret_dic["loss"]

                # use the snippet for checking higher
                # def test(params):
                #     params = [p for p in params if p.requires_grad]
                #     all_grads = torch.autograd.grad(
                #         loss,
                #         params,
                #         retain_graph=True,
                #         allow_unused=True,
                #     )
                #     print(len(params), sum(p is not None for p in all_grads))
                # import pdb; pdb.set_trace()
                # test(model.parameters())
                # test(fmodel.fast_params)

                diffopt.step(inner_loss)
            logger.info(f"Inner loss: {inner_loss.item()}")

            mean_outer_loss = torch.Tensor([0.0]).to(self.device)
            with torch.set_grad_enabled(model.training):
                for batch_id, outer_batch in enumerate(outer_batches):
                    outer_ret_dic = fmodel(outer_batch)
                    mean_outer_loss += outer_ret_dic["loss"]
            mean_outer_loss.div_(len(outer_batches))
            logger.info(f"Outer loss: {mean_outer_loss.item()}")

            final_loss = inner_loss + mean_outer_loss
            final_loss.backward()

            # not sure if it helps
            del fmodel
            import gc

            gc.collect()

        ret_dic["loss"] = final_loss.item()
        return ret_dic
Exemplo n.º 27
0
def train(dataset: AudioNShot, net: TorchModule, device, meta_optimizer: Optimizer, epoch_num, log):
    net.train()
    iterations = dataset.x_train.shape[0] # batch size

    for batch_index in range(iterations):
        start_time = time.time()

        # Get support and query sets
        x_support, y_support, x_query, y_query = dataset.next()
        task_num, set_size, c, sample_size = x_support.size()
        query_size = x_query.size(1)

        # Set inner optimizer
        inner_itteration = INNER_ITERATIONS
        optimizer = SGD(net.parameters(), lr=1e-1)

        query_losses = []
        query_accuracies = []
        meta_optimizer.zero_grad()
        for i in range(task_num):
            with higher.innerloop_ctx(net, optimizer, copy_initial_weights=False) as (fnet, diffopt):
                for _ in range(inner_itteration):
                    support_outputs = fnet(x_support[i])
                    support_loss = F.cross_entropy(support_outputs, y_support[i])
                    diffopt.step(support_loss)

                query_outputs = fnet(x_query[i])
                query_loss = F.cross_entropy(query_outputs, y_query[i])
                query_losses.append(query_loss.detach())
                query_accuracy = (query_outputs.argmax(dim=1) == y_query[i]).sum().item() / query_size
                query_accuracies.append(query_accuracy)

                query_loss.backward()

        meta_optimizer.step()
        query_losses = sum(query_losses) / task_num
        query_accuracies = 100. * sum(query_accuracies) / task_num
        i = epoch_num + float(batch_index) / iterations
        iterration_time = time.time() - start_time
        if batch_index % 4 == 0:
            print(f'[Epoch {i:.2f}] Train Loss: {query_losses:.2f} | Acc: {query_accuracies:.2f} | Time: {iterration_time:.2f}')

        log.append({
            'epoch': i,
            'loss': query_losses,
            'acc': query_accuracies,
            'mode': 'train',
            'time': time.time(),
        })
Exemplo n.º 28
0
    def train_on_batch(self, tasks_batch):
        # sprt should never intersect with qry! So only shuffle the task
        # at creation!
        # For each task in the batch
        inner_losses, meta_losses, accuracies = [], [], []
        self.meta_opt.zero_grad()
        for i, task in enumerate(tasks_batch):
            with higher.innerloop_ctx(
                    self.learner, self.inner_opt, copy_initial_weights=False
                    ) as (f_learner, diff_opt):
                meta_loss, inner_loss, task_accuracies = 0, 0, []
                sprt, qry = task
                f_learner.train()
                for s in range(self.inner_steps):
                    step_loss = 0
                    for x, y in sprt:
                        # sprt is an iterator returning batches
                        y_pred = f_learner(x)
                        step_loss += self.inner_loss(y_pred, y)
                    inner_loss += step_loss.detach()
                    diff_opt.step(step_loss)

                f_learner.eval()
                for x, y in qry:
                    y_pred = f_learner(x) # Use the updated model for that task
                    # Accumulate the loss over all tasks in the meta-testing set
                    meta_loss += self.meta_loss(y_pred, y)
                    if self._compute_accuracy:
                        scores, indices = y_pred.max(dim=1)
                        acc = (y == indices).sum() / y.size(0) # Mean accuracy per batch
                        task_accuracies.append(acc)

                # Divide by the number of samples because reduction is set to 'sum' so that
                # the meta-objective can be computed correctly.
                meta_losses.append(meta_loss.detach().div_(self.inner_steps*len(sprt.dataset)))
                inner_losses.append(inner_loss.mean().div_(len(qry.dataset)))
                if self._compute_accuracy:
                    accuracies.append(torch.tensor(task_accuracies).mean())

                # Update the model's meta-parameters to optimize the query
                # losses across all of the tasks sampled in this batch.
                # This unrolls through the gradient steps.
                meta_loss.backward()

        self.meta_opt.step()
        avg_inner_loss = torch.tensor(inner_losses).mean().item()
        avg_meta_loss = torch.tensor(meta_losses).mean().item()
        avg_accuracy = torch.tensor(accuracies).mean().item()
        return avg_inner_loss, avg_meta_loss, avg_accuracy
Exemplo n.º 29
0
    def testCtxManager(self, _, model_builder):
        model = model_builder(self)
        opt = torch.optim.SGD(model.parameters(), lr=self.lr)

        with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
            for _ in range(10):
                inputs = torch.rand(8, 4)
                loss = fmodel(inputs).pow(2).sum()
                diffopt.step(loss)
            param_sum = sum(p.sum() for p in fmodel.parameters())
            final_grads = torch.autograd.grad(param_sum,
                                              fmodel.parameters(time=0))

        for grad in final_grads:
            self.assertIsNotNone(grad)
Exemplo n.º 30
0
    def outer_loop(self, batch, is_train):

        self.network.zero_grad()

        train_inputs, train_targets, test_inputs, test_targets = self.unpack_batch(
            batch)

        loss_log = 0
        acc_log = 0
        grad_list = []
        loss_list = []

        for (train_input, train_target, test_input,
             test_target) in zip(train_inputs, train_targets, test_inputs,
                                 test_targets):
            with higher.innerloop_ctx(
                    self.network,
                    self.inner_optimizer,
                    track_higher_grads=is_train) as (fmodel, diffopt):

                for step in range(self.args.n_inner):
                    self.inner_loop(fmodel, diffopt, train_input, train_target)

                test_logit = fmodel(test_input)
                outer_loss = F.cross_entropy(test_logit, test_target)
                loss_log += outer_loss.item() / self.batch_size

                with torch.no_grad():
                    acc_log += get_accuracy(
                        test_logit, test_target).item() / self.batch_size

                if is_train:
                    outer_grad = torch.autograd.grad(outer_loss,
                                                     fmodel.parameters(time=0))
                    grad_list.append(outer_grad)
                    loss_list.append(outer_loss.item())

        if is_train:
            weight = torch.ones(len(grad_list))
            weight = weight / torch.sum(weight)
            grad = mix_grad(grad_list, weight)
            grad_log = apply_grad(self.network, grad)
            self.outer_optimizer.step()

            return loss_log, acc_log, grad_log
        else:
            return loss_log, acc_log