def validate(loader: torch.utils.data.DataLoader, model: Parser, cfg: DictConfig) -> FScore: # type: ignore "Run validation/testing without beam search" model.eval() # testing requires far less GPU memory than training # so there is no need to split a batch into multiple subbatches env = Environment(loader, model.encoder, subbatch_max_tokens=9999999) state = env.reset() pred_trees = [] gt_trees = [] time_start = time() with torch.no_grad(): # type: ignore while True: with torch.cuda.amp.autocast(cfg.amp): # type: ignore actions, _ = model(state) if cfg.decoder == "graph": # actions for a single step state, done = env.step(actions) if not done: continue else: assert cfg.decoder == "sequence" # actions for all steps for n_step in itertools.count(): a_t = [ action_seq[n_step] for action_seq in actions if len(action_seq) > n_step ] _, done = env.step(a_t) if done: break pred_trees.extend(env.pred_trees) gt_trees.extend(env.gt_trees) # load the next batch try: with torch.cuda.amp.autocast(cfg.amp): # type: ignore state = env.reset() except EpochEnd: # no next batch available (complete) f1_score = evalb( hydra.utils.to_absolute_path("./EVALB"), gt_trees, pred_trees # type: ignore ) log.info("Time elapsed: %f" % (time() - time_start)) return f1_score
def beam_search( loader: torch.utils.data.DataLoader, model: Parser, cfg: DictConfig # type: ignore ) -> FScore: "Run validation/testing with beam search" model.eval() device, _ = get_device() gt_trees = [] pred_trees = [] bar = ProgressBar(max_value=len(loader)) time_start = time() with torch.no_grad(): # type: ignore for i, data_batch in enumerate(loader): # calculate token embeddings tokens_emb = model.encoder( data_batch["tokens_idx"].to(device=device, non_blocking=True), data_batch["tags_idx"].to(device=device, non_blocking=True), data_batch["valid_tokens_mask"].to(device=device, non_blocking=True), data_batch["word_end_mask"].to(device=device, non_blocking=True), ) # initialize the beam beam = Beam( data_batch["tokens_word"], data_batch["tags"], tokens_emb, model, cfg, ) # keep executing actions and updating the beam until the entire batch is finished while not beam.grow(): pass gt_trees.extend(data_batch["trees"]) pred_trees.extend(beam.best_trees()) bar.update(i) f1_score = evalb(hydra.utils.to_absolute_path("./EVALB"), gt_trees, pred_trees) log.info("Time elapsed: %f" % (time() - time_start)) return f1_score