def main(args):
    logger.info('Checking...')
    SEED = args.seed
    logger.info('seed: {}'.format(SEED))
    logger.info('model: {}'.format(args.model))
    check_manual_seed(SEED)
    check_args(args)

    logger.info('Loading config...')
    bert_config = Config('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt')
    config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    model = import_module('model.' + args.model)
    model_d = import_module('model.' + 'detector')

    D = model.Discriminator(config)
    G = model.Generator(config)
    E = BertModel.from_pretrained(
        bert_config['PreTrainModelDir'])  # Bert encoder
    if args.loss == 'v1':
        detector = model_d.Detector(config)
    else:
        detector = model_d.Detector_v2(config)

    logger.info('Discriminator: {}'.format(D))
    logger.info('Generator: {}'.format(G))
    logger.info('Detector: {}'.format(detector))

    if args.fine_tune:
        for param in E.parameters():
            param.requires_grad = True
    else:
        for param in E.parameters():
            param.requires_grad = False

    D.to(device)
    G.to(device)
    E.to(device)
    detector.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        adversarial_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D = torch.optim.Adam(
            D.parameters(), lr=args.D_lr)  # optimizer for discriminator
        optimizer_E = AdamW(E.parameters(), args.bert_lr)
        optimizer_detector = torch.optim.Adam(detector.parameters(),
                                              lr=args.detector_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []
        detector_total_train_loss = []

        all_features = []
        result = dict()

        for i in range(args.n_epoch):

            # Initialize model state
            G.train()
            D.train()
            E.train()
            detector.train()

            G_train_loss = 0
            D_fake_loss = 0
            D_real_loss = 0
            FM_train_loss = 0
            D_class_loss = 0
            detector_train_loss = 0

            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                ood_sample = (y == 0.0)
                # weight = torch.ones(len(ood_sample)).to(device) - ood_sample * args.beta
                # real_loss_func = torch.nn.BCELoss(weight=weight).to(device)

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # train D on real
                optimizer_D.zero_grad()
                real_f_vector, discriminator_output, classification_output = D(
                    real_feature, return_feature=True)
                # discriminator_output = discriminator_output.squeeze()
                real_loss = adversarial_loss(discriminator_output, valid_label)
                real_loss.backward(retain_graph=True)

                if args.do_vis:
                    all_features.append(real_f_vector.detach())

                # # train D on fake
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                fake_discriminator_output = D.detect_only(fake_feature)
                fake_loss = adversarial_loss(fake_discriminator_output,
                                             fake_label)
                fake_loss.backward()
                optimizer_D.step()

                # if args.fine_tune:
                #     optimizer_E.step()

                # train G
                optimizer_G.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_f_vector, D_decision = D.detect_only(G(z),
                                                          return_feature=True)
                gd_loss = adversarial_loss(D_decision, valid_label)
                fm_loss = torch.abs(
                    torch.mean(real_f_vector.detach(), 0) -
                    torch.mean(fake_f_vector, 0)).mean()
                g_loss = gd_loss + 0 * fm_loss
                g_loss.backward()
                optimizer_G.step()

                optimizer_E.zero_grad()

                # train detector
                optimizer_detector.zero_grad()
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, 32, args.G_z_dim))).to(device)
                else:
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                if args.loss == 'v1':
                    loss_fake = adversarial_loss(
                        detector(fake_feature),
                        fake_label)  # fake sample is ood
                else:
                    loss_fake = adversarial_loss_v2(
                        detector(fake_feature),
                        fake_label.long().squeeze())
                if args.loss == 'v1':
                    loss_real = adversarial_loss(detector(real_feature),
                                                 y.float())
                else:
                    loss_real = adversarial_loss_v2(detector(real_feature),
                                                    y.long())
                if args.detect_loss == 'v1':
                    detector_loss = args.beta * loss_fake + (
                        1 - args.beta) * loss_real
                else:
                    detector_loss = args.beta * loss_fake + loss_real
                    detector_loss = args.sigma * detector_loss
                detector_loss.backward()
                optimizer_detector.step()

                if args.fine_tune:
                    optimizer_E.step()

                global_step += 1

                D_fake_loss += fake_loss.detach()
                D_real_loss += real_loss.detach()
                G_train_loss += g_loss.detach() + fm_loss.detach()
                FM_train_loss += fm_loss.detach()
                detector_train_loss += detector_loss

            logger.info('[Epoch {}] Train: D_fake_loss: {}'.format(
                i, D_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_real_loss: {}'.format(
                i, D_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_class_loss: {}'.format(
                i, D_class_loss / n_sample))
            logger.info('[Epoch {}] Train: G_train_loss: {}'.format(
                i, G_train_loss / n_sample))
            logger.info('[Epoch {}] Train: FM_train_loss: {}'.format(
                i, FM_train_loss / n_sample))
            logger.info('[Epoch {}] Train: detector_train_loss: {}'.format(
                i, detector_train_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            D_total_fake_loss.append(D_fake_loss / n_sample)
            D_total_real_loss.append(D_real_loss / n_sample)
            D_total_class_loss.append(D_class_loss / n_sample)
            G_total_train_loss.append(G_train_loss / n_sample)
            FM_total_train_loss.append(FM_train_loss / n_sample)
            detector_total_train_loss.append(detector_train_loss / n_sample)

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_gan_model(D, G, config['gan_save_path'])
                    if args.fine_tune:
                        save_model(E,
                                   path=config['bert_save_path'],
                                   model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        if args.patience >= args.n_epoch:
            save_gan_model(D, G, config['gan_save_path'])
            if args.fine_tune:
                save_model(E, path=config['bert_save_path'], model_name='bert')

        freeze_data['D_total_fake_loss'] = D_total_fake_loss
        freeze_data['D_total_real_loss'] = D_total_real_loss
        freeze_data['D_total_class_loss'] = D_total_class_loss
        freeze_data['G_total_train_loss'] = G_total_train_loss
        freeze_data['FM_total_train_loss'] = FM_total_train_loss
        freeze_data['valid_real_loss'] = valid_detection_loss
        freeze_data['valid_oos_ind_precision'] = valid_oos_ind_precision
        freeze_data['valid_oos_ind_recall'] = valid_oos_ind_recall
        freeze_data['valid_oos_ind_f_score'] = valid_oos_ind_f_score

        best_dev = -early_stopping.best_score

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features
        return result

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        # G.eval()
        # D.eval()
        E.eval()
        detector.eval()

        all_detection_preds = []
        all_class_preds = []
        all_logit = []

        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                    # all_detection_preds.append(discriminator_output)
                    # all_class_preds.append(classification_output)
                    pass

                # 只预测判别器
                else:
                    # f_vector, discriminator_output = D.detect_only(real_feature, return_feature=True)
                    # all_detection_preds.append(discriminator_output)
                    # f_vector = D.get_vector(real_feature)
                    if args.loss == 'v1':
                        detector_out = detector(real_feature)
                        all_detection_preds.append(detector_out)
                    else:
                        detector_out = detector(real_feature)
                        all_logit.append(detector_out)
                        all_detection_preds.append(
                            torch.argmax(detector_out, 1))

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        if args.loss == 'v1':
            all_detection_binary_preds = convert_to_int_by_threshold(
                all_detection_preds.squeeze())  # [length, 1]
        else:
            all_detection_binary_preds = all_detection_preds
            all_logit = torch.cat(all_logit, 0).cpu()

        # 计算损失
        if args.loss == 'v1':
            loss = detection_loss(all_detection_preds, all_binary_y.float())
        else:
            loss = detection_loss_v2(all_logit, all_y.long())
        result['detection_loss'] = loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds,
                                            0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds,
                                         all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(
                all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(
                metrics.classification_report(
                    all_y, all_class_preds,
                    target_names=processor.id_to_label))

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc

        freeze_data['valid_all_y'] = all_y
        freeze_data['vaild_all_pred'] = all_detection_binary_preds
        freeze_data['valid_score'] = y_score

        return result

    def test(dataset):
        # load BERT and GAN
        load_gan_model(D, G, config['gan_save_path'])
        if args.fine_tune:
            load_model(E, path=config['bert_save_path'], model_name='bert')

        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        detection_loss_v2 = torch.nn.CrossEntropyLoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        G.eval()
        D.eval()
        E.eval()
        detector.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []
        all_logit = []

        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                # 大于2表示除了训练判别器还要训练分类器
                if n_class > 2:
                    # f_vector, discriminator_output, classification_output = D(real_feature, return_feature=True)
                    # all_detection_preds.append(discriminator_output)
                    # all_class_preds.append(classification_output)
                    pass

                else:
                    if args.loss == 'v1':
                        detector_out = detector(real_feature)
                        all_detection_preds.append(detector_out)
                    else:
                        detector_out = detector(real_feature)
                        all_logit.append(detector_out)
                        all_detection_preds.append(
                            torch.argmax(detector_out, 1))
                # if args.do_vis:
                #     all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        if args.loss == 'v1':
            all_detection_binary_preds = convert_to_int_by_threshold(
                all_detection_preds.squeeze())  # [length, 1]
        else:
            all_detection_binary_preds = all_detection_preds
            all_logit = torch.cat(all_logit, 0).cpu()

        # 计算损失
        if args.loss == 'v1':
            loss = detection_loss(all_detection_preds, all_binary_y.float())
        else:
            loss = detection_loss_v2(all_logit, all_y.long())
        result['detection_loss'] = loss

        if n_class > 2:
            class_one_hot_preds = torch.cat(all_class_preds,
                                            0).detach().cpu()  # one hot label
            class_loss = classified_loss(class_one_hot_preds,
                                         all_y)  # compute loss
            all_class_preds = torch.argmax(class_one_hot_preds, 1)  # label
            class_acc = metrics.ind_class_accuracy(
                all_class_preds, all_y, oos_index=0)  # accuracy for ind class
            logger.info(
                metrics.classification_report(
                    all_y, all_class_preds,
                    target_names=processor.id_to_label))

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if n_class > 2:
            result['class_loss'] = class_loss
            result['class_acc'] = class_acc
        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result

    def get_fake_feature(num_output):
        """
        生成一定数量的假特征
        """
        G.eval()
        fake_features = []
        start = 0
        batch = args.predict_batch_size
        with torch.no_grad():
            while start < num_output:
                end = min(num_output, start + batch)
                if args.model == 'lstm_gan' or args.model == 'cnn_gan':
                    z = FloatTensor(
                        np.random.normal(0,
                                         1,
                                         size=(end - start, 32, args.G_z_dim)))
                else:
                    z = FloatTensor(
                        np.random.normal(0,
                                         1,
                                         size=(end - start, args.G_z_dim)))
                fake_feature = G(z)
                f_vector, _ = D.detect_only(fake_feature, return_feature=True)
                fake_features.append(f_vector)
                start += batch
            return torch.cat(fake_features, 0).cpu().numpy()

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train_result = train(train_dataset, dev_dataset)
        # save_feature(train_result['all_features'], os.path.join(args.output_dir, 'train_feature'))

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))
        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))
        # save_feature(test_result['all_features'], os.path.join(args.output_dir, 'test_feature'))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_binary_y'],
                     test_result['all_detection_binary_preds'],
                     os.path.join(args.output_dir,
                                  'test_cases.csv'), processor)

        # confusion matrix
        plot_confusion_matrix(test_result['all_binary_y'],
                              test_result['all_detection_binary_preds'],
                              args.output_dir)

        if args.do_vis:
            # [2 * length, feature_fim]
            features = np.concatenate([
                test_result['all_features'],
                get_fake_feature(len(test_dataset) // 2)
            ],
                                      axis=0)
            features = TSNE(n_components=2, verbose=1,
                            n_jobs=-1).fit_transform(
                                features)  # [2 * length, 2]
            # [2 * length, 1]
            if n_class > 2:
                labels = np.concatenate([
                    test_result['all_y'],
                    np.array([-1] * (len(test_dataset) // 2))
                ], 0).reshape((-1, 1))
            else:
                labels = np.concatenate([
                    test_result['all_binary_y'],
                    np.array([-1] * (len(test_dataset) // 2))
                ], 0).reshape((-1, 1))
            # [2 * length, 3]
            data = np.concatenate([features, labels], 1)
            fig = scatter_plot(data, processor)
            fig.savefig(os.path.join(args.output_dir, 'plot.png'))
            fig.show()
            freeze_data['feature_label'] = data
            # plot_train_test(train_result['all_features'], test_result['all_features'], args.output_dir)

    with open(os.path.join(config['output_dir'], 'freeze_data.pkl'),
              'wb') as f:
        pickle.dump(freeze_data, f)
    df = pd.DataFrame(
        data={
            'valid_y': freeze_data['valid_all_y'],
            'valid_score': freeze_data['valid_score'],
        })
    df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv'))

    df = pd.DataFrame(
        data={
            'test_y': freeze_data['test_all_y'],
            'test_score': freeze_data['test_score']
        })
    df.to_csv(os.path.join(config['output_dir'], 'test_score.csv'))
Example #2
0
def main(args):
    logger.info('Checking...')
    SEED = args.seed
    check_manual_seed(SEED)
    check_args(args)
    logger.info('seed: {}'.format(args.seed))
    gross_result['seed'] = args.seed

    logger.info('Loading config...')
    bert_config = BertConfig('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['model_save_path'] = os.path.join(args.output_dir, 'save',
                                             'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    model = TextCNN(bert_config, n_class)  # Bert encoder
    if args.fine_tune:
        model.unfreeze_bert_encoder()
    else:
        model.freeze_bert_encoder()
    model.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                logits = model(token, mask, type_ids)
                loss = classified_loss(logits, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(eval_result['accuracy'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                # logger.info(eval_result)

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        all_pred = []
        all_logit = []
        total_loss = 0
        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_logit.append(logit)
                all_pred.append(torch.argmax(logit, 1))
                total_loss += classified_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_pred = torch.cat(all_pred, 0).cpu()
        all_logit = torch.cat(all_logit, 0).cpu()
        ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y)
        report = metrics.classification_report(all_y,
                                               all_pred,
                                               output_dict=True)
        result.update(report)
        y_score = all_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_pred, all_binary_y)

        result['eer'] = eer
        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['y_score'] = y_score
        result['all_binary_y'] = all_binary_y

        freeze_data['valid_all_y'] = all_y
        freeze_data['vaild_all_pred'] = all_pred
        freeze_data['valid_score'] = y_score

        return result

    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_logit.append(logit)
                all_pred.append(torch.argmax(logit, 1))
                total_loss += classified_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_pred = torch.cat(all_pred, 0).cpu()
        all_logit = torch.cat(all_logit, 0).cpu()

        # classification report
        ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y)
        report = metrics.classification_report(all_y,
                                               all_pred,
                                               output_dict=True)
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_pred, all_binary_y)
        result.update(report)
        # 只有二分类时候ERR才有意义
        y_score = all_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample
        result['all_y'] = all_y.tolist()
        result['all_pred'] = all_pred.tolist()
        result['all_binary_y'] = all_binary_y

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_pred.tolist()
        freeze_data['test_score'] = y_score

        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['y_score'] = y_score
        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        # logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))
        gross_result['eval_eer'] = eval_result['eer']
        gross_result['eval_auc'] = eval_result['auc']
        gross_result['eval_fpr95'] = ErrorRateAt95Recall(
            eval_result['all_binary_y'], eval_result['y_score'])
        gross_result['eval_oos_ind_precision'] = eval_result[
            'oos_ind_precision']
        gross_result['eval_oos_ind_recall'] = eval_result['oos_ind_recall']
        gross_result['eval_oos_ind_f_score'] = eval_result['oos_ind_f_score']

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))
        # logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))

        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))

        gross_result['test_eer'] = test_result['eer']
        gross_result['test_auc'] = test_result['auc']
        gross_result['test_fpr95'] = ErrorRateAt95Recall(
            test_result['all_binary_y'], test_result['y_score'])
        gross_result['test_oos_ind_precision'] = test_result[
            'oos_ind_precision']
        gross_result['test_oos_ind_recall'] = test_result['oos_ind_recall']
        gross_result['test_oos_ind_f_score'] = test_result['oos_ind_f_score']

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_y'], test_result['all_pred'],
                     os.path.join(args.output_dir, 'test_cases.csv'),
                     processor)

        # confusion matrix
        plot_confusion_matrix(test_result['all_y'], test_result['all_pred'],
                              args.output_dir)

    with open(os.path.join(config['output_dir'], 'freeze_data.pkl'),
              'wb') as f:
        pickle.dump(freeze_data, f)
    df = pd.DataFrame(
        data={
            'valid_y': freeze_data['valid_all_y'],
            'valid_score': freeze_data['valid_score'],
        })
    df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv'))

    df = pd.DataFrame(
        data={
            'test_y': freeze_data['test_all_y'],
            'test_score': freeze_data['test_score']
        })
    df.to_csv(os.path.join(config['output_dir'], 'test_score.csv'))

    if args.result != 'no':
        pd_result = pd.DataFrame(gross_result)
        if args.seed == 16:
            pd_result.to_csv(args.result + '_gross_result.csv', index=False)
        else:
            pd_result.to_csv(args.result + '_gross_result.csv',
                             index=False,
                             mode='a',
                             header=False)
        if args.seed == 8192:
            print(args.result)
            std_mean(args.result + '_gross_result.csv')
