def run_eval( model, vocab, samples, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), evaluate_batch(), get_n_items(), ], ) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def report_log_ntrees_stats( samples: Sequence[dict], aa_mask_field: str, batch_size: int = 1, projective: bool = False, multiroot: bool = False, ) -> None: log_ntrees: list = [] pbar = tqdm(total=sum(len(s["words"]) for s in samples), leave=False) for batch in BucketIterator(samples, lambda s: len(s["words"]), batch_size): arr = batch.to_array() aaet_mask = torch.from_numpy(arr[aa_mask_field]).bool() cnt_scores = torch.zeros_like(aaet_mask).float().masked_fill( ~aaet_mask, -1e9) log_ntrees.extend( DepTreeCRF(cnt_scores, projective=projective, multiroot=multiroot).log_partitions().tolist()) pbar.update(arr["words"].size) pbar.close() logger.info( "Log number of trees: min %.2f | q1 %.2f | q2 %.2f | q3 %.2f | max %.2f", np.min(log_ntrees), np.quantile(log_ntrees, 0.25), np.quantile(log_ntrees, 0.5), np.quantile(log_ntrees, 0.75), np.max(log_ntrees), )
def run_eval( model, vocab, samples, compute_loss=True, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return ppt_loss = compute_aatrn_loss( state["total_arc_type_scores"], state["batch"]["ppt_mask"].bool(), projective=projective, multiroot=multiroot, ) state["ppt_loss"] = ppt_loss.item() state["size"] = state["batch"]["words"].size(0) runner.on(Event.BATCH, [ predict_batch(projective, multiroot), evaluate_batch(), get_n_items() ]) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) if compute_loss: MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def maybe_predict(state): if not predict_on_finished: return _log.info("Computing predictions") model.eval() preds, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples), unit="tok") for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) scores = model(words) pred = LinearCRF(scores).argmax() preds.extend(pred.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(preds) == len(eval_samples) assert len(_ids) == len(eval_samples) for i, preds_ in zip(_ids, preds): eval_samples[i]["preds"] = preds_ group = defaultdict(list) for s in eval_samples: group[str(s["path"])].append(s) _log.info("Writing predictions") for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [ config.id2label[x] for s in doc_samples for x in s["preds"] ] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file( f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml" )
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPT.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train", "dev"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"ppt_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_ppt_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective, multiroot) state["ppt_masks"].extend(ppt_mask.tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing PPT ambiguous arcs mask for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["ppt_masks"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, ppt_mask in zip(runner.state["_ids"], runner.state["ppt_masks"]): samples[wh][i]["ppt_mask"] = ppt_mask _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size, projective, multiroot) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] ppt_mask = state["batch"]["ppt_mask"].bool() scores = state["total_arc_type_scores"] ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective, multiroot) ppt_loss /= mask.size(0) loss = ppt_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "ppt_loss": ppt_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) ppt_loss = eval_state["mean_ppt_loss"] _log.info("dev_ppt_loss: %.4f", ppt_loss) _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def run_eval( model, id2label, samples, corpus, _log, device="cpu", batch_size=32, gold_path="", compute_loss=False, confusion=False, ): if not gold_path and not compute_loss: _log.info( "Skipping evaluation since gold data isn't provided and loss isn't required" ) return None, None runner = Runner() runner.state.update({"preds": [], "_ids": []}) @runner.on(Event.BATCH) def maybe_compute_prediction(state): if not gold_path: return arr = state["batch"].to_array() state["arr"] = arr assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) preds = LinearCRF(scores).argmax() state["preds"].extend(preds.tolist()) state["_ids"].extend(arr["_id"].tolist()) if compute_loss: state["scores"] = scores @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return arr = state["arr"] if "arr" in state else state["batch"].to_array() state["arr"] = arr if "scores" in state: scores = state["scores"] else: assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) masked_scores = scores.masked_fill(~ptst_mask, -1e9) crf = LinearCRF(masked_scores) crf_z = LinearCRF(scores) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() state["ptst_loss"] = ptst_loss.item() state["size"] = mask.size(0) @runner.on(Event.BATCH) def set_n_items(state): state["n_items"] = int(state["arr"]["mask"].sum()) n_tokens = sum(len(s["word_ids"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) if compute_loss: MeanReducer("mean_ptst_loss", value="ptst_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size)) if runner.state["preds"]: assert len(runner.state["preds"]) == len(samples) assert len(runner.state["_ids"]) == len(samples) for i, preds in zip(runner.state["_ids"], runner.state["preds"]): samples[i]["preds"] = preds if gold_path: group = defaultdict(list) for s in samples: group[str(s["path"])].append(s) with tempfile.TemporaryDirectory() as dirname: dirname = Path(dirname) for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [id2label[x] for s in doc_samples for x in s["preds"]] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (dirname / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file(f"{str(dirname / doc_path)}.xml") return ( score_time(gold_path, str(dirname), confusion), runner.state.get("mean_ptst_loss"), ) return None, runner.state.get("mean_ptst_loss")
def finetune( _log, _run, _rnd, corpus, artifacts_dir="artifacts", overwrite=False, temperature=1.0, freeze_embeddings=True, freeze_encoder_up_to=1, device="cpu", thresh=0.95, batch_size=16, lr=1e-5, max_epoch=5, predict_on_finished=False, ): """Finetune/train the source model on unlabeled target data.""" artifacts_dir = Path(artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = read_samples_() eval_samples = read_samples_(max_length=None) model_name = "clulab/roberta-timex-semeval" _log.info("Loading %s", model_name) config = AutoConfig.from_pretrained(model_name) token_clf = AutoModelForTokenClassification.from_pretrained(model_name, config=config) model = RoBERTagger(token_clf, config.num_labels, temperature) _log.info("Initializing transitions") torch.nn.init.zeros_(model.start_transition) torch.nn.init.zeros_(model.transition) for lid, label in config.id2label.items(): if not label.startswith("I-"): continue with torch.no_grad(): model.start_transition[lid] = -1e9 for plid, plabel in config.id2label.items(): if plabel == "O" or plabel[2:] != label[2:]: with torch.no_grad(): model.transition[plid, lid] = -1e9 for name, p in model.named_parameters(): freeze = False if freeze_embeddings and ".embeddings." in name: freeze = True if freeze_encoder_up_to >= 0: for i in range(freeze_encoder_up_to + 1): if f".encoder.layer.{i}." in name: freeze = True if freeze: _log.info("Freezing %s", name) p.requires_grad_(False) model.to(device) _log.info("Computing ambiguous PTST tag pairs mask") model.eval() ptst_masks, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok") for batch in BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) with torch.no_grad(): ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh) ptst_masks.extend(ptst_mask.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(ptst_masks) == len(samples) assert len(_ids) == len(samples) for i, ptst_mask in zip(_ids, ptst_masks): samples[i]["ptst_mask"] = ptst_mask _log.info("Report number of sequences") log_total_nseqs, log_nseqs = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), leave=False) for batch in BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) cnt_scores = torch.zeros_like(ptst_mask).float() cnt_scores_masked = cnt_scores.masked_fill(~ptst_mask, -1e9) log_total_nseqs.extend(LinearCRF(cnt_scores).log_partitions().tolist()) log_nseqs.extend( LinearCRF(cnt_scores_masked).log_partitions().tolist()) pbar.update(arr["word_ids"].size) pbar.close() cov = [math.exp(x - x_) for x, x_ in zip(log_nseqs, log_total_nseqs)] _log.info( "Number of seqs: min {:.2} ({:.2}%) | med {:.2} ({:.2}%) | max {:.2} ({:.2}%)" .format( math.exp(min(log_nseqs)), 100 * min(cov), math.exp(median(log_nseqs)), 100 * median(cov), math.exp(max(log_nseqs)), 100 * max(cov), )) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() @finetuner.on(Event.BATCH) def compute_loss(state): arr = state["batch"].to_array() words = torch.from_numpy(arr["word_ids"]).long().to(device) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) model.train() scores = model(words, mask) masked_scores = scores.masked_fill(~ptst_mask, -1e9) # mask passed to LinearCRF shouldn't include the last token last_idx = mask.long().sum(dim=1, keepdim=True) - 1 mask_ = mask.scatter(1, last_idx, False)[:, :-1] crf = LinearCRF(masked_scores, mask_) crf_z = LinearCRF(scores, mask_) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() ptst_loss /= mask.size(0) state["loss"] = ptst_loss state["stats"] = {"ptst_loss": ptst_loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def evaluate(state): _log.info("Evaluating on train") eval_score, loss = run_eval(model, config.id2label, samples, compute_loss=True) if eval_score is not None: print_accs(eval_score, on="train", run=_run, step=state["n_iters"]) _log.info("train_ptst_loss: %.4f", loss) _run.log_scalar("train_ptst_loss", loss, step=state["n_iters"]) _log.info("Evaluating on eval") eval_score, _ = run_eval(model, config.id2label, eval_samples) if eval_score is not None: print_accs(eval_score, on="eval", run=_run, step=state["n_iters"]) state["eval_f1"] = None if eval_score is None else eval_score["f1"] finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) @finetuner.on(Event.FINISHED) def maybe_predict(state): if not predict_on_finished: return _log.info("Computing predictions") model.eval() preds, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples), unit="tok") for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) scores = model(words) pred = LinearCRF(scores).argmax() preds.extend(pred.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(preds) == len(eval_samples) assert len(_ids) == len(eval_samples) for i, preds_ in zip(_ids, preds): eval_samples[i]["preds"] = preds_ group = defaultdict(list) for s in eval_samples: group[str(s["path"])].append(s) _log.info("Writing predictions") for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [ config.id2label[x] for s in doc_samples for x in s["preds"] ] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file( f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml" ) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["word_ids"]) for s in samples) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["word_ids"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples, bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state.get("eval_f1")
def report_coverage(corpus, _log, temperature=1.0, device="cpu", batch_size=16, thresh=0.95, gold_path=""): """Report coverage of gold tags in the chart.""" samples = read_samples_() model_name = "clulab/roberta-timex-semeval" _log.info("Loading %s", model_name) config = AutoConfig.from_pretrained(model_name) token_clf = AutoModelForTokenClassification.from_pretrained(model_name, config=config) model = RoBERTagger(token_clf, config.num_labels, temperature) _log.info("Initializing transitions") torch.nn.init.zeros_(model.start_transition) torch.nn.init.zeros_(model.transition) for lid, label in config.id2label.items(): if not label.startswith("I-"): continue with torch.no_grad(): model.start_transition[lid] = -1e9 for plid, plabel in config.id2label.items(): if plabel == "O" or plabel[2:] != label[2:]: with torch.no_grad(): model.transition[plid, lid] = -1e9 model.to(device) _log.info("Computing ambiguous PTST tag pairs mask") model.eval() ptst_masks, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok") for batch in BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) with torch.no_grad(): ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh) ptst_masks.extend(ptst_mask.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(ptst_masks) == len(samples) assert len(_ids) == len(samples) for i, ptst_mask in zip(_ids, ptst_masks): samples[i]["ptst_mask"] = ptst_mask _log.info("Reporting coverage of gold labels") group = defaultdict(list) for s in samples: k = str(s["path"])[len(f"{corpus['path']}/"):] group[k].append(s) n_cov_tp, n_total_tp, n_cov_ts, n_total_ts = 0, 0, 0, 0 for dirpath, _, filenames in os.walk(gold_path): if not filenames: continue if len(filenames) > 1: raise ValueError(f"more than 1 file is found in {dirpath}") if not filenames[0].endswith(".TimeNorm.gold.completed.xml"): raise ValueError( f"{filenames[0]} doesn't have the expected suffix") doc_path = os.path.join(dirpath, filenames[0]) data = AnaforaData.from_file(doc_path) prefix, suffix = f"{gold_path}/", ".TimeNorm.gold.completed.xml" doc_path = doc_path[len(prefix):-len(suffix)] tok_spans = [p for s in group[doc_path] for p in s["spans"]] tok_spans.sort() labeling = {} for ann in data.annotations: if len(ann.spans) != 1: raise ValueError("found annotation with >1 span") span = ann.spans[0] beg = 0 while beg < len(tok_spans) and tok_spans[beg][0] < span[0]: beg += 1 end = beg while end < len(tok_spans) and tok_spans[end][1] < span[1]: end += 1 if (beg < len(tok_spans) and end < len(tok_spans) and tok_spans[beg][0] == span[0] and tok_spans[end][1] == span[1] and beg not in labeling): labeling[beg] = f"B-{ann.type}" for i in range(beg + 1, end + 1): if i not in labeling: labeling[i] = f"I-{ann.type}" labels = ["O"] * len(tok_spans) for k, v in labeling.items(): labels[k] = v offset = 0 for s in group[doc_path]: ts_covd = True for i in range(1, len(s["spans"])): plab = labels[offset + i - 1] lab = labels[offset + i] if s["ptst_mask"][i - 1][config.label2id[plab]][ config.label2id[lab]]: n_cov_tp += 1 else: ts_covd = False n_total_tp += 1 if ts_covd: n_cov_ts += 1 n_total_ts += 1 offset += len(s["spans"]) _log.info( "Number of covered tag pairs: %d out of %d (%.1f%%)", n_cov_tp, n_total_tp, 100.0 * n_cov_tp / n_total_tp, ) _log.info( "Number of covered tag sequences: %d out of %d (%.1f%%)", n_cov_ts, n_total_ts, 100.0 * n_cov_ts / n_total_ts, )
def train( _log, _run, _rnd, artifacts_dir="artifacts", overwrite=False, max_length=None, load_types_vocab_from=None, batch_size=16, device="cpu", lr=0.001, patience=5, max_epoch=1000, ): """Train a self-attention graph-based parser.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) _log.info("Creating vocabulary") vocab = Vocab.from_samples(chain(*samples.values())) if load_types_vocab_from: path = Path(load_types_vocab_from) _log.info("Loading types vocab from %s", path) vocab["types"] = load(path.read_text(encoding="utf8"))["types"] _log.info("Vocabulary created") for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} model = make_model(vocab) model.to(device) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5) trainer = Runner() trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1}) trainer.on(Event.BATCH, [batch2tensors(device, vocab), set_train_mode(model)]) @trainer.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "heads"], bat["types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss state["loss"] = loss arc_loss, type_loss = arc_loss.item(), type_loss.item() state["stats"] = { "arc_ppl": math.exp(arc_loss), "type_ppl": math.exp(type_loss), } state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss} state["n_items"] = bat["mask"].long().sum().item() trainer.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @trainer.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) scheduler.step(accs["las_nopunct"]) if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]: state["better"] = True elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]: state["better"] = False elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]: state["better"] = True else: state["better"] = False if state["better"]: _log.info("Found new best result on dev!") state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct state["dev_accs"] = accs state["dev_epoch"] = state["epoch"] else: _log.info("Not better, the best so far is epoch %d:", state["dev_epoch"]) print_accs(state["dev_accs"]) print_accs(state["test_accs"], on="test") @trainer.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if not state["better"]: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) state["test_accs"] = eval_state["counts"].accs print_accs(state["test_accs"], on="test", run=_run, step=state["n_iters"]) trainer.on( Event.EPOCH_FINISHED, [ maybe_stop_early(patience=patience), save_state_dict("model", model, under=artifacts_dir, when="better"), ], ) EpochTimer().attach_on(trainer) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting training") try: trainer.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return trainer.state["dev_accs"]["las_nopunct"]
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with self-training.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"st_heads": [], "st_types": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), ], ) @runner.on(Event.BATCH) def save_st_trees(state): state["st_heads"].extend(state["pred_heads"].tolist()) state["st_types"].extend(state["pred_types"].tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing ST trees for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["st_heads"]) == len(samples[wh]) assert len(runner.state["st_types"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, st_heads, st_types in zip(runner.state["_ids"], runner.state["st_heads"], runner.state["st_types"]): assert len(samples[wh][i]["words"]) == len(st_heads) assert len(samples[wh][i]["words"]) == len(st_types) samples[wh][i]["st_heads"] = st_heads samples[wh][i]["st_types"] = st_types _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "st_heads"], bat["st_types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "arc_ppl": arc_loss.exp().item(), "type_ppl": type_loss.exp().item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = { "arc_loss": arc_loss.item(), "type_loss": type_loss.item() } finetuner.on( Event.BATCH, [ get_n_items(), update_params(opt), log_grads(_run, model), log_stats(_run) ], ) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]