Esempio n. 1
0
def run(
        lr=0.0001,
        enclrmul=0.01,
        smoothing=0.,
        gradnorm=3,
        batsize=60,
        epochs=16,
        patience=10,
        validinter=3,
        validfrac=0.1,
        warmup=3,
        cosinelr=False,
        dataset="scan/length",
        mode="normal",  # "normal", "noinp"
        maxsize=50,
        seed=42,
        hdim=768,
        numlayers=6,
        numheads=12,
        dropout=0.1,
        worddropout=0.,
        bertname="bert-base-uncased",
        testcode=False,
        userelpos=False,
        gpu=-1,
        evaltrain=False,
        trainonvalid=False,
        trainonvalidonly=False,
        recomputedata=False,
        mcdropout=-1,
        version="v3"):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    # torch.backends.cudnn.enabled = False

    wandb.init(project=f"compood_gru_baseline_v3",
               config=settings,
               reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    if maxsize < 0:
        if dataset.startswith("cfq"):
            maxsize = 155
        elif dataset.startswith("scan"):
            maxsize = 50
        print(f"maxsize: {maxsize}")

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)

    if "mcd" in dataset.split("/")[1]:
        print(f"Setting patience to -1 because MCD (was {patience})")
        patience = -1

    # if smalltrainvalid:
    if True:  # "mcd" in dataset.split("/")[1]:
        realtrainds = []
        indtestds = []
        splits = [True for _ in range(int(round(len(trainds) * 0.1)))]
        splits = splits + [False for _ in range(len(trainds) - len(splits))]
        random.shuffle(splits)
        for i in range(len(trainds)):
            if splits[i] is True:
                indtestds.append(trainds[i])
            else:
                realtrainds.append(trainds[i])
        trainds = Dataset(realtrainds)
        indtestds = Dataset(indtestds)
        tt.msg("split off 10% of training data for in-distribution test set")
    # else:
    #     indtestds = Dataset([x for x in validds.examples])
    #     tt.msg("using validation set as in-distribution test set")
    tt.msg(f"TRAIN DATA: {len(trainds)}")
    tt.msg(f"DEV DATA: {len(validds)}")
    tt.msg(f"TEST DATA: in-distribution: {len(indtestds)}, OOD: {len(testds)}")
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    indtestdl = DataLoader(indtestds,
                           batch_size=batsize,
                           shuffle=False,
                           collate_fn=autocollate)
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    cell = GRUDecoderCell(hdim,
                          vocab=fldic,
                          inpvocab=inpdic,
                          numlayers=numlayers,
                          dropout=dropout,
                          worddropout=worddropout,
                          mode=mode)
    decoder = SeqDecoderBaseline(cell,
                                 vocab=fldic,
                                 max_size=maxsize,
                                 smoothing=smoothing,
                                 mode=mode,
                                 mcdropout=mcdropout)
    # print(f"one layer of decoder: \n {cell.decoder.block[0]}")
    print(decoder)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = decoder(*batch)
        tt.tock()
        decoder.train(False)
        tt.tick("test")
        out = decoder(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "elemacc", "acc", reduction="mean")
    metricnames = ["treeacc", "decnll", "maxmaxnll", "entropy"]
    tmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    vmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    indxmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    oodxmetrics = make_array_of_metrics(*metricnames, reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(cell))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        if evaltrain:
            for name, loss in zip(metricnames, tmetrics):
                d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, indxmetrics):
            d["indtest_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, oodxmetrics):
            d["oodtest_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(cell, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        on_before_optim_step=[lambda: clipgradnorm(_m=cell, _norm=gradnorm)])

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)
    indtestepoch = partial(q.test_epoch,
                           model=decoder,
                           losses=indxmetrics,
                           dataloader=indtestdl,
                           device=device)
    oodtestepoch = partial(q.test_epoch,
                           model=decoder,
                           losses=oodxmetrics,
                           dataloader=testdl,
                           device=device)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    validfs = validfs + [indtestepoch, oodtestepoch]

    # results = evaluate(decoder, indtestds, testds, batsize=batsize, device=device)
    # print(json.dumps(results, indent=4))

    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testres = oodtestepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None and patience >= 0:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running ID test")
    testres = indtestepoch()
    print(f"ID test tree acc: {testres}")
    tt.tock()

    tt.tick("running OOD test")
    testres = oodtestepoch()
    print(f"OOD test tree acc: {testres}")
    tt.tock()

    results = evaluate(decoder,
                       indtestds,
                       testds,
                       batsize=batsize,
                       device=device)
    print(json.dumps(results, indent=4))

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update(
        {"final_indtest_tree_acc": indxmetrics[0].get_epoch_error()})
    settings.update(
        {"final_oodtest_tree_acc": oodxmetrics[0].get_epoch_error()})
    for k, v in results.items():
        for metric, ve in v.items():
            settings.update({f"{k}_{metric}": ve})

    wandb.config.update(settings)
    q.pp_dict(settings)

    return decoder, indtestds, testds
Esempio n. 2
0
def run(
    lr=0.0001,
    enclrmul=0.1,
    smoothing=0.1,
    gradnorm=3,
    batsize=60,
    epochs=16,
    patience=10,
    validinter=1,
    validfrac=0.1,
    warmup=3,
    cosinelr=False,
    dataset="scan/length",
    maxsize=50,
    seed=42,
    hdim=768,
    numlayers=6,
    numheads=12,
    dropout=0.1,
    sidedrop=0.0,
    bertname="bert-base-uncased",
    testcode=False,
    userelpos=False,
    gpu=-1,
    evaltrain=False,
    trainonvalid=False,
    trainonvalidonly=False,
    recomputedata=False,
    mode="normal",  # "normal", "vib", "aib"
    priorweight=1.,
):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_set", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)

    if dataset.startswith("cfq"):
        # filter
        idmap = torch.arange(fldic.number_of_ids(), device=device)
        for k, v in fldic.D.items():
            if not (k.startswith("ns:") or re.match(r"m\d+", k)):
                idmap[v] = 0
        trainds = trainds.map(lambda x: (x[0], idmap[x[1]])).cache()
        validds = validds.map(lambda x: (x[0], idmap[x[1]])).cache()
        testds = testds.map(lambda x: (x[0], idmap[x[1]])).cache()

    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    model = SetModel(hdim,
                     vocab=fldic,
                     inpvocab=inpdic,
                     numlayers=numlayers,
                     numheads=numheads,
                     dropout=dropout,
                     sidedrop=sidedrop,
                     bertname=bertname,
                     userelpos=userelpos,
                     useabspos=not userelpos,
                     mode=mode,
                     priorweight=priorweight)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = model(*batch)
        tt.tock()
        model.train(False)
        tt.tick("test")
        out = model(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "priorkl", "acc", reduction="mean")
    tmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")
    vmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")
    xmetrics = make_array_of_metrics("loss",
                                     "priorkl",
                                     "acc",
                                     reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(model))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "priorkl", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(model, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        on_before_optim_step=[lambda: clipgradnorm(_m=model, _norm=gradnorm)])

    print("using test data for validation")
    validdl = testdl

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=model,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=model,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=model,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        tt.msg("reloading best")
        model = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_acc": tmetrics[2].get_epoch_error()})
    settings.update({"final_valid_acc": vmetrics[2].get_epoch_error()})
    settings.update({"final_test_acc": xmetrics[2].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)
Esempio n. 3
0
def run(
        lr=0.0001,
        enclrmul=0.01,
        smoothing=0.,
        gradnorm=3,
        tmbatsize=60,
        grubatsize=60,
        tmepochs=16,
        gruepochs=16,
        patience=10,
        validinter=3,
        validfrac=0.1,
        warmup=3,
        cosinelr=False,
        dataset="scan/length",
        mode="normal",  # "normal", "noinp"
        maxsize=50,
        seed=42,
        hdim=768,
        tmnumlayers=6,
        grunumlayers=2,
        numheads=12,
        tmdropout=0.1,
        grudropout=0.1,
        worddropout=0.,
        bertname="bert-base-uncased",
        testcode=False,
        userelpos=False,
        gpu=-1,
        evaltrain=False,
        trainonvalid=False,
        trainonvalidonly=False,
        recomputedata=False,
        mcdropout=-1,
        version="grutm_v1.1"):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    grusettings = {(k[3:] if k.startswith("gru") else k): v
                   for k, v in settings.items() if not k.startswith("tm")}
    grudecoder, indtestds, oodtestds = run_gru(**grusettings)

    tmsettings = {(k[2:] if k.startswith("tm") else k): v
                  for k, v in settings.items() if not k.startswith("gru")}
    tmdecoder, _, _ = run_tm(**tmsettings)

    # create a model that uses tmdecoder to generate output and uses both to measure OOD
    decoder = HybridSeqDecoder(tmdecoder, grudecoder, mcdropout=mcdropout)
    results = evaluate(decoder,
                       indtestds,
                       oodtestds,
                       batsize=tmbatsize,
                       device=device)
    print("Results of the hybrid OOD:")
    print(json.dumps(results, indent=3))

    wandb.init(project=f"compood_grutm_baseline", config=settings, reinit=True)
    for k, v in results.items():
        for metric, ve in v.items():
            settings.update({f"{k}_{metric}": ve})

    wandb.config.update(settings)
    q.pp_dict(settings)
Esempio n. 4
0
def run(
    domain="restaurants",
    mode="baseline",  # "baseline", "ltr", "uniform", "binary"
    probthreshold=0.,  # 0. --> parallel, >1. --> serial, 0.< . <= 1. --> semi-parallel
    lr=0.0001,
    enclrmul=0.1,
    batsize=50,
    epochs=1000,
    hdim=366,
    numlayers=6,
    numheads=6,
    dropout=0.1,
    noreorder=False,
    trainonvalid=False,
    seed=87646464,
    gpu=-1,
    patience=-1,
    gradacc=1,
    cosinelr=False,
    warmup=20,
    gradnorm=3,
    validinter=10,
    maxsteps=20,
    maxsize=75,
    testcode=False,
    numbered=False,
):

    settings = locals().copy()
    q.pp_dict(settings)
    wandb.init(project=f"seqinsert_overnight_v2", config=settings, reinit=True)

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        domain,
        trainonvalid=trainonvalid,
        noreorder=noreorder,
        numbered=numbered)
    tt.tock("loaded")

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim,
                               flenc.vocab,
                               numlayers,
                               numheads,
                               dropout,
                               baseline=mode == "baseline")

    if mode == "baseline":
        decoder = SeqDecoderBaseline(tagger,
                                     flenc.vocab,
                                     max_steps=maxsteps,
                                     max_size=maxsize)
    elif mode == "ltr":
        decoder = SeqInsertionDecoderLTR(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize)
    elif mode == "uniform":
        decoder = SeqInsertionDecoderUniform(tagger,
                                             flenc.vocab,
                                             max_steps=maxsteps,
                                             max_size=maxsize,
                                             prob_threshold=probthreshold)
    elif mode == "binary":
        decoder = SeqInsertionDecoderBinary(tagger,
                                            flenc.vocab,
                                            max_steps=maxsteps,
                                            max_size=maxsize,
                                            prob_threshold=probthreshold)
    elif mode == "any":
        decoder = SeqInsertionDecoderAny(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize,
                                         prob_threshold=probthreshold)

    # test run
    if testcode:
        batch = next(iter(tdl_seq))
        # out = tagger(batch[1])
        # out = decoder(*batch)
        decoder.train(False)
        out = decoder(*batch)

    tloss = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    vmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))

    def wandb_logger():
        d = {}
        for name, loss in zip(["CE"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=tdl_seq,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=tdl_seq,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]

    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    q.run_training(
        run_train_epoch=trainepoch,
        # run_valid_epoch=[trainevalepoch, validepoch], #[validepoch],
        run_valid_epoch=[validepoch],
        max_epochs=epochs,
        check_stop=[lambda: eyt.check_stop()],
        validinter=validinter)
    tt.tock("done training")

    if eyt.remembered is not None and not trainonvalid:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        print(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=xdl_seq,
                        device=device)
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_CE": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()})
    settings.update({"final_train_steps_used": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_steps_used": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_steps_used": xmetrics[1].get_epoch_error()})

    # run different prob_thresholds:
    # thresholds = [0., 0.3, 0.5, 0.6, 0.75, 0.85, 0.9, 0.95,  1.]
    thresholds = [
        0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 1.
    ]
    for threshold in thresholds:
        tt.tick("running test for threshold " + str(threshold))
        decoder.prob_threshold = threshold
        testres = testepoch()
        print(f"Test tree acc for threshold {threshold}: testres: {testres}")
        settings.update(
            {f"_thr{threshold}_acc": xmetrics[0].get_epoch_error()})
        settings.update(
            {f"_thr{threshold}_len": xmetrics[1].get_epoch_error()})
        tt.tock("done")

    wandb.config.update(settings)
    q.pp_dict(settings)
Esempio n. 5
0
def run(
    lr=0.0001,
    enclrmul=0.1,
    smoothing=0.1,
    gradnorm=3,
    batsize=60,
    epochs=16,
    patience=-1,
    validinter=1,
    validfrac=0.1,
    warmup=3,
    cosinelr=False,
    dataset="scan/length",
    maxsize=50,
    seed=42,
    hdim=768,
    numlayers=6,
    numheads=12,
    dropout=0.1,
    bertname="bert-base-uncased",
    testcode=False,
    userelpos=False,
    gpu=-1,
    evaltrain=False,
    trainonvalid=False,
    trainonvalidonly=False,
    recomputedata=False,
    adviters=3,  # adversary updates per main update
    advreset=10,  # reset adversary after this number of epochs
    advcontrib=1.,
    advmaskfrac=0.2,
    advnumsamples=3,
):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_set_aib", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl_main = DataLoader(trainds,
                              batch_size=batsize,
                              shuffle=True,
                              collate_fn=autocollate)
    traindl_adv = DataLoader(trainds,
                             batch_size=batsize,
                             shuffle=True,
                             collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    if trainonvalidonly:
        traindl_main = DataLoader(validds,
                                  batch_size=batsize,
                                  shuffle=True,
                                  collate_fn=autocollate)
        traindl_adv = DataLoader(validds,
                                 batch_size=batsize,
                                 shuffle=True,
                                 collate_fn=autocollate)
        validdl = testdl
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    encoder = TransformerEncoder(hdim,
                                 vocab=inpdic,
                                 numlayers=numlayers,
                                 numheads=numheads,
                                 dropout=dropout,
                                 weightmode=bertname,
                                 userelpos=userelpos,
                                 useabspos=not userelpos)
    advencoder = TransformerEncoder(hdim,
                                    vocab=inpdic,
                                    numlayers=numlayers,
                                    numheads=numheads,
                                    dropout=dropout,
                                    weightmode="vanilla",
                                    userelpos=userelpos,
                                    useabspos=not userelpos)
    setdecoder = SetDecoder(hdim, vocab=fldic, encoder=encoder)
    adv = AdvTagger(advencoder, maskfrac=advmaskfrac, vocab=inpdic)
    model = AdvModel(setdecoder,
                     adv,
                     numsamples=advnumsamples,
                     advcontrib=advcontrib)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl_main))
        # out = tagger(batch[1])
        tt.tick("train")
        out = model(*batch)
        tt.tock()
        model.train(False)
        tt.tick("test")
        out = model(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss_main = make_array_of_metrics("loss",
                                       "mainloss",
                                       "advloss",
                                       "acc",
                                       reduction="mean")
    tloss_adv = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    vmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    xmetrics = make_array_of_metrics("loss", "acc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(model))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "mainloss", "advloss", "acc"],
                              tloss_main):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["advloss"], tloss_adv):
            d["train_adv_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim_main = get_optim(model.core, lr, enclrmul)
    optim_adv = get_optim(model.adv, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule_main = q.sched.LRSchedule(optim_main, lr_schedule)
    lr_schedule_adv = q.sched.LRSchedule(optim_adv, lr_schedule)

    trainbatch_main = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.core, _norm=gradnorm)
        ])
    trainbatch_adv = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.adv, _norm=gradnorm)
        ])

    print("using test data for validation")
    validdl = testdl

    trainepoch = partial(adv_train_epoch,
                         main_model=model.main_trainmodel,
                         adv_model=model.adv_trainmodel,
                         main_dataloader=traindl_main,
                         adv_dataloader=traindl_adv,
                         main_optim=optim_main,
                         adv_optim=optim_adv,
                         main_losses=tloss_main,
                         adv_losses=tloss_adv,
                         adviters=adviters,
                         device=device,
                         print_every_batch=True,
                         _main_train_batch=trainbatch_main,
                         _adv_train_batch=trainbatch_adv,
                         on_end=[
                             lambda: lr_schedule_main.step(),
                             lambda: lr_schedule_adv.step()
                         ])

    # eval epochs
    trainevalepoch = partial(q.test_epoch,
                             model=model,
                             losses=tmetrics,
                             dataloader=traindl_main,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=model,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=model,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        assert False
        tt.msg("reloading best")
        model = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_acc": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_acc": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_acc": xmetrics[1].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)