Exemple #1
0
def evaluate_fn_vampire(
        x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase,
        f_base_net: higher.patch._MonkeyPatchBase
) -> typing.List[torch.Tensor]:
    """
    """
    logits = [None] * config['num_models']
    for model_id in range(config['num_models']):
        base_net_params = f_hyper_net.forward()
        logits_temp = f_base_net.forward(x, params=base_net_params)

        logits[model_id] = logits_temp

    return logits
Exemple #2
0
def evaluate_fn_maml(
        x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase,
        f_base_net: higher.patch._MonkeyPatchBase) -> torch.Tensor:
    """A base function to evaluate on the validation subset

    Args:
        x: the data in the validation subset
        f_hyper_net: the adapted meta-parameter
        f_base_net: the functional form of the base net

    Return: the logits of the prediction
    """
    base_net_params = f_hyper_net.forward()
    logits = f_base_net.forward(x, params=base_net_params)

    return logits
    def predict(
        self, x: torch.Tensor, f_hyper_net: higher.patch._MonkeyPatchBase,
        f_base_net: higher.patch._MonkeyPatchBase
    ) -> typing.List[torch.Tensor]:
        """Make prediction

        Args:
            x: input data
            f_hyper_net: task-specific meta-model
            f_base_net: functional form of the base neural network

        Returns: a list of logits predicted by the task-specific meta-model
        """
        logits = [None] * self.config['num_models']
        for model_id in range(self.config['num_models']):
            base_net_params = f_hyper_net.forward()
            logits_temp = f_base_net.forward(x, params=base_net_params)

            logits[model_id] = logits_temp

        return logits
Exemple #4
0
def adapt_to_episode_innerloop(
        x: torch.Tensor, y: torch.Tensor, hyper_net: torch.nn.Module,
        f_base_net: higher.patch._MonkeyPatchBase,
        kl_div_fn: typing.Callable) -> higher.patch._MonkeyPatchBase:
    """Also known as inner loop

    Args:
      x, y: training data and label
      hyper_net: the model of interest
      f_base_net: the functional of the base model
      kl_divergence_loss: function calculating
        the KL divergence between variational posterior and prior
    
    Return: a MonkeyPatch module
    """
    f_hyper_net = higher.patch.monkeypatch(
        module=hyper_net,
        copy_initial_weights=False,
        track_higher_grads=config['train_flag'])

    hyper_net_params = [p for p in hyper_net.parameters()]

    for _ in range(config['num_inner_updates']):
        for _ in range(config['num_models']):
            base_net_params = f_hyper_net.forward()
            y_logits = f_base_net.forward(x, params=base_net_params)
            cls_loss = torch.nn.functional.cross_entropy(input=y_logits,
                                                         target=y)

            q_params = f_hyper_net.fast_params  # list of parameters/tensors

            grad_accum = [0] * len(q_params)

            # KL divergence
            kl_div = kl_div_fn(p=hyper_net_params, q=q_params)

            loss = cls_loss + kl_div * kl_weight

            if config['first_order']:
                all_grads = torch.autograd.grad(
                    outputs=loss,
                    inputs=q_params,
                    retain_graph=config['train_flag'])
            else:
                all_grads = torch.autograd.grad(
                    outputs=loss,
                    inputs=q_params,
                    create_graph=config['train_flag'])

            for i in range(len(all_grads)):
                grad_accum[
                    i] = grad_accum[i] + all_grads[i] / config['num_models']

        new_q_params = []
        for param, grad in zip(q_params, grad_accum):
            new_q_params.append(
                higher.optim._add(tensor=param,
                                  a1=-config['inner_lr'],
                                  a2=grad))

        f_hyper_net.update_params(new_q_params)

    return f_hyper_net
    def adapt_to_episode(
            self,
            x: torch.Tensor,
            y: torch.Tensor,
            hyper_net: torch.nn.Module,
            f_base_net: higher.patch._MonkeyPatchBase,
            train_flag: bool = True) -> higher.patch._MonkeyPatchBase:
        """Inner-loop for MAML-like algorithm

        Args:
            x, y: training data and corresponding labels
            hyper_net: the meta-model
            f_base_net: the functional form of the based neural network
            kl_div_fn: function that calculates the KL divergence

        Returns: the task-specific meta-model
        """
        # convert hyper_net to its functional form
        f_hyper_net = higher.patch.monkeypatch(module=hyper_net,
                                               copy_initial_weights=False,
                                               track_higher_grads=train_flag)

        hyper_net_params = [p for p in hyper_net.parameters()]

        for _ in range(self.config['num_inner_updates']):
            grads_accum = [0] * len(
                hyper_net_params
            )  # accumulate gradients of Monte Carlo sampling

            q_params = f_hyper_net.fast_params  # parameters of the task-specific hyper_net

            # KL divergence
            KL_div = self.KL_divergence(p=hyper_net_params, q=q_params)

            for _ in range(self.config['num_models']):
                base_net_params = f_hyper_net.forward()
                y_logits = f_base_net.forward(x, params=base_net_params)
                cls_loss = torch.nn.functional.cross_entropy(input=y_logits,
                                                             target=y)

                loss = cls_loss + self.config['KL_weight'] * KL_div

                if self.config['first_order']:
                    grads = torch.autograd.grad(outputs=loss,
                                                inputs=q_params,
                                                retain_graph=True)
                else:
                    grads = torch.autograd.grad(outputs=loss,
                                                inputs=q_params,
                                                create_graph=True)

                # accumulate gradients
                for i in range(len(grads)):
                    grads_accum[i] = grads_accum[
                        i] + grads[i] / self.config['num_models']

            new_q_params = []
            for param, grad in zip(q_params, grads_accum):
                new_q_params.append(
                    higher.optim._add(tensor=param,
                                      a1=-self.config['inner_lr'],
                                      a2=grad))

            f_hyper_net.update_params(new_q_params)

        return f_hyper_net