def run( lr=0.001, batsize=20, epochs=70, embdim=128, encdim=400, numlayers=1, beamsize=5, dropout=.5, wreg=1e-10, cuda=False, gpu=0, minfreq=2, gradnorm=3., smoothing=0.1, cosine_restarts=1., seed=123456, ): localargs = locals().copy() print(locals()) torch.manual_seed(seed) np.random.seed(seed) tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("loading data") ds = GeoDatasetRank() print( f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)") tt.tock("data loaded") # do_rare_stats(ds) # model = TreeRankModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers, # sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder) # model = ParikhRankModel(embdim=encdim, dropout=dropout, sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder) # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids]) # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens) ranker = Ranker(model, eval=[BCELoss(mode="logits", smoothing=smoothing)], evalseq=[ SeqAccuracies(), TreeAccuracy(tensor2tree=partial( tensor2tree, D=ds.query_encoder.vocab), orderless={"and", "or"}) ]) losses = make_array_of_metrics("loss", "seq_acc", "tree_acc") vlosses = make_array_of_metrics("seq_acc", "tree_acc") # 4. define optim # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg) optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wreg) # lr schedule if cosine_restarts >= 0: # t_max = epochs * len(train_dl) t_max = epochs print(f"Total number of updates: {t_max}") lr_schedule = q.WarmupCosineWithHardRestartsSchedule( optim, 0, t_max, cycles=cosine_restarts) reduce_lr = [lambda: lr_schedule.step()] else: reduce_lr = [] # 6. define training function clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_( model.parameters(), gradnorm) # clipgradnorm = lambda: None trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm]) trainepoch = partial(q.train_epoch, model=ranker, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses, _train_batch=trainbatch, device=device, on_end=reduce_lr) # 7. define validation function (using partial) validepoch = partial(q.test_epoch, model=ranker, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) # 7. run training tt.tick("training") q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs) tt.tock("done training") # testing tt.tick("testing") testresults = q.test_epoch(model=ranker, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print("validation test results: ", testresults) tt.tock("tested") tt.tick("testing") testresults = q.test_epoch(model=ranker, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print("test results: ", testresults) tt.tock("tested") # save model? tosave = input( "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>" ) # if True: # overwrite = None if tosave.lower() == "y" or tosave.lower() == "yes" or re.match( "\d+", tosave.lower()): overwrite = int(tosave) if re.match("\d+", tosave) else None p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite) q.save_dataset(ds, p) _model, _localargs = q.load_run(p) _ds = q.load_dataset(p) _freedecoder = BeamDecoder( _model, maxtime=100, beamsize=beamsize, copy_deep=True, eval=[SeqAccuracies()], eval_beam=[ TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"}) ]) # testing tt.tick("testing reloaded") _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize), losses=beamlosses, device=device) print(_testresults) tt.tock("tested") # save predictions _, testpreds = q.eval_loop(_freedecoder, ds.dataloader("test", batsize=batsize, shuffle=False), device=device) testout = get_outputs_for_save(testpreds) _, trainpreds = q.eval_loop(_freedecoder, ds.dataloader("train", batsize=batsize, shuffle=False), device=device) trainout = get_outputs_for_save(trainpreds) with open(os.path.join(p, "trainpreds.json"), "w") as f: ujson.dump(trainout, f) with open(os.path.join(p, "testpreds.json"), "w") as f: ujson.dump(testout, f)
def run( lr=0.001, batsize=50, epochs=50, embdim=100, encdim=100, numlayers=1, beamsize=1, dropout=.2, wreg=1e-10, cuda=False, gpu=0, minfreq=3, gradnorm=3., cosine_restarts=1., beta=0.001, vib_init=True, vib_enc=True, ): localargs = locals().copy() print(locals()) tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("loading data") ds = LCQuaDnoENTDataset( sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq) print( f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)") tt.tock("data loaded") do_rare_stats(ds) # batch = next(iter(train_dl)) # print(batch) # print("input graph") # print(batch.batched_states) model = BasicGenModel_VIB(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers, sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True, vib_init=vib_init, vib_enc=vib_enc) # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids]) # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens) losses = [CELoss(ignore_index=0, mode="logprobs")] if vib_init: losses.append( StatePenalty(lambda state: sum(state.mstate.vib.init), weight=beta)) if vib_enc: losses.append(StatePenalty("mstate.vib.enc", weight=beta)) tfdecoder = SeqDecoder( model, tf_ratio=1., eval=losses + [ SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"select", "count", "ask"}) ]) # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50) if beamsize == 1: freedecoder = SeqDecoder( model, maxtime=40, tf_ratio=0., eval=[ SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"select", "count", "ask"}) ]) else: freedecoder = BeamDecoder( model, maxtime=30, beamsize=beamsize, eval=[ SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"select", "count", "ask"}) ]) # # test # tt.tick("doing one epoch") # for batch in iter(train_dl): # batch = batch.to(device) # ttt.tick("start batch") # # with torch.no_grad(): # out = tfdecoder(batch) # ttt.tock("end batch") # tt.tock("done one epoch") # print(out) # sys.exit() # beamdecoder(next(iter(train_dl))) # print(dict(tfdecoder.named_parameters()).keys()) losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc") vlosses = make_array_of_metrics("seq_acc", "tree_acc") # if beamsize >= 3: # vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at3", "tree_acc_at_last") # else: # vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at_last") # trainable_params = tfdecoder.named_parameters() # exclude_params = set() # exclude_params.add("model.model.inp_emb.emb.weight") # don't train input embeddings if doing glove # trainable_params = [v for k, v in trainable_params if k not in exclude_params] # 4. define optim # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg) optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg) # lr schedule if cosine_restarts >= 0: # t_max = epochs * len(train_dl) t_max = epochs print(f"Total number of updates: {t_max}") lr_schedule = q.WarmupCosineWithHardRestartsSchedule( optim, 0, t_max, cycles=cosine_restarts) reduce_lr = [lambda: lr_schedule.step()] else: reduce_lr = [] # 6. define training function clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_( tfdecoder.parameters(), gradnorm) # clipgradnorm = lambda: None trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm]) trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses, _train_batch=trainbatch, device=device, on_end=reduce_lr) # 7. define validation function (using partial) validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device) # p = q.save_run(freedecoder, localargs, filepath=__file__) # q.save_dataset(ds, p) # _freedecoder, _localargs = q.load_run(p) # _ds = q.load_dataset(p) # sys.exit() # 7. run training tt.tick("training") q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs) tt.tock("done training") # testing tt.tick("testing") testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("valid", batsize), losses=vlosses, device=device) print("validation test results: ", testresults) tt.tock("tested") tt.tick("testing") testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print("test results: ", testresults) tt.tock("tested") # save model? tosave = input( "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>" ) if tosave.lower() == "y" or tosave.lower() == "yes" or re.match( "\d+", tosave.lower()): overwrite = int(tosave) if re.match("\d+", tosave) else None p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite) q.save_dataset(ds, p) _model, _localargs = q.load_run(p) _ds = q.load_dataset(p) _freedecoder = BeamDecoder( _model, maxtime=50, beamsize=beamsize, eval_beam=[ TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"}) ]) # testing tt.tick("testing reloaded") _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize), losses=vlosses, device=device) print(_testresults) assert (testresults == _testresults) tt.tock("tested")
def run( lr=0.001, batsize=20, epochs=60, embdim=128, encdim=256, numlayers=1, beamsize=5, dropout=.25, wreg=1e-10, cuda=False, gpu=0, minfreq=2, gradnorm=3., smoothing=0.1, cosine_restarts=1., seed=123456, numcvfolds=6, testfold=-1, # if non-default, must be within number of splits, the chosen value is used for validation reorder_random=False, ): localargs = locals().copy() print(locals()) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("loading data") cvfolds = None if testfold == -1 else numcvfolds testfold = None if testfold == -1 else testfold ds = GeoDataset( sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq, cvfolds=cvfolds, testfold=testfold, reorder_random=reorder_random) print( f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)") tt.tock("data loaded") do_rare_stats(ds) # batch = next(iter(train_dl)) # print(batch) # print("input graph") # print(batch.batched_states) model = BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers, sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True) # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids]) # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens) tfdecoder = SeqDecoder(model, tf_ratio=1., eval=[ CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing), SeqAccuracies(), TreeAccuracy(tensor2tree=partial( tensor2tree, D=ds.query_encoder.vocab), orderless={"and"}) ]) losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc") freedecoder = SeqDecoder(model, maxtime=100, tf_ratio=0., eval=[ SeqAccuracies(), TreeAccuracy(tensor2tree=partial( tensor2tree, D=ds.query_encoder.vocab), orderless={"and"}) ]) vlosses = make_array_of_metrics("seq_acc", "tree_acc") beamdecoder = BeamDecoder(model, maxtime=100, beamsize=beamsize, copy_deep=True, eval=[SeqAccuracies()], eval_beam=[ TreeAccuracy(tensor2tree=partial( tensor2tree, D=ds.query_encoder.vocab), orderless={"and"}) ]) beamlosses = make_array_of_metrics("seq_acc", "tree_acc", "tree_acc_at_last") # 4. define optim # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg) optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg) # lr schedule if cosine_restarts >= 0: # t_max = epochs * len(train_dl) t_max = epochs print(f"Total number of updates: {t_max}") lr_schedule = q.WarmupCosineWithHardRestartsSchedule( optim, 0, t_max, cycles=cosine_restarts) reduce_lr = [lambda: lr_schedule.step()] else: reduce_lr = [] # 6. define training function clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_( tfdecoder.parameters(), gradnorm) # clipgradnorm = lambda: None trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm]) train_on = "train" valid_on = "test" if testfold is None else "valid" trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader(train_on, batsize, shuffle=True), optim=optim, losses=losses, _train_batch=trainbatch, device=device, on_end=reduce_lr) # 7. define validation function (using partial) validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader(valid_on, batsize, shuffle=False), losses=vlosses, device=device) # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device) # p = q.save_run(freedecoder, localargs, filepath=__file__) # q.save_dataset(ds, p) # _freedecoder, _localargs = q.load_run(p) # _ds = q.load_dataset(p) # sys.exit() # 7. run training tt.tick("training") q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs) tt.tock("done training") if testfold is not None: return vlosses[1].get_epoch_error() # testing tt.tick("testing") testresults = q.test_epoch(model=beamdecoder, dataloader=ds.dataloader("test", batsize), losses=beamlosses, device=device) print("validation test results: ", testresults) tt.tock("tested") tt.tick("testing") testresults = q.test_epoch(model=beamdecoder, dataloader=ds.dataloader("test", batsize), losses=beamlosses, device=device) print("test results: ", testresults) tt.tock("tested") # save model? tosave = input( "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>" ) # if True: # overwrite = None if tosave.lower() == "y" or tosave.lower() == "yes" or re.match( "\d+", tosave.lower()): overwrite = int(tosave) if re.match("\d+", tosave) else None p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite) q.save_dataset(ds, p) _model, _localargs = q.load_run(p) _ds = q.load_dataset(p) _freedecoder = BeamDecoder(_model, maxtime=100, beamsize=beamsize, copy_deep=True, eval=[SeqAccuracies()], eval_beam=[ TreeAccuracy(tensor2tree=partial( tensor2tree, D=ds.query_encoder.vocab), orderless={"and"}) ]) # testing tt.tick("testing reloaded") _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize), losses=beamlosses, device=device) print(_testresults) tt.tock("tested") # save predictions _, testpreds = q.eval_loop(_freedecoder, ds.dataloader("test", batsize=batsize, shuffle=False), device=device) testout = get_outputs_for_save(testpreds) _, trainpreds = q.eval_loop(_freedecoder, ds.dataloader("train", batsize=batsize, shuffle=False), device=device) trainout = get_outputs_for_save(trainpreds) with open(os.path.join(p, "trainpreds.json"), "w") as f: ujson.dump(trainout, f) with open(os.path.join(p, "testpreds.json"), "w") as f: ujson.dump(testout, f)
def run(lr=0.001, batsize=20, epochs=60, embdim=128, encdim=256, numlayers=1, beamsize=1, dropout=.25, wreg=1e-10, cuda=False, gpu=0, minfreq=2, gradnorm=3., smoothing=0., cosine_restarts=1., seed=456789, p_step=.2, p_min=.3, ): localargs = locals().copy() print(locals()) torch.manual_seed(seed) np.random.seed(seed) tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("loading data") ds = GeoDataset(sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq) print(f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)") tt.tock("data loaded") do_rare_stats(ds) # batch = next(iter(train_dl)) # print(batch) # print("input graph") # print(batch.batched_states) model = BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers, sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True, p_step=p_step, p_min=p_min) # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids]) # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens) losses = [CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing)] tfdecoder = SeqDecoder(model, tf_ratio=1., eval=losses + [SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"and", "or"})]) losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc") # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50) if beamsize == 1: freedecoder = SeqDecoder(model, maxtime=100, tf_ratio=0., eval=[SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"and", "or"})]) vlosses = make_array_of_metrics("seq_acc", "tree_acc") else: freedecoder = BeamDecoder(model, maxtime=100, beamsize=beamsize, eval=[SeqAccuracies()], eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"and", "or"})]) vlosses = make_array_of_metrics("seq_acc", "tree_acc", "tree_acc_at_last") # 4. define optim # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg) optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg) # lr schedule if cosine_restarts >= 0: # t_max = epochs * len(train_dl) t_max = epochs print(f"Total number of updates: {t_max}") lr_schedule = q.WarmupCosineWithHardRestartsSchedule(optim, 0, t_max, cycles=cosine_restarts) reduce_lr = [lambda: lr_schedule.step()] else: reduce_lr = [] # 6. define training function clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(tfdecoder.parameters(), gradnorm) # clipgradnorm = lambda: None trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm]) trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses, _train_batch=trainbatch, device=device, on_end=reduce_lr) # 7. define validation function (using partial) validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device) # p = q.save_run(freedecoder, localargs, filepath=__file__) # q.save_dataset(ds, p) # _freedecoder, _localargs = q.load_run(p) # _ds = q.load_dataset(p) # sys.exit() # 7. run training tt.tick("training") q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs) tt.tock("done training") # testing tt.tick("testing") testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print("validation test results: ", testresults) tt.tock("tested") tt.tick("testing") testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print("test results: ", testresults) tt.tock("tested") # save model? tosave = input("Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>") if tosave.lower() == "y" or tosave.lower() == "yes" or re.match("\d+", tosave.lower()): overwrite = int(tosave) if re.match("\d+", tosave) else None p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite) q.save_dataset(ds, p) _model, _localargs = q.load_run(p) _ds = q.load_dataset(p) _freedecoder = BeamDecoder(_model, maxtime=50, beamsize=beamsize, eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"})]) # testing tt.tick("testing reloaded") _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize), losses=vlosses, device=device) print(_testresults) assert(testresults == _testresults) tt.tock("tested")
def run_rerank( lr=0.001, batsize=20, epochs=1, embdim=301, # not used encdim=200, numlayers=1, beamsize=5, dropout=.2, wreg=1e-10, cuda=False, gpu=0, minfreq=2, gradnorm=3., cosine_restarts=1., domain="restaurants", gensavedp="overnight_basic/run{}", genrunid=1, ): localargs = locals().copy() print(locals()) gensavedrunp = gensavedp.format(genrunid) tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("loading data") ds = q.load_dataset(gensavedrunp) # ds = OvernightDataset(domain=domain, sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq) print( f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)") tt.tock("data loaded") do_rare_stats(ds) # batch = next(iter(train_dl)) # print(batch) # print("input graph") # print(batch.batched_states) genmodel, genargs = q.load_run(gensavedrunp) # BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers, # sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True) # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids]) # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens) inpenc = q.LSTMEncoder(embdim, *([encdim // 2] * numlayers), bidir=True, dropout_in=dropout) outenc = q.LSTMEncoder(embdim, *([encdim // 2] * numlayers), bidir=True, dropout_in=dropout) scoremodel = SimpleScoreModel(genmodel.inp_emb, genmodel.out_emb, LSTMEncoderWrapper(inpenc), LSTMEncoderWrapper(outenc), DotSimilarity()) model = BeamReranker(genmodel, scoremodel, beamsize=beamsize, maxtime=50) # todo: run over whole dataset to populate beam cache testbatch = next(iter(ds.dataloader("train", batsize=2))) model(testbatch) sys.exit() tfdecoder = SeqDecoder(TFTransition(model), [ CELoss(ignore_index=0, mode="logprobs"), SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"}) ]) # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50) freedecoder = BeamDecoder( model, maxtime=50, beamsize=beamsize, eval_beam=[ TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"}) ]) # # test # tt.tick("doing one epoch") # for batch in iter(train_dl): # batch = batch.to(device) # ttt.tick("start batch") # # with torch.no_grad(): # out = tfdecoder(batch) # ttt.tock("end batch") # tt.tock("done one epoch") # print(out) # sys.exit() # beamdecoder(next(iter(train_dl))) # print(dict(tfdecoder.named_parameters()).keys()) losses = make_array_of_metrics("loss", "seq_acc", "tree_acc") vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3", "tree_acc_at_last") trainable_params = tfdecoder.named_parameters() exclude_params = {"model.model.inp_emb.emb.weight" } # don't train input embeddings if doing glove trainable_params = [ v for k, v in trainable_params if k not in exclude_params ] # 4. define optim optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg) # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg) # lr schedule if cosine_restarts >= 0: # t_max = epochs * len(train_dl) t_max = epochs print(f"Total number of updates: {t_max}") lr_schedule = q.WarmupCosineWithHardRestartsSchedule( optim, 0, t_max, cycles=cosine_restarts) reduce_lr = [lambda: lr_schedule.step()] else: reduce_lr = [] # 6. define training function clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_( tfdecoder.parameters(), gradnorm) trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm]) trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses, _train_batch=trainbatch, device=device, on_end=reduce_lr) # 7. define validation function (using partial) validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("valid", batsize), losses=vlosses, device=device) # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device) # p = q.save_run(freedecoder, localargs, filepath=__file__) # q.save_dataset(ds, p) # _freedecoder, _localargs = q.load_run(p) # _ds = q.load_dataset(p) # sys.exit() # 7. run training tt.tick("training") q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs) tt.tock("done training") # testing tt.tick("testing") testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device) print(testresults) tt.tock("tested") # save model? tosave = input( "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>" ) if tosave.lower() == "y" or tosave.lower() == "yes" or re.match( "\d+", tosave.lower()): overwrite = int(tosave) if re.match("\d+", tosave) else None p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite) q.save_dataset(ds, p) _model, _localargs = q.load_run(p) _ds = q.load_dataset(p) _freedecoder = BeamDecoder( _model, maxtime=50, beamsize=beamsize, eval_beam=[ TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab), orderless={"op:and", "SW:concat"}) ]) # testing tt.tick("testing reloaded") _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize), losses=vlosses, device=device) print(_testresults) assert (testresults == _testresults) tt.tock("tested")