def run_eval( model, vocab, samples, compute_loss=True, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return ppt_loss = compute_aatrn_loss( state["total_arc_type_scores"], state["batch"]["ppt_mask"].bool(), projective=projective, multiroot=multiroot, ) state["ppt_loss"] = ppt_loss.item() state["size"] = state["batch"]["words"].size(0) runner.on(Event.BATCH, [ predict_batch(projective, multiroot), evaluate_batch(), get_n_items() ]) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) if compute_loss: MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def run_eval( model, vocab, samples, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), evaluate_batch(), get_n_items(), ], ) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPT.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train", "dev"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"ppt_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_ppt_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective, multiroot) state["ppt_masks"].extend(ppt_mask.tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing PPT ambiguous arcs mask for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["ppt_masks"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, ppt_mask in zip(runner.state["_ids"], runner.state["ppt_masks"]): samples[wh][i]["ppt_mask"] = ppt_mask _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size, projective, multiroot) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] ppt_mask = state["batch"]["ppt_mask"].bool() scores = state["total_arc_type_scores"] ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective, multiroot) ppt_loss /= mask.size(0) loss = ppt_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "ppt_loss": ppt_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) ppt_loss = eval_state["mean_ppt_loss"] _log.info("dev_ppt_loss: %.4f", ppt_loss) _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def run_eval( model, id2label, samples, corpus, _log, device="cpu", batch_size=32, gold_path="", compute_loss=False, confusion=False, ): if not gold_path and not compute_loss: _log.info( "Skipping evaluation since gold data isn't provided and loss isn't required" ) return None, None runner = Runner() runner.state.update({"preds": [], "_ids": []}) @runner.on(Event.BATCH) def maybe_compute_prediction(state): if not gold_path: return arr = state["batch"].to_array() state["arr"] = arr assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) preds = LinearCRF(scores).argmax() state["preds"].extend(preds.tolist()) state["_ids"].extend(arr["_id"].tolist()) if compute_loss: state["scores"] = scores @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return arr = state["arr"] if "arr" in state else state["batch"].to_array() state["arr"] = arr if "scores" in state: scores = state["scores"] else: assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) model.eval() scores = model(words) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) masked_scores = scores.masked_fill(~ptst_mask, -1e9) crf = LinearCRF(masked_scores) crf_z = LinearCRF(scores) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() state["ptst_loss"] = ptst_loss.item() state["size"] = mask.size(0) @runner.on(Event.BATCH) def set_n_items(state): state["n_items"] = int(state["arr"]["mask"].sum()) n_tokens = sum(len(s["word_ids"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) if compute_loss: MeanReducer("mean_ptst_loss", value="ptst_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size)) if runner.state["preds"]: assert len(runner.state["preds"]) == len(samples) assert len(runner.state["_ids"]) == len(samples) for i, preds in zip(runner.state["_ids"], runner.state["preds"]): samples[i]["preds"] = preds if gold_path: group = defaultdict(list) for s in samples: group[str(s["path"])].append(s) with tempfile.TemporaryDirectory() as dirname: dirname = Path(dirname) for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [id2label[x] for s in doc_samples for x in s["preds"]] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (dirname / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file(f"{str(dirname / doc_path)}.xml") return ( score_time(gold_path, str(dirname), confusion), runner.state.get("mean_ptst_loss"), ) return None, runner.state.get("mean_ptst_loss")
def finetune( _log, _run, _rnd, corpus, artifacts_dir="artifacts", overwrite=False, temperature=1.0, freeze_embeddings=True, freeze_encoder_up_to=1, device="cpu", thresh=0.95, batch_size=16, lr=1e-5, max_epoch=5, predict_on_finished=False, ): """Finetune/train the source model on unlabeled target data.""" artifacts_dir = Path(artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = read_samples_() eval_samples = read_samples_(max_length=None) model_name = "clulab/roberta-timex-semeval" _log.info("Loading %s", model_name) config = AutoConfig.from_pretrained(model_name) token_clf = AutoModelForTokenClassification.from_pretrained(model_name, config=config) model = RoBERTagger(token_clf, config.num_labels, temperature) _log.info("Initializing transitions") torch.nn.init.zeros_(model.start_transition) torch.nn.init.zeros_(model.transition) for lid, label in config.id2label.items(): if not label.startswith("I-"): continue with torch.no_grad(): model.start_transition[lid] = -1e9 for plid, plabel in config.id2label.items(): if plabel == "O" or plabel[2:] != label[2:]: with torch.no_grad(): model.transition[plid, lid] = -1e9 for name, p in model.named_parameters(): freeze = False if freeze_embeddings and ".embeddings." in name: freeze = True if freeze_encoder_up_to >= 0: for i in range(freeze_encoder_up_to + 1): if f".encoder.layer.{i}." in name: freeze = True if freeze: _log.info("Freezing %s", name) p.requires_grad_(False) model.to(device) _log.info("Computing ambiguous PTST tag pairs mask") model.eval() ptst_masks, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), unit="tok") for batch in BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) with torch.no_grad(): ptst_mask = compute_ambiguous_tag_pairs_mask(model(words), thresh) ptst_masks.extend(ptst_mask.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(ptst_masks) == len(samples) assert len(_ids) == len(samples) for i, ptst_mask in zip(_ids, ptst_masks): samples[i]["ptst_mask"] = ptst_mask _log.info("Report number of sequences") log_total_nseqs, log_nseqs = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in samples), leave=False) for batch in BucketIterator(samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) cnt_scores = torch.zeros_like(ptst_mask).float() cnt_scores_masked = cnt_scores.masked_fill(~ptst_mask, -1e9) log_total_nseqs.extend(LinearCRF(cnt_scores).log_partitions().tolist()) log_nseqs.extend( LinearCRF(cnt_scores_masked).log_partitions().tolist()) pbar.update(arr["word_ids"].size) pbar.close() cov = [math.exp(x - x_) for x, x_ in zip(log_nseqs, log_total_nseqs)] _log.info( "Number of seqs: min {:.2} ({:.2}%) | med {:.2} ({:.2}%) | max {:.2} ({:.2}%)" .format( math.exp(min(log_nseqs)), 100 * min(cov), math.exp(median(log_nseqs)), 100 * median(cov), math.exp(max(log_nseqs)), 100 * max(cov), )) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() @finetuner.on(Event.BATCH) def compute_loss(state): arr = state["batch"].to_array() words = torch.from_numpy(arr["word_ids"]).long().to(device) mask = torch.from_numpy(arr["mask"]).bool().to(device) ptst_mask = torch.from_numpy(arr["ptst_mask"]).bool().to(device) model.train() scores = model(words, mask) masked_scores = scores.masked_fill(~ptst_mask, -1e9) # mask passed to LinearCRF shouldn't include the last token last_idx = mask.long().sum(dim=1, keepdim=True) - 1 mask_ = mask.scatter(1, last_idx, False)[:, :-1] crf = LinearCRF(masked_scores, mask_) crf_z = LinearCRF(scores, mask_) ptst_loss = -crf.log_partitions().sum() + crf_z.log_partitions().sum() ptst_loss /= mask.size(0) state["loss"] = ptst_loss state["stats"] = {"ptst_loss": ptst_loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def evaluate(state): _log.info("Evaluating on train") eval_score, loss = run_eval(model, config.id2label, samples, compute_loss=True) if eval_score is not None: print_accs(eval_score, on="train", run=_run, step=state["n_iters"]) _log.info("train_ptst_loss: %.4f", loss) _run.log_scalar("train_ptst_loss", loss, step=state["n_iters"]) _log.info("Evaluating on eval") eval_score, _ = run_eval(model, config.id2label, eval_samples) if eval_score is not None: print_accs(eval_score, on="eval", run=_run, step=state["n_iters"]) state["eval_f1"] = None if eval_score is None else eval_score["f1"] finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) @finetuner.on(Event.FINISHED) def maybe_predict(state): if not predict_on_finished: return _log.info("Computing predictions") model.eval() preds, _ids = [], [] pbar = tqdm(total=sum(len(s["word_ids"]) for s in eval_samples), unit="tok") for batch in BucketIterator(eval_samples, lambda s: len(s["word_ids"]), batch_size): arr = batch.to_array() assert arr["mask"].all() words = torch.from_numpy(arr["word_ids"]).long().to(device) scores = model(words) pred = LinearCRF(scores).argmax() preds.extend(pred.tolist()) _ids.extend(arr["_id"].tolist()) pbar.update(int(arr["mask"].sum())) pbar.close() assert len(preds) == len(eval_samples) assert len(_ids) == len(eval_samples) for i, preds_ in zip(_ids, preds): eval_samples[i]["preds"] = preds_ group = defaultdict(list) for s in eval_samples: group[str(s["path"])].append(s) _log.info("Writing predictions") for doc_path, doc_samples in group.items(): spans = [x for s in doc_samples for x in s["spans"]] labels = [ config.id2label[x] for s in doc_samples for x in s["preds"] ] doc_path = Path(doc_path[len(f"{corpus['path']}/"):]) data = make_anafora(spans, labels, doc_path.name) (artifacts_dir / "time" / doc_path.parent).mkdir(parents=True, exist_ok=True) data.to_file( f"{str(artifacts_dir / 'time' / doc_path)}.TimeNorm.system.completed.xml" ) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["word_ids"]) for s in samples) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["word_ids"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples, bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state.get("eval_f1")
def train( _log, _run, _rnd, artifacts_dir="artifacts", overwrite=False, max_length=None, load_types_vocab_from=None, batch_size=16, device="cpu", lr=0.001, patience=5, max_epoch=1000, ): """Train a self-attention graph-based parser.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) _log.info("Creating vocabulary") vocab = Vocab.from_samples(chain(*samples.values())) if load_types_vocab_from: path = Path(load_types_vocab_from) _log.info("Loading types vocab from %s", path) vocab["types"] = load(path.read_text(encoding="utf8"))["types"] _log.info("Vocabulary created") for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} model = make_model(vocab) model.to(device) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5) trainer = Runner() trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1}) trainer.on(Event.BATCH, [batch2tensors(device, vocab), set_train_mode(model)]) @trainer.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "heads"], bat["types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss state["loss"] = loss arc_loss, type_loss = arc_loss.item(), type_loss.item() state["stats"] = { "arc_ppl": math.exp(arc_loss), "type_ppl": math.exp(type_loss), } state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss} state["n_items"] = bat["mask"].long().sum().item() trainer.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @trainer.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) scheduler.step(accs["las_nopunct"]) if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]: state["better"] = True elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]: state["better"] = False elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]: state["better"] = True else: state["better"] = False if state["better"]: _log.info("Found new best result on dev!") state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct state["dev_accs"] = accs state["dev_epoch"] = state["epoch"] else: _log.info("Not better, the best so far is epoch %d:", state["dev_epoch"]) print_accs(state["dev_accs"]) print_accs(state["test_accs"], on="test") @trainer.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if not state["better"]: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) state["test_accs"] = eval_state["counts"].accs print_accs(state["test_accs"], on="test", run=_run, step=state["n_iters"]) trainer.on( Event.EPOCH_FINISHED, [ maybe_stop_early(patience=patience), save_state_dict("model", model, under=artifacts_dir, when="better"), ], ) EpochTimer().attach_on(trainer) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting training") try: trainer.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return trainer.state["dev_accs"]["las_nopunct"]
def runner(): return Runner()
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with self-training.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"st_heads": [], "st_types": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), ], ) @runner.on(Event.BATCH) def save_st_trees(state): state["st_heads"].extend(state["pred_heads"].tolist()) state["st_types"].extend(state["pred_types"].tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing ST trees for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["st_heads"]) == len(samples[wh]) assert len(runner.state["st_types"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, st_heads, st_types in zip(runner.state["_ids"], runner.state["st_heads"], runner.state["st_types"]): assert len(samples[wh][i]["words"]) == len(st_heads) assert len(samples[wh][i]["words"]) == len(st_types) samples[wh][i]["st_heads"] = st_heads samples[wh][i]["st_types"] = st_types _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "st_heads"], bat["st_types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "arc_ppl": arc_loss.exp().item(), "type_ppl": type_loss.exp().item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = { "arc_loss": arc_loss.item(), "type_loss": type_loss.item() } finetuner.on( Event.BATCH, [ get_n_items(), update_params(opt), log_grads(_run, model), log_stats(_run) ], ) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def finetune( corpus, _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", load_samples_from=None, overwrite=False, load_src=None, src_key_as_lang=False, main_src=None, device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, save_samples=False, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPTX.""" if max_length is None: max_length = {} if load_src is None: load_src = {"src": ("artifacts", "model.pth")} main_src = "src" elif main_src not in load_src: raise ValueError(f"{main_src} not found in load_src") artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) if load_samples_from: _log.info("Loading samples from %s", load_samples_from) with open(load_samples_from, "rb") as f: samples = pickle.load(f) else: samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) kv = KeyedVectors.load_word2vec_format(word_emb_path) if load_samples_from: _log.info( "Skipping non-main src because samples are processed and loaded") srcs = [] else: srcs = [src for src in load_src if src != main_src] if src_key_as_lang and corpus["lang"] in srcs: _log.info("Removing %s from src parsers because it's the tgt", corpus["lang"]) srcs.remove(corpus["lang"]) srcs.append(main_src) for src_i, src in enumerate(srcs): _log.info("Processing src %s [%d/%d]", src, src_i + 1, len(srcs)) load_from, load_params = load_src[src] path = Path(load_from) / "vocab.yml" _log.info("Loading %s vocabulary from %s", src, path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending %s vocabulary with target words", src) vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) samples_ = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading %s model from metadata %s", src, path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading %s model parameters from %s", src, path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating %s extended word embedding layer", src) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) model.to(device) for wh in ["train", "dev"]: if load_samples_from: assert all("pptx_mask" in s for s in samples[wh]) continue for i, s in enumerate(samples_[wh]): s["_id"] = i runner = Runner() runner.state.update({"pptx_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_pptx_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] pptx_mask = compute_ambiguous_arcs_mask( scores, thresh, projective, multiroot) state["pptx_masks"].extend(pptx_mask) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples_[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info( "Computing PPTX ambiguous arcs mask for %s set with source %s", wh, src) with torch.no_grad(): runner.run( BucketIterator(samples_[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["pptx_masks"]) == len(samples_[wh]) assert len(runner.state["_ids"]) == len(samples_[wh]) for i, pptx_mask in zip(runner.state["_ids"], runner.state["pptx_masks"]): samples_[wh][i]["pptx_mask"] = pptx_mask.tolist() _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples_[wh], "pptx_mask", batch_size, projective, multiroot) _log.info("Combining the ambiguous arcs mask") assert len(samples_[wh]) == len(samples[wh]) for i in range(len(samples_[wh])): pptx_mask = torch.tensor(samples_[wh][i]["pptx_mask"]) assert pptx_mask.dim() == 3 if "pptx_mask" in samples[wh][i]: old_mask = torch.tensor(samples[wh][i]["pptx_mask"]) else: old_mask = torch.zeros(1, 1, 1).bool() samples[wh][i]["pptx_mask"] = (old_mask | pptx_mask).tolist() assert src == main_src _log.info("Main source is %s", src) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") if save_samples: path = artifacts_dir / "samples.pkl" _log.info("Saving samples to %s", path) with open(path, "wb") as f: pickle.dump(samples, f) samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} for wh in ["train", "dev"]: _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "pptx_mask", batch_size, projective, multiroot) model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] pptx_mask = state["batch"]["pptx_mask"].bool() scores = state["total_arc_type_scores"] pptx_loss = compute_aatrn_loss(scores, pptx_mask, mask, projective, multiroot) pptx_loss /= mask.size(0) loss = pptx_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "pptx_loss": pptx_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) pptx_loss = eval_state["mean_pptx_loss"] _log.info("dev_pptx_loss: %.4f", pptx_loss) _run.log_scalar("dev_pptx_loss", pptx_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]