示例#1
0
def run(traindomains="ALL",
        domain="recipes",
        mincoverage=2,
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        ftlr=0.0001,
        cosinelr=False,
        warmup=0.,
        batsize=30,
        pretrainbatsize=100,
        epochs=100,
        resetmode="none",
        pretrainepochs=100,
        minpretrainepochs=10,
        dropout=0.1,
        decoderdropout=0.5,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="bert-base-uncased",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=30,
        localtest=False,
        printtest=False,
        fullsimplify=True,
        nopretrain=False,
        onlyabstract=False,
        pretrainsetting="all",  # "all", "all+lex", "lex"
        finetunesetting="min",      # "lex", "all", "min"
        ):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    numresets, resetafter, resetevery = 0, 0, 0
    if resetmode == "none":
        pass
    elif resetmode == "once":
        resetafter = 15
        resetevery = 5
        numresets = 1
    elif resetmode == "more":
        resetafter = 15
        resetevery = 5
        numresets = 3
    elif resetmode == "forever":
        resetafter = 15
        resetevery = 5
        numresets = 1000

    print(f'Resetting: "{resetmode}": {numresets} times, first after {resetafter} epochs, then every {resetevery} epochs')

    # wandb.init(project=f"overnight_joint_pretrain_fewshot_{pretrainsetting}-{finetunesetting}-{domain}",
    #            reinit=True, config=settings)
    if traindomains == "ALL":
        alldomains = {"recipes", "restaurants", "blocks", "calendar", "housing", "publications"}
        traindomains = alldomains - {domain, }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc, generaltokenmask = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, onlyabstract=onlyabstract,
                pretrainsetting=pretrainsetting, finetunesetting=finetunesetting)
    tt.msg(f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid")
    tt.msg(f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test")
    tdl = DataLoader(tds, batch_size=pretrainbatsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds, batch_size=pretrainbatsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 decoderdropout=decoderdropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 numbeam=numbeam,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab),
                                 generaltokenmask=generaltokenmask,
                                 resetmode=resetmode
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    if resetmode != "none":
        minpretrainepochs = resetafter + (numresets - 1) * resetevery
    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=minpretrainepochs,
                         more_is_better=True, remember_f=lambda: deepcopy(trainm.model))

    reinit = Reinitializer(trainm.model, resetafter=resetafter, resetevery=resetevery, numresets=numresets, resetothers=[eyt])

    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device,
                         on_end=[lambda: eyt.on_epoch_end(), lambda: reinit()])#, lambda: wandb_logger()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=pretrainepochs,
                       check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": ftlr * enclrmul},
                   {"params": otherparams}]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    # def wandb_logger_ft():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], ftmetrics):
    #         d["ft_train_" + name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], ftvmetrics):
    #         d["ft_valid_" + name] = loss.get_epoch_error()
    #     wandb.log(d)

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=ftdl, optim=ftoptim, losses=ftmetrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=fvdl, losses=ftvmetrics, device=device,
                         on_end=[])#, lambda: wandb_logger_ft()])

    tt.tick("finetuning")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs)
    tt.tock("done finetuning")

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm, dataloader=fvdl, losses=ftvmetrics, device=device)
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=ftxmetrics, device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
def run(lr=0.001,
        batsize=20,
        epochs=60,
        embdim=128,
        encdim=256,
        numlayers=1,
        beamsize=1,
        dropout=.25,
        wreg=1e-10,
        cuda=False,
        gpu=0,
        minfreq=2,
        gradnorm=3.,
        smoothing=0.,
        cosine_restarts=1.,
        seed=456789,
        p_step=.2,
        p_min=.3,
        ):
    localargs = locals().copy()
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = GeoDataset(sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq)
    print(f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)
    model = BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
                             sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True,
                          p_step=p_step, p_min=p_min)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    losses = [CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing)]

    tfdecoder = SeqDecoder(model, tf_ratio=1.,
                           eval=losses + [SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"and", "or"})])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(model, maxtime=100, tf_ratio=0.,
                                 eval=[SeqAccuracies(),
                                       TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                    orderless={"and", "or"})])
        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    else:

        freedecoder = BeamDecoder(model, maxtime=100, beamsize=beamsize,
                                  eval=[SeqAccuracies()],
                                  eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                    orderless={"and", "or"})])
        vlosses = make_array_of_metrics("seq_acc", "tree_acc", "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses,
                         _train_batch=trainbatch, device=device, on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input("Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>")
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match("\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model, maxtime=50, beamsize=beamsize,
                                  eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"op:and", "SW:concat"})])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses, device=device)
        print(_testresults)
        assert(testresults == _testresults)
        tt.tock("tested")
