Example #1
0
    def evaluate(self, epoch=None):
        saving_path = './results/' + self.model_str
        preds = []
        y_true = []
        self.model.eval()
        for batch_ix, item in enumerate(self.test_loader):
            labels = [item['label'] for item in item]
            out = self.model(item)
            labels = self._one_hot(labels, self.train_data.n_rel)
            preds.append(out.cpu().detach().numpy())
            y_true.append(labels.cpu().numpy())
            try:
                assert labels.cpu().numpy().shape == out.cpu().detach().numpy(
                ).shape
            except:
                pdb.set_trace()
        preds = np.concatenate(preds, axis=0)
        y_true = np.concatenate(y_true, axis=0)
        precision, recall = precision_recall_compute_multi(y_true, preds)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_precision.npy'.format(self.timestamp, epoch)),
            precision)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_recall.npy'.format(self.timestamp, epoch)),
            recall)
        # self.train()
        self.model.train()

        return
Example #2
0
    def evaluate_all(self, epoch=None):
        saving_path = './results/' + self.model_str
        preds = []
        y_true = []
        self.model.eval()
        for batch_ix, item in enumerate(self.test_loader):
            labels = [item['label'] for item in item]
            out = self.model(item)
            # out_bag = [o[0] for o in out]
            # labels = [label[-1] for label in labels]
            labels = torch.cat(labels, dim=0)
            out = torch.cat(out, dim=0)
            labels = self._one_hot(labels, self.train_data.n_rel)
            preds.append(out.cpu().detach().numpy())
            y_true.append(labels.cpu().numpy())
            try:
                assert labels.cpu().numpy().shape == out.cpu().detach().numpy(
                ).shape
            except:
                pdb.set_trace()
        preds = np.concatenate(preds, axis=0)
        y_true = np.concatenate(y_true, axis=0)
        precision, recall = precision_recall_compute_multi(y_true, preds)
        f1 = compute_max_f1(precision, recall)
        if not os.path.exists(saving_path):
            os.mkdir(saving_path)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_all_precision.npy'.format(self.timestamp, epoch)),
            precision)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_all_recall.npy'.format(self.timestamp, epoch)),
            recall)
        # self.train()
        self.model.train()

        return f1
Example #3
0
    def evaluate_bag(self, epoch=None):
        saving_path = './results/' + self.model_str
        preds = []
        y_true = []
        self.model.eval()
        for batch_ix, item in enumerate(self.test_loader):
            # labels : list of tensor shape []
            labels = [item['label'] for item in item]
            # out : list of tensor shape (bag_size, n_rel)
            out = self.model(item)
            out_bag = torch.stack([o[0] for o in out])
            labels = [label[-1] for label in labels]
            labels = self._one_hot(labels, self.train_data.n_rel)
            preds.append(out_bag.cpu().detach().numpy())
            y_true.append(labels.cpu().numpy())
            try:
                assert labels.cpu().numpy().shape == out_bag.cpu().detach(
                ).numpy().shape
            except:
                pdb.set_trace()
        preds = np.concatenate(preds, axis=0)
        y_true = np.concatenate(y_true, axis=0)
        precision, recall = precision_recall_compute_multi(y_true, preds)
        f1 = compute_max_f1(precision, recall)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_bag_precision.npy'.format(self.timestamp, epoch)),
            precision)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_bag_recall.npy'.format(self.timestamp, epoch)),
            recall)
        self.model.train()

        return f1
Example #4
0
    def evaluate_binary(self, epoch=None, max_f1=0):
        '''
        In models with binary loss, they use sigmoid as the probability computing layer.
        So logits is enough to rank them.

        :param epoch: for which epoch, used for saving models.
        :return:
        '''
        assert self.model_str in self.binary_loss_models
        saving_path = './results/' + self.model_str
        self.model.eval()
        correct_num = 0
        correct_num_with_rel = 0
        tot = 0
        tot_with_rel = 0
        # pdb.set_trace()
        preds = []
        y_true = []
        for batch_ix, item in enumerate(self.test_loader):
            labels = [each_item['label'] for each_item in item]
            out = self.model(item)
            labels = multi_hot_label(labels, self.train_data.n_rel)
            preds.append(out.cpu().detach().numpy())
            y_true.append(labels.cpu().numpy())
            # remove NA
            labels = labels[:, 1:]
            out = out[:, 1:].view(-1, self.train_data.n_rel - 1)
            # cast predictions
            pred = (out > 0.5).float()
            out_shape = out.shape
            if labels.shape != pred.shape:
                pdb.set_trace()
            correct_num += int(torch.sum((labels == pred)))
            correct_num_with_rel += int(torch.sum(labels * pred))
            tot += int(out_shape[0] * out_shape[1])
            tot_with_rel += int(torch.sum(labels))

        correct_num_without_rel = int(correct_num) - int(correct_num_with_rel)
        tot_without_rel = int(tot) - int(tot_with_rel)
        print('Total_with_rel:{}'.format(tot_with_rel))
        print('Total_without_rel:{}'.format(tot_without_rel))
        print('Correct_with_rel:{}'.format(correct_num_with_rel))
        print('Correct_without_rel:{}'.format(correct_num_without_rel))
        print("Precision_ALL : {}".format(float(correct_num) / tot))
        print("Precision_REL : {}".format(
            float(correct_num_with_rel) / tot_with_rel))
        preds = np.concatenate(preds, axis=0)
        y_true = np.concatenate(y_true, axis=0)
        precision, recall = precision_recall_compute_multi(y_true, preds)
        f1 = compute_max_f1(precision, recall)
        if not os.path.exists(saving_path):
            os.mkdir(saving_path)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_all_precision.npy'.format(self.timestamp, epoch)),
            precision)
        np.save(
            os.path.join(
                saving_path,
                '{}_Epoch_{}_all_recall.npy'.format(self.timestamp, epoch)),
            recall)
        # no saving now.
        self.model.train()
        return f1