Exemple #1
0
def args_parse(config_file=''):
    parser = argparse.ArgumentParser(description="fast-bbdl")
    parser.add_argument(
        "--config_file", default="", help="path to config file", type=str
    )
    parser.add_argument("--opts", help="Modify config options using the command-line key value", default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()

    config_file = args.config_file or config_file

    if config_file != "":
        cfg.merge_from_file(get_abs_path('configs', config_file))
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    name = cfg.MODEL.NAME

    output_dir = cfg.OUTPUT_DIR

    logger = setup_logger(name, get_abs_path(output_dir), 0)
    logger.info(args)

    if config_file != '':
        logger.info("Loaded configuration file {}".format(config_file))
        with open(get_abs_path('configs', config_file), 'r') as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)

    logger.info("Running with config:\n{}".format(cfg))
    return cfg
Exemple #2
0
def make_loaders(cfg, get_loader_fn, **kwargs):
    if cfg.DATASETS.TRAIN == '':
        train_loader = None
    else:
        train_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TRAIN),
                                     batch_size=cfg.SOLVER.BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=cfg.DATALOADER.NUM_WORKERS,
                                     **kwargs)
    if cfg.DATASETS.VALID == '':
        valid_loader = None
    else:
        valid_loader = get_loader_fn(get_abs_path(cfg.DATASETS.VALID),
                                     batch_size=cfg.TEST.BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=cfg.DATALOADER.NUM_WORKERS,
                                     **kwargs)
    if cfg.DATASETS.TEST == '':
        test_loader = None
    else:
        test_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TEST),
                                    batch_size=cfg.TEST.BATCH_SIZE,
                                    shuffle=False,
                                    num_workers=cfg.DATALOADER.NUM_WORKERS,
                                    **kwargs)
    return train_loader, valid_loader, test_loader
Exemple #3
0
def load_model(args):
    from bbcm.config import cfg
    cfg.merge_from_file(get_abs_path('configs', args.config_file))
    tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
    file_dir = get_abs_path("checkpoints", cfg.MODEL.NAME)
    if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
        model = BertForCsc.load_from_checkpoint(os.path.join(
            file_dir, args.ckpt_fn),
                                                cfg=cfg,
                                                tokenizer=tokenizer)
    else:
        model = SoftMaskedBertModel.load_from_checkpoint(os.path.join(
            file_dir, args.ckpt_fn),
                                                         cfg=cfg,
                                                         tokenizer=tokenizer)
    model.eval()
    model.to(cfg.MODEL.DEVICE)

    return model
Exemple #4
0
def load_model_directly(ckpt_file, config_file):
    # Example:
    # ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' (find in checkpoints)
    # config_file = 'csc/train_SoftMaskedBert.yml' (find in configs)

    from bbcm.config import cfg
    cp = get_abs_path('checkpoints', ckpt_file)
    cfg.merge_from_file(get_abs_path('configs', config_file))
    tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)

    if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']:
        model = BertForCsc.load_from_checkpoint(cp,
                                                cfg=cfg,
                                                tokenizer=tokenizer)
    else:
        model = SoftMaskedBertModel.load_from_checkpoint(cp,
                                                         cfg=cfg,
                                                         tokenizer=tokenizer)
    model.eval()
    model.to(cfg.MODEL.DEVICE)
    return model
def main():
    cfg = args_parse("csc/train_macbert4csc.yml")

    # 如果不存在训练文件则先处理数据
    if not os.path.exists(get_abs_path(cfg.DATASETS.TRAIN)):
        preproc()
    tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
    if cfg.MODEL.NAME in ["bert4csc", "macbert4csc"]:
        model = BertForCsc(cfg, tokenizer)
    else:
        model = SoftMaskedBertModel(cfg, tokenizer)

    if len(cfg.MODEL.WEIGHTS) > 0:
        ckpt_path = get_abs_path(cfg.OUTPUT_DIR, cfg.MODEL.WEIGHTS)
        model.load_from_checkpoint(ckpt_path, cfg=cfg, tokenizer=tokenizer)

    loaders = make_loaders(cfg, get_csc_loader, tokenizer=tokenizer)
    ckpt_callback = ModelCheckpoint(monitor='val_loss',
                                    dirpath=get_abs_path(cfg.OUTPUT_DIR),
                                    filename='{epoch}-{val_loss:.2f}',
                                    save_top_k=1,
                                    mode='min')
    train(cfg, model, loaders, ckpt_callback)
