def run_eval( model, vocab, samples, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), evaluate_batch(), get_n_items(), ], ) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def run_eval( model, vocab, samples, compute_loss=True, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return ppt_loss = compute_aatrn_loss( state["total_arc_type_scores"], state["batch"]["ppt_mask"].bool(), projective=projective, multiroot=multiroot, ) state["ppt_loss"] = ppt_loss.item() state["size"] = state["batch"]["words"].size(0) runner.on(Event.BATCH, [ predict_batch(projective, multiroot), evaluate_batch(), get_n_items() ]) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) if compute_loss: MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with self-training.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"st_heads": [], "st_types": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), ], ) @runner.on(Event.BATCH) def save_st_trees(state): state["st_heads"].extend(state["pred_heads"].tolist()) state["st_types"].extend(state["pred_types"].tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing ST trees for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["st_heads"]) == len(samples[wh]) assert len(runner.state["st_types"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, st_heads, st_types in zip(runner.state["_ids"], runner.state["st_heads"], runner.state["st_types"]): assert len(samples[wh][i]["words"]) == len(st_heads) assert len(samples[wh][i]["words"]) == len(st_types) samples[wh][i]["st_heads"] = st_heads samples[wh][i]["st_types"] = st_types _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "st_heads"], bat["st_types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "arc_ppl": arc_loss.exp().item(), "type_ppl": type_loss.exp().item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = { "arc_loss": arc_loss.item(), "type_loss": type_loss.item() } finetuner.on( Event.BATCH, [ get_n_items(), update_params(opt), log_grads(_run, model), log_stats(_run) ], ) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]