Example #1
0
 def net_eval(self, target_set, ptracker):
     if len(target_set[0]) == 0: return torch.tensor(0.).to(self.device)
     
     targets_x, targets_y = target_set
     targets_h = self.backbone(targets_x)
     proto_h, proto_y = self.get_prototypes()
     
     relation_pairs = self.construct_pairs(proto_h, targets_h)
     
     n_way = len(proto_y)
     scores = self.relation_module(relation_pairs)
     scores = scores.view(-1, n_way)
     
     if self.loss_type == 'mse':
         targets_y_onehot = Variable(uu.onehot(targets_y).float().to(self.device))
         loss = self.loss_fn(scores, targets_y_onehot)
     else:
         loss = self.strategy.apply_outer_loss(self.loss_fn, scores, targets_y)
     
     _, pred_y = torch.max(scores, 1)
     
     ptracker.add_task_performance(
         pred_y.detach().cpu().numpy(),
         targets_y.detach().cpu().numpy(),
         loss.detach().cpu().numpy())
     
     return loss
    def net_eval(self, target_set, ptracker):
        if len(target_set[0]) == 0: return torch.tensor(0.).to(self.device)

        target_x, target_y = target_set
        target_y_onehots = uu.onehot(target_y,
                                     fill_with=-1,
                                     dim=self.output_dim[self.mode]).split(
                                         1, 1)

        with torch.no_grad():
            self.gpmodel.eval()
            self.likelihood.eval()
            self.backbone.eval()

            target_h = self.forward(target_x).detach()

            total_losses = []
            predictions_list = list()
            for idx in range(self.output_dim[self.mode]):
                self.gpmodel.set_train_data(
                    inputs=self.support_h,
                    targets=self.support_y_onehots[idx].squeeze(),
                    strict=False)
                prediction = self.likelihood(self.gpmodel(target_h))
                predictions_list.append(torch.sigmoid(prediction.mean))
                output = self.gpmodel(*self.gpmodel.train_inputs)
                loss = -self.mll(output, self.gpmodel.train_targets)
                total_losses.append(loss)

            pred_y = torch.stack(predictions_list).argmax(0)
            loss = torch.stack(total_losses).sum(0)

            ptracker.add_task_performance(pred_y.detach().cpu().numpy(),
                                          target_y.detach().cpu().numpy(),
                                          loss.detach().cpu().numpy())
Example #3
0
 def net_train(self, support_set):  # innerloop / adaptation step
     x, y = self.strategy.update_support_set(support_set)
     z = self.backbone.forward(x)
     z, y = self.strategy.update_support_features((z, y))
     z = z.contiguous().view(len(z), -1)
     self.G = self.encode_training_set(z)
     self.support_y_onehot = Variable(utils.onehot(y).float()).to(
         self.device)
 def _category_onehot(self, category_columns_for_onehot):
     """
     
     """
     for col in category_columns_for_onehot:
         self.onehot_col.append(col + "_onehot")
         self.dataset = self.dataset.map(
             lambda row: utils.onehot(row, col, self.category_summary_dict[
                 col]),
             num_parallel_calls=config.NUM_PARALLEL)
Example #5
0
 def net_train(self, support_set):
     self.gpmodel.train()
     self.likelihood.train()
     self.backbone.eval()
     
     support_set = self.strategy.update_support_set(support_set)
     support_x, support_y = support_set
     support_h = self.forward(support_x).detach()
     support_h, support_y = self.strategy.update_support_features((support_h, support_y))
     
     self.support_y_onehots = uu.onehot(support_y, fill_with=-1, dim=self.output_dim[self.mode])
     self.support_h = support_h
