Ejemplo n.º 1
0
def main(cfg: DictConfig) -> None:
    "The entry point for parsing user-provided texts"

    assert cfg.model_path is not None, "Need to specify model_path for testing."
    assert cfg.input is not None
    assert cfg.language in ("english", "chinese")
    log.info("\n" + OmegaConf.to_yaml(cfg))

    # load the model checkpoint
    model_path = hydra.utils.to_absolute_path(cfg.model_path)
    log.info("Loading the model from %s" % model_path)
    checkpoint = load_model(model_path)
    restore_hyperparams(checkpoint["cfg"], cfg)
    vocabs = checkpoint["vocabs"]

    model = Parser(vocabs, cfg)
    model.load_state_dict(checkpoint["model_state"])
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()]))

    input_file = hydra.utils.to_absolute_path(cfg.input)
    ds = UserProvidedTexts(input_file, cfg.language, vocabs, cfg.encoder)
    loader = DataLoader(
        ds,
        batch_size=cfg.eval_batch_size,
        collate_fn=form_batch,
        num_workers=cfg.num_workers,
        pin_memory=torch.cuda.is_available(),
    )

    env = Environment(loader, model.encoder, subbatch_max_tokens=9999999)
    state = env.reset()
    oup = (sys.stdout if cfg.output is None else open(
        hydra.utils.to_absolute_path(cfg.output), "wt"))
    time_start = time()

    with torch.no_grad():  # type: ignore
        while True:
            with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                actions, _ = model(state)
            state, done = env.step(actions)
            if done:
                for tree in env.pred_trees:
                    assert tree is not None
                    print(tree.linearize(), file=oup)
                # pred_trees.extend(env.pred_trees)
                # load the next batch
                try:
                    with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                        state = env.reset()
                except EpochEnd:
                    # no next batch available (complete)
                    log.info("Time elapsed: %f" % (time() - time_start))
                    break

    if cfg.output is not None:
        log.info("Parse trees saved to %s" % cfg.output)
