Exemplo n.º 1
0
    def from_pretrained(cls,
                        pretrained_model_name,
                        state_dict=None,
                        cache_dir='',
                        *inputs,
                        **kwargs):
        serialization_dir = os.path.join(cache_dir, pretrained_model_name)
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None:
            weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
            state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})
            module._load_from_state_dict(state_dict, prefix, local_metadata,
                                         True, missing_keys, unexpected_keys,
                                         error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        if len(missing_keys) > 0:
            logger.info(
                "Weights of {} not initialized from pretrained model: {}".
                format(model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info(
                "Weights from pretrained model not used in {}: {}".format(
                    model.__class__.__name__, unexpected_keys))
        return model
Exemplo n.º 2
0
def main(args):
    logger.info('Checking...')
    SEED = args.seed
    check_manual_seed(SEED)
    check_args(args)
    logger.info('seed: {}'.format(args.seed))
    gross_result['seed'] = args.seed

    logger.info('Loading config...')
    bert_config = BertConfig('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['model_save_path'] = os.path.join(args.output_dir, 'save',
                                             'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    model = TextCNN(bert_config, n_class)  # Bert encoder
    if args.fine_tune:
        model.unfreeze_bert_encoder()
    else:
        model.freeze_bert_encoder()
    model.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                logits = model(token, mask, type_ids)
                loss = classified_loss(logits, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(eval_result['accuracy'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                # logger.info(eval_result)

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        all_pred = []
        all_logit = []
        total_loss = 0
        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_logit.append(logit)
                all_pred.append(torch.argmax(logit, 1))
                total_loss += classified_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_pred = torch.cat(all_pred, 0).cpu()
        all_logit = torch.cat(all_logit, 0).cpu()
        ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y)
        report = metrics.classification_report(all_y,
                                               all_pred,
                                               output_dict=True)
        result.update(report)
        y_score = all_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_pred, all_binary_y)

        result['eer'] = eer
        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['y_score'] = y_score
        result['all_binary_y'] = all_binary_y

        freeze_data['valid_all_y'] = all_y
        freeze_data['vaild_all_pred'] = all_pred
        freeze_data['valid_score'] = y_score

        return result

    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_logit.append(logit)
                all_pred.append(torch.argmax(logit, 1))
                total_loss += classified_loss(logit, y.long())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_pred = torch.cat(all_pred, 0).cpu()
        all_logit = torch.cat(all_logit, 0).cpu()

        # classification report
        ind_class_acc = metrics.ind_class_accuracy(all_pred, all_y)
        report = metrics.classification_report(all_y,
                                               all_pred,
                                               output_dict=True)
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_pred, all_binary_y)
        result.update(report)
        # 只有二分类时候ERR才有意义
        y_score = all_logit.softmax(1)[:, 1].tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample
        result['all_y'] = all_y.tolist()
        result['all_pred'] = all_pred.tolist()
        result['all_binary_y'] = all_binary_y

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_pred.tolist()
        freeze_data['test_score'] = y_score

        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['auc'] = roc_auc_score(all_binary_y, y_score)
        result['y_score'] = y_score
        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        # logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))
        gross_result['eval_eer'] = eval_result['eer']
        gross_result['eval_auc'] = eval_result['auc']
        gross_result['eval_fpr95'] = ErrorRateAt95Recall(
            eval_result['all_binary_y'], eval_result['y_score'])
        gross_result['eval_oos_ind_precision'] = eval_result[
            'oos_ind_precision']
        gross_result['eval_oos_ind_recall'] = eval_result['oos_ind_recall']
        gross_result['eval_oos_ind_f_score'] = eval_result['oos_ind_f_score']

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))
        # logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))

        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))

        gross_result['test_eer'] = test_result['eer']
        gross_result['test_auc'] = test_result['auc']
        gross_result['test_fpr95'] = ErrorRateAt95Recall(
            test_result['all_binary_y'], test_result['y_score'])
        gross_result['test_oos_ind_precision'] = test_result[
            'oos_ind_precision']
        gross_result['test_oos_ind_recall'] = test_result['oos_ind_recall']
        gross_result['test_oos_ind_f_score'] = test_result['oos_ind_f_score']

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_y'], test_result['all_pred'],
                     os.path.join(args.output_dir, 'test_cases.csv'),
                     processor)

        # confusion matrix
        plot_confusion_matrix(test_result['all_y'], test_result['all_pred'],
                              args.output_dir)

    with open(os.path.join(config['output_dir'], 'freeze_data.pkl'),
              'wb') as f:
        pickle.dump(freeze_data, f)
    df = pd.DataFrame(
        data={
            'valid_y': freeze_data['valid_all_y'],
            'valid_score': freeze_data['valid_score'],
        })
    df.to_csv(os.path.join(config['output_dir'], 'valid_score.csv'))

    df = pd.DataFrame(
        data={
            'test_y': freeze_data['test_all_y'],
            'test_score': freeze_data['test_score']
        })
    df.to_csv(os.path.join(config['output_dir'], 'test_score.csv'))

    if args.result != 'no':
        pd_result = pd.DataFrame(gross_result)
        if args.seed == 16:
            pd_result.to_csv(args.result + '_gross_result.csv', index=False)
        else:
            pd_result.to_csv(args.result + '_gross_result.csv',
                             index=False,
                             mode='a',
                             header=False)
        if args.seed == 8192:
            print(args.result)
            std_mean(args.result + '_gross_result.csv')
