Ejemplo n.º 1
0
    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_logit.append(logit)
                all_pred.append(torch.argmax(logit, 1))
                total_loss += classified_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_pred = torch.cat(all_pred, 0).cpu()
        all_logit = torch.cat(all_logit, 0).cpu()

        # classification report
        ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y)
        report = metrics.classification_report(all_y,
                                               all_pred,
                                               output_dict=True)
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_pred, all_binary_y)
        result.update(report)
        # 只有二分类时候ERR才有意义
        y_score = all_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample
        result['all_y'] = all_y.tolist()
        result['all_pred'] = all_pred.tolist()
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['y_score'] = y_score
        result['all_binary_y'] = all_binary_y

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_pred.tolist()
        freeze_data['test_score'] = y_score
        return result
Ejemplo n.º 2
0
    def eval(dataset):
        dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        detection_loss = torch.nn.CrossEntropyLoss().to(device)

        pos.eval()
        E.eval()

        all_detection_preds = []
        all_detection_logit = []

        for sample in tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, pos1, pos2, pos_mask, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                all_detection_logit.append(out)
                all_detection_preds.append(torch.argmax(out, 1))

        all_y = LongTensor(dataset.dataset[:, -4].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds, 0).cpu()  # [length, 1]
        # all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]
        all_detection_logit = torch.cat(all_detection_logit, 0).cpu()

        # 计算损失
        detection_loss = detection_loss(all_detection_logit, all_binary_y.long())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y, all_detection_preds, target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_preds, all_binary_y)

        y_score = all_detection_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result
Ejemplo n.º 3
0
    def test(dataset):
        # load BERT and GAN
        load_gan_model(D, G, config['gan_save_path'])
        if args.fine_tune:
            load_model(E, path=config['bert_save_path'], model_name='bert')

        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        G.eval()
        D.eval()
        E.eval()
        detector.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []
        all_logit = []

        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                    # all_detection_preds.append(discriminator_output)
                    # all_class_preds.append(classification_output)
                    pass

                else:
                    if args.loss == 'v1':
                        detector_out = detector(real_feature)
                        all_detection_preds.append(detector_out)
                    else:
                        detector_out = detector(real_feature)
                        all_logit.append(detector_out)
                        all_detection_preds.append(
                            torch.argmax(detector_out, 1))
                # if args.do_vis:
                #     all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        if args.loss == 'v1':
            all_detection_binary_preds = convert_to_int_by_threshold(
                all_detection_preds.squeeze())  # [length, 1]
        else:
            all_detection_binary_preds = all_detection_preds
            all_logit = torch.cat(all_logit, 0).cpu()

        # 计算损失
        if args.loss == 'v1':
            loss = detection_loss(all_detection_preds, all_binary_y.float())
        else:
            loss = detection_loss_v2(all_logit, all_y.long())
        result['detection_loss'] = loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds,
                                            0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds,
                                         all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(
                all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(
                metrics.classification_report(
                    all_y, all_class_preds,
                    target_names=processor.id_to_label))

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc
        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result
Ejemplo n.º 4
0
    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        detection_loss = torch.nn.BCELoss().to(device)
        all_detection_preds = []
        all_features = []
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                all_detection_preds.append(discriminator_output)
                if args.do_vis:
                    all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds,
                                                   all_y)

        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['all_pred'] = all_detection_binary_preds
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result
Ejemplo n.º 5
0
    def eval(dataset):
        dev_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        G.eval()
        D.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []

        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                    all_detection_preds.append(discriminator_output)
                    all_class_preds.append(classification_output)

                # 只预测判别器
                else:
                    f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True)
                    all_detection_preds.append(discriminator_output)

        all_y = LongTensor(dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds, 0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]

        # print('all_detection_preds', all_detection_preds.size())
        # print('all_binary_y', all_binary_y.size())
        # 计算损失
        detection_loss = detection_loss(all_detection_preds.squeeze(), all_binary_y.float())
        result['detection_loss'] = detection_loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds, all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(metrics.classification_report(all_y, all_class_preds, target_names=processor.id_to_label))

        # logger.info(metrics.classification_report(all_binary_y, all_detection_binary_preds, target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc

        freeze_data['valid_all_y'] = all_y
        freeze_data['vaild_all_pred'] = all_detection_binary_preds
        freeze_data['valid_score'] = y_score

        return result
