Exemple #1
0
    def adapt(self,
              inputs,
              targets,
              is_classification_task=None,
              num_adaptation_steps=1,
              step_size=0.1,
              first_order=False):
        if is_classification_task is None:
            is_classification_task = (not targets.dtype.is_floating_point)
        params = None

        results = {
            'inner_losses': np.zeros((num_adaptation_steps, ),
                                     dtype=np.float32)
        }

        for step in range(num_adaptation_steps):
            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets)
            results['inner_losses'][step] = inner_loss.item()

            if (step == num_adaptation_steps - 1) and is_classification_task:
                results['accuracy_before'] = compute_accuracy(logits, targets)
                #print(step, logits,targets,results['accuracy_before'])

            self.model.zero_grad()
            params = gradient_update_parameters(
                self.model,
                inner_loss,
                step_size=step_size,
                params=params,
                first_order=(not self.model.training) or first_order)

        return params, results
Exemple #2
0
    def get_outer_loss(self, batch):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)
        results = {
            'num_tasks':
            num_tasks,
            'inner_losses':
            np.zeros((self.num_adaptation_steps, num_tasks), dtype=np.float32),
            'outer_losses':
            np.zeros((num_tasks, ), dtype=np.float32),
            'mean_outer_loss':
            0.
        }
        if is_classification_task:
            results.update({
                'accuracies_before':
                np.zeros((num_tasks, ), dtype=np.float32),
                'accuracies_after':
                np.zeros((num_tasks, ), dtype=np.float32)
            })

        mean_outer_loss = torch.tensor(0., device=self.device)

        mean_outer_losses = []

        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, adaptation_results = self.adapt(
                train_inputs,
                train_targets,
                is_classification_task=is_classification_task,
                num_adaptation_steps=self.num_adaptation_steps,
                step_size=self.step_size,
                first_order=self.first_order)

            results['inner_losses'][:, task_id] = adaptation_results[
                'inner_losses']
            if is_classification_task:
                results['accuracies_before'][task_id] = adaptation_results[
                    'accuracy_before']

            with torch.set_grad_enabled(self.model.training):
                test_logits = self.model(test_inputs, params=params)
                outer_loss = self.loss_function(test_logits, test_targets)
                results['outer_losses'][task_id] = outer_loss.item()
                mean_outer_loss += outer_loss
                mean_outer_losses.append(outer_loss)

            if is_classification_task:
                results['accuracies_after'][task_id] = compute_accuracy(
                    test_logits, test_targets)

        mean_outer_loss.div_(num_tasks)
        results['mean_outer_loss'] = mean_outer_loss.item()

        return mean_outer_loss, results, mean_outer_losses
    def get_outer_losses(self, task_weighting: TaskWeightingBase, batch) -> (torch.Tensor, dict):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)
        results = {
            'num_tasks': num_tasks,
            'inner_losses': np.zeros((self.num_adaptation_steps,
                                      num_tasks), dtype=np.float32),
            'outer_losses': np.zeros((num_tasks,), dtype=np.float32),
            'mean_outer_loss': 0.
        }
        if is_classification_task:
            results.update({
                'accuracies_before': np.zeros((num_tasks,), dtype=np.float32),
                'accuracies_after': np.zeros((num_tasks,), dtype=np.float32)
            })

        outer_losses = torch.zeros(len(batch['train'][0]), device=self.device)
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, adaptation_results = self.adapt(train_inputs, train_targets,
                                                    is_classification_task=is_classification_task,
                                                    num_adaptation_steps=self.num_adaptation_steps,
                                                    step_size=self.step_size, first_order=self.first_order)

            results['inner_losses'][:, task_id] = adaptation_results['inner_losses']
            if is_classification_task:
                results['accuracies_before'][task_id] = adaptation_results['accuracy_before']

            with torch.set_grad_enabled(self.model.training):
                test_logits = self.model(test_inputs, params=params)
                outer_losses_for_each_image = \
                    self.loss_function(test_logits, test_targets, reduction='none')
                outer_loss = task_weighting.compute_weighted_losses_for_each_image(task_id, outer_losses_for_each_image)
                results['outer_losses'][task_id] = outer_loss.item()
                outer_losses[task_id] = outer_loss

            if is_classification_task:
                results['accuracies_after'][task_id] = compute_accuracy(
                    test_logits, test_targets)

        results['mean_outer_loss'] = outer_losses.mean().item()

        return outer_losses, results