Exemplo n.º 3
0
def main(args):
    logger.info('Checking...')
    check_manual_seed(args.seed)
    check_args(args)

    logger.info('Loading config...')
    bert_config = BertConfig('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['model_save_path'] = os.path.join(args.output_dir, 'save',
                                             'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    model = BertClassifier(bert_config, config)  # Bert encoder
    if args.fine_tune:
        model.unfreeze_bert_encoder()
    else:
        model.freeze_bert_encoder()
    model.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps,
                                      shuffle=True,
                                      num_workers=2)

        nonlocal global_step
        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.lr)

        train_loss = []
        if dev_dataset:
            valid_loss = []
            valid_ind_class_acc = []
        iteration = 0
        for i in range(args.n_epoch):

            model.train()

            total_loss = 0
            for sample in tqdm.tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                if args.BCE:
                    loss = adversarial_loss(discriminator_output,
                                            (y != 0.0).float())
                else:
                    loss = classified_loss(discriminator_output, y.long())
                total_loss += loss.item()
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                # bp and update parameters
                if (global_step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info('[Epoch {}] Train: train_loss: {}'.format(
                i, total_loss / n_sample))
            logger.info('-' * 30)

            train_loss.append(total_loss / n_sample)
            iteration += 1

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_loss.append(eval_result['loss'])
                valid_ind_class_acc.append(eval_result['ind_class_acc'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                elif signal == 0:
                    pass
                elif signal == 1:
                    save_model(model,
                               path=config['model_save_path'],
                               model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)
        if dev_dataset:
            draw_curve(valid_loss, iteration, 'valid_loss', args.output_dir)
            draw_curve(valid_ind_class_acc, iteration,
                       'valid_ind_class_accuracy', args.output_dir)

        if args.patience >= args.n_epoch:
            save_model(model,
                       path=config['model_save_path'],
                       model_name='bert')

        freeze_data['train_loss'] = train_loss
        freeze_data['valid_loss'] = valid_loss

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        detection_loss = torch.nn.BCELoss().to(device)
        all_detection_preds = []
        all_class_preds = []
        all_pred = []
        all_logit = []
        total_loss = 0
        for sample in tqdm.tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                all_detection_preds.append(discriminator_output)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]
        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds,
                                                   all_y)

        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    def test(dataset):
        load_model(model, path=config['model_save_path'], model_name='bert')
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()
        model.eval()

        # Loss function
        classified_loss = torch.nn.CrossEntropyLoss().to(device)
        detection_loss = torch.nn.BCELoss().to(device)
        all_detection_preds = []
        all_features = []
        all_pred = []
        total_loss = 0
        all_logit = []
        for sample in tqdm.tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            with torch.no_grad():
                f_vector, discriminator_output, classification_output = model(
                    token, mask, type_ids, return_feature=True)
                discriminator_output = discriminator_output.squeeze()
                all_detection_preds.append(discriminator_output)
                if args.do_vis:
                    all_features.append(f_vector)

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]

        # 计算损失
        detection_loss = detection_loss(all_detection_preds,
                                        all_binary_y.float())
        result['detection_loss'] = detection_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_preds.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        if args.do_vis:
            all_features = torch.cat(all_features, 0).cpu().numpy()
            result['all_features'] = all_features

        ind_class_acc = metrics.ind_class_accuracy(all_detection_binary_preds,
                                                   all_y)

        result['ind_class_acc'] = ind_class_acc
        result['loss'] = total_loss / n_sample

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['all_y'] = all_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['score'] = y_score
        result['y_score'] = y_score
        result['all_pred'] = all_detection_binary_preds
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        freeze_data['test_all_y'] = all_y.tolist()
        freeze_data['test_all_pred'] = all_detection_binary_preds.tolist()
        freeze_data['test_score'] = y_score

        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        # output_cases(texts, test_result['all_y'], test_result['all_pred'],
        #              os.path.join(args.output_dir, 'test_cases.csv'), processor, test_result['test_logit'])

        # confusion matrix
        plot_confusion_matrix(test_result['all_y'], test_result['all_pred'],
                              args.output_dir)