Ejemplo n.º 6
0
    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        pos.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, pos1, pos2, pos_mask, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                all_detection_preds.append(out)

        all_y = LongTensor(
            dataset.dataset[:, -4].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result
    def test(dataset):
        # load BERT and GAN
        load_gan_model(D, G, config['gan_save_path'])
        if args.fine_tune:
            load_model(E, path=config['bert_save_path'], model_name='bert')

        test_dataloader = DataLoader(dataset, batch_size=args.predict_batch_size, shuffle=False, num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        G.eval()
        D.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []

        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            if args.dataset == 'smp':
                token, mask, type_ids, knowledge_tag, y = sample
            if args.dataset == 'oos-eval':
                token, mask, type_ids, y = sample
            batch = len(token)

            anchor_ood = torch.zeros(args.num_outcomes, dtype=torch.float).to(device) + torch.tensor(anchor0, dtype=torch.float).to(device)
            anchor_ind = torch.zeros(args.num_outcomes, dtype=torch.float).to(device) + torch.tensor(anchor1, dtype=torch.float).to(device)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                    all_class_preds.append(classification_output)

                # 只预测判别器
                else:
                    f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True)

                discriminator_output = discriminator_output.log_softmax(1).exp()

                if args.do_vis:
                    all_features.append(f_vector)

                divergence_to_preidction = []

                # logger.info('discriminator_output: {}'.format(discriminator_output))

                for output in discriminator_output:
                    d_ood = triplet_loss(anchor_ood, output, skewness=args.positive_skew)
                    d_ind = triplet_loss(anchor_ind, output, skewness=args.negative_skew)
                    # logger.info('d_ood : d_ind = {} : {}'.format(d_ood, d_ind))
                    # divergence_to_preidction.append(1 if d_ind < d_ood else 0)
                    divergence_to_preidction.append(d_ood / (d_ind + d_ood))
                all_detection_preds.extend(divergence_to_preidction)

        all_y = LongTensor(dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos

        # 用 realness_D 做 ood 判别
        # all_detection_preds = torch.cat(all_detection_preds, 0).cpu()  # [length, 1]
        # all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]
        all_detection_preds = FloatTensor(all_detection_preds).cpu()
        # all_detection_binary_preds = all_detection_preds.squeeze()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]

        # logger.info('all_detection_preds: {}'.format(all_detection_preds))
        # logger.info('all_binary_y: {}'.format(all_binary_y))

        # 计算损失
        detection_loss = detection_loss(all_detection_preds, all_binary_y.float())
        result['detection_loss'] = detection_loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds, 0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds, all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(metrics.classification_report(all_y, all_class_preds, target_names=processor.id_to_label))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(all_detection_binary_preds,all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds, all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['fpr95'] = ErrorRateAt95Recall(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc
        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result
Ejemplo n.º 8
0
    score = np.sum(data, axis=1) / len(data[0])
    # print(score)
    label = tools.convert_to_int_by_threshold(score)
    # print(label)
    return label, score


if __name__ == '__main__':
    all_binary_y = np.load('../data/smp/all_binary_y.npy')

    path = '../output/base-gan-oodp-alpha0.85'
    data = get_data(path)
    label, score = soft_voting(data)
    print("label")
    print(label)
    print('score')
    print(score)
    report = metrics.binary_classification_report(all_binary_y, label)
    print(report)
    oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
        label, all_binary_y)
    print('oos_ind_precision', oos_ind_precision)
    print('oos_ind_recall', oos_ind_recall)
    print('oos_ind_fscore', oos_ind_fscore)
    fpr95 = tools.ErrorRateAt95Recall(all_binary_y, score)
    print('fpr95', fpr95)
    eer = metrics.cal_eer(all_binary_y, score)
    print('eer', eer)
    auc = tools.roc_auc_score(all_binary_y, score)
    print('auc', auc)
Ejemplo n.º 9
0
    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.CrossEntropyLoss().to(device)

        model.eval()

        all_detection_preds = []
        all_detection_logit = []
        total_loss = 0

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_detection_logit.append(logit)
                all_detection_preds.append(torch.argmax(logit, 1))
                total_loss += detection_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        # all_detection_binary_preds = convert_to_int_by_threshold(all_detection_preds.squeeze())  # [length, 1]
        all_detection_logit = torch.cat(all_detection_logit, 0).cpu()

        # 计算损失
        result['detection_loss'] = total_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_preds, all_binary_y)

        # y_score = all_detection_preds.squeeze().tolist()
        y_score = all_detection_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        test_logit = all_detection_logit.tolist()
        result['test_logit'] = test_logit

        result['eer'] = eer
        result['all_detection_preds'] = all_detection_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result