Beispiel #1
0
def bilstm_crf_train_eval(config,
                          import_model,
                          train_examples,
                          dev_examples=None,
                          test_examples=None):
    processor = DataProcessor(config.data_dir, config.do_lower_case)
    word2id = processor.get_vocab()
    config.vocab_size = len(word2id)
    train_features = convert_examples_to_features_crf(
        examples=train_examples,
        word2id=word2id,
        label_list=config.label_list,
    )
    train_dataset = BuildDataSet(train_features)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    if dev_examples:
        dev_features = convert_examples_to_features_crf(
            examples=dev_examples,
            word2id=word2id,
            label_list=config.label_list,
        )
        dev_dataset = BuildDataSet(dev_features)
        dev_loader = DataLoader(dev_dataset,
                                batch_size=config.batch_size,
                                shuffle=True)
    else:
        dev_loader = None

    if test_examples:
        test_features = convert_examples_to_features_crf(
            examples=test_examples,
            word2id=word2id,
            label_list=config.label_list,
        )
        test_dataset = BuildDataSet(test_features)
        test_loader = DataLoader(test_dataset,
                                 batch_size=config.batch_size,
                                 shuffle=False)
    else:
        test_loader = None

    logger.info("self config:\n {}".format(config_to_json_string(config)))

    model = import_model.Model(config).to(config.device)
    best_model = model_train(config,
                             model,
                             train_loader,
                             dev_loader,
                             to_crf=True)
    model_test(config, best_model, test_loader, to_crf=True)
def train_dev_test(
    config,
    model,
    tokenizer,
    train_data=None,
    dev_data=None,
    test_examples=None,
):
    dev_acc = 0.
    dev_f1 = 0.
    predict_label = []

    # 加载模型
    model_example = copy.deepcopy(model).to(config.device)
    best_model = None
    convert_to_features, build_data_set, data_loader = MODEL_CLASSES[config.use_model]
    if train_data:
        config.train_num_examples = len(train_data)
        # 特征转化
        train_features = convert_to_features(
            examples=train_data,
            tokenizer=tokenizer,
            label_list=config.class_list,
            max_length=config.pad_size,
            to_pair=config.to_pair,
            speaker_dict=config.speaker_dict
        )
        train_dataset = build_data_set(train_features, config=config)
        train_loader = data_loader(train_dataset, device=config.device, batch_size=config.batch_size, shuffle=True)

        # dev 数据加载与转换
        if dev_data is not None:
            config.dev_num_examples = len(dev_data)
            dev_features = convert_to_features(
                examples=dev_data,
                tokenizer=tokenizer,
                label_list=config.class_list,
                max_length=config.pad_size,
                to_pair=config.to_pair,
                speaker_dict=config.speaker_dict
            )
            dev_dataset = build_data_set(dev_features, config=config)
            dev_loader = data_loader(dev_dataset, device=config.device, batch_size=config.batch_size, shuffle=True)
        else:
            dev_loader = None

        best_model = model_train(config, model_example, train_loader, dev_loader)

        if dev_data is not None:
            dev_acc, dev_f1 = model_metrics(config, best_model, dev_loader)

    if test_examples is not None or dev_data is not None:
        if test_examples is None:
            test_examples = dev_data
        test_features = convert_to_features(
            examples=test_examples,
            tokenizer=tokenizer,
            label_list=config.class_list,
            max_length=config.pad_size,
            to_pair=config.to_pair,
            speaker_dict=config.speaker_dict
        )
        test_dataset = build_data_set(test_features, config=config)
        test_loader = data_loader(test_dataset, device=config.device, batch_size=config.batch_size, shuffle=False)
        predict_label = model_evaluate(config, model_example, test_loader, test=True)

    return best_model, (dev_acc, dev_f1), predict_label