Exemple #4
0
    def adapt(self, inputs, targets, is_classification_task=None,
              num_adaptation_steps=1, step_size=0.1, first_order=False, write_params=False, reset_params=False):
        if is_classification_task is None:
            is_classification_task = (not targets.dtype.is_floating_point)
        params = OrderedDict(self.model.meta_named_parameters())

        for key in params.keys():
            if isinstance(params[key], BatchParameter) and params[key].expanding:
                params[key] = params[key].expanded(inputs.size(0))

        results = {}

        state = None
        for step in range(num_adaptation_steps):
            if self.warp_model is not None:
                self.warp_model.set_listening(True)

            logits = self.model(inputs, params=params)
            inner_loss = self.loss_function(logits, targets, reduction="none")

            inner_loss = inner_loss.sum(dim=0)

            if (step == 0) and is_classification_task:
                results['accuracy_before'] = compute_accuracy(logits, targets)

            params = gradient_update_parameters_warp(self.model, inner_loss,
                warp_model=self.warp_model, step_size=step_size, params=params,
                first_order=(not self.model.training) or first_order, state=state)

            if write_params:
                old_params = OrderedDict(self.model.meta_named_parameters())
                with torch.no_grad():
                    for key in old_params.keys():
                        # old_params[key].copy_(params[key]=True)
                        old_params[key].detach_().requires_grad_(True)

            if reset_params:
                for module in self.model.modules():
                    if isinstance(module, BatchLinear):
                        module.reset_parameters()

            if self.warp_model is not None:
                self.warp_model.set_listening(False)

        return params, results
Exemple #5
0
    def get_outer_loss(self, batch, is_classification_task=False):
        """
        Performs one full training iteration on a batch of tasks.

        For each task in the batch:
        - Evaluate the test loss on a batch of test inputs and targets
        - Update the mean test loss across tasks

        Parameters
        ----------
        batch : dict
            A dict mapping the keys 'train' and 'test' to their respective
            batches of Tasks. Each Task contains inputs and targets.

        Returns
        -------
        float
            The average test loss across tasks in the batch

        dict
            A dict with relevant training statistics ('inner_losses', 'outer_losses', 
            'accuracies_before', 'accuracies_after') as numpy arrays
        """
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point) or is_classification_task
        results = {
            'num_tasks': num_tasks,
            'outer_losses': np.zeros((num_tasks,), dtype=np.float32),
            'mean_outer_loss': 0.
        }
        if is_classification_task:
            results.update({
                'accuracies_before': np.zeros((num_tasks,), dtype=np.float32),
                'accuracies_after': np.zeros((num_tasks,), dtype=np.float32)
            })

        mean_outer_loss = torch.tensor(0., device=self.device)
        failed_adaptation_Xs = torch.empty(0, device=self.device)
        failed_adaptation_ys = torch.empty(0, device=self.device, dtype=torch.long)
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, _ = self.adapt(train_inputs, train_targets)

            if self.mixup_alpha > 0:
                test_inputs, test_targets, test_targets_b, lam = self.mixup_data(test_inputs, test_targets)
                # The corresponding interpolated outputs will be given by the line below
                # test_targets = lam * test_targets + (1. - lam) * test_targets_b
            
            with torch.set_grad_enabled(self.model.training):
                    
                test_logits = self.model(test_inputs, params=params)
                if test_targets.shape[-1] == 2 and is_classification_task:
                    # If `test_targets` has 2 columns for a classification task, then the second column 
                    # will be treated as the weights for each point when computing the loss.
                    test_targets, weights = test_targets[:,0].long(), test_targets[:,1]
                    if self.mixup_alpha > 0:
                        outer_loss = self.mixup_criterion(self.loss_function, test_logits, test_targets, test_targets_b, lam)
                    else:
                        outer_loss = self.loss_function(test_logits, test_targets, reduction='none')
                        outer_loss = torch.mean(outer_loss * weights)
                    # print("Sum of weights:", torch.sum(weights))
                    # print("query point:", test_inputs[-1], test_targets[-1])
                    # print("Query point NML prob:", F.softmax(test_logits, -1)[-1,test_targets[-1]])
                elif len(test_targets.shape) == 1:
                    outer_loss = self.loss_function(test_logits, test_targets)
                else:
                    raise Exception(f"Invalid target shape: {test_targets.shape}. "
                        + "Must have either 1 or 2 columns.")
                results['outer_losses'][task_id] = outer_loss.item()
                mean_outer_loss += outer_loss

            if is_classification_task:
                results['accuracies_after'][task_id] = compute_accuracy(
                    test_logits, test_targets)

                with torch.no_grad():
                    logits = self.model(train_inputs, params=params)
                    incorrect = torch.argmax(logits, dim=-1) != train_targets
                    failed_adaptation_Xs = torch.cat([failed_adaptation_Xs, train_inputs[incorrect]], axis=0)
                    failed_adaptation_ys = torch.cat([failed_adaptation_ys, train_targets[incorrect]])

        mean_outer_loss.div_(num_tasks)
        results['mean_outer_loss'] = mean_outer_loss.item()
        if is_classification_task:
            results['failed_adaptations'] = (failed_adaptation_Xs.cpu().numpy(), failed_adaptation_ys.cpu().numpy())

        return mean_outer_loss, results
