def run( lr=0.0001, enclrmul=0.01, smoothing=0., gradnorm=3, batsize=60, epochs=16, patience=10, validinter=3, validfrac=0.1, warmup=3, cosinelr=False, dataset="scan/length", mode="normal", # "normal", "noinp" maxsize=50, seed=42, hdim=768, numlayers=6, numheads=12, dropout=0.1, worddropout=0., bertname="bert-base-uncased", testcode=False, userelpos=False, gpu=-1, evaltrain=False, trainonvalid=False, trainonvalidonly=False, recomputedata=False, mcdropout=-1, version="v3"): settings = locals().copy() q.pp_dict(settings, indent=3) # wandb.init() # torch.backends.cudnn.enabled = False wandb.init(project=f"compood_gru_baseline_v3", config=settings, reinit=True) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu) if maxsize < 0: if dataset.startswith("cfq"): maxsize = 155 elif dataset.startswith("scan"): maxsize = 50 print(f"maxsize: {maxsize}") tt = q.ticktock("script") tt.tick("data") trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset, validfrac=validfrac, bertname=bertname, recompute=recomputedata) if "mcd" in dataset.split("/")[1]: print(f"Setting patience to -1 because MCD (was {patience})") patience = -1 # if smalltrainvalid: if True: # "mcd" in dataset.split("/")[1]: realtrainds = [] indtestds = [] splits = [True for _ in range(int(round(len(trainds) * 0.1)))] splits = splits + [False for _ in range(len(trainds) - len(splits))] random.shuffle(splits) for i in range(len(trainds)): if splits[i] is True: indtestds.append(trainds[i]) else: realtrainds.append(trainds[i]) trainds = Dataset(realtrainds) indtestds = Dataset(indtestds) tt.msg("split off 10% of training data for in-distribution test set") # else: # indtestds = Dataset([x for x in validds.examples]) # tt.msg("using validation set as in-distribution test set") tt.msg(f"TRAIN DATA: {len(trainds)}") tt.msg(f"DEV DATA: {len(validds)}") tt.msg(f"TEST DATA: in-distribution: {len(indtestds)}, OOD: {len(testds)}") if trainonvalid: trainds = trainds + validds validds = testds tt.tick("dataloaders") traindl = DataLoader(trainds, batch_size=batsize, shuffle=True, collate_fn=autocollate) validdl = DataLoader(validds, batch_size=batsize, shuffle=False, collate_fn=autocollate) testdl = DataLoader(testds, batch_size=batsize, shuffle=False, collate_fn=autocollate) indtestdl = DataLoader(indtestds, batch_size=batsize, shuffle=False, collate_fn=autocollate) # print(json.dumps(next(iter(trainds)), indent=3)) # print(next(iter(traindl))) # print(next(iter(validdl))) tt.tock() tt.tock() tt.tick("model") cell = GRUDecoderCell(hdim, vocab=fldic, inpvocab=inpdic, numlayers=numlayers, dropout=dropout, worddropout=worddropout, mode=mode) decoder = SeqDecoderBaseline(cell, vocab=fldic, max_size=maxsize, smoothing=smoothing, mode=mode, mcdropout=mcdropout) # print(f"one layer of decoder: \n {cell.decoder.block[0]}") print(decoder) tt.tock() if testcode: tt.tick("testcode") batch = next(iter(traindl)) # out = tagger(batch[1]) tt.tick("train") out = decoder(*batch) tt.tock() decoder.train(False) tt.tick("test") out = decoder(*batch) tt.tock() tt.tock("testcode") tloss = make_array_of_metrics("loss", "elemacc", "acc", reduction="mean") metricnames = ["treeacc", "decnll", "maxmaxnll", "entropy"] tmetrics = make_array_of_metrics(*metricnames, reduction="mean") vmetrics = make_array_of_metrics(*metricnames, reduction="mean") indxmetrics = make_array_of_metrics(*metricnames, reduction="mean") oodxmetrics = make_array_of_metrics(*metricnames, reduction="mean") # region parameters def get_parameters(m, _lr, _enclrmul): bertparams = [] otherparams = [] for k, v in m.named_parameters(): if "encoder_model." in k: bertparams.append(v) else: otherparams.append(v) if len(bertparams) == 0: raise Exception("No encoder parameters found!") paramgroups = [{ "params": bertparams, "lr": _lr * _enclrmul }, { "params": otherparams }] return paramgroups # endregion def get_optim(_m, _lr, _enclrmul, _wreg=0): paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul) optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg) return optim def clipgradnorm(_m=None, _norm=None): torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm) eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(cell)) def wandb_logger(): d = {} for name, loss in zip(["loss", "acc"], tloss): d["train_" + name] = loss.get_epoch_error() if evaltrain: for name, loss in zip(metricnames, tmetrics): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(metricnames, vmetrics): d["valid_" + name] = loss.get_epoch_error() for name, loss in zip(metricnames, indxmetrics): d["indtest_" + name] = loss.get_epoch_error() for name, loss in zip(metricnames, oodxmetrics): d["oodtest_" + name] = loss.get_epoch_error() wandb.log(d) t_max = epochs optim = get_optim(cell, lr, enclrmul) print(f"Total number of updates: {t_max} .") if cosinelr: assert t_max > (warmup + 10) lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine( low=0., high=1.0, steps=t_max - warmup) >> (0. * lr) else: lr_schedule = q.sched.Linear(steps=warmup) >> 1. lr_schedule = q.sched.LRSchedule(optim, lr_schedule) trainbatch = partial( q.train_batch, on_before_optim_step=[lambda: clipgradnorm(_m=cell, _norm=gradnorm)]) if trainonvalidonly: traindl = validdl validdl = testdl trainepoch = partial(q.train_epoch, model=decoder, dataloader=traindl, optim=optim, losses=tloss, device=device, _train_batch=trainbatch, on_end=[lambda: lr_schedule.step()]) trainevalepoch = partial(q.test_epoch, model=decoder, losses=tmetrics, dataloader=traindl, device=device) on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()] validepoch = partial(q.test_epoch, model=decoder, losses=vmetrics, dataloader=validdl, device=device, on_end=on_end_v) indtestepoch = partial(q.test_epoch, model=decoder, losses=indxmetrics, dataloader=indtestdl, device=device) oodtestepoch = partial(q.test_epoch, model=decoder, losses=oodxmetrics, dataloader=testdl, device=device) tt.tick("training") if evaltrain: validfs = [trainevalepoch, validepoch] else: validfs = [validepoch] validfs = validfs + [indtestepoch, oodtestepoch] # results = evaluate(decoder, indtestds, testds, batsize=batsize, device=device) # print(json.dumps(results, indent=4)) q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validfs, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()], validinter=validinter) tt.tock("done training") tt.tick("running test before reloading") testres = oodtestepoch() print(f"Test tree acc: {testres}") tt.tock("ran test") if eyt.remembered is not None and patience >= 0: tt.msg("reloading best") decoder.tagger = eyt.remembered tagger = eyt.remembered tt.tick("rerunning validation") validres = validepoch() tt.tock(f"Validation results: {validres}") tt.tick("running train") trainres = trainevalepoch() print(f"Train tree acc: {trainres}") tt.tock() tt.tick("running ID test") testres = indtestepoch() print(f"ID test tree acc: {testres}") tt.tock() tt.tick("running OOD test") testres = oodtestepoch() print(f"OOD test tree acc: {testres}") tt.tock() results = evaluate(decoder, indtestds, testds, batsize=batsize, device=device) print(json.dumps(results, indent=4)) settings.update({"final_train_loss": tloss[0].get_epoch_error()}) settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()}) settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()}) settings.update( {"final_indtest_tree_acc": indxmetrics[0].get_epoch_error()}) settings.update( {"final_oodtest_tree_acc": oodxmetrics[0].get_epoch_error()}) for k, v in results.items(): for metric, ve in v.items(): settings.update({f"{k}_{metric}": ve}) wandb.config.update(settings) q.pp_dict(settings) return decoder, indtestds, testds
def run( lr=0.0001, enclrmul=0.1, smoothing=0.1, gradnorm=3, batsize=60, epochs=16, patience=10, validinter=1, validfrac=0.1, warmup=3, cosinelr=False, dataset="scan/length", maxsize=50, seed=42, hdim=768, numlayers=6, numheads=12, dropout=0.1, sidedrop=0.0, bertname="bert-base-uncased", testcode=False, userelpos=False, gpu=-1, evaltrain=False, trainonvalid=False, trainonvalidonly=False, recomputedata=False, mode="normal", # "normal", "vib", "aib" priorweight=1., ): settings = locals().copy() q.pp_dict(settings, indent=3) # wandb.init() wandb.init(project=f"compgen_set", config=settings, reinit=True) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu) tt = q.ticktock("script") tt.tick("data") trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset, validfrac=validfrac, bertname=bertname, recompute=recomputedata) if dataset.startswith("cfq"): # filter idmap = torch.arange(fldic.number_of_ids(), device=device) for k, v in fldic.D.items(): if not (k.startswith("ns:") or re.match(r"m\d+", k)): idmap[v] = 0 trainds = trainds.map(lambda x: (x[0], idmap[x[1]])).cache() validds = validds.map(lambda x: (x[0], idmap[x[1]])).cache() testds = testds.map(lambda x: (x[0], idmap[x[1]])).cache() if trainonvalid: trainds = trainds + validds validds = testds tt.tick("dataloaders") traindl = DataLoader(trainds, batch_size=batsize, shuffle=True, collate_fn=autocollate) validdl = DataLoader(validds, batch_size=batsize, shuffle=False, collate_fn=autocollate) testdl = DataLoader(testds, batch_size=batsize, shuffle=False, collate_fn=autocollate) # print(json.dumps(next(iter(trainds)), indent=3)) # print(next(iter(traindl))) # print(next(iter(validdl))) tt.tock() tt.tock() tt.tick("model") model = SetModel(hdim, vocab=fldic, inpvocab=inpdic, numlayers=numlayers, numheads=numheads, dropout=dropout, sidedrop=sidedrop, bertname=bertname, userelpos=userelpos, useabspos=not userelpos, mode=mode, priorweight=priorweight) tt.tock() if testcode: tt.tick("testcode") batch = next(iter(traindl)) # out = tagger(batch[1]) tt.tick("train") out = model(*batch) tt.tock() model.train(False) tt.tick("test") out = model(*batch) tt.tock() tt.tock("testcode") tloss = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean") tmetrics = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean") vmetrics = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean") xmetrics = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean") # region parameters def get_parameters(m, _lr, _enclrmul): bertparams = [] otherparams = [] for k, v in m.named_parameters(): if "encoder_model." in k: bertparams.append(v) else: otherparams.append(v) if len(bertparams) == 0: raise Exception("No encoder parameters found!") paramgroups = [{ "params": bertparams, "lr": _lr * _enclrmul }, { "params": otherparams }] return paramgroups # endregion def get_optim(_m, _lr, _enclrmul, _wreg=0): paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul) optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg) return optim def clipgradnorm(_m=None, _norm=None): torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm) if patience < 0: patience = epochs eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(model)) def wandb_logger(): d = {} for name, loss in zip(["loss", "priorkl", "acc"], tloss): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["acc"], tmetrics): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["acc"], vmetrics): d["valid_" + name] = loss.get_epoch_error() wandb.log(d) t_max = epochs optim = get_optim(model, lr, enclrmul) print(f"Total number of updates: {t_max} .") if cosinelr: assert t_max > (warmup + 10) lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine( low=0., high=1.0, steps=t_max - warmup) >> (0. * lr) else: lr_schedule = q.sched.Linear(steps=warmup) >> 1. lr_schedule = q.sched.LRSchedule(optim, lr_schedule) trainbatch = partial( q.train_batch, on_before_optim_step=[lambda: clipgradnorm(_m=model, _norm=gradnorm)]) print("using test data for validation") validdl = testdl if trainonvalidonly: traindl = validdl validdl = testdl trainepoch = partial(q.train_epoch, model=model, dataloader=traindl, optim=optim, losses=tloss, device=device, _train_batch=trainbatch, on_end=[lambda: lr_schedule.step()]) trainevalepoch = partial(q.test_epoch, model=model, losses=tmetrics, dataloader=traindl, device=device) on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()] validepoch = partial(q.test_epoch, model=model, losses=vmetrics, dataloader=validdl, device=device, on_end=on_end_v) tt.tick("training") if evaltrain: validfs = [trainevalepoch, validepoch] else: validfs = [validepoch] q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validfs, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()], validinter=validinter) tt.tock("done training") tt.tick("running test before reloading") testepoch = partial(q.test_epoch, model=model, losses=xmetrics, dataloader=testdl, device=device) testres = testepoch() print(f"Test tree acc: {testres}") tt.tock("ran test") if eyt.remembered is not None: tt.msg("reloading best") model = eyt.remembered tt.tick("rerunning validation") validres = validepoch() tt.tock(f"Validation results: {validres}") tt.tick("running train") trainres = trainevalepoch() print(f"Train tree acc: {trainres}") tt.tock() tt.tick("running test") testres = testepoch() print(f"Test tree acc: {testres}") tt.tock() settings.update({"final_train_loss": tloss[0].get_epoch_error()}) settings.update({"final_train_acc": tmetrics[2].get_epoch_error()}) settings.update({"final_valid_acc": vmetrics[2].get_epoch_error()}) settings.update({"final_test_acc": xmetrics[2].get_epoch_error()}) wandb.config.update(settings) q.pp_dict(settings)
def run( lr=0.0001, enclrmul=0.01, smoothing=0., gradnorm=3, tmbatsize=60, grubatsize=60, tmepochs=16, gruepochs=16, patience=10, validinter=3, validfrac=0.1, warmup=3, cosinelr=False, dataset="scan/length", mode="normal", # "normal", "noinp" maxsize=50, seed=42, hdim=768, tmnumlayers=6, grunumlayers=2, numheads=12, tmdropout=0.1, grudropout=0.1, worddropout=0., bertname="bert-base-uncased", testcode=False, userelpos=False, gpu=-1, evaltrain=False, trainonvalid=False, trainonvalidonly=False, recomputedata=False, mcdropout=-1, version="grutm_v1.1"): settings = locals().copy() q.pp_dict(settings, indent=3) device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu) grusettings = {(k[3:] if k.startswith("gru") else k): v for k, v in settings.items() if not k.startswith("tm")} grudecoder, indtestds, oodtestds = run_gru(**grusettings) tmsettings = {(k[2:] if k.startswith("tm") else k): v for k, v in settings.items() if not k.startswith("gru")} tmdecoder, _, _ = run_tm(**tmsettings) # create a model that uses tmdecoder to generate output and uses both to measure OOD decoder = HybridSeqDecoder(tmdecoder, grudecoder, mcdropout=mcdropout) results = evaluate(decoder, indtestds, oodtestds, batsize=tmbatsize, device=device) print("Results of the hybrid OOD:") print(json.dumps(results, indent=3)) wandb.init(project=f"compood_grutm_baseline", config=settings, reinit=True) for k, v in results.items(): for metric, ve in v.items(): settings.update({f"{k}_{metric}": ve}) wandb.config.update(settings) q.pp_dict(settings)
def run( domain="restaurants", mode="baseline", # "baseline", "ltr", "uniform", "binary" probthreshold=0., # 0. --> parallel, >1. --> serial, 0.< . <= 1. --> semi-parallel lr=0.0001, enclrmul=0.1, batsize=50, epochs=1000, hdim=366, numlayers=6, numheads=6, dropout=0.1, noreorder=False, trainonvalid=False, seed=87646464, gpu=-1, patience=-1, gradacc=1, cosinelr=False, warmup=20, gradnorm=3, validinter=10, maxsteps=20, maxsize=75, testcode=False, numbered=False, ): settings = locals().copy() q.pp_dict(settings) wandb.init(project=f"seqinsert_overnight_v2", config=settings, reinit=True) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) device = torch.device("cpu") if gpu < 0 else torch.device(gpu) tt = q.ticktock("script") tt.tick("loading") tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds( domain, trainonvalid=trainonvalid, noreorder=noreorder, numbered=numbered) tt.tock("loaded") tdl_seq = DataLoader(tds_seq, batch_size=batsize, shuffle=True, collate_fn=autocollate) vdl_seq = DataLoader(vds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate) xdl_seq = DataLoader(xds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate) # model tagger = TransformerTagger(hdim, flenc.vocab, numlayers, numheads, dropout, baseline=mode == "baseline") if mode == "baseline": decoder = SeqDecoderBaseline(tagger, flenc.vocab, max_steps=maxsteps, max_size=maxsize) elif mode == "ltr": decoder = SeqInsertionDecoderLTR(tagger, flenc.vocab, max_steps=maxsteps, max_size=maxsize) elif mode == "uniform": decoder = SeqInsertionDecoderUniform(tagger, flenc.vocab, max_steps=maxsteps, max_size=maxsize, prob_threshold=probthreshold) elif mode == "binary": decoder = SeqInsertionDecoderBinary(tagger, flenc.vocab, max_steps=maxsteps, max_size=maxsize, prob_threshold=probthreshold) elif mode == "any": decoder = SeqInsertionDecoderAny(tagger, flenc.vocab, max_steps=maxsteps, max_size=maxsize, prob_threshold=probthreshold) # test run if testcode: batch = next(iter(tdl_seq)) # out = tagger(batch[1]) # out = decoder(*batch) decoder.train(False) out = decoder(*batch) tloss = make_array_of_metrics("loss", reduction="mean") tmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean") vmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean") xmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean") # region parameters def get_parameters(m, _lr, _enclrmul): bertparams = [] otherparams = [] for k, v in m.named_parameters(): if "bert_model." in k: bertparams.append(v) else: otherparams.append(v) if len(bertparams) == 0: raise Exception("No encoder parameters found!") paramgroups = [{ "params": bertparams, "lr": _lr * _enclrmul }, { "params": otherparams }] return paramgroups # endregion def get_optim(_m, _lr, _enclrmul, _wreg=0): paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul) optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg) return optim def clipgradnorm(_m=None, _norm=None): torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm) if patience < 0: patience = epochs eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(tagger)) def wandb_logger(): d = {} for name, loss in zip(["CE"], tloss): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["tree_acc", "stepsused"], tmetrics): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["tree_acc", "stepsused"], vmetrics): d["valid_" + name] = loss.get_epoch_error() wandb.log(d) t_max = epochs optim = get_optim(tagger, lr, enclrmul) print(f"Total number of updates: {t_max} .") if cosinelr: lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine( steps=t_max - warmup) >> 0. else: lr_schedule = q.sched.Linear(steps=warmup) >> 1. lr_schedule = q.sched.LRSchedule(optim, lr_schedule) trainbatch = partial( q.train_batch, gradient_accumulation_steps=gradacc, on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)]) trainepoch = partial(q.train_epoch, model=decoder, dataloader=tdl_seq, optim=optim, losses=tloss, device=device, _train_batch=trainbatch, on_end=[lambda: lr_schedule.step()]) trainevalepoch = partial(q.test_epoch, model=decoder, losses=tmetrics, dataloader=tdl_seq, device=device) on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()] validepoch = partial(q.test_epoch, model=decoder, losses=vmetrics, dataloader=vdl_seq, device=device, on_end=on_end_v) tt.tick("training") q.run_training( run_train_epoch=trainepoch, # run_valid_epoch=[trainevalepoch, validepoch], #[validepoch], run_valid_epoch=[validepoch], max_epochs=epochs, check_stop=[lambda: eyt.check_stop()], validinter=validinter) tt.tock("done training") if eyt.remembered is not None and not trainonvalid: tt.msg("reloading best") decoder.tagger = eyt.remembered tagger = eyt.remembered tt.tick("rerunning validation") validres = validepoch() print(f"Validation results: {validres}") tt.tick("running train") trainres = trainevalepoch() print(f"Train tree acc: {trainres}") tt.tock() tt.tick("running test") testepoch = partial(q.test_epoch, model=decoder, losses=xmetrics, dataloader=xdl_seq, device=device) testres = testepoch() print(f"Test tree acc: {testres}") tt.tock() settings.update({"final_train_CE": tloss[0].get_epoch_error()}) settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()}) settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()}) settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()}) settings.update({"final_train_steps_used": tmetrics[1].get_epoch_error()}) settings.update({"final_valid_steps_used": vmetrics[1].get_epoch_error()}) settings.update({"final_test_steps_used": xmetrics[1].get_epoch_error()}) # run different prob_thresholds: # thresholds = [0., 0.3, 0.5, 0.6, 0.75, 0.85, 0.9, 0.95, 1.] thresholds = [ 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 1. ] for threshold in thresholds: tt.tick("running test for threshold " + str(threshold)) decoder.prob_threshold = threshold testres = testepoch() print(f"Test tree acc for threshold {threshold}: testres: {testres}") settings.update( {f"_thr{threshold}_acc": xmetrics[0].get_epoch_error()}) settings.update( {f"_thr{threshold}_len": xmetrics[1].get_epoch_error()}) tt.tock("done") wandb.config.update(settings) q.pp_dict(settings)
def run( lr=0.0001, enclrmul=0.1, smoothing=0.1, gradnorm=3, batsize=60, epochs=16, patience=-1, validinter=1, validfrac=0.1, warmup=3, cosinelr=False, dataset="scan/length", maxsize=50, seed=42, hdim=768, numlayers=6, numheads=12, dropout=0.1, bertname="bert-base-uncased", testcode=False, userelpos=False, gpu=-1, evaltrain=False, trainonvalid=False, trainonvalidonly=False, recomputedata=False, adviters=3, # adversary updates per main update advreset=10, # reset adversary after this number of epochs advcontrib=1., advmaskfrac=0.2, advnumsamples=3, ): settings = locals().copy() q.pp_dict(settings, indent=3) # wandb.init() wandb.init(project=f"compgen_set_aib", config=settings, reinit=True) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu) tt = q.ticktock("script") tt.tick("data") trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset, validfrac=validfrac, bertname=bertname, recompute=recomputedata) if trainonvalid: trainds = trainds + validds validds = testds tt.tick("dataloaders") traindl_main = DataLoader(trainds, batch_size=batsize, shuffle=True, collate_fn=autocollate) traindl_adv = DataLoader(trainds, batch_size=batsize, shuffle=True, collate_fn=autocollate) validdl = DataLoader(validds, batch_size=batsize, shuffle=False, collate_fn=autocollate) testdl = DataLoader(testds, batch_size=batsize, shuffle=False, collate_fn=autocollate) if trainonvalidonly: traindl_main = DataLoader(validds, batch_size=batsize, shuffle=True, collate_fn=autocollate) traindl_adv = DataLoader(validds, batch_size=batsize, shuffle=True, collate_fn=autocollate) validdl = testdl # print(json.dumps(next(iter(trainds)), indent=3)) # print(next(iter(traindl))) # print(next(iter(validdl))) tt.tock() tt.tock() tt.tick("model") encoder = TransformerEncoder(hdim, vocab=inpdic, numlayers=numlayers, numheads=numheads, dropout=dropout, weightmode=bertname, userelpos=userelpos, useabspos=not userelpos) advencoder = TransformerEncoder(hdim, vocab=inpdic, numlayers=numlayers, numheads=numheads, dropout=dropout, weightmode="vanilla", userelpos=userelpos, useabspos=not userelpos) setdecoder = SetDecoder(hdim, vocab=fldic, encoder=encoder) adv = AdvTagger(advencoder, maskfrac=advmaskfrac, vocab=inpdic) model = AdvModel(setdecoder, adv, numsamples=advnumsamples, advcontrib=advcontrib) tt.tock() if testcode: tt.tick("testcode") batch = next(iter(traindl_main)) # out = tagger(batch[1]) tt.tick("train") out = model(*batch) tt.tock() model.train(False) tt.tick("test") out = model(*batch) tt.tock() tt.tock("testcode") tloss_main = make_array_of_metrics("loss", "mainloss", "advloss", "acc", reduction="mean") tloss_adv = make_array_of_metrics("loss", reduction="mean") tmetrics = make_array_of_metrics("loss", "acc", reduction="mean") vmetrics = make_array_of_metrics("loss", "acc", reduction="mean") xmetrics = make_array_of_metrics("loss", "acc", reduction="mean") # region parameters def get_parameters(m, _lr, _enclrmul): bertparams = [] otherparams = [] for k, v in m.named_parameters(): if "encoder_model." in k: bertparams.append(v) else: otherparams.append(v) if len(bertparams) == 0: raise Exception("No encoder parameters found!") paramgroups = [{ "params": bertparams, "lr": _lr * _enclrmul }, { "params": otherparams }] return paramgroups # endregion def get_optim(_m, _lr, _enclrmul, _wreg=0): paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul) optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg) return optim def clipgradnorm(_m=None, _norm=None): torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm) if patience < 0: patience = epochs eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(model)) def wandb_logger(): d = {} for name, loss in zip(["loss", "mainloss", "advloss", "acc"], tloss_main): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["advloss"], tloss_adv): d["train_adv_" + name] = loss.get_epoch_error() for name, loss in zip(["acc"], tmetrics): d["train_" + name] = loss.get_epoch_error() for name, loss in zip(["acc"], vmetrics): d["valid_" + name] = loss.get_epoch_error() wandb.log(d) t_max = epochs optim_main = get_optim(model.core, lr, enclrmul) optim_adv = get_optim(model.adv, lr, enclrmul) print(f"Total number of updates: {t_max} .") if cosinelr: assert t_max > (warmup + 10) lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine( low=0., high=1.0, steps=t_max - warmup) >> (0. * lr) else: lr_schedule = q.sched.Linear(steps=warmup) >> 1. lr_schedule_main = q.sched.LRSchedule(optim_main, lr_schedule) lr_schedule_adv = q.sched.LRSchedule(optim_adv, lr_schedule) trainbatch_main = partial( q.train_batch, on_before_optim_step=[ lambda: clipgradnorm(_m=model.core, _norm=gradnorm) ]) trainbatch_adv = partial( q.train_batch, on_before_optim_step=[ lambda: clipgradnorm(_m=model.adv, _norm=gradnorm) ]) print("using test data for validation") validdl = testdl trainepoch = partial(adv_train_epoch, main_model=model.main_trainmodel, adv_model=model.adv_trainmodel, main_dataloader=traindl_main, adv_dataloader=traindl_adv, main_optim=optim_main, adv_optim=optim_adv, main_losses=tloss_main, adv_losses=tloss_adv, adviters=adviters, device=device, print_every_batch=True, _main_train_batch=trainbatch_main, _adv_train_batch=trainbatch_adv, on_end=[ lambda: lr_schedule_main.step(), lambda: lr_schedule_adv.step() ]) # eval epochs trainevalepoch = partial(q.test_epoch, model=model, losses=tmetrics, dataloader=traindl_main, device=device) on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()] validepoch = partial(q.test_epoch, model=model, losses=vmetrics, dataloader=validdl, device=device, on_end=on_end_v) tt.tick("training") if evaltrain: validfs = [trainevalepoch, validepoch] else: validfs = [validepoch] q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validfs, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()], validinter=validinter) tt.tock("done training") tt.tick("running test before reloading") testepoch = partial(q.test_epoch, model=model, losses=xmetrics, dataloader=testdl, device=device) testres = testepoch() print(f"Test tree acc: {testres}") tt.tock("ran test") if eyt.remembered is not None: assert False tt.msg("reloading best") model = eyt.remembered tt.tick("rerunning validation") validres = validepoch() tt.tock(f"Validation results: {validres}") tt.tick("running train") trainres = trainevalepoch() print(f"Train tree acc: {trainres}") tt.tock() tt.tick("running test") testres = testepoch() print(f"Test tree acc: {testres}") tt.tock() settings.update({"final_train_loss": tloss[0].get_epoch_error()}) settings.update({"final_train_acc": tmetrics[1].get_epoch_error()}) settings.update({"final_valid_acc": vmetrics[1].get_epoch_error()}) settings.update({"final_test_acc": xmetrics[1].get_epoch_error()}) wandb.config.update(settings) q.pp_dict(settings)