Example #3
0
def main(args):
    check_manual_seed(args.seed)
    logger.info('seed: {}'.format(args.seed))

    logger.info('Loading config...')
    bert_config = Config('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')
    with open(data_path, 'r', encoding='utf-8') as fp:
        data = json.load(fp)
        for type in data:
            logger.info('{} : {}'.format(type, len(data[type])))
    with open(label_path, 'r', encoding='utf-8') as fp:
        logger.info(json.load(fp))

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
        logger.info('OOSProcessor')
    elif args.dataset == 'smp':
        # processor = SMPProcessor(bert_config, maxlen=32)
        processor = PosSMPProcessor(bert_config, maxlen=32)
        logger.info('SMPProcessor')
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.
    processor.load_pos('data/pos.json')
    logger.info("label_to_id: {}".format(processor.label_to_id))
    logger.info("id_to_label: {}".format(processor.id_to_label))

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt')
    config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    from model.pos_emb_v2 import Pos_emb
    E = BertModel.from_pretrained(
        bert_config['PreTrainModelDir'])  # Bert encoder
    config['pos_dim'] = args.pos_dim
    config['batch_size'] = args.train_batch_size
    config['n_pos'] = len(processor.pos)
    config['device'] = device
    config['nhead'] = 2
    config['num_layers'] = 1
    config['maxlen'] = processor.maxlen
    print('config', config)
    print(processor.pos)
    pos = Pos_emb(config)

    if args.fine_tune:
        for param in E.parameters():
            param.requires_grad = True
    else:
        for param in E.parameters():
            param.requires_grad = False

    pos.to(device)
    E.to(device)

    # logger.info(('pos_dim: {}, feature_dim'.format(config['pos_dim'], config['feature_dim'])))

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step

        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer_pos = torch.optim.Adam(pos.parameters(), lr=args.pos_lr)
        optimizer_E = AdamW(E.parameters(), args.bert_lr)

        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        train_loss = []
        iteration = 0

        for i in range(args.n_epoch):
            logger.info('***********************************')
            logger.info('epoch: {}'.format(i))

            # Initialize model state
            pos.train()
            E.train()

            total_loss = 0
            for sample in tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, pos1, pos2, pos_mask, y = sample
                batch = len(token)

                optimizer_E.zero_grad()
                optimizer_pos.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                loss = adversarial_loss(out, y.float())
                loss.backward()
                total_loss += loss.detach()

                if args.fine_tune:
                    optimizer_E.step()

                optimizer_pos.step()

            logger.info('[Epoch {}] Train: loss: {}'.format(
                i, total_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )
            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                # elif signal == 0:
                #     pass
                # elif signal == 1:
                #     save_gan_model(D, G, config['gan_save_path'])
                #     if args.fine_tune:
                #         save_model(E, path=config['bert_save_path'], model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        best_dev = -early_stopping.best_score
        # 绘制训练损失曲线
        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        detection_loss = torch.nn.BCELoss().to(device)

        pos.eval()
        E.eval()

        all_detection_preds = []

        for sample in tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, pos1, pos2, pos_mask, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                all_detection_preds.append(out)

        all_y = LongTensor(
            dataset.dataset[:, -4].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        pos.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, pos1, pos2, pos_mask, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                out = pos(pos1, pos2, real_feature)
                all_detection_preds.append(out)

        all_y = LongTensor(
            dataset.dataset[:, -4].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = PosOOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = PosOOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = PosOOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = PosOOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))
        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_binary_y'],
                     test_result['all_detection_binary_preds'],
                     os.path.join(args.output_dir, 'test_cases.csv'),
                     processor, test_result['y_score'])

        # confusion matrix
        plot_confusion_matrix(test_result['all_binary_y'],
                              test_result['all_detection_binary_preds'],
                              args.output_dir)

        beta_log_path = 'beta_log.txt'
        if os.path.exists(beta_log_path):
            flag = True
        else:
            flag = False
        with open(beta_log_path, 'a', encoding='utf-8') as f:
            if flag == False:
                f.write('seed\tdataset\tdev_eer\ttest_eer\tdata_size\n')
            line = '\t'.join([
                str(config['seed']),
                str(config['data_file']),
                str(best_dev),
                str(test_result['eer']), '100'
            ])
            f.write(line + '\n')