Exemplo n.º 4
0
            # print(layer.weights)
            new_model_weights.extend(layer.weights)
        print(len(new_model_weights))

        # for i in range(len(mapping)):
        #     print(new_model_weights[i], mapping[i])

        if len(new_model_weights) != len(values):
            raise ValueError(
                'Expecting %s weights, but provide a list of %s weights.' %
                (len(new_model_weights), len(values)))

        K.batch_set_value(zip(new_model_weights, values))


if __name__ == '__main__':
    model_path = "/Users/James/Study/pretrained_models/bert/chinese-bert_chinese_wwm_L-12_H-768_A-12/bert_model.ckpt"
    config_dict = "/Users/James/Study/pretrained_models/bert/chinese-bert_chinese_wwm_L-12_H-768_A-12/bert_config.json"

    bert_config = BertConfig(config_dict)

    bert = BertEncoder(bert_config)
    bert.summary()
    # bert.build(input_shape=[(None,None),(None,None)])
    bert.load_weights_from_checkpoint(model_path)

    # print(bert.layers)
    # print(len(bert.layers))

    # bert.load_weights_from_checkpoint(model_path)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--model_name",
                        default='GBert-predict',
                        type=str,
                        required=False,
                        help="model name")
    parser.add_argument("--data_dir",
                        default='../data',
                        type=str,
                        required=False,
                        help="The input data dir.")
    parser.add_argument("--pretrain_dir",
                        default='../saved/GBert-pretraining',
                        type=str,
                        required=False,
                        help="pretraining model")
    parser.add_argument("--train_file",
                        default='data-multi-visit.pkl',
                        type=str,
                        required=False,
                        help="training data file.")
    parser.add_argument(
        "--output_dir",
        default='../saved/',
        type=str,
        required=False,
        help="The output directory where the model checkpoints will be written."
    )

    # Other parameters
    parser.add_argument("--use_pretrain",
                        default=False,
                        action='store_true',
                        help="is use pretrain")
    parser.add_argument("--graph",
                        default=False,
                        action='store_true',
                        help="if use ontology embedding")
    parser.add_argument("--therhold",
                        default=0.3,
                        type=float,
                        help="therhold.")
    parser.add_argument(
        "--max_seq_length",
        default=55,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=True,
                        action='store_true',
                        help="Whether to run on the dev set.")
    parser.add_argument("--do_test",
                        default=True,
                        action='store_true',
                        help="Whether to run on the test set.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=1203,
                        help="random seed for initialization")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")

    args = parser.parse_args()
    args.output_dir = os.path.join(args.output_dir, args.model_name)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError(
    #         "Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    print("Loading Dataset")
    tokenizer, (train_dataset, eval_dataset, test_dataset) = load_dataset(args)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=1)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=SequentialSampler(eval_dataset),
                                 batch_size=1)
    test_dataloader = DataLoader(test_dataset,
                                 sampler=SequentialSampler(test_dataset),
                                 batch_size=1)

    print('Loading Model: ' + args.model_name)
    # config = BertConfig(vocab_size_or_config_json_file=len(tokenizer.vocab.word2idx), side_len=train_dataset.side_len)
    # config.graph = args.graph
    # model = SeperateBertTransModel(config, tokenizer.dx_voc, tokenizer.rx_voc)
    if args.use_pretrain:
        logger.info("Use Pretraining model")
        model = GBERT_Predict.from_pretrained(args.pretrain_dir,
                                              tokenizer=tokenizer)
    else:
        config = BertConfig(
            vocab_size_or_config_json_file=len(tokenizer.vocab.word2idx))
        config.graph = args.graph
        model = GBERT_Predict(config, tokenizer)
    logger.info('# of model parameters: ' + str(get_n_params(model)))

    model.to(device)

    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    rx_output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")

    # Prepare optimizer
    # num_train_optimization_steps = int(
    #     len(train_dataset) / args.train_batch_size) * args.num_train_epochs
    # param_optimizer = list(model.named_parameters())
    # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # optimizer_grouped_parameters = [
    #     {'params': [p for n, p in param_optimizer if not any(
    #         nd in n for nd in no_decay)], 'weight_decay': 0.01},
    #     {'params': [p for n, p in param_optimizer if any(
    #         nd in n for nd in no_decay)], 'weight_decay': 0.0}
    # ]

    # optimizer = BertAdam(optimizer_grouped_parameters,
    #                      lr=args.learning_rate,
    #                      warmup=args.warmup_proportion,
    #                      t_total=num_train_optimization_steps)
    optimizer = Adam(model.parameters(), lr=args.learning_rate)

    global_step = 0
    if args.do_train:
        writer = SummaryWriter(args.output_dir)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", 1)

        dx_acc_best, rx_acc_best = 0, 0
        acc_name = 'prauc'
        dx_history = {'prauc': []}
        rx_history = {'prauc': []}

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            print('')
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            prog_iter = tqdm(train_dataloader, leave=False, desc='Training')
            model.train()
            for _, batch in enumerate(prog_iter):
                batch = tuple(t.to(device) for t in batch)
                input_ids, dx_labels, rx_labels = batch
                input_ids, dx_labels, rx_labels = input_ids.squeeze(
                    dim=0), dx_labels.squeeze(dim=0), rx_labels.squeeze(dim=0)
                loss, rx_logits = model(input_ids,
                                        dx_labels=dx_labels,
                                        rx_labels=rx_labels,
                                        epoch=global_step)
                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += 1
                nb_tr_steps += 1

                # Display loss
                prog_iter.set_postfix(loss='%.4f' % (tr_loss / nb_tr_steps))

                optimizer.step()
                optimizer.zero_grad()

            writer.add_scalar('train/loss', tr_loss / nb_tr_steps, global_step)
            global_step += 1

            if args.do_eval:
                print('')
                logger.info("***** Running eval *****")
                model.eval()
                dx_y_preds = []
                dx_y_trues = []
                rx_y_preds = []
                rx_y_trues = []
                for eval_input in tqdm(eval_dataloader, desc="Evaluating"):
                    eval_input = tuple(t.to(device) for t in eval_input)
                    input_ids, dx_labels, rx_labels = eval_input
                    input_ids, dx_labels, rx_labels = input_ids.squeeze(
                    ), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
                    with torch.no_grad():
                        loss, rx_logits = model(input_ids,
                                                dx_labels=dx_labels,
                                                rx_labels=rx_labels)
                        rx_y_preds.append(t2n(torch.sigmoid(rx_logits)))
                        rx_y_trues.append(t2n(rx_labels))
                        # dx_y_preds.append(t2n(torch.sigmoid(dx_logits)))
                        # dx_y_trues.append(
                        #     t2n(dx_labels.view(-1, len(tokenizer.dx_voc.word2idx))))
                        # rx_y_preds.append(t2n(torch.sigmoid(rx_logits))[
                        #                   :, tokenizer.rx_singe2multi])
                        # rx_y_trues.append(
                        #     t2n(rx_labels)[:, tokenizer.rx_singe2multi])

                print('')
                # dx_acc_container = metric_report(np.concatenate(dx_y_preds, axis=0), np.concatenate(dx_y_trues, axis=0),
                #                                  args.therhold)
                rx_acc_container = metric_report(
                    np.concatenate(rx_y_preds, axis=0),
                    np.concatenate(rx_y_trues, axis=0), args.therhold)
                for k, v in rx_acc_container.items():
                    writer.add_scalar('eval/{}'.format(k), v, global_step)

                if rx_acc_container[acc_name] > rx_acc_best:
                    rx_acc_best = rx_acc_container[acc_name]
                    # save model
                    torch.save(model_to_save.state_dict(),
                               rx_output_model_file)

        with open(os.path.join(args.output_dir, 'bert_config.json'),
                  'w',
                  encoding='utf-8') as fout:
            fout.write(model.config.to_json_string())

    if args.do_test:
        logger.info("***** Running test *****")
        logger.info("  Num examples = %d", len(test_dataset))
        logger.info("  Batch size = %d", 1)

        def test(task=0):
            # Load a trained model that you have fine-tuned
            model_state_dict = torch.load(rx_output_model_file)
            model.load_state_dict(model_state_dict)
            model.to(device)

            model.eval()
            y_preds = []
            y_trues = []
            for test_input in tqdm(test_dataloader, desc="Testing"):
                test_input = tuple(t.to(device) for t in test_input)
                input_ids, dx_labels, rx_labels = test_input
                input_ids, dx_labels, rx_labels = input_ids.squeeze(
                ), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
                with torch.no_grad():
                    loss, rx_logits = model(input_ids,
                                            dx_labels=dx_labels,
                                            rx_labels=rx_labels)
                    y_preds.append(t2n(torch.sigmoid(rx_logits)))
                    y_trues.append(t2n(rx_labels))

            print('')
            acc_container = metric_report(np.concatenate(y_preds, axis=0),
                                          np.concatenate(y_trues, axis=0),
                                          args.therhold)

            # save report
            if args.do_train:
                for k, v in acc_container.items():
                    writer.add_scalar('test/{}'.format(k), v, 0)

            return acc_container

        test(task=0)
