Example #1
0
    def forward(self, episodes, updates=1, testing=False):
        support_losses = []
        query_losses, query_accuracies, query_precisions, query_recalls, query_f1s = [], [], [], [], []
        n_episodes = len(episodes)

        for episode_id, episode in enumerate(episodes):

            self.initialize_output_layer(episode.n_classes)

            batch_x, batch_len, batch_y = next(iter(episode.support_loader))
            batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y)

            if self.proto_maml:
                output_repr = self.learner(batch_x, batch_len)
                init_weights, init_bias = self._initialize_with_proto_weights(output_repr, batch_y, episode.n_classes)
            else:
                init_weights, init_bias = 0, 0

            with torch.backends.cudnn.flags(enabled=self.fomaml or testing or not isinstance(self.learner, RNNSequenceModel)), \
                 higher.innerloop_ctx(self.learner, self.learner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=(not self.fomaml and not testing)) as (flearner, diffopt):

                all_predictions, all_labels = [], []
                self.train()
                flearner.train()
                flearner.zero_grad()

                for i in range(updates):
                    output = flearner(batch_x, batch_len)
                    output = self.output_layer(output, init_weights, init_bias)
                    output = output.view(output.size()[0] * output.size()[1], -1)
                    batch_y = batch_y.view(-1)
                    loss = self.learner_loss[episode.base_task](output, batch_y)

                    # Update the output layer parameters
                    output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias], retain_graph=True)
                    self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad
                    self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad

                    # Update the shared parameters
                    diffopt.step(loss)

                relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach()
                pred = make_prediction(output[relevant_indices].detach()).cpu()
                all_predictions.extend(pred)
                all_labels.extend(batch_y[relevant_indices].cpu())

                support_loss = loss.item()

                accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions,
                                                                                all_labels, binary=False)

                logger.info('Episode {}/{}, task {} [support_set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, '
                            'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id,
                                                                        support_loss, accuracy, precision, recall, f1_score))

                query_loss = 0.0
                all_predictions, all_labels = [], []

                # Disable dropout
                for module in flearner.modules():
                    if isinstance(module, nn.Dropout):
                        module.eval()

                for n_batch, (batch_x, batch_len, batch_y) in enumerate(episode.query_loader):
                    batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y)
                    output = flearner(batch_x, batch_len)
                    output = self.output_layer(output, init_weights, init_bias)
                    output = output.view(output.size()[0] * output.size()[1], -1)
                    batch_y = batch_y.view(-1)
                    loss = self.learner_loss[episode.base_task](output, batch_y)

                    if not testing:
                        if self.fomaml:
                            meta_grads = torch.autograd.grad(loss, [p for p in flearner.parameters() if p.requires_grad], retain_graph=self.proto_maml)
                        else:
                            meta_grads = torch.autograd.grad(loss, [p for p in flearner.parameters(time=0) if p.requires_grad], retain_graph=self.proto_maml)
                        if self.proto_maml:
                            proto_grads = torch.autograd.grad(loss, [p for p in self.learner.parameters() if p.requires_grad])
                            meta_grads = [mg + pg for (mg, pg) in zip(meta_grads, proto_grads)]
                    query_loss += loss.item()

                    relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach()
                    pred = make_prediction(output[relevant_indices].detach()).cpu()
                    all_predictions.extend(pred)
                    all_labels.extend(batch_y[relevant_indices].cpu())

                query_loss /= n_batch + 1

            accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions,
                                                                            all_labels, binary=False)

            logger.info('Episode {}/{}, task {} [query set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, '
                        'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id,
                                                                    query_loss, accuracy, precision, recall, f1_score))
            support_losses.append(support_loss)
            query_losses.append(query_loss)
            query_accuracies.append(accuracy)
            query_precisions.append(precision)
            query_recalls.append(recall)
            query_f1s.append(f1_score)

            if not testing:
                for param, meta_grad in zip([p for p in self.learner.parameters() if p.requires_grad], meta_grads):
                    if param.grad is not None:
                        param.grad += meta_grad.detach()
                    else:
                        param.grad = meta_grad.detach()

        # Average the accumulated gradients
        if not testing:
            for param in self.learner.parameters():
                if param.requires_grad:
                    param.grad /= len(query_accuracies)

        if testing:
            return support_losses, query_accuracies, query_precisions, query_recalls, query_f1s
        else:
            return query_losses, query_accuracies, query_precisions, query_recalls, query_f1s