Example #4
0
def main(args):
    logger.info('Checking...')
    check_manual_seed(args.seed)
    check_args(args)

    logger.info('Loading config...')
    bert_config = BertConfig('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['model_save_path'] = os.path.join(args.output_dir, 'save',
                                             'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    model = BertClassifier(bert_config, config)  # Bert encoder
    if args.fine_tune:
        model.unfreeze_bert_encoder()
    else:
        model.freeze_bert_encoder()
    model.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                if args.BCE:
                    loss = adversarial_loss(discriminator_output,
                                            (y != 0.0).float())
                else:
                    loss = classified_loss(discriminator_output, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        detection_loss = torch.nn.BCELoss().to(device)
        all_detection_preds = []
        all_class_preds = []
        all_pred = []
        all_logit = []
        total_loss = 0
        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                all_detection_preds.append(discriminator_output)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]
        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds,
                                                   all_y)

        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        detection_loss = torch.nn.BCELoss().to(device)
        all_detection_preds = []
        all_features = []
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                all_detection_preds.append(discriminator_output)
                if args.do_vis:
                    all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds,
                                                   all_y)

        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['all_pred'] = all_detection_binary_preds
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        # output_cases(texts, test_result['all_y'], test_result['all_pred'],
        #              os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['test_logit'])

        # confusion matrix
        plot_confusion_matrix(test_result['all_y'], test_result['all_pred'],
                              args.output_dir)