Exemplo n.º 6
0
        'weight_decay':
        0.01,
        'decay_filter':
        lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.
        lower(),
    }),
    'Momentum':
    edict({
        'learning_rate': 2e-5,
        'momentum': 0.9,
    }),
})

bert_net_cfg = BertConfig(
    seq_length=128,
    vocab_size=21128,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    type_vocab_size=2,
    initializer_range=0.02,
    use_relative_positions=False,
    dtype=ts.float32,
    compute_type=ts.float16,
)
Exemplo n.º 7
0
def main(args):
    check_manual_seed(args.seed)
    logger.info('seed: {}'.format(args.seed))

    logger.info('Loading config...')
    bert_config = BertConfig('config/bert.ini')
    bert_config = bert_config(args.bert_type)

    # for oos-eval dataset
    data_config = Config('config/data.ini')
    data_config = data_config(args.dataset)

    # Prepare data processor
    data_path = os.path.join(data_config['DataDir'],
                             data_config[args.data_file])  # 把目录和文件名合成一个路径
    label_path = data_path.replace('.json', '.label')
    with open(data_path, 'r', encoding='utf-8') as fp:
        data = json.load(fp)
        for type in data:
            logger.info('{} : {}'.format(type, len(data[type])))
    with open(label_path, 'r', encoding='utf-8') as fp:
        logger.info(json.load(fp))

    if args.dataset == 'oos-eval':
        processor = OOSProcessor(bert_config, maxlen=32)
        logger.info('OOSProcessor')
    elif args.dataset == 'smp':
        processor = SMPProcessor(bert_config, maxlen=32)
        # processor = PosSMPProcessor(bert_config, maxlen=32)
        logger.info('SMPProcessor')
    else:
        raise ValueError('The dataset {} is not supported.'.format(
            args.dataset))

    processor.load_label(
        label_path)  # Adding label_to_id and id_to_label ot processor.
    logger.info("label_to_id: {}".format(processor.label_to_id))
    logger.info("id_to_label: {}".format(processor.id_to_label))

    n_class = len(processor.id_to_label)
    config = vars(args)  # 返回参数字典
    config['gan_save_path'] = os.path.join(args.output_dir, 'save', 'gan.pt')
    config['bert_save_path'] = os.path.join(args.output_dir, 'save', 'bert.pt')
    config['n_class'] = n_class

    logger.info('config:')
    logger.info(config)

    # E = BertModel.from_pretrained(bert_config['PreTrainModelDir'])  # Bert encoder
    model = BertClassifier(bert_config, n_class)  # Bert encoder

    if args.fine_tune:
        for param in model.parameters():
            param.requires_grad = True
    else:
        for param in model.parameters():
            param.requires_grad = False

    model.to(device)

    global_step = 0

    def train(train_dataset, dev_dataset):
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=args.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)

        global best_dev
        nonlocal global_step

        n_sample = len(train_dataloader)
        early_stopping = EarlyStopping(args.patience, logger=logger)
        # Loss function
        adversarial_loss = torch.nn.BCELoss().to(device)

        # Optimizers
        optimizer = AdamW(model.parameters(), args.bert_lr)

        valid_detection_loss = []
        valid_oos_ind_precision = []
        valid_oos_ind_recall = []
        valid_oos_ind_f_score = []

        train_loss = []
        iteration = 0

        for i in range(args.n_epoch):
            logger.info('***********************************')
            logger.info('epoch: {}'.format(i))

            # Initialize model state
            model.train()

            total_loss = 0
            for sample in tqdm(train_dataloader):
                sample = (i.to(device) for i in sample)
                token, mask, type_ids, y = sample
                batch = len(token)

                optimizer.zero_grad()
                logit = model(token, mask, type_ids)
                loss = adversarial_loss(logit, y.float())
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            logger.info('[Epoch {}] Train: loss: {}'.format(
                i, total_loss / n_sample))
            train_loss.append(total_loss / n_sample)
            iteration += 1
            logger.info(
                '---------------------------------------------------------------------------'
            )

            if dev_dataset:
                logger.info(
                    '#################### eval result at step {} ####################'
                    .format(global_step))
                eval_result = eval(dev_dataset)

                valid_detection_loss.append(eval_result['detection_loss'])
                valid_oos_ind_precision.append(
                    eval_result['oos_ind_precision'])
                valid_oos_ind_recall.append(eval_result['oos_ind_recall'])
                valid_oos_ind_f_score.append(eval_result['oos_ind_f_score'])

                # 1 表示要保存模型
                # 0 表示不需要保存模型
                # -1 表示不需要模型,且超过了patience,需要early stop
                signal = early_stopping(-eval_result['eer'])
                if signal == -1:
                    break
                # elif signal == 0:
                #     pass
                # elif signal == 1:
                #     save_gan_model(D, G, config['gan_save_path'])
                #     if args.fine_tune:
                #         save_model(E, path=config['bert_save_path'], model_name='bert')

                logger.info(eval_result)
                logger.info('valid_eer: {}'.format(eval_result['eer']))
                logger.info('valid_oos_ind_precision: {}'.format(
                    eval_result['oos_ind_precision']))
                logger.info('valid_oos_ind_recall: {}'.format(
                    eval_result['oos_ind_recall']))
                logger.info('valid_oos_ind_f_score: {}'.format(
                    eval_result['oos_ind_f_score']))
                logger.info('valid_auc: {}'.format(eval_result['auc']))
                logger.info('valid_fpr95: {}'.format(
                    ErrorRateAt95Recall(eval_result['all_binary_y'],
                                        eval_result['y_score'])))

        from utils.visualization import draw_curve
        draw_curve(train_loss, iteration, 'train_loss', args.output_dir)

        best_dev = -early_stopping.best_score

    def eval(dataset):
        dev_dataloader = DataLoader(dataset,
                                    batch_size=args.predict_batch_size,
                                    shuffle=False,
                                    num_workers=2)
        n_sample = len(dev_dataloader)
        result = dict()

        detection_loss = torch.nn.BCELoss().to(device)

        model.eval()

        all_detection_preds = []
        all_detection_logit = []
        total_loss = 0

        for sample in tqdm(dev_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_detection_logit.append(logit)
                all_detection_preds.append(logit)
                total_loss += detection_loss(logit, y.float())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]
        all_detection_logit = torch.cat(all_detection_logit, 0).cpu()

        # 计算损失
        result['detection_loss'] = total_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_logit.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    def test(dataset):
        # # load BERT and GAN
        # load_gan_model(D, G, config['gan_save_path'])
        # if args.fine_tune:
        #     load_model(E, path=config['bert_save_path'], model_name='bert')
        #
        test_dataloader = DataLoader(dataset,
                                     batch_size=args.predict_batch_size,
                                     shuffle=False,
                                     num_workers=2)
        n_sample = len(test_dataloader)
        result = dict()

        # Loss function
        detection_loss = torch.nn.BCELoss().to(device)

        model.eval()

        all_detection_preds = []
        all_detection_logit = []
        total_loss = 0

        for sample in tqdm(test_dataloader):
            sample = (i.to(device) for i in sample)
            token, mask, type_ids, y = sample
            batch = len(token)

            # -------------------------evaluate D------------------------- #
            # BERT encode sentence to feature vector
            with torch.no_grad():
                logit = model(token, mask, type_ids)
                all_detection_logit.append(logit)
                all_detection_preds.append(logit)
                total_loss += detection_loss(logit, y.float())

        all_y = LongTensor(
            dataset.dataset[:, -1].astype(int)).cpu()  # [length, n_class]
        all_binary_y = (all_y != 0).long()  # [length, 1] label 0 is oos
        all_detection_preds = torch.cat(all_detection_preds,
                                        0).cpu()  # [length, 1]
        all_detection_binary_preds = convert_to_int_by_threshold(
            all_detection_preds.squeeze())  # [length, 1]
        all_detection_logit = torch.cat(all_detection_logit, 0).cpu()

        # 计算损失
        result['detection_loss'] = total_loss

        logger.info(
            metrics.classification_report(all_binary_y,
                                          all_detection_binary_preds,
                                          target_names=['oos', 'in']))

        # report
        oos_ind_precision, oos_ind_recall, oos_ind_fscore, _ = metrics.binary_recall_fscore(
            all_detection_binary_preds, all_binary_y)
        detection_acc = metrics.accuracy(all_detection_binary_preds,
                                         all_binary_y)

        y_score = all_detection_logit.squeeze().tolist()
        eer = metrics.cal_eer(all_binary_y, y_score)

        test_logit = all_detection_logit.tolist()
        result['test_logit'] = test_logit

        result['eer'] = eer
        result['all_detection_binary_preds'] = all_detection_binary_preds
        result['detection_acc'] = detection_acc
        result['all_binary_y'] = all_binary_y
        result['oos_ind_precision'] = oos_ind_precision
        result['oos_ind_recall'] = oos_ind_recall
        result['oos_ind_f_score'] = oos_ind_fscore
        result['y_score'] = y_score
        result['auc'] = roc_auc_score(all_binary_y, y_score)

        return result

    if args.do_train:
        if config['data_file'].startswith('binary'):
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_train_set = processor.read_dataset(data_path,
                                                    ['train', 'oos_train'])
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_train_set = processor.read_dataset(data_path, ['train'])
            text_dev_set = processor.read_dataset(data_path, ['val'])

        train_features = processor.convert_to_ids(text_train_set)
        train_dataset = OOSDataset(train_features)
        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)

        train(train_dataset, dev_dataset)

    if args.do_eval:
        logger.info(
            '#################### eval result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_dev_set = processor.read_dataset(data_path, ['val'])
        elif config['dataset'] == 'oos-eval':
            text_dev_set = processor.read_dataset(data_path,
                                                  ['val', 'oos_val'])
        elif config['dataset'] == 'smp':
            text_dev_set = processor.read_dataset(data_path, ['val'])

        dev_features = processor.convert_to_ids(text_dev_set)
        dev_dataset = OOSDataset(dev_features)
        eval_result = eval(dev_dataset)
        logger.info(eval_result)
        logger.info('eval_eer: {}'.format(eval_result['eer']))
        logger.info('eval_oos_ind_precision: {}'.format(
            eval_result['oos_ind_precision']))
        logger.info('eval_oos_ind_recall: {}'.format(
            eval_result['oos_ind_recall']))
        logger.info('eval_oos_ind_f_score: {}'.format(
            eval_result['oos_ind_f_score']))
        logger.info('eval_auc: {}'.format(eval_result['auc']))
        logger.info('eval_fpr95: {}'.format(
            ErrorRateAt95Recall(eval_result['all_binary_y'],
                                eval_result['y_score'])))

    if args.do_test:
        logger.info(
            '#################### test result at step {} ####################'.
            format(global_step))
        if config['data_file'].startswith('binary'):
            text_test_set = processor.read_dataset(data_path, ['test'])
        elif config['dataset'] == 'oos-eval':
            text_test_set = processor.read_dataset(data_path,
                                                   ['test', 'oos_test'])
        elif config['dataset'] == 'smp':
            text_test_set = processor.read_dataset(data_path, ['test'])

        test_features = processor.convert_to_ids(text_test_set)
        test_dataset = OOSDataset(test_features)
        test_result = test(test_dataset)
        logger.info(test_result)
        logger.info('test_eer: {}'.format(test_result['eer']))
        logger.info('test_ood_ind_precision: {}'.format(
            test_result['oos_ind_precision']))
        logger.info('test_ood_ind_recall: {}'.format(
            test_result['oos_ind_recall']))
        logger.info('test_ood_ind_f_score: {}'.format(
            test_result['oos_ind_f_score']))
        logger.info('test_auc: {}'.format(test_result['auc']))
        logger.info('test_fpr95: {}'.format(
            ErrorRateAt95Recall(test_result['all_binary_y'],
                                test_result['y_score'])))
        my_plot_roc(test_result['all_binary_y'], test_result['y_score'],
                    os.path.join(args.output_dir, 'roc_curve.png'))
        save_result(test_result, os.path.join(args.output_dir, 'test_result'))

        # 输出错误cases
        if config['dataset'] == 'oos-eval':
            texts = [line[0] for line in text_test_set]
        elif config['dataset'] == 'smp':
            texts = [line['text'] for line in text_test_set]
        else:
            raise ValueError('The dataset {} is not supported.'.format(
                args.dataset))

        output_cases(texts, test_result['all_binary_y'],
                     test_result['all_detection_binary_preds'],
                     os.path.join(args.output_dir, 'test_cases.csv'),
                     processor, test_result['test_logit'])

        # confusion matrix
        plot_confusion_matrix(test_result['all_binary_y'],
                              test_result['all_detection_binary_preds'],
                              args.output_dir)

        beta_log_path = 'beta_log.txt'
        if os.path.exists(beta_log_path):
            flag = True
        else:
            flag = False
        with open(beta_log_path, 'a', encoding='utf-8') as f:
            if flag == False:
                f.write('seed\tdataset\tdev_eer\ttest_eer\tdata_size\n')
            line = '\t'.join([
                str(config['seed']),
                str(config['data_file']),
                str(best_dev),
                str(test_result['eer']), '100'
            ])
            f.write(line + '\n')