Example #2
0
    def forward(self, episodes, updates=1, testing=False):
        query_losses, query_accuracies, query_precisions, query_recalls, query_f1s = [], [], [], [], []
        n_episodes = len(episodes)

        for episode_id, episode in enumerate(episodes):

            batch_x, batch_len, batch_y = next(iter(episode.support_loader))
            batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y)

            self.train()
            support_repr, support_label = [], []

            batch_x_repr = self.learner(batch_x, batch_len)
            support_repr.append(batch_x_repr)
            support_label.append(batch_y)

            prototypes = self._build_prototypes(support_repr, support_label, episode.n_classes)

            # Run on query
            query_loss = 0.0
            all_predictions, all_labels = [], []
            for module in self.learner.modules():
                if isinstance(module, nn.Dropout):
                    module.eval()

            for n_batch, (batch_x, batch_len, batch_y) in enumerate(episode.query_loader):
                batch_x, batch_len, batch_y = self.vectorize(batch_x, batch_len, batch_y)
                batch_x_repr = self.learner(batch_x, batch_len)
                output = self._normalized_distances(prototypes, batch_x_repr)
                output = output.view(output.size()[0] * output.size()[1], -1)
                batch_y = batch_y.view(-1)
                loss = self.loss_fn[episode.base_task](output, batch_y)
                query_loss += loss.item()

                if not testing:
                    self.optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    self.optimizer.step()
                    self.lr_scheduler.step()

                relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach()
                all_predictions.extend(make_prediction(output[relevant_indices]).cpu())
                all_labels.extend(batch_y[relevant_indices].cpu())

            query_loss /= n_batch + 1

            # Calculate metrics
            accuracy, precision, recall, f1_score = utils.calculate_metrics(all_predictions,
                                                                            all_labels, binary=False)

            logger.info('Episode {}/{}, task {} [query set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, '
                        'recall = {:.5f}, F1 score = {:.5f}'.format(episode_id + 1, n_episodes, episode.task_id,
                                                                    query_loss, accuracy, precision, recall, f1_score))

            query_losses.append(query_loss)
            query_accuracies.append(accuracy)
            query_precisions.append(precision)
            query_recalls.append(recall)
            query_f1s.append(f1_score)

        return query_losses, query_accuracies, query_precisions, query_recalls, query_f1s
    def forward(self, episodes, updates=1, testing=False):
        support_losses = []
        query_losses, query_accuracies, query_precisions, query_recalls, query_f1s = [], [], [], [], []
        n_episodes = len(episodes)

        for episode_id, episode in enumerate(episodes):
            self.initialize_output_layer(episode.n_classes)

            params = [p for p in self.parameters() if p.requires_grad] + \
                     [p for p in self.output_layer.parameters() if p.requires_grad]
            optimizer = optim.Adam(params,
                                   lr=self.learner_lr,
                                   weight_decay=self.weight_decay)

            batch_x, batch_len, batch_y = next(iter(episode.support_loader))
            batch_x, batch_len, batch_y = self.vectorize(
                batch_x, batch_len, batch_y)

            self.train()

            all_predictions, all_labels = [], []

            output = self.learner(batch_x, batch_len)
            output = self.output_layer(output)
            output = output.view(output.size()[0] * output.size()[1], -1)
            batch_y = batch_y.view(-1)
            loss = self.learner_loss[episode.base_task](output, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            relevant_indices = torch.nonzero(batch_y != -1).view(-1).detach()
            pred = make_prediction(output[relevant_indices].detach()).cpu()
            all_predictions.extend(pred)
            all_labels.extend(batch_y[relevant_indices].cpu())

            support_loss = loss.item()

            accuracy, precision, recall, f1_score = utils.calculate_metrics(
                all_predictions, all_labels, binary=False)

            logger.info(
                'Episode {}/{}, task {} [support_set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, '
                'recall = {:.5f}, F1 score = {:.5f}'.format(
                    episode_id + 1, n_episodes, episode.task_id, support_loss,
                    accuracy, precision, recall, f1_score))

            query_loss = 0.0
            all_predictions, all_labels = [], []

            if testing:
                self.eval()

            for n_batch, (batch_x, batch_len,
                          batch_y) in enumerate(episode.query_loader):
                batch_x, batch_len, batch_y = self.vectorize(
                    batch_x, batch_len, batch_y)
                output = self.learner(batch_x, batch_len)
                output = self.output_layer(output)
                output = output.view(output.size()[0] * output.size()[1], -1)
                batch_y = batch_y.view(-1)
                loss = self.learner_loss[episode.base_task](output, batch_y)

                if not testing:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                query_loss += loss.item()

                relevant_indices = torch.nonzero(
                    batch_y != -1).view(-1).detach()
                pred = make_prediction(output[relevant_indices].detach()).cpu()
                all_predictions.extend(pred)
                all_labels.extend(batch_y[relevant_indices].cpu())

            query_loss /= n_batch + 1

            accuracy, precision, recall, f1_score = utils.calculate_metrics(
                all_predictions, all_labels, binary=False)

            logger.info(
                'Episode {}/{}, task {} [query set]: Loss = {:.5f}, accuracy = {:.5f}, precision = {:.5f}, '
                'recall = {:.5f}, F1 score = {:.5f}'.format(
                    episode_id + 1, n_episodes, episode.task_id, query_loss,
                    accuracy, precision, recall, f1_score))

            support_losses.append(support_loss)
            query_losses.append(query_loss)
            query_accuracies.append(accuracy)
            query_precisions.append(precision)
            query_recalls.append(recall)
            query_f1s.append(f1_score)

        if testing:
            return support_losses, query_accuracies, query_precisions, query_recalls, query_f1s
        else:
            return query_losses, query_accuracies, query_precisions, query_recalls, query_f1s