def run(
    lr=0.001,
    domain="restaurants",
    minlr=0.000001,
    enclrmul=0.1,
    hdim=768,
    numlayers=8,
    numheads=12,
    dropout=0.1,
    encdropout=0.1,
    wreg=0.,
    batsize=10,
    epochs=100,
    warmup=0,
    cosinelr=False,
    sustain=0,
    cooldown=0,
    unfreezebertafter=5,
    gradacc=1,
    gradnorm=3,
    patience=15,
    validinter=1,
    seed=87646464,
    gpu=-1,
    mode="leastentropy",
    trainonvalid=False,
    entropylimit=0.,
    noreorder=False,
    # datamode="single",
    # decodemode="single",    # "full", "ltr" (left to right), "single", "entropy-single"
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        domain, trainonvalid=trainonvalid, noreorder=noreorder)
    tt.tock("loaded")

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim,
                               flenc.vocab,
                               numlayers,
                               numheads,
                               dropout,
                               encdropout=encdropout)
    decodermodel = TreeInsertionDecoder(tagger,
                                        seqenc=flenc,
                                        maxsteps=70,
                                        max_tree_size=30,
                                        mode=mode)
    decodermodel.entropylimit = entropylimit

    # batch = next(iter(tdl))
    # out = tagmodel(*batch)

    tmetrics = make_array_of_metrics("loss",
                                     "elemrecall",
                                     "allrecall",
                                     "lowestentropyrecall",
                                     reduction="mean")
    tvmetrics = make_array_of_metrics("treesizes",
                                      "seqlens",
                                      "numsteps",
                                      "treeacc",
                                      reduction="mean")
    vmetrics = make_array_of_metrics("treesizes",
                                     "seqlens",
                                     "numsteps",
                                     "treeacc",
                                     reduction="mean")
    xmetrics = make_array_of_metrics("treesizes",
                                     "seqlens",
                                     "numsteps",
                                     "treeacc",
                                     reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[-1],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul, wreg)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Constant(
            1., steps=sustain) >> q.sched.Cosine(
                low=minlr / lr,
                high=1.,
                steps=t_max - warmup - sustain - cooldown) >> minlr / lr
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    bert_lr_schedule = q.sched.Constant(0., steps=unfreezebertafter) >> 1.
    bert_lr_schedule = bert_lr_schedule * lr_schedule
    lr_schedule = LambdaLR(optim, (bert_lr_schedule, lr_schedule))

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=decodermodel,
                         dataloader=tdl_seq,
                         optim=optim,
                         losses=tmetrics,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainvalidepoch = partial(
        q.test_epoch,
        model=decodermodel,
        losses=tvmetrics,
        dataloader=tdl_seq,
        device=device,
    )

    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    # validepoch()        # TODO: remove this after debugging

    tt.tick("training")
    q.run_training(
        run_train_epoch=trainepoch,
        run_valid_epoch=validepoch,  #[trainvalidepoch, validepoch],
        max_epochs=epochs,
        check_stop=[lambda: eyt.check_stop()],
        validinter=validinter)
    tt.tock("done training")

    if eyt.remembered is not None:
        decodermodel.tagger = eyt.remembered
    tt.msg("reloaded best")

    tt.tick("trying different entropy limits")
    vmetrics2 = make_array_of_metrics("treesizes",
                                      "seqlens",
                                      "numsteps",
                                      "treeacc",
                                      reduction="mean")
    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vmetrics2,
                         dataloader=vdl_seq,
                         device=device)
    entropylimits = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1., 10][::-1]
    for _entropylimit in entropylimits:
        tt.msg(f"entropy limit {_entropylimit}")
        decodermodel.entropylimit = _entropylimit
        tt.msg(validepoch())
    tt.tock("done trying entropy limits")

    tt.tick("testing on test")
    testepoch = partial(
        q.test_epoch,
        model=decodermodel,
        losses=xmetrics,
        dataloader=xdl_seq,
        device=device,
    )
    print(testepoch())
    tt.tock("tested on test")