Example #6
0
def get_batch(params, idxs, input_idxs, input_lens):
  if is_list_empty(idxs):
    return None, None, None, None

  batch_x = []
  batch_x_lens = []
  batch_y = []
  batch_yoh = []
  for i, idx in enumerate(idxs):
    # per category
    if idx:
      batch_x.append(input_idxs[i][idx])
      batch_x_lens.append(input_lens[i][idx])
      batch_y.append([i] * len(idx))
      # get onehot y
      batch_yoh.append(onehot(i, params.cldc_label_size).expand(len(idx), -1))
  batch_x = np.concatenate(batch_x)
  batch_x_lens = np.concatenate(batch_x_lens)
  batch_y = np.concatenate(batch_y)
  batch_yoh = np.concatenate(batch_yoh)

  # sort in the descending order
  sorted_len_idxs = np.argsort(-batch_x_lens)
  sorted_batch_x_lens = batch_x_lens[sorted_len_idxs]
  sorted_batch_x = batch_x[sorted_len_idxs]
  sorted_batch_x = torch.LongTensor(sorted_batch_x)
  sorted_batch_y = batch_y[sorted_len_idxs]
  sorted_batch_y = torch.LongTensor(sorted_batch_y)
  sorted_batch_yoh = batch_yoh[sorted_len_idxs]
  sorted_batch_yoh = torch.Tensor(sorted_batch_yoh)

  if params.cuda:
    sorted_batch_x = sorted_batch_x.cuda()
    sorted_batch_y = sorted_batch_y.cuda()
    sorted_batch_yoh = sorted_batch_yoh.cuda()

  return sorted_batch_x, sorted_batch_x_lens, sorted_batch_y, sorted_batch_yoh
    def meta_train(self, task, ptracker):
        """
        Trained by feeding both the query set and the support set into the model 
        """
        self.mode = 'train'
        self.train()
        self.net_reset()
        total_losses = []

        for support_set, target_set in task:
            self.backbone.train()
            self.gpmodel.train()
            self.likelihood.train()

            support_set = self.strategy.update_support_set(support_set)
            support_x, support_y = support_set
            target_x, target_y = target_set
            support_n = len(support_y)

            # Combine target and support set
            if len(target_x) > 0:
                all_x = torch.cat((support_x, target_x), dim=0)
                all_y = torch.cat((support_y, target_y), dim=0)
            else:
                all_x = support_x
                all_y = support_y

            all_h = self.forward(all_x)
            all_h, all_y = self.strategy.update_support_features(
                (all_h, all_y))
            all_y_onehots = uu.onehot(all_y,
                                      fill_with=-1,
                                      dim=self.output_dim[self.mode]).split(
                                          1, 1)

            self.optimizer.zero_grad()

            total_losses = []
            for idx in range(self.output_dim[self.mode]):
                self.gpmodel.set_train_data(
                    inputs=all_h,
                    targets=all_y_onehots[idx].squeeze(),
                    strict=False)
                output = self.gpmodel(*self.gpmodel.train_inputs)
                loss = -self.mll(output, self.gpmodel.train_targets)
                total_losses.append(loss)

            loss = torch.stack(total_losses).sum(0)

            if len(target_x) > 0:
                #                 with torch.no_grad():
                self.gpmodel.eval()
                self.likelihood.eval()
                self.backbone.eval()

                target_h = self.forward(target_x).detach()

                predictions_list = list()
                for idx in range(self.output_dim[self.mode]):
                    self.gpmodel.set_train_data(
                        inputs=all_h[:support_n],
                        targets=all_y_onehots[idx].squeeze()[:support_n],
                        strict=False)
                    prediction = self.likelihood(self.gpmodel(target_h))
                    predictions_list.append(torch.sigmoid(prediction.mean))

                predictions_list = torch.stack(predictions_list).T

                loss *= self.loss_fn(predictions_list, target_y)

                pred_y = predictions_list.argmax(1)

                ptracker.add_task_performance(pred_y.detach().cpu().numpy(),
                                              target_y.detach().cpu().numpy(),
                                              loss.detach().cpu().numpy())

            loss.backward()
            self.optimizer.step()