def inner_loop_adapt(self,
                         task,
                         num_updates=None,
                         analysis=False,
                         iter=None):
        # adapt means doing the complete inner loop update
        measurements_trajectory = defaultdict(list)
        if analysis:
            grad_norm_by_step = [
            ]  # records the gradient norm at every inner loop step
            grad_quantiles_by_step = defaultdict(list)

        adapted_param_dict = OrderedDict()
        adapted_param_dict[
            'classifier.fully_connected.weight'] = self._model.classifier.fully_connected.weight
        adapted_param_dict[
            'classifier.fully_connected.bias'] = self._model.classifier.fully_connected.bias

        modulation = self._embedding_model(task, return_task_embedding=False)

        if num_updates is None:
            # if num_updates is not specified
            # apply inner loop update for self._num_updates times
            num_updates = self._num_updates

        for i in range(num_updates):
            # here model is just a functional template
            # all of the parameters are passed in through params and embeddings
            adapted_param_dict, measurements, grad_list = \
                self.inner_loop_one_step_gradient_descent(task=task,
                                                          adapted_param_dict=adapted_param_dict,
                                                          modulation=modulation,
                                                          return_grad_list=analysis)
            # add this step's measurement to its trajectory
            for key in measurements.keys():
                measurements_trajectory[key].append(measurements[key])

            if analysis:
                grad_norm_by_step.append(get_grad_norm(grad_list))
                grad_quantiles_by_step[i + 1].extend(
                    get_grad_quantiles(grad_list))

        with torch.no_grad(
        ):  # compute the train loss after the last adaptation
            preds = self._model(task.x,
                                modulation=modulation,
                                update_params=adapted_param_dict)
            loss = self._inner_loss_func(preds, task.y)
            measurements_trajectory['loss'].append(loss.item())
            if self.is_classification:
                measurements_trajectory['accu'].append(accuracy(preds, task.y))

        info_dict = None
        if analysis:
            info_dict = {}
            info_dict['grad_norm_by_step'] = grad_norm_by_step
            info_dict['grad_quantiles_by_step'] = grad_quantiles_by_step
            info_dict['modulation'] = modulation
        return adapted_param_dict, measurements_trajectory, info_dict
    def inner_loop_adapt(self,
                         task,
                         num_updates=None,
                         analysis=False,
                         iter=None):
        # adapt means doing the complete inner loop update
        measurements_trajectory = defaultdict(list)
        if analysis:
            grad_norm_by_step = [
            ]  # records the gradient norm at every inner loop step
            grad_quantiles_by_step = defaultdict(list)

        adapted_param_dict = self._model.param_dict  # parameters to be updated in the inner loop
        if isinstance(self._embedding_model, LSTMAttentionEmbeddingModel):
            layer_modulations = self._embedding_model(
                task, return_task_embedding=False, iter=iter)
        else:
            layer_modulations = self._embedding_model(
                task, return_task_embedding=False)
        # apply inner loop update for self._num_updates times
        if num_updates is None:
            num_updates = self._num_updates

        for i in range(num_updates):
            # here model is just a functional template
            # all of the parameters are passed in through params and embeddings
            adapted_param_dict, measurements, grad_list = \
                                self.inner_loop_one_step_gradient_descent(
                                    task=task,
                                    layer_modulations=layer_modulations,
                                    param_dict=adapted_param_dict,
                                    return_grad_list=analysis)
            # add this step's measurement to its trajectory
            for key in measurements.keys():
                measurements_trajectory[key].append(measurements[key])

            if analysis:
                grad_norm_by_step.append(get_grad_norm(grad_list))
                grad_quantiles_by_step[i + 1].extend(
                    get_grad_quantiles(grad_list))

        with torch.no_grad(
        ):  # compute the train loss after the last adaptation
            preds = self._model(task.x,
                                params=adapted_param_dict,
                                layer_modulations=layer_modulations)
            loss = self._inner_loss_func(preds, task.y)
            measurements_trajectory['loss'].append(loss.item())
            if self.is_classification:
                measurements_trajectory['accu'].append(accuracy(preds, task.y))

        info_dict = None
        if analysis:
            info_dict = {}
            info_dict['grad_norm_by_step'] = grad_norm_by_step
            info_dict['grad_quantiles_by_step'] = grad_quantiles_by_step
            info_dict['layer_modulations'] = layer_modulations
        return adapted_param_dict, measurements_trajectory, info_dict
    def inner_loop_one_step_gradient_descent(self,
                                             task,
                                             adapted_param_dict,
                                             modulation,
                                             return_grad_list=False):
        """Apply one step of gradient descent on self._inner_loss_func,
        based on data in the single task from argument task
        with respect to parameters in param_dict
        with step-size `self._fast_lr`, and returns
            the updated parameters
            loss before adaptation
            gradient if return_grad_list=True
        """
        preds = self._model(task.x,
                            modulation=modulation,
                            update_params=adapted_param_dict)
        loss = self._inner_loss_func(preds, task.y)

        measurements = {}
        measurements['loss'] = loss.item()
        if self.is_classification:
            measurements['accu'] = accuracy(preds, task.y)

        create_graph = not self._first_order
        grad_list = torch.autograd.grad(loss,
                                        adapted_param_dict.values(),
                                        create_graph=create_graph,
                                        allow_unused=False)
        # allow_unused If False, specifying inputs that were not used when computing outputs
        # (and therefore their grad is always zero) is an error. Defaults to False.

        clip_grad = (self._inner_loop_grad_clip > 0)
        if clip_grad:
            clip_grad_list = []
        for (name, param), grad in zip(adapted_param_dict.items(), grad_list):
            # grad will be torch.Tensor
            assert grad is not None
            if clip_grad:
                grad = soft_clip(grad,
                                 clip_value=self._inner_loop_grad_clip,
                                 slope=self._inner_loop_soft_clip_slope)
                clip_grad_list.append(grad)
            adapted_param_dict[name] = param - self._fast_lr * grad

        if return_grad_list:
            if clip_grad:
                grad_list = clip_grad_list
        else:
            grad_list = None
        return adapted_param_dict, measurements, grad_list
