def args_parse(config_file=''): parser = argparse.ArgumentParser(description="fast-bbdl") parser.add_argument( "--config_file", default="", help="path to config file", type=str ) parser.add_argument("--opts", help="Modify config options using the command-line key value", default=None, nargs=argparse.REMAINDER) args = parser.parse_args() config_file = args.config_file or config_file if config_file != "": cfg.merge_from_file(get_abs_path('configs', config_file)) cfg.merge_from_list(args.opts) cfg.freeze() name = cfg.MODEL.NAME output_dir = cfg.OUTPUT_DIR logger = setup_logger(name, get_abs_path(output_dir), 0) logger.info(args) if config_file != '': logger.info("Loaded configuration file {}".format(config_file)) with open(get_abs_path('configs', config_file), 'r') as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) return cfg
def make_loaders(cfg, get_loader_fn, **kwargs): if cfg.DATASETS.TRAIN == '': train_loader = None else: train_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TRAIN), batch_size=cfg.SOLVER.BATCH_SIZE, shuffle=True, num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) if cfg.DATASETS.VALID == '': valid_loader = None else: valid_loader = get_loader_fn(get_abs_path(cfg.DATASETS.VALID), batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) if cfg.DATASETS.TEST == '': test_loader = None else: test_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TEST), batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) return train_loader, valid_loader, test_loader
def load_model(args): from bbcm.config import cfg cfg.merge_from_file(get_abs_path('configs', args.config_file)) tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) file_dir = get_abs_path("checkpoints", cfg.MODEL.NAME) if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: model = BertForCsc.load_from_checkpoint(os.path.join( file_dir, args.ckpt_fn), cfg=cfg, tokenizer=tokenizer) else: model = SoftMaskedBertModel.load_from_checkpoint(os.path.join( file_dir, args.ckpt_fn), cfg=cfg, tokenizer=tokenizer) model.eval() model.to(cfg.MODEL.DEVICE) return model
def load_model_directly(ckpt_file, config_file): # Example: # ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' (find in checkpoints) # config_file = 'csc/train_SoftMaskedBert.yml' (find in configs) from bbcm.config import cfg cp = get_abs_path('checkpoints', ckpt_file) cfg.merge_from_file(get_abs_path('configs', config_file)) tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: model = BertForCsc.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) else: model = SoftMaskedBertModel.load_from_checkpoint(cp, cfg=cfg, tokenizer=tokenizer) model.eval() model.to(cfg.MODEL.DEVICE) return model
def main(): cfg = args_parse("csc/train_macbert4csc.yml") # 如果不存在训练文件则先处理数据 if not os.path.exists(get_abs_path(cfg.DATASETS.TRAIN)): preproc() tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) if cfg.MODEL.NAME in ["bert4csc", "macbert4csc"]: model = BertForCsc(cfg, tokenizer) else: model = SoftMaskedBertModel(cfg, tokenizer) if len(cfg.MODEL.WEIGHTS) > 0: ckpt_path = get_abs_path(cfg.OUTPUT_DIR, cfg.MODEL.WEIGHTS) model.load_from_checkpoint(ckpt_path, cfg=cfg, tokenizer=tokenizer) loaders = make_loaders(cfg, get_csc_loader, tokenizer=tokenizer) ckpt_callback = ModelCheckpoint(monitor='val_loss', dirpath=get_abs_path(cfg.OUTPUT_DIR), filename='{epoch}-{val_loss:.2f}', save_top_k=1, mode='min') train(cfg, model, loaders, ckpt_callback)
def convert(fn, model_name): """ 从保存的ckpt文件中取出模型的state_dict用于迁移。 Args: fn: ckpt文件的文件名 model_name: 模型名,应与yml中的一致。 Returns: """ file_dir = get_abs_path("checkpoints", model_name) state_dict = torch.load((os.path.join(file_dir, fn)))['state_dict'] new_state_dict = OrderedDict() if model_name in ['bert4csc', 'macbert4csc']: for k, v in state_dict.items(): new_state_dict[k[5:]] = v else: new_state_dict = state_dict torch.save(new_state_dict, os.path.join(file_dir, 'pytorch_model.bin'))
def train(config, model, loaders, ckpt_callback=None): """ 训练 Args: config: 配置 model: 模型 loaders: 各个数据的loader,包含train,valid,test ckpt_callback: 按需保存模型的callback,如为空则默认每个epoch保存一次模型。 Returns: None """ train_loader, valid_loader, test_loader = loaders trainer = pl.Trainer( max_epochs=config.SOLVER.MAX_EPOCHS, gpus=None if config.MODEL.DEVICE == 'cpu' else config.MODEL.GPU_IDS, accumulate_grad_batches=config.SOLVER.ACCUMULATE_GRAD_BATCHES, checkpoint_callback=ckpt_callback) # 满足以下条件才进行训练 # 1. 配置文件中要求进行训练 # 2. train_loader不为空 # 3. train_loader中有数据 if 'train' in config.MODE and train_loader and len(train_loader) > 0: if valid_loader and len(valid_loader) > 0: trainer.fit(model, train_loader, valid_loader) else: trainer.fit(model, train_loader) # 是否进行测试的逻辑同训练 if 'test' in config.MODE and test_loader and len(test_loader) > 0: if ckpt_callback and len(ckpt_callback.best_model_path) > 0: ckpt_path = ckpt_callback.best_model_path elif len(config.MODEL.WEIGHTS) > 0: ckpt_path = get_abs_path(config.OUTPUT_DIR, config.MODEL.WEIGHTS) else: ckpt_path = None print(ckpt_path) if (ckpt_path is not None) and os.path.exists(ckpt_path): model.load_state_dict(torch.load(ckpt_path)['state_dict']) trainer.test(model, test_loader)
def preproc(): rst_items = [] convertor = opencc.OpenCC('tw2sp.json') test_items = proc_test_set(get_abs_path('datasets', 'csc'), convertor) for item in read_data(get_abs_path('datasets', 'csc')): rst_items += proc_item(item, convertor) for item in read_confusion_data(get_abs_path('datasets', 'csc')): rst_items += proc_confusion_item(item) # 拆分训练与测试 dev_set_len = len(rst_items) // 10 print(len(rst_items)) random.seed(666) random.shuffle(rst_items) dump_json(rst_items[:dev_set_len], get_abs_path('datasets', 'csc', 'dev.json')) dump_json(rst_items[dev_set_len:], get_abs_path('datasets', 'csc', 'train.json')) dump_json(test_items, get_abs_path('datasets', 'csc', 'test.json')) gc.collect()