def main(args):
    check_manual_seed(args.seed)
    logger.info('seed: {}'.format(args.seed))

    logger.info('Loading config...')
    bert_config = Config('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')
    with open(data_path, 'r', encoding='utf-8') as fp:
        data = json.load(fp)
        for type in data:
            logger.info('{} : {}'.format(type, len(data[type])))
    with open(label_path, 'r', encoding='utf-8') as fp:
        logger.info(json.load(fp))

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
        logger.info('OOSProcessor')
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
        logger.info('SMPProcessor')
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.
    logger.info("label_to_id: {}".format(processor.label_to_id))
    logger.info("id_to_label: {}".format(processor.id_to_label))

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt')
    config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    D_detect = Discriminator(config)
    D_g = Discriminator(config)
    G = Generator(config)
    E = BertModel.from_pretrained(
        bert_config['PreTrainModelDir'])  # Bert encoder

    if args.fine_tune:
        for param in E.parameters():
            param.requires_grad = True
    else:
        for param in E.parameters():
            param.requires_grad = False

    D_detect.to(device)
    D_g.to(device)
    G.to(device)
    E.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step

        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer_G = torch.optim.Adam(G.parameters(),
                                       lr=args.G_lr)  # optimizer for generator
        optimizer_D_detect = torch.optim.Adam(
            D_detect.parameters(),
            lr=args.D_detect_lr)  # optimizer for discriminator
        optimizer_D_g = torch.optim.Adam(D_g.parameters(), lr=args.D_g_lr)
        optimizer_E = AdamW(E.parameters(), args.bert_lr)

        G_total_train_loss = []
        D_total_fake_loss = []
        D_total_real_loss = []
        FM_total_train_loss = []
        D_total_class_loss = []
        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        for i in range(args.n_epoch):
            logger.info('***********************************')
            logger.info('epoch: {}'.format(i))

            # Initialize model state
            G.train()
            D_detect.train()
            D_g.train()
            E.train()

            D_g_real_loss = 0
            D_g_fake_loss = 0
            D_detect_real_loss = 0
            D_detect_fake_loss = 0
            G_loss = 0

            for sample in tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                all_g_D_g_loss = 0
                D_gen_real_loss = None
                D_gen_fake_loss = None

                # the label used to train generator and discriminator.
                valid_label = FloatTensor(batch, 1).fill_(1.0).detach()
                fake_label = FloatTensor(batch, 1).fill_(0.0).detach()

                optimizer_E.zero_grad()
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                for gan_i in range(args.time):
                    # ------------------------- train D_g -------------------------#
                    # train on D_g real
                    id_sample = (y == 1.0)
                    weight = torch.ones(len(id_sample)).to(
                        device) - id_sample * 1.0  # 除去id损失, 只用ood数据
                    real_loss_func = torch.nn.BCELoss(weight=weight).to(device)
                    optimizer_D_g.zero_grad()
                    D_gen_real_discriminator_output, f_vector = D_g(
                        real_feature)
                    # D_gen_real_loss = adversarial_loss(D_gen_real_discriminator_output, valid_label) # 判别器对真实样本的损失
                    D_gen_real_loss = real_loss_func(
                        D_gen_real_discriminator_output.squeeze(),
                        valid_label.squeeze())

                    # train on D_g fake
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch, args.G_z_dim))).to(device)
                    fake_feature = G(z).detach()
                    D_gen_fake_discriminator_output, f_vector = D_g(
                        fake_feature)
                    D_gen_fake_loss = adversarial_loss(
                        D_gen_fake_discriminator_output.squeeze(),
                        fake_label.squeeze())  # 判别器对假样本的损失

                    D_gen_loss = D_gen_real_loss + D_gen_fake_loss
                    D_gen_loss.backward(retain_graph=True)  # 保存计算图,生成器还要使用
                    optimizer_D_g.step()

                    # ------------------------- train G -------------------------#
                    list_g_D_g_loss = []
                    for gi in range(args.g_time):
                        optimizer_G.zero_grad()
                        z = FloatTensor(
                            np.random.normal(0, 1,
                                             (batch, args.G_z_dim))).to(device)
                        fake_feature = G(z).detach()
                        D_gen_fake_discriminator_output, f_vector = D_g(
                            fake_feature)
                        g_D_g_loss = adversarial_loss(
                            D_gen_fake_discriminator_output.squeeze(),
                            valid_label.squeeze())  # 生成器欺骗 D_g, 认为是真实样本
                        g_D_g_loss.backward()
                        optimizer_G.step()
                        all_g_D_g_loss += g_D_g_loss.detach()
                        list_g_D_g_loss.append(g_D_g_loss)

                # ------------------------- train D_detect_ood -------------------------#
                # train on real(detect real sample)
                optimizer_D_detect.zero_grad()
                ood_real_detect_discriminator_output, f_vector = D_detect(
                    real_feature)
                ood_real_detect_loss = adversarial_loss(
                    ood_real_detect_discriminator_output.squeeze(),
                    (y != 0.0).float())  # ood 判别器对真实样本的损失

                # train on fake(detect fake sample) fake sample is fake id -> ood
                z = FloatTensor(np.random.normal(
                    0, 1, (batch, args.G_z_dim))).to(device)
                fake_feature = G(z).detach()
                ood_fake_detect_discriminator_output, f_vector = D_detect(
                    fake_feature)
                ood_fake_detect_loss = adversarial_loss(
                    ood_fake_detect_discriminator_output.squeeze(),
                    fake_label.squeeze())  # 假ood认为是ood样本

                D_detect_loss = args.beta * ood_real_detect_loss + (
                    1 - args.beta) * ood_fake_detect_loss  # 真实样本与假ood样本影响比例
                D_detect_loss.backward()
                optimizer_D_detect.step()

                if args.fine_tune:
                    optimizer_E.step()

                global_step += 1

                D_g_real_loss += D_gen_real_loss.detach()
                D_g_fake_loss += D_gen_fake_loss.detach()
                D_detect_real_loss += ood_real_detect_loss.detach()
                D_detect_fake_loss += ood_fake_detect_loss.detach()
                G_loss += all_g_D_g_loss

            logger.info('[Epoch {}] Train: D_g_real_loss: {}'.format(
                i, D_g_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_g_fake_loss: {}'.format(
                i, D_g_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: D_detect_real_loss: {}'.format(
                i, D_detect_real_loss / n_sample))
            logger.info('[Epoch {}] Train: D_detect_fake_loss: {}'.format(
                i, D_detect_fake_loss / n_sample))
            logger.info('[Epoch {}] Train: G_loss: {}'.format(
                i, G_loss / n_sample))
            logger.info(
                '---------------------------------------------------------------------------'
            )

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                # elif signal == 0:
                #     pass
                # elif signal == 1:
                #     save_gan_model(D, G, config['gan_save_path'])
                #     if args.fine_tune:
                #         save_model(E, path=config['bert_save_path'], model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        best_dev = -early_stopping.best_score

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        detection_loss = torch.nn.BCELoss().to(device)

        D_detect.eval()
        E.eval()

        all_detection_preds = []

        for sample in tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                discriminator_output, f_vector = D_detect(real_feature)
                all_detection_preds.append(discriminator_output)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)
        classified_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

        D_detect.eval()
        E.eval()

        all_detection_preds = []
        all_class_preds = []
        all_features = []

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector

            with torch.no_grad():
                sequence_output, pooled_output = E(token, mask, type_ids)
                real_feature = pooled_output

                discriminator_output, f_vector = D_detect(real_feature)
                all_detection_preds.append(discriminator_output)
                if args.do_vis:
                    all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        return result

    def get_fake_feature(num_output):
        """
        生成一定数量的假特征
        """
        G.eval()
        fake_features = []
        start = 0
        batch = args.predict_batch_size
        with torch.no_grad():
            while start < num_output:
                end = min(num_output, start + batch)
                z = FloatTensor(
                    np.random.normal(0, 1, size=(end - start, args.G_z_dim)))
                fake_feature = G(z)
                discriminator_output, f_vector = D_detect(fake_feature)
                fake_features.append(f_vector)
                start += batch
            return torch.cat(fake_features, 0).cpu().numpy()

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))
        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_binary_y'],
                     test_result['all_detection_binary_preds'],
                     os.path.join(args.output_dir,
                                  'test_cases.csv'), processor)

        # confusion matrix
        plot_confusion_matrix(test_result['all_binary_y'],
                              test_result['all_detection_binary_preds'],
                              args.output_dir)

        beta_log_path = 'beta_log.txt'
        if os.path.exists(beta_log_path):
            flag = True
        else:
            flag = False
        with open(beta_log_path, 'a', encoding='utf-8') as f:
            if flag == False:
                f.write('seed\tbeta\tdataset\tdev_eer\ttest_eer\tdata_size\n')
            line = '\t'.join([
                str(config['seed']),
                str(config['beta']),
                str(config['data_file']),
                str(best_dev),
                str(test_result['eer']), '100'
            ])
            f.write(line + '\n')

        if args.do_vis:
            # [2 * length, feature_fim]
            features = np.concatenate([
                test_result['all_features'],
                get_fake_feature(len(test_dataset) // 2)
            ],
                                      axis=0)
            features = TSNE(n_components=2, verbose=1,
                            n_jobs=-1).fit_transform(
                                features)  # [2 * length, 2]
            # [2 * length, 1]
            if n_class > 2:
                labels = np.concatenate([
                    test_result['all_y'],
                    np.array([-1] * (len(test_dataset) // 2))
                ], 0).reshape((-1, 1))
            else:
                labels = np.concatenate([
                    test_result['all_binary_y'],
                    np.array([-1] * (len(test_dataset) // 2))
                ], 0).reshape((-1, 1))
            # [2 * length, 3]
            data = np.concatenate([features, labels], 1)
            fig = scatter_plot(data, processor)
            fig.savefig(os.path.join(args.output_dir, 'plot.png'))
            fig.show()