Пример #1
0
def main():
    opts = Argparser().get_training_arguments()
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                             tokenizer, opts.test_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    opts.label2id = CnerDataset.label2id()
    opts.id2label = CnerDataset.id2label()

    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(opts.pretrained_model_path,
                                          num_labels=opts.num_labels,
                                          label2id=opts.label2id,
                                          id2label=opts.id2label)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)

    # trainer
    logger.info("initializing traniner")
    labels = {
        label.split('-')[1]
        for label in CnerDataset.get_labels() if '-' in label
    }
    metrics = [
        SequenceLabelingScore(labels=labels, average='micro', schema='BIOS')
    ]
    trainer = SequenceLabelingTrainer(opts=opts,
                                      model=model,
                                      tokenizer=tokenizer,
                                      metrics=metrics,
                                      logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)

    if opts.do_predict:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)
Пример #2
0
def main():
    parser = Argparser.get_training_parser()
    group = parser.add_argument_group(title="global pointer",
                                      description="Global pointer")
    group.add_argument("--decode_thresh", type=float, default=0.0)
    group.add_argument('--pe_dim',
                       default=64,
                       type=int,
                       help='The dim of Positional embedding')
    group.add_argument('--use_rope', action='store_true')
    opts = parser.parse_args_from_parser(parser)
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                             tokenizer, opts.test_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    opts.label2id = CnerDataset.label2id()
    opts.id2label = CnerDataset.id2label()
    # model
    logger.info("initializing model and config")
    config, unused_kwargs = config_class.from_pretrained(
        opts.pretrained_model_path,
        return_unused_kwargs=True,
        pe_dim=opts.pe_dim,
        use_rope=opts.use_rope,
        num_labels=opts.num_labels,
        id2label=opts.id2label,
        label2id=opts.label2id,
        decode_thresh=opts.decode_thresh,
        max_seq_length=512)
    # FIXED: 默认`from_dict`中,只有config中有键才能设置值,这里强制设置
    for key, value in unused_kwargs.items():
        setattr(config, key, value)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)
    # trainer
    logger.info("initializing traniner")
    metrics = [
        SequenceLabelingScore(CnerDataset.get_labels(),
                              schema='BIOS',
                              average='micro')
    ]
    trainer = SequenceLabelingTrainer(opts=opts,
                                      model=model,
                                      tokenizer=tokenizer,
                                      metrics=metrics,
                                      logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)
    if opts.do_predict:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)
Пример #3
0
def main():
    opts = Argparser().get_training_arguments()
    logger = Logger(opts=opts)
    # device
    logger.info("initializing device")
    opts.device, opts.device_num = prepare_device(opts.device_id)
    seed_everything(opts.seed)
    config_class, model_class, tokenizer_class = MODEL_CLASSES[opts.model_type]
    # data processor
    logger.info("initializing data processor")
    tokenizer = tokenizer_class.from_pretrained(
        opts.pretrained_model_path, do_lower_case=opts.do_lower_case)
    train_dataset = load_data(opts.train_input_file, opts.data_dir, "train",
                              tokenizer, opts.train_max_seq_length)
    dev_dataset = load_data(opts.eval_input_file, opts.data_dir, "dev",
                            tokenizer, opts.eval_max_seq_length)
    opts.num_labels = train_dataset.num_labels
    # model
    logger.info("initializing model and config")
    config = config_class.from_pretrained(opts.pretrained_model_path,
                                          num_labels=opts.num_labels)
    model = model_class.from_pretrained(opts.pretrained_model_path,
                                        config=config)
    model.to(opts.device)
    # trainer
    logger.info("initializing traniner")
    trainer = TextClassifierTrainer(
        opts=opts,
        model=model,
        tokenizer=tokenizer,
        metrics=[MattewsCorrcoef(num_classes=opts.num_labels)],
        logger=logger)
    # do train
    if opts.do_train:
        trainer.train(train_data=train_dataset, dev_data=dev_dataset)
    if opts.do_eval:
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        if opts.eval_all_checkpoints:
            checkpoints = find_all_checkpoints(checkpoint_dir=opts.output_dir)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.evaluate(dev_data=dev_dataset,
                             save_result=True,
                             save_dir=prefix)

    if opts.do_predict:
        test_dataset = load_data(opts.test_input_file, opts.data_dir, "test",
                                 tokenizer, opts.test_max_seq_length)
        checkpoints = []
        if opts.checkpoint_predict_code is not None:
            checkpoint = os.path.join(opts.output_dir,
                                      opts.checkpoint_predict_code)
            check_dir(checkpoint)
            checkpoints.append(checkpoint)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            prefix = checkpoint.split("/")[-1]
            model = model_class.from_pretrained(checkpoint, config=config)
            model.to(opts.device)
            trainer.model = model
            trainer.predict(test_data=test_dataset,
                            save_result=True,
                            save_dir=prefix)