def main(cfg: DictConfig) -> None: "The entry point for parsing user-provided texts" assert cfg.model_path is not None, "Need to specify model_path for testing." assert cfg.input is not None assert cfg.language in ("english", "chinese") log.info("\n" + OmegaConf.to_yaml(cfg)) # load the model checkpoint model_path = hydra.utils.to_absolute_path(cfg.model_path) log.info("Loading the model from %s" % model_path) checkpoint = load_model(model_path) restore_hyperparams(checkpoint["cfg"], cfg) vocabs = checkpoint["vocabs"] model = Parser(vocabs, cfg) model.load_state_dict(checkpoint["model_state"]) device, _ = get_device() model.to(device) log.info("\n" + str(model)) log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()])) input_file = hydra.utils.to_absolute_path(cfg.input) ds = UserProvidedTexts(input_file, cfg.language, vocabs, cfg.encoder) loader = DataLoader( ds, batch_size=cfg.eval_batch_size, collate_fn=form_batch, num_workers=cfg.num_workers, pin_memory=torch.cuda.is_available(), ) env = Environment(loader, model.encoder, subbatch_max_tokens=9999999) state = env.reset() oup = (sys.stdout if cfg.output is None else open( hydra.utils.to_absolute_path(cfg.output), "wt")) time_start = time() with torch.no_grad(): # type: ignore while True: with torch.cuda.amp.autocast(cfg.amp): # type: ignore actions, _ = model(state) state, done = env.step(actions) if done: for tree in env.pred_trees: assert tree is not None print(tree.linearize(), file=oup) # pred_trees.extend(env.pred_trees) # load the next batch try: with torch.cuda.amp.autocast(cfg.amp): # type: ignore state = env.reset() except EpochEnd: # no next batch available (complete) log.info("Time elapsed: %f" % (time() - time_start)) break if cfg.output is not None: log.info("Parse trees saved to %s" % cfg.output)
def main(cfg: DictConfig) -> None: "The entry point for testing" assert cfg.model_path is not None, "Need to specify model_path for testing." log.info("\n" + OmegaConf.to_yaml(cfg)) # restore the hyperparameters used for training model_path = hydra.utils.to_absolute_path(cfg.model_path) log.info("Loading the model from %s" % model_path) checkpoint = load_model(model_path) restore_hyperparams(checkpoint["cfg"], cfg) # create dataloaders for validation and testing vocabs = checkpoint["vocabs"] loader_val, _ = create_dataloader( hydra.utils.to_absolute_path(cfg.path_val), "val", cfg.encoder, vocabs, cfg.eval_batch_size, cfg.num_workers, ) loader_test, _ = create_dataloader( hydra.utils.to_absolute_path(cfg.path_test), "test", cfg.encoder, vocabs, cfg.eval_batch_size, cfg.num_workers, ) # restore the trained model checkpoint model = Parser(vocabs, cfg) model.load_state_dict(checkpoint["model_state"]) device, _ = get_device() model.to(device) log.info("\n" + str(model)) log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()])) # validation log.info("Validating..") f1_score = validate(loader_val, model, cfg) log.info( "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f" % ( f1_score.fscore, f1_score.complete_match, f1_score.precision, f1_score.recall, )) # testing log.info("Testing..") if cfg.beam_size > 1: log.info("Performing beam search..") f1_score = beam_search(loader_test, model, cfg) else: log.info("Running without beam search..") f1_score = validate(loader_test, model, cfg) log.info( "Testing F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f" % ( f1_score.fscore, f1_score.complete_match, f1_score.precision, f1_score.recall, ))
def train_val(cfg: DictConfig) -> None: # create dataloaders for training and validation loader_train, vocabs = create_dataloader( hydra.utils.to_absolute_path(cfg.path_train), "train", cfg.encoder, None, cfg.batch_size, cfg.num_workers, ) assert vocabs is not None loader_val, _ = create_dataloader( hydra.utils.to_absolute_path(cfg.path_val), "val", cfg.encoder, vocabs, cfg.eval_batch_size, cfg.num_workers, ) # create the model model = Parser(vocabs, cfg) device, _ = get_device() model.to(device) log.info("\n" + str(model)) log.info("#parameters = %d" % count_params(model)) # create the optimizer optimizer = torch.optim.RMSprop( model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay, ) start_epoch = 0 if cfg.resume is not None: # resume training from a checkpoint checkpoint = load_model(cfg.resume) model.load_state_dict(checkpoint["model_state"]) start_epoch = checkpoint["epoch"] + 1 optimizer.load_state_dict(checkpoint["optimizer_state"]) del checkpoint scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.5, patience=cfg.learning_rate_patience, cooldown=cfg.learning_rate_cooldown, verbose=True, ) # start training and validation best_f1_score = -1.0 num_iters = 0 for epoch in range(start_epoch, cfg.num_epochs): log.info("Epoch #%d" % epoch) if not cfg.skip_training: log.info("Training..") num_iters, accuracy_train, loss_train = train( num_iters, loader_train, model, optimizer, vocabs["label"], cfg, ) log.info("Action accuracy: %.03f, Loss: %.03f" % (accuracy_train, loss_train)) log.info("Validating..") f1_score_val = validate(loader_val, model, cfg) log.info( "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f" % ( f1_score_val.fscore, f1_score_val.complete_match, f1_score_val.precision, f1_score_val.recall, )) if f1_score_val.fscore > best_f1_score: log.info("F1 score has improved") best_f1_score = f1_score_val.fscore scheduler.step(best_f1_score) save_checkpoint( "model_latest.pth", epoch, model, optimizer, f1_score_val.fscore, vocabs, cfg, )