示例#4
0
def run(
    lr=0.0001,
    enclrmul=0.1,
    smoothing=0.1,
    gradnorm=3,
    batsize=60,
    epochs=16,
    patience=10,
    validinter=1,
    validfrac=0.1,
    warmup=3,
    cosinelr=False,
    dataset="scan/length",
    maxsize=50,
    seed=42,
    hdim=768,
    numlayers=6,
    numheads=12,
    dropout=0.1,
    sidedrop=0.0,
    bertname="bert-base-uncased",
    testcode=False,
    userelpos=False,
    gpu=-1,
    evaltrain=False,
    trainonvalid=False,
    trainonvalidonly=False,
    recomputedata=False,
    mode="normal",  # "normal", "vib", "aib"
    priorweight=1.,
):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_set", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)

    if dataset.startswith("cfq"):
        # filter
        idmap = torch.arange(fldic.number_of_ids(), device=device)
        for k, v in fldic.D.items():
            if not (k.startswith("ns:") or re.match(r"m\d+", k)):
                idmap[v] = 0
        trainds = trainds.map(lambda x: (x[0], idmap[x[1]])).cache()
        validds = validds.map(lambda x: (x[0], idmap[x[1]])).cache()
        testds = testds.map(lambda x: (x[0], idmap[x[1]])).cache()

    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    model = SetModel(hdim,
                     vocab=fldic,
                     inpvocab=inpdic,
                     numlayers=numlayers,
                     numheads=numheads,
                     dropout=dropout,
                     sidedrop=sidedrop,
                     bertname=bertname,
                     userelpos=userelpos,
                     useabspos=not userelpos,
                     mode=mode,
                     priorweight=priorweight)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = model(*batch)
        tt.tock()
        model.train(False)
        tt.tick("test")
        out = model(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean")
    tmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")
    vmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")
    xmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(model))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "priorkl", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(model, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        on_before_optim_step=[lambda: clipgradnorm(_m=model, _norm=gradnorm)])

    print("using test data for validation")
    validdl = testdl

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=model,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=model,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=model,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        tt.msg("reloading best")
        model = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_acc": tmetrics[2].get_epoch_error()})
    settings.update({"final_valid_acc": vmetrics[2].get_epoch_error()})
    settings.update({"final_test_acc": xmetrics[2].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)
示例#5
0
def run(
    lr=0.001,
    batsize=20,
    epochs=60,
    embdim=128,
    encdim=256,
    numlayers=1,
    beamsize=5,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
    numcvfolds=6,
    testfold=-1,  # if non-default, must be within number of splits, the chosen value is used for validation
    reorder_random=False,
):
    localargs = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    cvfolds = None if testfold == -1 else numcvfolds
    testfold = None if testfold == -1 else testfold
    ds = GeoDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq,
        cvfolds=cvfolds,
        testfold=testfold,
        reorder_random=reorder_random)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0,
                                      mode="logprobs",
                                      smoothing=smoothing),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"and"})
                           ])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model,
                             maxtime=100,
                             tf_ratio=0.,
                             eval=[
                                 SeqAccuracies(),
                                 TreeAccuracy(tensor2tree=partial(
                                     tensor2tree, D=ds.query_encoder.vocab),
                                              orderless={"and"})
                             ])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model,
                              maxtime=100,
                              beamsize=beamsize,
                              copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[
                                  TreeAccuracy(tensor2tree=partial(
                                      tensor2tree, D=ds.query_encoder.vocab),
                                               orderless={"and"})
                              ])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc",
                                       "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])

    train_on = "train"
    valid_on = "test" if testfold is None else "valid"
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader(train_on,
                                                  batsize,
                                                  shuffle=True),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader(valid_on,
                                                  batsize,
                                                  shuffle=False),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    if testfold is not None:
        return vlosses[1].get_epoch_error()

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model,
                                   maxtime=100,
                                   beamsize=beamsize,
                                   copy_deep=True,
                                   eval=[SeqAccuracies()],
                                   eval_beam=[
                                       TreeAccuracy(tensor2tree=partial(
                                           tensor2tree,
                                           D=ds.query_encoder.vocab),
                                                    orderless={"and"})
                                   ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
示例#6
0
def run(
    lr=0.001,
    batsize=50,
    epochs=50,
    embdim=100,
    encdim=100,
    numlayers=1,
    beamsize=1,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=3,
    gradnorm=3.,
    cosine_restarts=1.,
    beta=0.001,
    vib_init=True,
    vib_enc=True,
):
    localargs = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = LCQuaDnoENTDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel_VIB(embdim=embdim,
                              hdim=encdim,
                              dropout=dropout,
                              numlayers=numlayers,
                              sentence_encoder=ds.sentence_encoder,
                              query_encoder=ds.query_encoder,
                              feedatt=True,
                              vib_init=vib_init,
                              vib_enc=vib_enc)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    losses = [CELoss(ignore_index=0, mode="logprobs")]
    if vib_init:
        losses.append(
            StatePenalty(lambda state: sum(state.mstate.vib.init),
                         weight=beta))
    if vib_enc:
        losses.append(StatePenalty("mstate.vib.enc", weight=beta))

    tfdecoder = SeqDecoder(
        model,
        tf_ratio=1.,
        eval=losses + [
            SeqAccuracies(),
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"select", "count", "ask"})
        ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(
            model,
            maxtime=40,
            tf_ratio=0.,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"select", "count", "ask"})
            ])
    else:

        freedecoder = BeamDecoder(
            model,
            maxtime=30,
            beamsize=beamsize,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"select", "count", "ask"})
            ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    # if beamsize >= 3:
    #     vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at3", "tree_acc_at_last")
    # else:
    #     vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at_last")

    # trainable_params = tfdecoder.named_parameters()
    # exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")   # don't train input embeddings if doing glove
    # trainable_params = [v for k, v in trainable_params if k not in exclude_params]

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("test", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("valid", batsize),
                               losses=vlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
def run(lr=0.001,
        enclrmul=0.1,
        hdim=768,
        numlayers=8,
        numheads=12,
        dropout=0.1,
        wreg=0.,
        batsize=10,
        epochs=100,
        warmup=0,
        sustain=0,
        cosinelr=False,
        gradacc=1,
        gradnorm=100,
        patience=5,
        validinter=3,
        seed=87646464,
        gpu=-1,
        datamode="full",    # "full", "ltr" (left to right)
        ):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds, vds, xds, tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds("restaurants", mode=datamode)
    tt.tock("loaded")

    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=collate_fn)
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=collate_fn)
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=collate_fn)

    tdl_seq = DataLoader(tds_seq, batch_size=batsize, shuffle=True, collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim, flenc.vocab, numlayers, numheads, dropout)
    tagmodel = TreeInsertionTaggerModel(tagger)
    decodermodel = TreeInsertionDecoder(tagger, seqenc=flenc, maxsteps=50, max_tree_size=30,
                                        mode=datamode)
    decodermodel = TreeInsertionDecoderTrainModel(decodermodel)

    # batch = next(iter(tdl))
    # out = tagmodel(*batch)

    tmetrics = make_array_of_metrics("loss", "elemrecall", "seqrecall", reduction="mean")
    tseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    vmetrics = make_array_of_metrics("treeacc", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{"params": bertparams, "lr": _lr * _enclrmul},
                       {"params": otherparams}]
        return paramgroups
    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(tagger))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul, wreg)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, gradient_accumulation_steps=gradacc,
                                        on_before_optim_step=[lambda : clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch, model=tagmodel,
                                        dataloader=tdl,
                                        optim=optim,
                                        losses=tmetrics,
                                        device=device,
                                        _train_batch=trainbatch,
                                        on_end=[lambda: lr_schedule.step()])

    trainseqepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=tseqmetrics,
                         dataloader=tdl_seq,
                         device=device)

    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    validepoch()        # TODO: remove this after debugging

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=[trainseqepoch, validepoch],
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")
示例#8
0
def run_rerank(
    lr=0.001,
    batsize=20,
    epochs=1,
    embdim=301,  # not used
    encdim=200,
    numlayers=1,
    beamsize=5,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    cosine_restarts=1.,
    domain="restaurants",
    gensavedp="overnight_basic/run{}",
    genrunid=1,
):
    localargs = locals().copy()
    print(locals())
    gensavedrunp = gensavedp.format(genrunid)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = q.load_dataset(gensavedrunp)
    # ds = OvernightDataset(domain=domain, sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    genmodel, genargs = q.load_run(gensavedrunp)
    # BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
    #                          sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    inpenc = q.LSTMEncoder(embdim,
                           *([encdim // 2] * numlayers),
                           bidir=True,
                           dropout_in=dropout)
    outenc = q.LSTMEncoder(embdim,
                           *([encdim // 2] * numlayers),
                           bidir=True,
                           dropout_in=dropout)
    scoremodel = SimpleScoreModel(genmodel.inp_emb, genmodel.out_emb,
                                  LSTMEncoderWrapper(inpenc),
                                  LSTMEncoderWrapper(outenc), DotSimilarity())

    model = BeamReranker(genmodel, scoremodel, beamsize=beamsize, maxtime=50)

    # todo: run over whole dataset to populate beam cache
    testbatch = next(iter(ds.dataloader("train", batsize=2)))
    model(testbatch)

    sys.exit()

    tfdecoder = SeqDecoder(TFTransition(model), [
        CELoss(ignore_index=0, mode="logprobs"),
        SeqAccuracies(),
        TreeAccuracy(tensor2tree=partial(tensor2tree,
                                         D=ds.query_encoder.vocab),
                     orderless={"op:and", "SW:concat"})
    ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(
        model,
        maxtime=50,
        beamsize=beamsize,
        eval_beam=[
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"op:and", "SW:concat"})
        ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3",
                                    "tree_acc_at_last")

    trainable_params = tfdecoder.named_parameters()
    exclude_params = {"model.model.inp_emb.emb.weight"
                      }  # don't train input embeddings if doing glove
    trainable_params = [
        v for k, v in trainable_params if k not in exclude_params
    ]

    # 4. define optim
    optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("valid", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
def run(
    lr=0.001,
    batsize=20,
    epochs=100,
    embdim=100,
    encdim=164,
    numlayers=4,
    numheads=4,
    dropout=.0,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3000.,
    cosine_restarts=1.,
):
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDataset(sentence_encoder=SequenceEncoder(tokenizer=tokenizer),
                         min_freq=minfreq)

    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(hdim=encdim,
                         dropout=dropout,
                         numlayers=numlayers,
                         numheads=numheads,
                         sentence_encoder=ds.sentence_encoder,
                         query_encoder=ds.query_encoder)

    model._metrics = [CELoss(ignore_index=0, mode="logprobs"), SeqAccuracies()]

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc")
    vlosses = make_array_of_metrics("loss", "seq_acc")

    # 4. define optim
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        model.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=model,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")
示例#10
0
def run(
        sourcelang="en",
        supportlang="en",
        testlang="en",
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        cosinelr=False,
        warmup=0.,
        batsize=20,
        epochs=100,
        dropout=0.1,
        dropoutdec=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="xlm-roberta-base",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        trainonvalid=False,
        statesimweight=0.,
        probsimweight=0.,
        projmode="simple",  # "simple" or "twolayer"
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    # wandb.init(project=f"overnight_pretrain_bert-{domain}",
    #            reinit=True, config=settings)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")

    nltok_name = encoder
    tds, vds, xds, nltok, flenc = load_multilingual_geoquery(
        sourcelang,
        supportlang,
        testlang,
        nltok_name=nltok_name,
        trainonvalid=trainonvalid)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(vds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(xds)/(len(tds) + len(vds) + len(xds)):.2f} ({len(tds)}/{len(vds)}/{len(xds)}) train/valid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(
        encoder_name=encoder,
        dec_vocabsize=flenc.vocab.number_of_ids(),
        dec_layers=numlayers,
        dec_dim=hdim,
        dec_heads=numheads,
        dropout=dropout,
        dropoutdec=dropoutdec,
        smoothing=smoothing,
        maxlen=maxlen,
        numbeam=numbeam,
        tensor2tree=partial(_tensor2tree, D=flenc.vocab),
        statesimweight=statesimweight,
        probsimweight=probsimweight,
        projmode=projmode,
    )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params
        if k.startswith("model.model.encoder.model")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder.model")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[-1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda:
                         (deepcopy(trainm.model), deepcopy(trainm.model2)))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["_train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["_valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()
                                 ])  #, on_end=[lambda: wandb_logger()])

    # validepoch()        # TODO comment out after debugging
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.remembered is not None:
        trainm.model = eyt.remembered[0]
        trainm.model2 = eyt.remembered[1]
        testm.model = eyt.remembered[0]
        testm.model2 = eyt.remembered[1]
    tt.msg("reloaded best")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model2
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
示例#11
0
def run(
    lr=0.001,
    enclrmul=0.1,
    hdim=768,
    numlayers=8,
    numheads=12,
    dropout=0.1,
    wreg=0.,
    batsize=10,
    epochs=100,
    warmup=0,
    sustain=0,
    cosinelr=False,
    gradacc=1,
    gradnorm=100,
    patience=5,
    validinter=3,
    seed=87646464,
    gpu=-1,
    datamode="single",
    decodemode="single",  # "full", "ltr" (left to right), "single", "entropy-single"
    trainonvalid=False,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds, vds, xds, tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        "restaurants", mode=datamode, trainonvalid=trainonvalid)
    tt.tock("loaded")

    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=collate_fn)
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=collate_fn)
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=collate_fn)

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim, flenc.vocab, numlayers, numheads, dropout)
    tagmodel = TreeInsertionTaggerModel(tagger)
    decodermodel = TreeInsertionDecoder(tagger,
                                        seqenc=flenc,
                                        maxsteps=50,
                                        max_tree_size=30,
                                        mode=decodemode)
    decodermodel = TreeInsertionDecoderTrainModel(decodermodel)

    # batch = next(iter(tdl))
    # out = tagmodel(*batch)

    tmetrics = make_array_of_metrics("loss",
                                     "elemrecall",
                                     "allrecall",
                                     "entropyrecall",
                                     reduction="mean")
    vmetrics = make_array_of_metrics("loss",
                                     "elemrecall",
                                     "allrecall",
                                     "entropyrecall",
                                     reduction="mean")
    tseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    vseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vseqmetrics[-1],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul, wreg)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=tagmodel,
                         dataloader=tdl,
                         optim=optim,
                         losses=tmetrics,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainseqepoch = partial(q.test_epoch,
                            model=decodermodel,
                            losses=tseqmetrics,
                            dataloader=tdl_seq,
                            device=device)

    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vseqmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    # validepoch()        # TODO: remove this after debugging

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=[trainseqepoch, validepoch],
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.msg("reloading best")
    if eyt.remembered is not None:
        decodermodel.model.tagger = eyt.remembered
        tagmodel.tagger = eyt.remembered

    tt.tick("running test")
    testepoch = partial(q.test_epoch,
                        model=decodermodel,
                        losses=xmetrics,
                        dataloader=xdl_seq,
                        device=device)
    print(testepoch())
    tt.tock()

    # inspect predictions
    validepoch = partial(q.test_epoch,
                         model=tagmodel,
                         losses=vmetrics,
                         dataloader=vdl,
                         device=device)
    print(validepoch())
    inps, outs = q.eval_loop(tagmodel, vdl, device=device)

    # print(outs)

    doexit = False
    for i in range(len(inps[0])):
        for j in range(len(inps[0][i])):
            ui = input("next? (ENTER for next/anything else to exit)>>>")
            if ui != "":
                doexit = True
                break
            question = " ".join(nltok.convert_ids_to_tokens(inps[0][i][j]))
            out_toks = flenc.vocab.tostr(
                inps[1][i][j].detach().cpu().numpy()).split(" ")

            iscorrect = True

            lines = []
            for k, out_tok in enumerate(out_toks):
                gold_toks_for_k = inps[3][i][j][k].detach().cpu().nonzero()[:,
                                                                            0]
                if len(gold_toks_for_k) > 0:
                    gold_toks_for_k = flenc.vocab.tostr(gold_toks_for_k).split(
                        " ")
                else:
                    gold_toks_for_k = [""]

                isopen = inps[2][i][j][k]
                isopen = isopen.detach().cpu().item()

                pred_tok = outs[1][i][j][k].max(-1)[1].detach().cpu().item()
                pred_tok = flenc.vocab(pred_tok)

                pred_tok_correct = pred_tok in gold_toks_for_k or not isopen
                if not pred_tok_correct:
                    iscorrect = False

                entropy = torch.softmax(outs[1][i][j][k], -1).clamp_min(1e-6)
                entropy = -(entropy * torch.log(entropy)).sum().item()
                lines.append(
                    f"{out_tok:25} [{isopen:1}] >> {f'{pred_tok} ({entropy:.3f})':35} {'!!' if not pred_tok_correct else '  '} [{','.join(gold_toks_for_k) if isopen else ''}]"
                )

            print(f"{question} {'!!WRONG!!' if not iscorrect else ''}")
            for line in lines:
                print(line)

        if doexit:
            break
