Exemplo n.º 1
0
def bilstm_crf_train_eval(config,
                          import_model,
                          train_examples,
                          dev_examples=None,
                          test_examples=None):
    processor = DataProcessor(config.data_dir, config.do_lower_case)
    word2id = processor.get_vocab()
    config.vocab_size = len(word2id)
    train_features = convert_examples_to_features_crf(
        examples=train_examples,
        word2id=word2id,
        label_list=config.label_list,
    )
    train_dataset = BuildDataSet(train_features)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    if dev_examples:
        dev_features = convert_examples_to_features_crf(
            examples=dev_examples,
            word2id=word2id,
            label_list=config.label_list,
        )
        dev_dataset = BuildDataSet(dev_features)
        dev_loader = DataLoader(dev_dataset,
                                batch_size=config.batch_size,
                                shuffle=True)
    else:
        dev_loader = None

    if test_examples:
        test_features = convert_examples_to_features_crf(
            examples=test_examples,
            word2id=word2id,
            label_list=config.label_list,
        )
        test_dataset = BuildDataSet(test_features)
        test_loader = DataLoader(test_dataset,
                                 batch_size=config.batch_size,
                                 shuffle=False)
    else:
        test_loader = None

    logger.info("self config:\n {}".format(config_to_json_string(config)))

    model = import_model.Model(config).to(config.device)
    best_model = model_train(config,
                             model,
                             train_loader,
                             dev_loader,
                             to_crf=True)
    model_test(config, best_model, test_loader, to_crf=True)
Exemplo n.º 2
0
def Task(config):
    if config.device.type == 'cuda':
        torch.cuda.set_device(config.device_id)

    tokenizer = BertTokenizer.from_pretrained(
        config.tokenizer_file, do_lower_case=config.do_lower_case)
    processor = DataProcessor(data_dir=config.data_dir,
                              do_lower_case=config.do_lower_case,
                              language=config.language,
                              do_preprocessing=config.do_preprocessing,
                              split=config.split)
    config.class_list = processor.get_labels()
    config.num_labels = len(config.class_list)
    config.speaker_dict = processor.get_speaker_map(
        config.speaker_threshold) if config.speaker_tag else None
    train_examples = processor.get_train_examples()
    dev_examples = processor.get_dev_examples()
    test_examples = processor.get_test_examples(config.test_file)

    cur_model = MODEL_CLASSES[config.use_model]
    model = cur_model(config)

    logger.info("self config: {}\n".format(config_to_json_string(config)))

    model_example, metrics_result, predict_label = cross_validation(
        config=config,
        train_examples=train_examples,
        dev_examples=dev_examples,
        model=model,
        tokenizer=tokenizer,
        pattern=config.pattern,
        test_examples=test_examples)

    if config.pattern == 'k_fold':
        logger.info('K({})-fold models dev acc: {}'.format(
            config.k_fold, metrics_result[0]))
        logger.info('K({})-fold models dev f1: {}'.format(
            config.k_fold, metrics_result[1]))
        dev_acc = np.array(metrics_result[0]).mean()
        dev_f1 = np.array(metrics_result[1]).mean()
        predict_label = combined_result(predict_label, pattern='average')
    else:
        dev_acc = metrics_result[0]
        dev_f1 = metrics_result[1]
    logger.info("dev evaluate average Acc: {}, F1:{}".format(dev_acc, dev_f1))
    file_name = '{}_{}_{:>.6f}.csv'.format(
        strftime("%m%d-%H%M%S", localtime()), config.language, dev_f1)
    predict_to_save(predict_label,
                    path=config.result_save_path,
                    file=file_name,
                    prob_threshold=config.prob_threshold)