Пример #1
0
def inference(configs):
    if configs.command != 'inference':
        raise RuntimeError('Not in inference mode')
    inference_data_loader, inference_conllu = load_bertology_input(configs)
    with Timer('load trainer'):
        trainer = load_trainer(configs)
    with Timer('inference'):
        trainer.inference(inference_data_loader, inference_conllu, output_conllu_path=configs.output_conllu_path)
    print(f'INFERENCE output file saved in {configs.output_conllu_path}')
Пример #2
0
def train(configs) -> None:
    """
    训练模式

    Args:
        configs: 配置参数

    Returns:
        None
    """
    if configs.command != 'train':
        raise RuntimeError('Not in train mode')

    with Timer('Load data set'):
        # 目前仅仅支持BERTology形式的输入
        train_data_loader, _, dev_data_loader, dev_conllu = load_bertology_input(configs)

    logger.info(f'train batch size: {configs.train_batch_size}')
    logger.info(f'train data batch num: {len(train_data_loader)}')
    # dev的间隔步数:
    configs.eval_interval = len(train_data_loader) * configs.eval_epoch
    logger.info(f'eval interval: {configs.eval_interval}')
    # 注意该参数影响学习率warm up
    configs.max_train_steps = len(train_data_loader) * configs.max_train_epochs
    logger.info(f'max steps: {configs.max_train_steps}')
    # 如果6个epoch之后仍然不能提升,就停止
    if configs.early_stop:
        logger.info(f'early stop steps: {configs.early_stop_epochs}\n')
    else:
        logger.info(f'do not use early stop, training will last {configs.max_train_epochs} epochs')
    with Timer('Load trainer'):
        trainer = load_trainer(configs)

    save_config_to_yaml(configs)

    with Timer('Train'):
        trainer.train(train_data_loader, dev_data_loader, dev_conllu)
    logger.info('Train Complete!')
    logger.info(f'Best Model Parameters and Configs saved in {configs.output_model_dir}')

    if configs.test_after_train and configs.local_rank in [-1, 0]:
        if configs.no_output:
            raise RuntimeError('no_output为True时无法训练后立刻测试')
        # 最优模型的保存位置
        configs.saved_model_path = configs.output_model_dir
        # 测试gold文件
        configs.input_conllu_path = os.path.join(configs.data_dir, configs.test_file)
        configs.output_conllu_path = configs.test_output_path
        configs.command = 'test_after_train'
        dev(configs)
Пример #3
0
def dev(configs):
    """
    验证模式,gold input file: configs.input_conllu_path; dev output file: configs.output_conllu_path
    Args:
        configs:

    Returns:

    """
    if configs.command not in ['dev', 'test_after_train']:
        raise RuntimeError('Not in dev mode')
    dev_data_loader, dev_conllu = load_bertology_input(configs)
    with Timer('Load trainer'):
        trainer = load_trainer(configs)
    with Timer('dev'):
        dev_UAS, dev_LAS = trainer.dev(dev_data_loader, dev_conllu,
                                       input_conllu_path=configs.input_conllu_path,
                                       output_conllu_path=configs.output_conllu_path)
    print(f'DEV output file saved in {configs.output_conllu_path}')
    print(f'DEV metrics:\nUAS:{dev_UAS}\nLAS:{dev_LAS}')
Пример #4
0
def main():
    with Timer('Parse args'):
        # 加载参数设置
        configs = parse_args()
    # 添加多卡运行下的配置参数, Setup CUDA, GPU & distributed training
    setup_for_multi_gpu(configs)
    # set_seed 必须在设置n_gpu之后
    set_seed(configs)
    # 训练模式下需要创建输出文件夹,以用来保存运行结果,配置文件,模型参数等
    if configs.command == 'train' and configs.local_rank in [-1, 0]:
        setup_output_dir(configs)

    if configs.command == 'train':
        train(configs)
    elif configs.command == 'dev':
        dev(configs)
    # 支持训练完成之后立刻在test上测试结果
    elif configs.command == 'inference':
        inference(configs)
