Exemple #1
0
def main(cfg: DictConfig) -> None:
    "The entry point for parsing user-provided texts"

    assert cfg.model_path is not None, "Need to specify model_path for testing."
    assert cfg.input is not None
    assert cfg.language in ("english", "chinese")
    log.info("\n" + OmegaConf.to_yaml(cfg))

    # load the model checkpoint
    model_path = hydra.utils.to_absolute_path(cfg.model_path)
    log.info("Loading the model from %s" % model_path)
    checkpoint = load_model(model_path)
    restore_hyperparams(checkpoint["cfg"], cfg)
    vocabs = checkpoint["vocabs"]

    model = Parser(vocabs, cfg)
    model.load_state_dict(checkpoint["model_state"])
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()]))

    input_file = hydra.utils.to_absolute_path(cfg.input)
    ds = UserProvidedTexts(input_file, cfg.language, vocabs, cfg.encoder)
    loader = DataLoader(
        ds,
        batch_size=cfg.eval_batch_size,
        collate_fn=form_batch,
        num_workers=cfg.num_workers,
        pin_memory=torch.cuda.is_available(),
    )

    env = Environment(loader, model.encoder, subbatch_max_tokens=9999999)
    state = env.reset()
    oup = (sys.stdout if cfg.output is None else open(
        hydra.utils.to_absolute_path(cfg.output), "wt"))
    time_start = time()

    with torch.no_grad():  # type: ignore
        while True:
            with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                actions, _ = model(state)
            state, done = env.step(actions)
            if done:
                for tree in env.pred_trees:
                    assert tree is not None
                    print(tree.linearize(), file=oup)
                # pred_trees.extend(env.pred_trees)
                # load the next batch
                try:
                    with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                        state = env.reset()
                except EpochEnd:
                    # no next batch available (complete)
                    log.info("Time elapsed: %f" % (time() - time_start))
                    break

    if cfg.output is not None:
        log.info("Parse trees saved to %s" % cfg.output)
def main(cfg: DictConfig) -> None:
    "The entry point for testing"

    assert cfg.model_path is not None, "Need to specify model_path for testing."
    log.info("\n" + OmegaConf.to_yaml(cfg))

    # restore the hyperparameters used for training
    model_path = hydra.utils.to_absolute_path(cfg.model_path)
    log.info("Loading the model from %s" % model_path)
    checkpoint = load_model(model_path)
    restore_hyperparams(checkpoint["cfg"], cfg)

    # create dataloaders for validation and testing
    vocabs = checkpoint["vocabs"]
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )
    loader_test, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_test),
        "test",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # restore the trained model checkpoint
    model = Parser(vocabs, cfg)
    model.load_state_dict(checkpoint["model_state"])
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()]))

    # validation
    log.info("Validating..")
    f1_score = validate(loader_val, model, cfg)
    log.info(
        "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))

    # testing
    log.info("Testing..")
    if cfg.beam_size > 1:
        log.info("Performing beam search..")
        f1_score = beam_search(loader_test, model, cfg)
    else:
        log.info("Running without beam search..")
        f1_score = validate(loader_test, model, cfg)
    log.info(
        "Testing F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))
def train_val(cfg: DictConfig) -> None:

    # create dataloaders for training and validation
    loader_train, vocabs = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_train),
        "train",
        cfg.encoder,
        None,
        cfg.batch_size,
        cfg.num_workers,
    )
    assert vocabs is not None
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # create the model
    model = Parser(vocabs, cfg)
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % count_params(model))

    # create the optimizer
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
    )
    start_epoch = 0
    if cfg.resume is not None:  # resume training from a checkpoint
        checkpoint = load_model(cfg.resume)
        model.load_state_dict(checkpoint["model_state"])
        start_epoch = checkpoint["epoch"] + 1
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        del checkpoint
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=cfg.learning_rate_patience,
        cooldown=cfg.learning_rate_cooldown,
        verbose=True,
    )

    # start training and validation
    best_f1_score = -1.0
    num_iters = 0

    for epoch in range(start_epoch, cfg.num_epochs):
        log.info("Epoch #%d" % epoch)

        if not cfg.skip_training:
            log.info("Training..")
            num_iters, accuracy_train, loss_train = train(
                num_iters,
                loader_train,
                model,
                optimizer,
                vocabs["label"],
                cfg,
            )
            log.info("Action accuracy: %.03f, Loss: %.03f" %
                     (accuracy_train, loss_train))

        log.info("Validating..")
        f1_score_val = validate(loader_val, model, cfg)

        log.info(
            "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
            % (
                f1_score_val.fscore,
                f1_score_val.complete_match,
                f1_score_val.precision,
                f1_score_val.recall,
            ))

        if f1_score_val.fscore > best_f1_score:
            log.info("F1 score has improved")
            best_f1_score = f1_score_val.fscore

        scheduler.step(best_f1_score)

        save_checkpoint(
            "model_latest.pth",
            epoch,
            model,
            optimizer,
            f1_score_val.fscore,
            vocabs,
            cfg,
        )