Esempio n. 4
0
 def _update_measurements(self, task, loss, preds):
     self._count_iters += 1.0
     self._cum_loss += loss.data.cpu().numpy()
     if self._collect_accuracies:
         self._cum_accuracy += accuracy(preds, task.y).data.cpu().numpy()
Esempio n. 5
0
    def run(self, dataset_iterator, is_training=False, start=1, stop=1):
        # looping through the entire meta_dataset once
        sum_train_measurements_trajectory_over_meta_set = defaultdict(float)
        sum_test_measurements_before_adapt_over_meta_set = defaultdict(float)
        sum_test_measurements_after_adapt_over_meta_set = defaultdict(float)
        n_tasks = 0

        for i, (train_task_batch, test_task_batch) in tqdm(
                enumerate(dataset_iterator,
                          start=start if is_training else 1)):

            if is_training and i == stop:
                return {
                    'train_loss_trajectory':
                    divide_measurements(
                        sum_train_measurements_trajectory_over_meta_set,
                        n=n_tasks),
                    'test_loss_before':
                    divide_measurements(
                        sum_test_measurements_before_adapt_over_meta_set,
                        n=n_tasks),
                    'test_loss_after':
                    divide_measurements(
                        sum_test_measurements_after_adapt_over_meta_set,
                        n=n_tasks)
                }

            # _meta_dataset yields data iteration
            train_measurements_trajectory_over_batch = defaultdict(list)
            test_measurements_before_adapt_over_batch = defaultdict(list)
            test_measurements_after_adapt_over_batch = defaultdict(list)
            analysis = (i % self._log_interval == 0 or i == 1)
            modulation_analysis = hasattr(self._algorithm, '_embedding_model') and \
                                       isinstance(self._algorithm._embedding_model,
                                                  LSTMAttentionEmbeddingModel)

            if analysis and is_training:
                grad_norm_by_step_over_batch = []
                grad_quantiles_by_step_over_batch = defaultdict(list)
                if modulation_analysis:
                    task_modulation_params_over_batch = []

            batch_size = len(train_task_batch)
            sum_test_loss_after_adapt = 0.0
            for train_task, test_task in zip(train_task_batch,
                                             test_task_batch):
                # evalute test loss before adapt over train
                with torch.no_grad():
                    test_pred_before_adapt = self._algorithm.predict_without_adapt(
                        train_task, test_task.x)
                    test_loss_before_adapt = self._outer_loss_func(
                        test_pred_before_adapt, test_task.y)
                    test_measurements_before_adapt_over_batch['loss'].append(
                        test_loss_before_adapt.item())
                    if self._algorithm.is_classification:
                        test_measurements_before_adapt_over_batch[
                            'accu'].append(
                                accuracy(test_pred_before_adapt, test_task.y))

                # adapt according train_task
                adapted_param_dict, train_measurements_trajectory, info_dict = \
                        self._algorithm.inner_loop_adapt(train_task, analysis=analysis and is_training, iter=i)

                for key, measurements in train_measurements_trajectory.items():
                    train_measurements_trajectory_over_batch[key].append(
                        measurements)

                if analysis and is_training:
                    grad_norm_by_step = info_dict['grad_norm_by_step']
                    grad_quantiles_by_step = info_dict[
                        'grad_quantiles_by_step']
                    grad_norm_by_step_over_batch.append(grad_norm_by_step)
                    for step, quantiles in grad_quantiles_by_step.items():
                        grad_quantiles_by_step_over_batch[step].append(
                            quantiles)
                    if modulation_analysis:
                        task_modulation_params = info_dict['layer_modulations']
                        task_modulation_params_over_batch.append(
                            task_modulation_params)

                test_pred_after_adapt = self._algorithm.predict_without_adapt(
                    train_task, test_task.x, param_dict=adapted_param_dict)
                test_loss_after_adapt = self._outer_loss_func(
                    test_pred_after_adapt, test_task.y)
                sum_test_loss_after_adapt += test_loss_after_adapt

                test_measurements_after_adapt_over_batch['loss'].append(
                    test_loss_after_adapt.item())
                if self._algorithm.is_classification:
                    test_measurements_after_adapt_over_batch['accu'].append(
                        accuracy(test_pred_after_adapt, test_task.y))

            update_sum_measurements_trajectory(
                sum_train_measurements_trajectory_over_meta_set,
                train_measurements_trajectory_over_batch)
            update_sum_measurements(
                sum_test_measurements_before_adapt_over_meta_set,
                test_measurements_before_adapt_over_batch)
            update_sum_measurements(
                sum_test_measurements_after_adapt_over_meta_set,
                test_measurements_after_adapt_over_batch)
            n_tasks += batch_size

            if is_training:
                avg_test_loss_after_adapt = sum_test_loss_after_adapt / batch_size
                # torch.mean(torch.stack(test_measurements_after_adapt_over_batch['loss'])) # make list a torch.tensor
                self._outer_optimizer.zero_grad()
                avg_test_loss_after_adapt.backward(
                )  # here back prop will propagate all the way to the initialization parameters
                outer_grad_norm_before_clip = get_grad_norm_from_parameters(
                    self._algorithm._model.parameters())
                self._writer.add_scalar('outer_grad/model_norm/before_clip',
                                        outer_grad_norm_before_clip, i)
                if self._outer_loop_grad_norm > 0.:
                    clip_grad_norm_(self._algorithm._model.parameters(),
                                    self._outer_loop_grad_norm)
                    #clip_grad_norm_(self._algorithm._embedding_model.parameters(), self._outer_loop_grad_norm)
                self._outer_optimizer.step()

            # logging
            # (i % self._log_interval == 0 or i == 1)
            if analysis and is_training:
                self.log_output(i,
                                train_measurements_trajectory_over_batch,
                                test_measurements_before_adapt_over_batch,
                                test_measurements_after_adapt_over_batch,
                                write_tensorboard=is_training)

                if is_training:
                    self.write_gradient_info_to_board(
                        i, grad_norm_by_step_over_batch,
                        grad_quantiles_by_step_over_batch)
                    if modulation_analysis:
                        metadata = [
                            str(t.task_info['task_id'])
                            for t in train_task_batch
                        ]
                        self.write_embeddings_output_to_board(
                            task_modulation_params_over_batch, metadata, i)

            # Save model
            if (i % self._save_interval == 0 or i == 1) and is_training:
                save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i)
                save_path = os.path.join(self._save_folder, save_name)
                with open(save_path, 'wb') as f:
                    torch.save(self._algorithm.state_dict(), f)

        results = {
            'train_loss_trajectory':
            divide_measurements(
                sum_train_measurements_trajectory_over_meta_set, n=n_tasks),
            'test_loss_before':
            divide_measurements(
                sum_test_measurements_before_adapt_over_meta_set, n=n_tasks),
            'test_loss_after':
            divide_measurements(
                sum_test_measurements_after_adapt_over_meta_set, n=n_tasks)
        }

        if not is_training:
            self.log_output(start,
                            results['train_loss_trajectory'],
                            results['test_loss_before'],
                            results['test_loss_after'],
                            write_tensorboard=True,
                            meta_val=True)

        return results