def load_and_cache_examples(args,
                            conllu_file_path,
                            graph_vocab,
                            tokenizer,
                            training=False):
    logger = get_logger(args.log_name)
    word_vocab = tokenizer.vocab if args.encoder_type == 'bertology' else None
    processor = CoNLLUProcessor(args, graph_vocab, word_vocab)
    label_list = graph_vocab.get_labels()

    if args.use_cache:
        cached_dir, _file_name = pathlib.Path(
            args.data_dir) / 'cached', pathlib.Path(conllu_file_path).name
        cached_dataset = cached_dir / \
                         f'{_file_name}_{args.encoder_type}_pos-{args.use_pos}_len-{args.max_seq_len}-dataset.torch.cache'
        cached_conllu = cached_dir / \
                        f'{_file_name}-conllu.pickle.cache'
        if not cached_dir.is_dir():
            cached_dir.mkdir()

    if args.use_cache and args.command == 'train' and cached_dataset.is_file():
        # 加载缓存
        logger.info("Loading cached file")
        if cached_conllu.is_file():
            with open(str(cached_conllu), 'rb') as f:
                conllu_file = pickle.load(f)
        else:
            conllu_file, _ = load_conllu_file(conllu_file_path)
        with Timer('Load cached data set'):
            data_set = torch.load(cached_dataset)
        # if args.use_pos:
        #     pos_tokenizer = get_pos_tokenizer(new_pos_list=training, file_path=cached_dir)
        #     args.pos_label_pad_idx = pos_tokenizer.get_idx('<PAD>')
        #     args.pos_label_num = pos_tokenizer.get_label_num()
        return data_set, conllu_file
    else:
        conllu_file, conllu_data = load_conllu_file(conllu_file_path)
        if args.use_pos:
            # 仅在training=True时,生成新的pos_list
            pos_tokenizer = get_pos_tokenizer(new_pos_list=training,
                                              file_path=cached_dir,
                                              conllu_data=conllu_data)
            args.pos_label_pad_idx = pos_tokenizer.get_idx('<PAD>')
            args.pos_label_num = pos_tokenizer.get_label_num()
        with Timer(f'Create {"train" if training else "dev|infer"} example'):
            examples = processor.create_bert_example(
                conllu_data,
                'train' if training else 'dev',
                args.max_seq_len,
                training=training,
            )
        with Timer(
                f'Convert {"train" if training else "dev|infer"} example to features'
        ):
            features = convert_examples_to_features(
                examples,
                label_list,
                args.max_seq_len,
                tokenizer,
                cls_token_at_end=bool(args.encoder_type in ['xlnet']),
                # xlnet has a cls token at the end
                cls_token=tokenizer.cls_token,
                cls_token_segment_id=2
                if args.encoder_type in ['xlnet'] else 0,
                sep_token=tokenizer.sep_token,
                sep_token_extra=bool(args.encoder_type in ['roberta']),
                # roberta uses an extra separator b/w pairs of sentences,
                # cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
                pad_on_left=bool(args.encoder_type in ['xlnet']),
                # pad on the left for xlnet
                pad_token=tokenizer.convert_tokens_to_ids(
                    [tokenizer.pad_token])[0],
                pad_token_segment_id=4
                if args.encoder_type in ['xlnet'] else 0,
                # skip_too_long_input=args.skip_too_long_input,
                pos_tokenizer=pos_tokenizer if args.use_pos else None)

        # Convert to Tensors and build dataset
        with Timer(
                f'{"train" if training else "dev|infer"} Features to Dataset'):
            data_set = feature_to_dataset(features)

        if args.local_rank in [
                -1, 0
        ] and args.use_cache and args.command == 'train':
            # with Timer(f'Save {"train" if training else "dev|infer"} cache'):
            #     torch.save((conllu_file, features), str(cached_features_file))
            with open(str(cached_conllu), 'wb') as f:
                pickle.dump(conllu_file, f)
            with Timer('Save data set'):
                torch.save(data_set, cached_dataset)
            logger.info("Saved dateset into cached file %s",
                        str(cached_dataset))

    return data_set, conllu_file