def convert(fn, model_name):
    """
    从保存的ckpt文件中取出模型的state_dict用于迁移。
    Args:
        fn: ckpt文件的文件名
        model_name: 模型名,应与yml中的一致。

    Returns:

    """
    file_dir = get_abs_path("checkpoints", model_name)
    state_dict = torch.load((os.path.join(file_dir, fn)))['state_dict']
    new_state_dict = OrderedDict()
    if model_name in ['bert4csc', 'macbert4csc']:
        for k, v in state_dict.items():
            new_state_dict[k[5:]] = v
    else:
        new_state_dict = state_dict
    torch.save(new_state_dict, os.path.join(file_dir, 'pytorch_model.bin'))
Exemple #7
0
def train(config, model, loaders, ckpt_callback=None):
    """
    训练
    Args:
        config: 配置
        model: 模型
        loaders: 各个数据的loader,包含train,valid,test
        ckpt_callback: 按需保存模型的callback,如为空则默认每个epoch保存一次模型。
    Returns:
        None
    """
    train_loader, valid_loader, test_loader = loaders
    trainer = pl.Trainer(
        max_epochs=config.SOLVER.MAX_EPOCHS,
        gpus=None if config.MODEL.DEVICE == 'cpu' else config.MODEL.GPU_IDS,
        accumulate_grad_batches=config.SOLVER.ACCUMULATE_GRAD_BATCHES,
        checkpoint_callback=ckpt_callback)
    # 满足以下条件才进行训练
    # 1. 配置文件中要求进行训练
    # 2. train_loader不为空
    # 3. train_loader中有数据
    if 'train' in config.MODE and train_loader and len(train_loader) > 0:
        if valid_loader and len(valid_loader) > 0:
            trainer.fit(model, train_loader, valid_loader)
        else:
            trainer.fit(model, train_loader)
    # 是否进行测试的逻辑同训练
    if 'test' in config.MODE and test_loader and len(test_loader) > 0:
        if ckpt_callback and len(ckpt_callback.best_model_path) > 0:
            ckpt_path = ckpt_callback.best_model_path
        elif len(config.MODEL.WEIGHTS) > 0:
            ckpt_path = get_abs_path(config.OUTPUT_DIR, config.MODEL.WEIGHTS)
        else:
            ckpt_path = None
        print(ckpt_path)
        if (ckpt_path is not None) and os.path.exists(ckpt_path):
            model.load_state_dict(torch.load(ckpt_path)['state_dict'])
        trainer.test(model, test_loader)
def preproc():
    rst_items = []
    convertor = opencc.OpenCC('tw2sp.json')
    test_items = proc_test_set(get_abs_path('datasets', 'csc'), convertor)
    for item in read_data(get_abs_path('datasets', 'csc')):
        rst_items += proc_item(item, convertor)
    for item in read_confusion_data(get_abs_path('datasets', 'csc')):
        rst_items += proc_confusion_item(item)

    # 拆分训练与测试
    dev_set_len = len(rst_items) // 10
    print(len(rst_items))
    random.seed(666)
    random.shuffle(rst_items)
    dump_json(rst_items[:dev_set_len],
              get_abs_path('datasets', 'csc', 'dev.json'))
    dump_json(rst_items[dev_set_len:],
              get_abs_path('datasets', 'csc', 'train.json'))
    dump_json(test_items, get_abs_path('datasets', 'csc', 'test.json'))
    gc.collect()