Esempio n. 1
0
    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.CrossEntropyLoss().to(device)

        model.eval()

        all_detection_preds = []
        all_detection_logit = []
        total_loss = 0

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_detection_logit.append(logit)
                all_detection_preds.append(torch.argmax(logit, 1))
                total_loss += detection_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        # all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]
        all_detection_logit = torch.cat(all_detection_logit, 0).cpu()

        # 计算损失
        result['detection_loss'] = total_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_preds, all_binary_y)

        # y_score = all_detection_preds.squeeze().tolist()
        y_score = all_detection_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        test_logit = all_detection_logit.tolist()
        result['test_logit'] = test_logit

        result['eer'] = eer
        result['all_detection_preds'] = all_detection_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result
Esempio n. 2
0
    def test(dataset):
        # load BERT and GAN
        load_gan_model(D, G, config['gan_save_path'])
        if args.fine_tune:
            load_model(E, path=config['bert_save_path'], model_name='bert')

        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        G.eval()
        D.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []

        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    f_vector, discriminator_output, classification_output = D(
                        real_feature, return_feature=True)
                    all_detection_preds.append(discriminator_output)
                    all_class_preds.append(classification_output)

                # 只预测判别器
                else:
                    f_vector, discriminator_output = D.detect_only(
                        real_feature, return_feature=True)
                    all_detection_preds.append(discriminator_output)
                if args.do_vis:
                    all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds,
                                            0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds,
                                         all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(
                all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(
                metrics.classification_report(
                    all_y, all_class_preds,
                    target_names=processor.id_to_label))

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc
        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result