Exemple #6
0
    def get_outer_loss(self, batch, eval_mode=False, repeats=32, write_params=False):
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (not test_targets.dtype.is_floating_point)
        results = {
            'num_tasks': num_tasks,
            #'inner_losses': np.zeros((self.num_adaptation_steps,
            #    num_tasks), dtype=np.float32),
            'outer_losses': np.zeros((num_tasks,), dtype=np.float32),
            'mean_outer_loss': 0.
        }
        if is_classification_task:
            results.update({
                'accuracies_before': np.zeros((num_tasks,), dtype=np.float32),
                'accuracies_after': np.zeros((num_tasks,), dtype=np.float32)
            })

        mean_outer_loss = torch.tensor(0., device=self.device)

        train_inputs = batch["train"][0]
        train_targets = batch["train"][1]
        test_inputs = batch["test"][0]
        test_targets = batch["test"][1]

        """
        if self.ensembler is not None and not self.model.training:
            train_inputs = train_inputs.repeat_interleave(self.ensemble_size, dim=0)
            train_targets = train_targets.repeat_interleave(self.ensemble_size, dim=0)
            test_inputs = test_inputs.repeat_interleave(self.ensemble_size, dim=0)
        """

        if self.ensembler is not None and not self.model.training:
            repeats = self.ensemble_size
            for i in range(repeats):
                train_input_exp =  train_inputs[i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)].detach().repeat_interleave(repeats, dim=0)
                train_target_exp =  train_targets[i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)].repeat_interleave(repeats, dim=0)
                test_input_exp =  test_inputs[i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)].detach().repeat_interleave(repeats, dim=0)
                train_input_exp.requires_grad_(True)
                test_input_exp.requires_grad_(True)

                params, adaptation_results = self.adapt(train_input_exp, train_target_exp,
                    is_classification_task=is_classification_task,
                    num_adaptation_steps=self.num_adaptation_steps,
                    step_size=self.step_size, first_order=self.first_order)

                with torch.set_grad_enabled(False):
                    test_logits = self.model(test_input_exp, params=params)
                    test_logits = test_logits.view(-1, repeats, *test_logits.size()[1:])
                    test_logits = test_logits.sum(dim=1)
                    outer_loss = self.loss_function(test_logits, test_targets[i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)])
                    mean_outer_loss += outer_loss

                if is_classification_task:
                    results['accuracies_after'][i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)] = compute_accuracy(
                        test_logits, test_targets[i * (train_inputs.size(0) // repeats) : (i+1) * (train_inputs.size(0) // repeats)])

                torch.cuda.empty_cache()

            mean_outer_loss.div_(repeats)
            results['mean_outer_loss'] = mean_outer_loss.item()

        else:
            params, adaptation_results = self.adapt(train_inputs, train_targets,
                is_classification_task=is_classification_task,
                num_adaptation_steps=self.num_adaptation_steps,
                step_size=self.step_size, first_order=self.first_order, write_params=write_params)

            #results['inner_losses'][:] = adaptation_results['inner_losses']
            if is_classification_task:
                results['accuracies_before'] = adaptation_results['accuracy_before']

            with torch.set_grad_enabled(self.model.training):
                if self.ensembler is not None and not self.model.training:
                    test_logits = []
                    for i in range(test_logits.size(0) // self.ensemble_size):
                        test_logits.append(self.model(test_inputs[i * self.ensemble_size : (i + 1) * self.ensemble_size]))

                    temp = test_logits
                    temp = temp.view(-1, self.ensemble_size, *temp.size()[1:])
                    temp = temp.transpose(0,1).reshape(self.ensemble_size, -1, temp.size(-1))
                    temp = self.ensembler(temp)
                    temp = temp.sum(dim=0)
                    test_logits = temp.view(test_logits.size(0) // self.ensemble_size, test_logits.size(1), -1)
                else:
                    test_logits = self.model(test_inputs, params=params)
                outer_loss = self.loss_function(test_logits, test_targets)
                mean_outer_loss += outer_loss

            if is_classification_task:
                results['accuracies_after'] = compute_accuracy(
                    test_logits, test_targets)

            results['mean_outer_loss'] = mean_outer_loss.item()

        return mean_outer_loss, results
Exemple #7
0
    def adapt(self,
              inputs,
              targets,
              is_classification_task=None,
              num_adaptation_steps=1,
              step_size=0.1,
              first_order=False,
              start_params=None,
              **kwargs):
        """
        Performs `num_adaptation_steps` gradient steps on the given training inputs and targets.
        Does NOT actually change self.model!!

        Returns
        -------
        OrderedDict
            The model params after taking gradient steps.
            (self.model is not modified)

        dict
            Relevant training statistics ('inner_losses' and 'accuracy_before') as numpy arrays
        """
        if is_classification_task is None:
            is_classification_task = (not targets.dtype.is_floating_point)
        params = start_params

        results = {
            'inner_losses': np.zeros((num_adaptation_steps, ),
                                     dtype=np.float32)
        }

        for step in range(num_adaptation_steps):
            logits = self.model(inputs, params=params)
            if targets.shape[-1] == 2 and is_classification_task:
                # If `test_targets` has 2 columns for a classification task, then the second column
                # will be treated as the weights for each point when computing the loss.
                targets, weights = targets[:, 0], targets[:, 1]
                inner_loss = self.loss_function(logits,
                                                targets,
                                                reduction='none')
                inner_loss = torch.sum(
                    inner_loss * weights) / torch.sum(weights)
            elif len(targets.shape) == 1:
                inner_loss = self.loss_function(logits, targets)
            else:
                raise Exception(
                    f"Invalid target shape: {test_targets.shape}. " +
                    "Must have either 1 or 2 columns.")
            results['inner_losses'][step] = inner_loss.item()

            if (step == 0) and is_classification_task:
                results['accuracy_before'] = compute_accuracy(logits, targets)

            self.model.zero_grad()
            params = gradient_update_parameters(
                self.model,
                inner_loss,
                step_size=step_size,
                params=params,
                num_layers=self.num_finetuning_layers,
                first_order=(not self.model.training) or first_order)

        return params, results
Exemple #8
0
    def get_outer_loss(self, batch, is_classification_task=False):
        """
        Performs one full MAML training iteration (inner and outer loop)
        on a batch of tasks.

        For each task in the batch:
        - Do `self.num_adaptation_steps` gradient steps on a batch of training inputs and targets
        - Evaluate the test loss on a batch of test inputs and targets
        - Update the mean test loss across tasks

        Parameters
        ----------
        batch : dict
            A dict mapping the keys 'train' and 'test' to their respective
            batches of Tasks. Each Task contains inputs and targets.

        Returns
        -------
        float
            The average test loss across tasks in the batch

        dict
            A dict with relevant training statistics ('inner_losses', 'outer_losses', 
            'accuracies_before', 'accuracies_after') as numpy arrays
        """
        if 'test' not in batch:
            raise RuntimeError('The batch does not contain any test dataset.')

        _, test_targets = batch['test']
        num_tasks = test_targets.size(0)
        is_classification_task = (
            not test_targets.dtype.is_floating_point) or is_classification_task
        results = {
            'num_tasks':
            num_tasks,
            'inner_losses':
            np.zeros((self.num_adaptation_steps, num_tasks), dtype=np.float32),
            'outer_losses':
            np.zeros((num_tasks, ), dtype=np.float32),
            'mean_outer_loss':
            0.
        }
        if is_classification_task:
            results.update({
                'accuracies_before':
                np.zeros((num_tasks, ), dtype=np.float32),
                'accuracies_after':
                np.zeros((num_tasks, ), dtype=np.float32)
            })

        mean_outer_loss = torch.tensor(0., device=self.device)
        mean_query_acc = torch.tensor(0., device=self.device)
        mean_l2_penalty = torch.tensor(0., device=self.device)
        # failed_adaptation_Xs = torch.empty(0, device=self.device)
        # failed_adaptation_ys = torch.empty(0, device=self.device, dtype=torch.long)
        for task_id, (train_inputs, train_targets, test_inputs, test_targets) \
                in enumerate(zip(*batch['train'], *batch['test'])):
            params, adaptation_results = self.adapt(
                train_inputs,
                train_targets,
                is_classification_task=is_classification_task,
                num_adaptation_steps=self.num_adaptation_steps,
                step_size=self.step_size,
                first_order=self.first_order)

            results['inner_losses'][:, task_id] = adaptation_results[
                'inner_losses']
            if is_classification_task:
                results['accuracies_before'][task_id] = adaptation_results[
                    'accuracy_before']

            with torch.set_grad_enabled(self.model.training):
                test_logits = self.model(test_inputs, params=params)
                if test_targets.shape[-1] == 2 and is_classification_task:
                    # If `test_targets` has 2 columns for a classification task, then the second column
                    # will be treated as the weights for each point when computing the loss.
                    test_targets, weights = test_targets[:, 0], test_targets[:,
                                                                             1]
                    outer_loss = self.loss_function(test_logits,
                                                    test_targets,
                                                    reduction='none')
                    outer_loss = torch.sum(
                        outer_loss * weights) / torch.sum(weights)
                    # print("Sum of weights:", torch.sum(weights))
                    # print("query point:", test_inputs[-1], test_targets[-1])
                    # print("Query point NML prob:", F.softmax(test_logits, -1)[-1,test_targets[-1]])
                elif len(test_targets.shape) == 1:
                    outer_loss = self.loss_function(test_logits, test_targets)
                else:
                    raise Exception(
                        f"Invalid target shape: {test_targets.shape}. " +
                        "Must have either 1 or 2 columns.")

                results['outer_losses'][task_id] = outer_loss.item()
                mean_outer_loss += outer_loss

            if is_classification_task:
                results['accuracies_after'][task_id] = compute_accuracy(
                    test_logits, test_targets)

                # Store which training points we're incorrect on after adaptation
                """
                with torch.no_grad():
                    logits = self.model(train_inputs, params=params)
                    target_labels = train_targets[:,0].long() if train_targets.shape[-1] == 2 else train_targets
                    incorrect = torch.argmax(logits, dim=-1) != target_labels
                    failed_adaptation_Xs = torch.cat([failed_adaptation_Xs, train_inputs[incorrect]], axis=0)
                    failed_adaptation_ys = torch.cat([failed_adaptation_ys, target_labels[incorrect]])

                    mean_query_acc += 1 - int(incorrect[0])
                """

        mean_outer_loss.div_(num_tasks)
        mean_query_acc.div_(num_tasks)
        mean_l2_penalty.div_(num_tasks)

        if self.weight_decay_lambda:
            # Pre-adaptation weight decay
            l2_penalty = self._get_l2_penalty(
                self.model.meta_named_parameters(), self.weight_decay_lambda)
            mean_outer_loss += l2_penalty

        results['mean_outer_loss'] = mean_outer_loss.item()
        results['mean_query_acc'] = mean_query_acc.item()
        results['mean_l2_penalty'] = mean_l2_penalty.item()
        # if is_classification_task:
        #     results['failed_adaptations'] = (failed_adaptation_Xs.cpu().numpy(), failed_adaptation_ys.cpu().numpy())

        return mean_outer_loss, results