Ejemplo n.º 2
0
def main(cfg: DictConfig) -> None:
    "The entry point for testing"

    assert cfg.model_path is not None, "Need to specify model_path for testing."
    log.info("\n" + OmegaConf.to_yaml(cfg))

    # restore the hyperparameters used for training
    model_path = hydra.utils.to_absolute_path(cfg.model_path)
    log.info("Loading the model from %s" % model_path)
    checkpoint = load_model(model_path)
    restore_hyperparams(checkpoint["cfg"], cfg)

    # create dataloaders for validation and testing
    vocabs = checkpoint["vocabs"]
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )
    loader_test, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_test),
        "test",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # restore the trained model checkpoint
    model = Parser(vocabs, cfg)
    model.load_state_dict(checkpoint["model_state"])
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % sum([p.numel() for p in model.parameters()]))

    # validation
    log.info("Validating..")
    f1_score = validate(loader_val, model, cfg)
    log.info(
        "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))

    # testing
    log.info("Testing..")
    if cfg.beam_size > 1:
        log.info("Performing beam search..")
        f1_score = beam_search(loader_test, model, cfg)
    else:
        log.info("Running without beam search..")
        f1_score = validate(loader_test, model, cfg)
    log.info(
        "Testing F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
        % (
            f1_score.fscore,
            f1_score.complete_match,
            f1_score.precision,
            f1_score.recall,
        ))
Ejemplo n.º 3
0
def train(
    num_iters: int,
    loader: torch.utils.data.DataLoader,  # type: ignore
    model: Parser,
    optimizer: torch.optim.Optimizer,
    label_vocab: List[Label],
    cfg: DictConfig,
) -> Tuple[int, float, float]:
    "Train the model for one epoch"

    model.train()
    env = Environment(loader, model.encoder, cfg.subbatch_max_tokens)
    optimizer.zero_grad()
    state = env.reset()
    device, _ = get_device()
    loss = torch.tensor(0.0, device=device)

    # stats
    losses = [0.0]
    num_examples = 0
    num_correct_actions = 0
    num_total_actions = 0

    time_start = time()

    # Each batch is divided into multiple subbatches (for saving GPU memory).
    # Accumulate gradients calculated from subbatches and perform a single optimization step for a batch (not subbatch)
    while True:
        actions, logits = model(state)  # action generation from partial trees

        if cfg.decoder == "graph":
            # for graph-based decoder, actons: List[Action] are actions at the current step for a subbatch
            gt_actions = env.gt_actions()
            loss += action_loss(logits, gt_actions, label_vocab,
                                cfg.batch_size)

            correct, total = count_actions(actions, gt_actions)
            num_correct_actions += correct
            num_total_actions += total

            state, done = env.step(gt_actions)  # teacher forcing
            if done:  # a subbatch is finished
                num_examples += len(env.pred_trees)
            else:
                continue

        else:
            # for sequence-based decoder, actons: List[List[Action]] are action sequences for all steps
            assert cfg.decoder == "sequence"
            all_gt_actions = env.gt_action_seqs()
            loss = action_seqs_loss(logits, all_gt_actions, label_vocab,
                                    cfg.batch_size)

            correct, total = count_actions(actions, all_gt_actions)
            num_correct_actions += correct
            num_total_actions += total

            num_examples += len(all_gt_actions)

        # a subbatch is finished
        losses[-1] += loss.item()
        loss.backward()  # type: ignore
        loss = 0  # type: ignore

        if num_examples % cfg.batch_size == 0:  # a full batch is finished
            if num_iters <= cfg.learning_rate_warmup_steps:
                adjust_lr(num_iters, optimizer, cfg)

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           cfg.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

            losses.append(0)
            num_iters += 1

        try:
            state = env.reset(force=True)  # load a new batch
        except EpochEnd:
            accuracy = 100 * num_correct_actions / num_total_actions
            return num_iters, accuracy, np.mean(losses)

        # log training stats
        if (num_examples / cfg.batch_size) % cfg.log_freq == 0:
            recent_loss = np.mean(losses[-cfg.log_freq:])
            running_accuracy = 100 * num_correct_actions / num_total_actions
            log.info("[%d] Loss: %.03f, Running accuracy: %.03f, Time: %.02f" %
                     (
                         num_examples,
                         recent_loss,
                         running_accuracy,
                         time() - time_start,
                     ))
            time_start = time()
Ejemplo n.º 4
0
def train_val(cfg: DictConfig) -> None:

    # create dataloaders for training and validation
    loader_train, vocabs = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_train),
        "train",
        cfg.encoder,
        None,
        cfg.batch_size,
        cfg.num_workers,
    )
    assert vocabs is not None
    loader_val, _ = create_dataloader(
        hydra.utils.to_absolute_path(cfg.path_val),
        "val",
        cfg.encoder,
        vocabs,
        cfg.eval_batch_size,
        cfg.num_workers,
    )

    # create the model
    model = Parser(vocabs, cfg)
    device, _ = get_device()
    model.to(device)
    log.info("\n" + str(model))
    log.info("#parameters = %d" % count_params(model))

    # create the optimizer
    optimizer = torch.optim.RMSprop(
        model.parameters(),
        lr=cfg.learning_rate,
        weight_decay=cfg.weight_decay,
    )
    start_epoch = 0
    if cfg.resume is not None:  # resume training from a checkpoint
        checkpoint = load_model(cfg.resume)
        model.load_state_dict(checkpoint["model_state"])
        start_epoch = checkpoint["epoch"] + 1
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        del checkpoint
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.5,
        patience=cfg.learning_rate_patience,
        cooldown=cfg.learning_rate_cooldown,
        verbose=True,
    )

    # start training and validation
    best_f1_score = -1.0
    num_iters = 0

    for epoch in range(start_epoch, cfg.num_epochs):
        log.info("Epoch #%d" % epoch)

        if not cfg.skip_training:
            log.info("Training..")
            num_iters, accuracy_train, loss_train = train(
                num_iters,
                loader_train,
                model,
                optimizer,
                vocabs["label"],
                cfg,
            )
            log.info("Action accuracy: %.03f, Loss: %.03f" %
                     (accuracy_train, loss_train))

        log.info("Validating..")
        f1_score_val = validate(loader_val, model, cfg)

        log.info(
            "Validation F1 score: %.03f, Exact match: %.03f, Precision: %.03f, Recall: %.03f"
            % (
                f1_score_val.fscore,
                f1_score_val.complete_match,
                f1_score_val.precision,
                f1_score_val.recall,
            ))

        if f1_score_val.fscore > best_f1_score:
            log.info("F1 score has improved")
            best_f1_score = f1_score_val.fscore

        scheduler.step(best_f1_score)

        save_checkpoint(
            "model_latest.pth",
            epoch,
            model,
            optimizer,
            f1_score_val.fscore,
            vocabs,
            cfg,
        )