def maybe_compute_prediction(state): if not gold_path: return arr = state["batch"].to_array() state["arr"] = arr assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) preds = LinearCRF(scores).argmax() state["preds"].extend(preds.tolist()) state["_ids"].extend(arr["_id"].tolist()) if compute_loss: state["scores"] = scores
def maybe_predict(state): if not predict_on_finished: return _log.info("Computing predictions") model.eval() preds, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples), unit="tok") for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) scores = model(words) pred = LinearCRF(scores).argmax() preds.extend(pred.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(preds) == len(eval_samples) assert len(_ids) == len(eval_samples) for i, preds_ in zip(_ids, preds): eval_samples[i]["preds"] = preds_ group = defaultdict(list) for s in eval_samples: group[str(s["path"])].append(s) _log.info("Writing predictions") for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [ config.id2label[x] for s in doc_samples for x in s["preds"] ] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file( f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml" )