示例#12
0
def run(
    traindomains="ALL",
    domain="recipes",
    mincoverage=2,
    lr=0.001,
    advlr=-1,
    enclrmul=0.1,
    numbeam=1,
    ftlr=0.0001,
    cosinelr=False,
    warmup=0.,
    batsize=30,
    epochs=100,
    pretrainepochs=100,
    dropout=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="bert-base-uncased",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=30,
    localtest=False,
    printtest=False,
    fullsimplify=True,
    domainstart=False,
    useall=False,
    nopretrain=False,
    entropycontrib=1.,
    advsteps=5,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    if advlr < 0:
        advlr = lr
    if traindomains == "ALL":
        alldomains = {
            "recipes", "restaurants", "blocks", "calendar", "housing",
            "publications"
        }
        traindomains = alldomains - {
            domain,
        }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc, absflenc = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, add_domain_start=domainstart, useall=useall)
    advds = Dataset(tds.examples)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid"
    )
    tt.msg(
        f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(autocollate, pad_value=0))
    advdl = DataLoader(advds,
                       batch_size=batsize,
                       shuffle=True,
                       collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds,
                      batch_size=batsize,
                      shuffle=True,
                      collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds,
                      batch_size=batsize,
                      shuffle=False,
                      collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, advtrainm, testm = create_model(
        encoder_name=encoder,
        fl_vocab=flenc.vocab,
        abs_fl_vocab=absflenc.vocab,
        dec_layers=numlayers,
        dec_dim=hdim,
        dec_heads=numheads,
        dropout=dropout,
        smoothing=smoothing,
        maxlen=maxlen,
        numbeam=numbeam,
        abs_id=absflenc.vocab["@ABS@"],
        entropycontrib=entropycontrib,
    )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "ce", "elem_acc", "tree_acc")
    advmetrics = make_array_of_metrics("adv_loss", "adv_elem_acc",
                                       "adv_tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    advoptim = torch.optim.Adam(advtrainm.parameters(),
                                lr=advlr,
                                weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)
    advclipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        advtrainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
        advlr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
        advlr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)
    advlr_schedule = q.sched.LRSchedule(advoptim, advlr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    advtrainbatch = partial(q.train_batch,
                            on_before_optim_step=[advclipgradnorm])
    trainepoch = partial(
        adv_train_epoch,
        model=trainm,
        dataloader=tdl,
        optim=optim,
        losses=metrics,
        advmodel=advtrainm,
        advdataloader=advdl,
        advoptim=advoptim,
        advlosses=advmetrics,
        _train_batch=trainbatch,
        _adv_train_batch=advtrainbatch,
        device=device,
        on_end=[lambda: lr_schedule.step(), lambda: advlr_schedule.step()],
        advsteps=advsteps)
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch,
                       run_valid_epoch=validepoch,
                       max_epochs=pretrainepochs,
                       check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "ce", "elem_acc", "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": ftlr * enclrmul
    }, {
        "params": otherparams
    }]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(ftvmetrics[1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=ftdl,
                         optim=ftoptim,
                         losses=ftmetrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=fvdl,
                         losses=ftvmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=fvdl,
                                losses=ftvmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=ftxmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
示例#13
0
def run(domain="restaurants",
        lr=0.001,
        ptlr=0.0001,
        enclrmul=0.1,
        cosinelr=False,
        ptcosinelr=False,
        warmup=0.,
        ptwarmup=0.,
        batsize=20,
        ptbatsize=50,
        epochs=100,
        ptepochs=100,
        dropout=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        dataseed=12345678,
        datatemp=0.33,
        ptN=3000,
        tokenmaskp=0.,
        spanmaskp=0.,
        spanmasklamda=2.2,
        treemaskp=0.,
        encoder="bart-large",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        ):
    settings = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, vds, xds, nltok, flenc = load_ds(domain=domain, nl_mode=encoder)
    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    tt.tock("data loaded")

    tt.tick("creating grammar dataset generator")
    pcfg = build_grammar(tds, vds)
    ptds = PCFGDataset(pcfg, N=ptN, seed=seed, temperature=datatemp, maxlen=100)
    tt.tock("created dataset generator")

    tt.tick("creating model")
    trainm, testm, pretrainm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab)
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        print("generated dataset")
        print(ptds[0])
        print(ptds[0])
        allexamples = []
        for i in tqdm(range(len(ptds))):
            allexamples.append(ptds[i])
        uniqueexamples = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples)/len(allexamples)}% unique examples ({len(uniqueexamples)}/{len(allexamples)})")
        ptds.advance_seed()
        print(ptds[0])
        allexamples = list(ptds.examples)
        uniqueexamples2 = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples2)/len(allexamples)}% unique examples ({len(uniqueexamples2)}/{len(allexamples)})")
        print(f"{len(uniqueexamples & uniqueexamples2)}/{len(uniqueexamples | uniqueexamples2)} overlap")
        print("---")
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretraining
    # setup data perturbation
    tokenmasker = TokenMasker(p=tokenmaskp, seed=dataseed) if tokenmaskp > 0 else lambda x: x
    spanmasker = SpanMasker(p=spanmaskp, lamda=spanmasklamda, seed=dataseed) if spanmaskp > 0 else lambda x: x
    treemasker = SubtreeMasker(p=treemaskp, seed=dataseed) if treemaskp > 0 else lambda x: x

    perturbed_ptds = ptds\
        .map(lambda x: (treemasker(x), x))\
        .map(lambda x: (flenc.convert(x[0], "tokens"),
                        flenc.convert(x[1], "tokens")))\
        .map(lambda x: (spanmasker(tokenmasker(x[0])), x[1]))
    perturbed_ptds_tokens = perturbed_ptds
    perturbed_ptds = perturbed_ptds\
        .map(lambda x: (flenc.convert(x[0], "tensor"),
                        flenc.convert(x[1], "tensor")))

    if localtest:
        allex = []
        allperturbedex = []
        _nepo = 50
        print(f"checking {_nepo}, each {ptN} generated examples")
        for _e in tqdm(range(_nepo)):
            for i in range(len(perturbed_ptds_tokens)):
                ex = str(ptds[i])
                perturbed_ex = perturbed_ptds_tokens[i]
                perturbed_ex = f"{' '.join(perturbed_ex[0])}->{' '.join(perturbed_ex[1])}"
                allex.append(ex)
                allperturbedex.append(perturbed_ex)
            ptds.advance_seed()
        uniqueex = set(allex)
        uniqueperturbedex = set(allperturbedex)
        print(f"{len(uniqueex)}/{len(allex)} unique examples")
        print(f"{len(uniqueperturbedex)}/{len(allperturbedex)} unique perturbed examples")

    ptdl = DataLoader(perturbed_ptds, batch_size=ptbatsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    ptmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    ptparams = pretrainm.parameters()
    ptoptim = torch.optim.Adam(ptparams, lr=ptlr, weight_decay=wreg)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)
    t_max = ptepochs
    print(f"Total number of pretraining updates: {t_max} .")
    if ptcosinelr:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> q.sched.Cosine(steps=t_max-ptwarmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ptoptim, lr_schedule)

    pttrainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    pttrainepoch = partial(q.train_epoch, model=pretrainm, dataloader=ptdl, optim=ptoptim, losses=ptmetrics,
                         _train_batch=pttrainbatch, device=device, on_end=[lambda: lr_schedule.step(),
                                                                           lambda: ptds.advance_seed()])

    tt.tick("pretraining")
    q.run_training(run_train_epoch=pttrainepoch, max_epochs=ptepochs)
    tt.tock("done pretraining")

    # endregion

    # region finetuning
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=10, more_is_better=True, remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step(), lambda: eyt.on_epoch_end()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device)

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    tt.tick("testing")
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=xmetrics, device=device)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
示例#14
0
def run(
    lr=0.001,
    batsize=20,
    epochs=100,
    embdim=64,
    encdim=128,
    numlayers=1,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    beamsize=1,
    cosine_restarts=1.,
    seed=456789,
):
    # DONE: Porter stemmer
    # DONE: linear attention
    # DONE: grad norm
    # DONE: beam search
    # DONE: lr scheduler
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDatasetFunQL(
        sentence_encoder=SequenceEncoder(tokenizer=tokenizer),
        min_freq=minfreq)

    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(embdim=embdim,
                         hdim=encdim,
                         dropout=dropout,
                         numlayers=numlayers,
                         sentence_encoder=ds.sentence_encoder,
                         query_encoder=ds.query_encoder,
                         feedatt=True)

    # model.apply(initializer)

    tfdecoder = SeqDecoder(
        model,
        tf_ratio=1.,
        eval=[
            CELoss(ignore_index=0, mode="logprobs"),
            SeqAccuracies(),
            TreeAccuracy(
                tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
        ])

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(
            model,
            maxtime=100,
            tf_ratio=0.,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    else:
        print("Doing beam search!")
        freedecoder = BeamDecoder(
            model,
            beamsize=beamsize,
            maxtime=60,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    # 4. define optim
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")
示例#15
0
def run(
        lr=0.0001,
        enclrmul=0.01,
        smoothing=0.,
        gradnorm=3,
        batsize=60,
        epochs=16,
        patience=10,
        validinter=3,
        validfrac=0.1,
        warmup=3,
        cosinelr=False,
        dataset="scan/length",
        mode="normal",  # "normal", "noinp"
        maxsize=50,
        seed=42,
        hdim=768,
        numlayers=6,
        numheads=12,
        dropout=0.1,
        worddropout=0.,
        bertname="bert-base-uncased",
        testcode=False,
        userelpos=False,
        gpu=-1,
        evaltrain=False,
        trainonvalid=False,
        trainonvalidonly=False,
        recomputedata=False,
        mcdropout=-1,
        version="v3"):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    # torch.backends.cudnn.enabled = False

    wandb.init(project=f"compood_gru_baseline_v3",
               config=settings,
               reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    if maxsize < 0:
        if dataset.startswith("cfq"):
            maxsize = 155
        elif dataset.startswith("scan"):
            maxsize = 50
        print(f"maxsize: {maxsize}")

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)

    if "mcd" in dataset.split("/")[1]:
        print(f"Setting patience to -1 because MCD (was {patience})")
        patience = -1

    # if smalltrainvalid:
    if True:  # "mcd" in dataset.split("/")[1]:
        realtrainds = []
        indtestds = []
        splits = [True for _ in range(int(round(len(trainds) * 0.1)))]
        splits = splits + [False for _ in range(len(trainds) - len(splits))]
        random.shuffle(splits)
        for i in range(len(trainds)):
            if splits[i] is True:
                indtestds.append(trainds[i])
            else:
                realtrainds.append(trainds[i])
        trainds = Dataset(realtrainds)
        indtestds = Dataset(indtestds)
        tt.msg("split off 10% of training data for in-distribution test set")
    # else:
    #     indtestds = Dataset([x for x in validds.examples])
    #     tt.msg("using validation set as in-distribution test set")
    tt.msg(f"TRAIN DATA: {len(trainds)}")
    tt.msg(f"DEV DATA: {len(validds)}")
    tt.msg(f"TEST DATA: in-distribution: {len(indtestds)}, OOD: {len(testds)}")
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    indtestdl = DataLoader(indtestds,
                           batch_size=batsize,
                           shuffle=False,
                           collate_fn=autocollate)
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    cell = GRUDecoderCell(hdim,
                          vocab=fldic,
                          inpvocab=inpdic,
                          numlayers=numlayers,
                          dropout=dropout,
                          worddropout=worddropout,
                          mode=mode)
    decoder = SeqDecoderBaseline(cell,
                                 vocab=fldic,
                                 max_size=maxsize,
                                 smoothing=smoothing,
                                 mode=mode,
                                 mcdropout=mcdropout)
    # print(f"one layer of decoder: \n {cell.decoder.block[0]}")
    print(decoder)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = decoder(*batch)
        tt.tock()
        decoder.train(False)
        tt.tick("test")
        out = decoder(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "elemacc", "acc", reduction="mean")
    metricnames = ["treeacc", "decnll", "maxmaxnll", "entropy"]
    tmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    vmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    indxmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    oodxmetrics = make_array_of_metrics(*metricnames, reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(cell))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        if evaltrain:
            for name, loss in zip(metricnames, tmetrics):
                d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, indxmetrics):
            d["indtest_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, oodxmetrics):
            d["oodtest_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(cell, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        on_before_optim_step=[lambda: clipgradnorm(_m=cell, _norm=gradnorm)])

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)
    indtestepoch = partial(q.test_epoch,
                           model=decoder,
                           losses=indxmetrics,
                           dataloader=indtestdl,
                           device=device)
    oodtestepoch = partial(q.test_epoch,
                           model=decoder,
                           losses=oodxmetrics,
                           dataloader=testdl,
                           device=device)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    validfs = validfs + [indtestepoch, oodtestepoch]

    # results = evaluate(decoder, indtestds, testds, batsize=batsize, device=device)
    # print(json.dumps(results, indent=4))

    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testres = oodtestepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None and patience >= 0:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running ID test")
    testres = indtestepoch()
    print(f"ID test tree acc: {testres}")
    tt.tock()

    tt.tick("running OOD test")
    testres = oodtestepoch()
    print(f"OOD test tree acc: {testres}")
    tt.tock()

    results = evaluate(decoder,
                       indtestds,
                       testds,
                       batsize=batsize,
                       device=device)
    print(json.dumps(results, indent=4))

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update(
        {"final_indtest_tree_acc": indxmetrics[0].get_epoch_error()})
    settings.update(
        {"final_oodtest_tree_acc": oodxmetrics[0].get_epoch_error()})
    for k, v in results.items():
        for metric, ve in v.items():
            settings.update({f"{k}_{metric}": ve})

    wandb.config.update(settings)
    q.pp_dict(settings)

    return decoder, indtestds, testds
示例#16
0
def run(
    domain="restaurants",
    mode="baseline",  # "baseline", "ltr", "uniform", "binary"
    probthreshold=0.,  # 0. --> parallel, >1. --> serial, 0.< . <= 1. --> semi-parallel
    lr=0.0001,
    enclrmul=0.1,
    batsize=50,
    epochs=1000,
    hdim=366,
    numlayers=6,
    numheads=6,
    dropout=0.1,
    noreorder=False,
    trainonvalid=False,
    seed=87646464,
    gpu=-1,
    patience=-1,
    gradacc=1,
    cosinelr=False,
    warmup=20,
    gradnorm=3,
    validinter=10,
    maxsteps=20,
    maxsize=75,
    testcode=False,
    numbered=False,
):

    settings = locals().copy()
    q.pp_dict(settings)
    wandb.init(project=f"seqinsert_overnight_v2", config=settings, reinit=True)

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        domain,
        trainonvalid=trainonvalid,
        noreorder=noreorder,
        numbered=numbered)
    tt.tock("loaded")

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim,
                               flenc.vocab,
                               numlayers,
                               numheads,
                               dropout,
                               baseline=mode == "baseline")

    if mode == "baseline":
        decoder = SeqDecoderBaseline(tagger,
                                     flenc.vocab,
                                     max_steps=maxsteps,
                                     max_size=maxsize)
    elif mode == "ltr":
        decoder = SeqInsertionDecoderLTR(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize)
    elif mode == "uniform":
        decoder = SeqInsertionDecoderUniform(tagger,
                                             flenc.vocab,
                                             max_steps=maxsteps,
                                             max_size=maxsize,
                                             prob_threshold=probthreshold)
    elif mode == "binary":
        decoder = SeqInsertionDecoderBinary(tagger,
                                            flenc.vocab,
                                            max_steps=maxsteps,
                                            max_size=maxsize,
                                            prob_threshold=probthreshold)
    elif mode == "any":
        decoder = SeqInsertionDecoderAny(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize,
                                         prob_threshold=probthreshold)

    # test run
    if testcode:
        batch = next(iter(tdl_seq))
        # out = tagger(batch[1])
        # out = decoder(*batch)
        decoder.train(False)
        out = decoder(*batch)

    tloss = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    vmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))

    def wandb_logger():
        d = {}
        for name, loss in zip(["CE"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=tdl_seq,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=tdl_seq,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]

    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    q.run_training(
        run_train_epoch=trainepoch,
        # run_valid_epoch=[trainevalepoch, validepoch], #[validepoch],
        run_valid_epoch=[validepoch],
        max_epochs=epochs,
        check_stop=[lambda: eyt.check_stop()],
        validinter=validinter)
    tt.tock("done training")

    if eyt.remembered is not None and not trainonvalid:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        print(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=xdl_seq,
                        device=device)
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_CE": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()})
    settings.update({"final_train_steps_used": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_steps_used": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_steps_used": xmetrics[1].get_epoch_error()})

    # run different prob_thresholds:
    # thresholds = [0., 0.3, 0.5, 0.6, 0.75, 0.85, 0.9, 0.95,  1.]
    thresholds = [
        0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 1.
    ]
    for threshold in thresholds:
        tt.tick("running test for threshold " + str(threshold))
        decoder.prob_threshold = threshold
        testres = testepoch()
        print(f"Test tree acc for threshold {threshold}: testres: {testres}")
        settings.update(
            {f"_thr{threshold}_acc": xmetrics[0].get_epoch_error()})
        settings.update(
            {f"_thr{threshold}_len": xmetrics[1].get_epoch_error()})
        tt.tock("done")

    wandb.config.update(settings)
    q.pp_dict(settings)
def run(lr=0.01,
        batsize=20,
        epochs=101,
        embdim=100,
        encdim=200,
        numlayers=1,
        dropout=.25,
        wreg=1e-6,
        cuda=False,
        gpu=0,
        minfreq=2,
        gradnorm=3.,
        beamsize=5,
        smoothing=0.,
        fulltest=False,
        cosine_restarts=1.,
        nocopy=True,
        validinter=5,
        ):
    print(locals().copy())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    # stemmer = PorterStemmer()
    # tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    tokenizer = lambda x: x.split()
    ds = GeoQueryDataset(sentence_encoder=SequenceEncoder(tokenizer=tokenizer), min_freq=minfreq)
    dls = ds.dataloader(batsize=batsize)
    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
                             sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True, nocopy=nocopy)

    tfdecoder = SeqDecoder(TFTransition(model),
                           [CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing),
                            SeqAccuracies()])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(model, beamsize=beamsize, maxtime=60,
                              eval_beam=[BeamSeqAccuracies()])

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc")
    vlosses = make_array_of_metrics(*([f"beam_seq_recall_at{i}" for i in range(1, min(beamsize, 5))] + ["beam_recall"]))

    # 4. define optim
    optim = torch.optim.RMSprop(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        t_max = epochs # * len(train_dl)
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=train_dl, optim=optim, losses=losses,
                         _train_batch=trainbatch, device=device, on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch, model=freedecoder, dataloader=test_dl, losses=vlosses, device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs, validinter=validinter)
    tt.tock("done training")
