def run_eval( model, vocab, samples, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), evaluate_batch(), get_n_items(), ], ) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def run_eval( model, vocab, samples, compute_loss=True, device="cpu", projective=False, multiroot=True, batch_size=32, ): runner = Runner() runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def maybe_compute_loss(state): if not compute_loss: return ppt_loss = compute_aatrn_loss( state["total_arc_type_scores"], state["batch"]["ppt_mask"].bool(), projective=projective, multiroot=multiroot, ) state["ppt_loss"] = ppt_loss.item() state["size"] = state["batch"]["words"].size(0) runner.on(Event.BATCH, [ predict_batch(projective, multiroot), evaluate_batch(), get_n_items() ]) n_tokens = sum(len(s["words"]) for s in samples) ProgressBar(leave=False, total=n_tokens, unit="tok").attach_on(runner) SumReducer("counts").attach_on(runner) if compute_loss: MeanReducer("mean_ppt_loss", value="ppt_loss").attach_on(runner) with torch.no_grad(): runner.run( BucketIterator(samples, lambda s: len(s["words"]), batch_size)) return runner.state
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPT.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train", "dev"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"ppt_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_ppt_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] ppt_mask = compute_ambiguous_arcs_mask(scores, thresh, projective, multiroot) state["ppt_masks"].extend(ppt_mask.tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing PPT ambiguous arcs mask for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["ppt_masks"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, ppt_mask in zip(runner.state["_ids"], runner.state["ppt_masks"]): samples[wh][i]["ppt_mask"] = ppt_mask _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "ppt_mask", batch_size, projective, multiroot) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] ppt_mask = state["batch"]["ppt_mask"].bool() scores = state["total_arc_type_scores"] ppt_loss = compute_aatrn_loss(scores, ppt_mask, mask, projective, multiroot) ppt_loss /= mask.size(0) loss = ppt_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "ppt_loss": ppt_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) ppt_loss = eval_state["mean_ppt_loss"] _log.info("dev_ppt_loss: %.4f", ppt_loss) _run.log_scalar("dev_ppt_loss", ppt_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def train( _log, _run, _rnd, artifacts_dir="artifacts", overwrite=False, max_length=None, load_types_vocab_from=None, batch_size=16, device="cpu", lr=0.001, patience=5, max_epoch=1000, ): """Train a self-attention graph-based parser.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) _log.info("Creating vocabulary") vocab = Vocab.from_samples(chain(*samples.values())) if load_types_vocab_from: path = Path(load_types_vocab_from) _log.info("Loading types vocab from %s", path) vocab["types"] = load(path.read_text(encoding="utf8"))["types"] _log.info("Vocabulary created") for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} model = make_model(vocab) model.to(device) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="max", factor=0.5) trainer = Runner() trainer.state.update({"dev_larcs_nopunct": -1, "dev_uarcs_nopunct": -1}) trainer.on(Event.BATCH, [batch2tensors(device, vocab), set_train_mode(model)]) @trainer.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "heads"], bat["types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(Vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss state["loss"] = loss arc_loss, type_loss = arc_loss.item(), type_loss.item() state["stats"] = { "arc_ppl": math.exp(arc_loss), "type_ppl": math.exp(type_loss), } state["extra_stats"] = {"arc_loss": arc_loss, "type_loss": type_loss} state["n_items"] = bat["mask"].long().sum().item() trainer.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @trainer.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) scheduler.step(accs["las_nopunct"]) if eval_state["counts"].larcs_nopunct > state["dev_larcs_nopunct"]: state["better"] = True elif eval_state["counts"].larcs_nopunct < state["dev_larcs_nopunct"]: state["better"] = False elif eval_state["counts"].uarcs_nopunct > state["dev_uarcs_nopunct"]: state["better"] = True else: state["better"] = False if state["better"]: _log.info("Found new best result on dev!") state["dev_larcs_nopunct"] = eval_state["counts"].larcs_nopunct state["dev_uarcs_nopunct"] = eval_state["counts"].uarcs_nopunct state["dev_accs"] = accs state["dev_epoch"] = state["epoch"] else: _log.info("Not better, the best so far is epoch %d:", state["dev_epoch"]) print_accs(state["dev_accs"]) print_accs(state["test_accs"], on="test") @trainer.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if not state["better"]: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) state["test_accs"] = eval_state["counts"].accs print_accs(state["test_accs"], on="test", run=_run, step=state["n_iters"]) trainer.on( Event.EPOCH_FINISHED, [ maybe_stop_early(patience=patience), save_state_dict("model", model, under=artifacts_dir, when="better"), ], ) EpochTimer().attach_on(trainer) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(trainer) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting training") try: trainer.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return trainer.state["dev_accs"]["las_nopunct"]
def finetune( _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", overwrite=False, load_from="artifacts", load_params="model.pth", device="cpu", word_emb_path="wiki.id.vec", freeze=False, projective=False, multiroot=True, batch_size=32, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with self-training.""" if max_length is None: max_length = {} artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) path = Path(load_from) / "vocab.yml" _log.info("Loading vocabulary from %s", path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending vocabulary with target words") vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading model from metadata %s", path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading model parameters from %s", path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating extended word embedding layer") kv = KeyedVectors.load_word2vec_format(word_emb_path) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) model.to(device) for wh in ["train"]: for i, s in enumerate(samples[wh]): s["_id"] = i runner = Runner() runner.state.update({"st_heads": [], "st_types": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), predict_batch(projective, multiroot), ], ) @runner.on(Event.BATCH) def save_st_trees(state): state["st_heads"].extend(state["pred_heads"].tolist()) state["st_types"].extend(state["pred_types"].tolist()) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info("Computing ST trees for %s set", wh) with torch.no_grad(): runner.run( BucketIterator(samples[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["st_heads"]) == len(samples[wh]) assert len(runner.state["st_types"]) == len(samples[wh]) assert len(runner.state["_ids"]) == len(samples[wh]) for i, st_heads, st_types in zip(runner.state["_ids"], runner.state["st_heads"], runner.state["st_types"]): assert len(samples[wh][i]["words"]) == len(st_heads) assert len(samples[wh][i]["words"]) == len(st_types) samples[wh][i]["st_heads"] = st_heads samples[wh][i]["st_types"] = st_types _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): bat = state["batch"] words, tags, heads, types = bat["words"], bat["tags"], bat[ "st_heads"], bat["st_types"] mask = bat["mask"] arc_scores, type_scores = model(words, tags, mask, heads) arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), -1e9) # mask padding heads type_scores[..., vocab["types"].index(vocab.PAD_TOKEN)] = -1e9 # remove root arc_scores, type_scores = arc_scores[:, :, 1:], type_scores[:, 1:] heads, types, mask = heads[:, 1:], types[:, 1:], mask[:, 1:] arc_scores = rearrange(arc_scores, "bsz slen1 slen2 -> (bsz slen2) slen1") heads = heads.reshape(-1) arc_loss = torch.nn.functional.cross_entropy(arc_scores, heads, reduction="none") type_scores = rearrange(type_scores, "bsz slen ntypes -> (bsz slen) ntypes") types = types.reshape(-1) type_loss = torch.nn.functional.cross_entropy(type_scores, types, reduction="none") arc_loss = arc_loss.masked_select(mask.reshape(-1)).mean() type_loss = type_loss.masked_select(mask.reshape(-1)).mean() loss = arc_loss + type_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "arc_ppl": arc_loss.exp().item(), "type_ppl": type_loss.exp().item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = { "arc_loss": arc_loss.item(), "type_loss": type_loss.item() } finetuner.on( Event.BATCH, [ get_n_items(), update_params(opt), log_grads(_run, model), log_stats(_run) ], ) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"]) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]
def finetune( corpus, _log, _run, _rnd, max_length=None, artifacts_dir="ft_artifacts", load_samples_from=None, overwrite=False, load_src=None, src_key_as_lang=False, main_src=None, device="cpu", word_emb_path="wiki.id.vec", freeze=False, thresh=0.95, projective=False, multiroot=True, batch_size=32, save_samples=False, lr=1e-5, l2_coef=1.0, max_epoch=5, ): """Finetune a trained model with PPTX.""" if max_length is None: max_length = {} if load_src is None: load_src = {"src": ("artifacts", "model.pth")} main_src = "src" elif main_src not in load_src: raise ValueError(f"{main_src} not found in load_src") artifacts_dir = Path(artifacts_dir) _log.info("Creating artifacts directory %s", artifacts_dir) artifacts_dir.mkdir(exist_ok=overwrite) if load_samples_from: _log.info("Loading samples from %s", load_samples_from) with open(load_samples_from, "rb") as f: samples = pickle.load(f) else: samples = { wh: list(read_samples(which=wh, max_length=max_length.get(wh))) for wh in ["train", "dev", "test"] } for wh in samples: n_toks = sum(len(s["words"]) for s in samples[wh]) _log.info("Read %d %s samples and %d tokens", len(samples[wh]), wh, n_toks) kv = KeyedVectors.load_word2vec_format(word_emb_path) if load_samples_from: _log.info( "Skipping non-main src because samples are processed and loaded") srcs = [] else: srcs = [src for src in load_src if src != main_src] if src_key_as_lang and corpus["lang"] in srcs: _log.info("Removing %s from src parsers because it's the tgt", corpus["lang"]) srcs.remove(corpus["lang"]) srcs.append(main_src) for src_i, src in enumerate(srcs): _log.info("Processing src %s [%d/%d]", src, src_i + 1, len(srcs)) load_from, load_params = load_src[src] path = Path(load_from) / "vocab.yml" _log.info("Loading %s vocabulary from %s", src, path) vocab = load(path.read_text(encoding="utf8")) for name in vocab: _log.info("Found %d %s", len(vocab[name]), name) _log.info("Extending %s vocabulary with target words", src) vocab.extend(chain(*samples.values()), ["words"]) _log.info("Found %d words now", len(vocab["words"])) samples_ = {wh: list(vocab.stoi(samples[wh])) for wh in samples} path = Path(load_from) / "model.yml" _log.info("Loading %s model from metadata %s", src, path) model = load(path.read_text(encoding="utf8")) path = Path(load_from) / load_params _log.info("Loading %s model parameters from %s", src, path) model.load_state_dict(torch.load(path, "cpu")) _log.info("Creating %s extended word embedding layer", src) assert model.word_emb.embedding_dim == kv.vector_size with torch.no_grad(): model.word_emb = torch.nn.Embedding.from_pretrained( extend_word_embedding(model.word_emb.weight, vocab["words"], kv)) model.to(device) for wh in ["train", "dev"]: if load_samples_from: assert all("pptx_mask" in s for s in samples[wh]) continue for i, s in enumerate(samples_[wh]): s["_id"] = i runner = Runner() runner.state.update({"pptx_masks": [], "_ids": []}) runner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model, training=False), compute_total_arc_type_scores(model, vocab), ], ) @runner.on(Event.BATCH) def compute_pptx_ambiguous_arcs_mask(state): assert state["batch"]["mask"].all() scores = state["total_arc_type_scores"] pptx_mask = compute_ambiguous_arcs_mask( scores, thresh, projective, multiroot) state["pptx_masks"].extend(pptx_mask) state["_ids"].extend(state["batch"]["_id"].tolist()) state["n_items"] = state["batch"]["words"].numel() n_toks = sum(len(s["words"]) for s in samples_[wh]) ProgressBar(total=n_toks, unit="tok").attach_on(runner) _log.info( "Computing PPTX ambiguous arcs mask for %s set with source %s", wh, src) with torch.no_grad(): runner.run( BucketIterator(samples_[wh], lambda s: len(s["words"]), batch_size)) assert len(runner.state["pptx_masks"]) == len(samples_[wh]) assert len(runner.state["_ids"]) == len(samples_[wh]) for i, pptx_mask in zip(runner.state["_ids"], runner.state["pptx_masks"]): samples_[wh][i]["pptx_mask"] = pptx_mask.tolist() _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples_[wh], "pptx_mask", batch_size, projective, multiroot) _log.info("Combining the ambiguous arcs mask") assert len(samples_[wh]) == len(samples[wh]) for i in range(len(samples_[wh])): pptx_mask = torch.tensor(samples_[wh][i]["pptx_mask"]) assert pptx_mask.dim() == 3 if "pptx_mask" in samples[wh][i]: old_mask = torch.tensor(samples[wh][i]["pptx_mask"]) else: old_mask = torch.zeros(1, 1, 1).bool() samples[wh][i]["pptx_mask"] = (old_mask | pptx_mask).tolist() assert src == main_src _log.info("Main source is %s", src) path = artifacts_dir / "vocab.yml" _log.info("Saving vocabulary to %s", path) path.write_text(dump(vocab), encoding="utf8") path = artifacts_dir / "model.yml" _log.info("Saving model metadata to %s", path) path.write_text(dump(model), encoding="utf8") if save_samples: path = artifacts_dir / "samples.pkl" _log.info("Saving samples to %s", path) with open(path, "wb") as f: pickle.dump(samples, f) samples = {wh: list(vocab.stoi(samples[wh])) for wh in samples} for wh in ["train", "dev"]: _log.info("Computing (log) number of trees stats on %s set", wh) report_log_ntrees_stats(samples[wh], "pptx_mask", batch_size, projective, multiroot) model.word_emb.requires_grad_(not freeze) model.tag_emb.requires_grad_(not freeze) _log.info("Creating optimizer") opt = torch.optim.Adam(model.parameters(), lr=lr) finetuner = Runner() origin_params = { name: p.clone().detach() for name, p in model.named_parameters() } finetuner.on( Event.BATCH, [ batch2tensors(device, vocab), set_train_mode(model), compute_l2_loss(model, origin_params), compute_total_arc_type_scores(model, vocab), ], ) @finetuner.on(Event.BATCH) def compute_loss(state): mask = state["batch"]["mask"] pptx_mask = state["batch"]["pptx_mask"].bool() scores = state["total_arc_type_scores"] pptx_loss = compute_aatrn_loss(scores, pptx_mask, mask, projective, multiroot) pptx_loss /= mask.size(0) loss = pptx_loss + l2_coef * state["l2_loss"] state["loss"] = loss state["stats"] = { "pptx_loss": pptx_loss.item(), "l2_loss": state["l2_loss"].item(), } state["extra_stats"] = {"loss": loss.item()} state["n_items"] = mask.long().sum().item() finetuner.on(Event.BATCH, [update_params(opt), log_grads(_run, model), log_stats(_run)]) @finetuner.on(Event.EPOCH_FINISHED) def eval_on_dev(state): _log.info("Evaluating on dev") eval_state = run_eval(model, vocab, samples["dev"]) accs = eval_state["counts"].accs print_accs(accs, run=_run, step=state["n_iters"]) pptx_loss = eval_state["mean_pptx_loss"] _log.info("dev_pptx_loss: %.4f", pptx_loss) _run.log_scalar("dev_pptx_loss", pptx_loss, step=state["n_iters"]) state["dev_accs"] = accs @finetuner.on(Event.EPOCH_FINISHED) def maybe_eval_on_test(state): if state["epoch"] != max_epoch: return _log.info("Evaluating on test") eval_state = run_eval(model, vocab, samples["test"], compute_loss=False) print_accs(eval_state["counts"].accs, on="test", run=_run, step=state["n_iters"]) finetuner.on(Event.EPOCH_FINISHED, save_state_dict("model", model, under=artifacts_dir)) EpochTimer().attach_on(finetuner) n_tokens = sum(len(s["words"]) for s in samples["train"]) ProgressBar(stats="stats", total=n_tokens, unit="tok").attach_on(finetuner) bucket_key = lambda s: (len(s["words"]) - 1) // 10 trn_iter = ShuffleIterator( BucketIterator(samples["train"], bucket_key, batch_size, shuffle_bucket=True, rng=_rnd), rng=_rnd, ) _log.info("Starting finetuning") try: finetuner.run(trn_iter, max_epoch) except KeyboardInterrupt: _log.info("Interrupt detected, training will abort") else: return finetuner.state["dev_accs"]["las_nopunct"]