Пример #1
0
def validate(loader: torch.utils.data.DataLoader, model: Parser,
             cfg: DictConfig) -> FScore:  # type: ignore
    "Run validation/testing without beam search"

    model.eval()
    # testing requires far less GPU memory than training
    # so there is no need to split a batch into multiple subbatches
    env = Environment(loader, model.encoder, subbatch_max_tokens=9999999)
    state = env.reset()

    pred_trees = []
    gt_trees = []
    time_start = time()

    with torch.no_grad():  # type: ignore
        while True:
            with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                actions, _ = model(state)

            if cfg.decoder == "graph":
                # actions for a single step
                state, done = env.step(actions)
                if not done:
                    continue
            else:
                assert cfg.decoder == "sequence"
                # actions for all steps
                for n_step in itertools.count():
                    a_t = [
                        action_seq[n_step] for action_seq in actions
                        if len(action_seq) > n_step
                    ]
                    _, done = env.step(a_t)
                    if done:
                        break

            pred_trees.extend(env.pred_trees)
            gt_trees.extend(env.gt_trees)

            # load the next batch
            try:
                with torch.cuda.amp.autocast(cfg.amp):  # type: ignore
                    state = env.reset()
            except EpochEnd:
                # no next batch available (complete)
                f1_score = evalb(
                    hydra.utils.to_absolute_path("./EVALB"),
                    gt_trees,
                    pred_trees  # type: ignore
                )
                log.info("Time elapsed: %f" % (time() - time_start))
                return f1_score
Пример #2
0
def beam_search(
        loader: torch.utils.data.DataLoader,
        model: Parser,
        cfg: DictConfig  # type: ignore
) -> FScore:
    "Run validation/testing with beam search"

    model.eval()
    device, _ = get_device()
    gt_trees = []
    pred_trees = []
    bar = ProgressBar(max_value=len(loader))
    time_start = time()

    with torch.no_grad():  # type: ignore

        for i, data_batch in enumerate(loader):
            # calculate token embeddings
            tokens_emb = model.encoder(
                data_batch["tokens_idx"].to(device=device, non_blocking=True),
                data_batch["tags_idx"].to(device=device, non_blocking=True),
                data_batch["valid_tokens_mask"].to(device=device,
                                                   non_blocking=True),
                data_batch["word_end_mask"].to(device=device,
                                               non_blocking=True),
            )
            # initialize the beam
            beam = Beam(
                data_batch["tokens_word"],
                data_batch["tags"],
                tokens_emb,
                model,
                cfg,
            )
            # keep executing actions and updating the beam until the entire batch is finished
            while not beam.grow():
                pass

            gt_trees.extend(data_batch["trees"])
            pred_trees.extend(beam.best_trees())

            bar.update(i)

    f1_score = evalb(hydra.utils.to_absolute_path("./EVALB"), gt_trees,
                     pred_trees)
    log.info("Time elapsed: %f" % (time() - time_start))
    return f1_score