示例#18
0
def run(
    domain="restaurants",
    lr=0.001,
    enclrmul=0.1,
    cosinelr=False,
    warmup=0.,
    batsize=20,
    epochs=100,
    dropout=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="bart-large",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=50,
    localtest=False,
    printtest=False,
    trainonvalid=False,
):
    settings = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, vds, xds, nltok, flenc = load_ds(domain=domain,
                                          nl_mode=encoder,
                                          trainonvalid=trainonvalid)
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(autocollate, pad_value=1))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=1))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=1))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 tensor2tree=partial(_tensor2tree,
                                                     D=flenc.vocab))
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device)

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
示例#19
0
def run(
    lr=0.001,
    batsize=20,
    epochs=70,
    embdim=128,
    encdim=400,
    numlayers=1,
    beamsize=5,
    dropout=.5,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
):
    localargs = locals().copy()
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = GeoDatasetRank()
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    # do_rare_stats(ds)

    # model = TreeRankModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
    #                          sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder)
    #
    model = ParikhRankModel(embdim=encdim,
                            dropout=dropout,
                            sentence_encoder=ds.sentence_encoder,
                            query_encoder=ds.query_encoder)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    ranker = Ranker(model,
                    eval=[BCELoss(mode="logits", smoothing=smoothing)],
                    evalseq=[
                        SeqAccuracies(),
                        TreeAccuracy(tensor2tree=partial(
                            tensor2tree, D=ds.query_encoder.vocab),
                                     orderless={"and", "or"})
                    ])

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        model.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=ranker,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=ranker,
                         dataloader=ds.dataloader("test", batsize),
                         losses=vlosses,
                         device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=ranker,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=ranker,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=100,
            beamsize=beamsize,
            copy_deep=True,
            eval=[SeqAccuracies()],
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
示例#20
0
def run(
    lr=0.0001,
    enclrmul=0.1,
    smoothing=0.1,
    gradnorm=3,
    batsize=60,
    epochs=16,
    patience=-1,
    validinter=1,
    validfrac=0.1,
    warmup=3,
    cosinelr=False,
    dataset="scan/length",
    maxsize=50,
    seed=42,
    hdim=768,
    numlayers=6,
    numheads=12,
    dropout=0.1,
    bertname="bert-base-uncased",
    testcode=False,
    userelpos=False,
    gpu=-1,
    evaltrain=False,
    trainonvalid=False,
    trainonvalidonly=False,
    recomputedata=False,
    adviters=3,  # adversary updates per main update
    advreset=10,  # reset adversary after this number of epochs
    advcontrib=1.,
    advmaskfrac=0.2,
    advnumsamples=3,
):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_set_aib", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl_main = DataLoader(trainds,
                              batch_size=batsize,
                              shuffle=True,
                              collate_fn=autocollate)
    traindl_adv = DataLoader(trainds,
                             batch_size=batsize,
                             shuffle=True,
                             collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    if trainonvalidonly:
        traindl_main = DataLoader(validds,
                                  batch_size=batsize,
                                  shuffle=True,
                                  collate_fn=autocollate)
        traindl_adv = DataLoader(validds,
                                 batch_size=batsize,
                                 shuffle=True,
                                 collate_fn=autocollate)
        validdl = testdl
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    encoder = TransformerEncoder(hdim,
                                 vocab=inpdic,
                                 numlayers=numlayers,
                                 numheads=numheads,
                                 dropout=dropout,
                                 weightmode=bertname,
                                 userelpos=userelpos,
                                 useabspos=not userelpos)
    advencoder = TransformerEncoder(hdim,
                                    vocab=inpdic,
                                    numlayers=numlayers,
                                    numheads=numheads,
                                    dropout=dropout,
                                    weightmode="vanilla",
                                    userelpos=userelpos,
                                    useabspos=not userelpos)
    setdecoder = SetDecoder(hdim, vocab=fldic, encoder=encoder)
    adv = AdvTagger(advencoder, maskfrac=advmaskfrac, vocab=inpdic)
    model = AdvModel(setdecoder,
                     adv,
                     numsamples=advnumsamples,
                     advcontrib=advcontrib)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl_main))
        # out = tagger(batch[1])
        tt.tick("train")
        out = model(*batch)
        tt.tock()
        model.train(False)
        tt.tick("test")
        out = model(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss_main = make_array_of_metrics("loss",
                                       "mainloss",
                                       "advloss",
                                       "acc",
                                       reduction="mean")
    tloss_adv = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    vmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    xmetrics = make_array_of_metrics("loss", "acc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(model))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "mainloss", "advloss", "acc"],
                              tloss_main):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["advloss"], tloss_adv):
            d["train_adv_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim_main = get_optim(model.core, lr, enclrmul)
    optim_adv = get_optim(model.adv, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule_main = q.sched.LRSchedule(optim_main, lr_schedule)
    lr_schedule_adv = q.sched.LRSchedule(optim_adv, lr_schedule)

    trainbatch_main = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.core, _norm=gradnorm)
        ])
    trainbatch_adv = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.adv, _norm=gradnorm)
        ])

    print("using test data for validation")
    validdl = testdl

    trainepoch = partial(adv_train_epoch,
                         main_model=model.main_trainmodel,
                         adv_model=model.adv_trainmodel,
                         main_dataloader=traindl_main,
                         adv_dataloader=traindl_adv,
                         main_optim=optim_main,
                         adv_optim=optim_adv,
                         main_losses=tloss_main,
                         adv_losses=tloss_adv,
                         adviters=adviters,
                         device=device,
                         print_every_batch=True,
                         _main_train_batch=trainbatch_main,
                         _adv_train_batch=trainbatch_adv,
                         on_end=[
                             lambda: lr_schedule_main.step(),
                             lambda: lr_schedule_adv.step()
                         ])

    # eval epochs
    trainevalepoch = partial(q.test_epoch,
                             model=model,
                             losses=tmetrics,
                             dataloader=traindl_main,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=model,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=model,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        assert False
        tt.msg("reloading best")
        model = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_acc": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_acc": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_acc": xmetrics[1].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)