Ejemplo n.º 1
0
def run_span_borders(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=DEFAULT_EPOCHS,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    warmup=-1.,
    sched="ang",
    savep="exp_bert_span_borders_",
    freezeemb=False,
):
    settings = locals().copy()
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span border with BERT")
    tt.tick("loading data")
    data = load_data(which="span/borders")
    trainds, devds, testds = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:-1])
    evalloader = DataLoader(evalds, batch_size=batsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:-1])
    evalloader_dev = DataLoader(evalds_dev, batch_size=batsize, shuffle=False)
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    spandet = BorderSpanDetector(bert, dropout=dropout)
    spandet.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = []
    for paramname, param in spandet.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    optim = BertAdam(params,
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched])
    losses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(reduction="none"),
        q.SeqAccuracy()
    ]
    xlosses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(reduction="none"),
        q.SeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(spandet, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(spandet, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), testpreds)
        # save dev predictions
        testpreds = q.eval_loop(spandet, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), testpreds)
        tt.tock("done")
Ejemplo n.º 2
0
def run(
    lr=0.001,
    batsize=20,
    epochs=100,
    embdim=100,
    encdim=164,
    numlayers=4,
    numheads=4,
    dropout=.0,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3000.,
    cosine_restarts=1.,
):
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDataset(sentence_encoder=SequenceEncoder(tokenizer=tokenizer),
                         min_freq=minfreq)

    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(hdim=encdim,
                         dropout=dropout,
                         numlayers=numlayers,
                         numheads=numheads,
                         sentence_encoder=ds.sentence_encoder,
                         query_encoder=ds.query_encoder)

    model._metrics = [CELoss(ignore_index=0, mode="logprobs"), SeqAccuracies()]

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc")
    vlosses = make_array_of_metrics("loss", "seq_acc")

    # 4. define optim
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        model.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=model,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")
Ejemplo n.º 3
0
def try_dataset():
    tt = q.ticktock("dataset")
    tt.tick("building dataset")
    ds = GeoQueryDataset(sentence_encoder=SentenceEncoder(
        tokenizer=lambda x: x.split()))
    tt.tock("dataset built")
Ejemplo n.º 4
0
def run(
    lr=0.001,
    batsize=20,
    epochs=60,
    embdim=128,
    encdim=256,
    numlayers=1,
    beamsize=5,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
    numcvfolds=6,
    testfold=-1,  # if non-default, must be within number of splits, the chosen value is used for validation
    reorder_random=False,
):
    localargs = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    cvfolds = None if testfold == -1 else numcvfolds
    testfold = None if testfold == -1 else testfold
    ds = GeoDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq,
        cvfolds=cvfolds,
        testfold=testfold,
        reorder_random=reorder_random)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0,
                                      mode="logprobs",
                                      smoothing=smoothing),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"and"})
                           ])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model,
                             maxtime=100,
                             tf_ratio=0.,
                             eval=[
                                 SeqAccuracies(),
                                 TreeAccuracy(tensor2tree=partial(
                                     tensor2tree, D=ds.query_encoder.vocab),
                                              orderless={"and"})
                             ])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model,
                              maxtime=100,
                              beamsize=beamsize,
                              copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[
                                  TreeAccuracy(tensor2tree=partial(
                                      tensor2tree, D=ds.query_encoder.vocab),
                                               orderless={"and"})
                              ])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc",
                                       "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])

    train_on = "train"
    valid_on = "test" if testfold is None else "valid"
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader(train_on,
                                                  batsize,
                                                  shuffle=True),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader(valid_on,
                                                  batsize,
                                                  shuffle=False),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    if testfold is not None:
        return vlosses[1].get_epoch_error()

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model,
                                   maxtime=100,
                                   beamsize=beamsize,
                                   copy_deep=True,
                                   eval=[SeqAccuracies()],
                                   eval_beam=[
                                       TreeAccuracy(tensor2tree=partial(
                                           tensor2tree,
                                           D=ds.query_encoder.vocab),
                                                    orderless={"and"})
                                   ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
Ejemplo n.º 5
0
def run_gatedtree(
    lr=0.01,
    gradclip=5.,
    batsize=20,
    epochs=80,
    embdim=200,
    encdim=200,
    numlayer=1,
    cuda=False,
    gpu=0,
    wreg=1e-8,
    dropout=0.5,
    smoothing=0.4,
    goldsmoothing=-0.1,
    which="geo",
    relatt=False,
):
    tt = q.ticktock("script")
    tt.msg("running gated tree decoder")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)

    # region data
    tt.tick("generating data")
    # dss, D = gen_sort_data(seqlen=seqlen, numvoc=numvoc, numex=numex, prepend_inp=False)
    dss, nlD, flD = gen_datasets(which=which)
    tloader, vloader, xloader = [
        torch.utils.data.DataLoader(ds, batch_size=batsize, shuffle=True)
        for ds in dss
    ]
    seqlen = len(dss[0][0][1])
    id2pushpop = torch.zeros(len(flD), dtype=torch.long, device=device)
    id2pushpop[flD["("]] = +1
    id2pushpop[flD[")"]] = -1

    tt.tock("data generated")
    # endregion

    # region model
    tt.tick("building model")
    # source side
    inpemb = q.WordEmb(embdim, worddic=nlD)
    encdims = [encdim] * numlayer
    encoder = q.LSTMEncoder(embdim,
                            *encdims,
                            bidir=False,
                            dropout_in_shared=dropout)

    # target side
    decemb = q.WordEmb(embdim, worddic=flD)
    decinpdim = embdim
    decdims = [decinpdim] + [encdim] * numlayer
    dec_core = \
        [GatedTreeLSTMCell(decdims[i-1], decdims[i], dropout_in=dropout) for i in range(1, len(decdims))]        ###
    dec_core = TreeRNNDecoderCellCore(*dec_core)
    if relatt:
        att = ComboAbsRelAttention(ctxdim=encdim, vecdim=encdim)
    else:
        att = BasicAttention()
    out = torch.nn.Sequential(q.WordLinout(encdim, worddic=flD),
                              # torch.nn.Softmax(-1)
                              )
    merge = q.rnn.FwdDecCellMerge(decdims[-1], encdims[-1], outdim=encdim)
    deccell = TreeRNNDecoderCell(emb=decemb,
                                 core=dec_core,
                                 att=att,
                                 out=out,
                                 merge=merge,
                                 id2pushpop=id2pushpop)
    train_dec = q.TFDecoder(deccell)
    test_dec = q.FreeDecoder(deccell, maxtime=seqlen + 10)
    train_encdec = EncDec(inpemb, encoder, train_dec)
    test_encdec = Test_EncDec(inpemb, encoder, test_dec)

    train_encdec.to(device)
    test_encdec.to(device)
    tt.tock("built model")
    # endregion

    # region training
    # losses:
    if smoothing == 0:
        ce = q.loss.CELoss(mode="logits", ignore_index=0)
    elif goldsmoothing < 0.:
        ce = q.loss.SmoothedCELoss(mode="logits",
                                   ignore_index=0,
                                   smoothing=smoothing)
    else:
        ce = q.loss.DiffSmoothedCELoss(mode="logits",
                                       ignore_index=0,
                                       alpha=goldsmoothing,
                                       beta=smoothing)
    acc = q.loss.SeqAccuracy(ignore_index=0)
    elemacc = q.loss.SeqElemAccuracy(ignore_index=0)
    treeacc = TreeAccuracyLambdaDFPar(flD=flD)
    # optim
    optim = torch.optim.RMSprop(train_encdec.parameters(),
                                lr=lr,
                                alpha=0.95,
                                weight_decay=wreg)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_value_(
        train_encdec.parameters(), clip_value=gradclip)
    # lööps
    batchloop = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainloop = partial(
        q.train_epoch,
        model=train_encdec,
        dataloader=tloader,
        optim=optim,
        device=device,
        losses=[q.LossWrapper(ce),
                q.LossWrapper(elemacc),
                q.LossWrapper(acc)],
        print_every_batch=False,
        _train_batch=batchloop)
    validloop = partial(q.test_epoch,
                        model=test_encdec,
                        dataloader=vloader,
                        device=device,
                        losses=[q.LossWrapper(treeacc)],
                        print_every_batch=False)

    tt.tick("training")
    q.run_training(trainloop, validloop, max_epochs=epochs)
    tt.tock("trained")

    tt.tick("testing")
    test_results = validloop(model=test_encdec, dataloader=xloader)
    print("Test results (freerunning): {}".format(test_results))
    test_results = validloop(model=train_encdec, dataloader=xloader)
    print("Test results (TF): {}".format(test_results))
    tt.tock("tested")
    # endregion
    tt.msg("done")
Ejemplo n.º 6
0
def train_epoch_distill(model=None,
                        dataloader=None,
                        optim=None,
                        losses=None,
                        device=torch.device("cpu"),
                        tt=q.ticktock("-"),
                        current_epoch=0,
                        max_epochs=0,
                        _train_batch=train_batch_distill,
                        on_start=tuple(),
                        on_end=tuple(),
                        run=False,
                        mbase=None,
                        goldgetter=None):
    """
    Performs an epoch of training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    # if run is False:
    #     kwargs = locals().copy()
    #     return partial(train_epoch, **kwargs)

    for loss in losses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()

    [e() for e in on_start]

    q.epoch_reset(model)
    if mbase is not None:
        q.epoch_reset(mbase)

    for i, _batch in enumerate(dataloader):
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs,
                             run=True,
                             mbase=mbase,
                             goldgetter=goldgetter)
        tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Ejemplo n.º 7
0
def run(traindomains="ALL",
        domain="recipes",
        mincoverage=2,
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        ftlr=0.0001,
        cosinelr=False,
        warmup=0.,
        batsize=30,
        epochs=100,
        pretrainepochs=100,
        dropout=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="bert-base-uncased",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=30,
        localtest=False,
        printtest=False,
        fullsimplify=True,
        domainstart=False,
        useall=False,
        nopretrain=False,
        onlyabstract=False,
        uselexicon=False,
        ):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    wandb.init(project="overnight_base_fewshot", reinit=True, config=settings)
    if traindomains == "ALL":
        alldomains = {"recipes", "restaurants", "blocks", "calendar", "housing", "publications"}
        traindomains = alldomains - {domain, }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, add_domain_start=domainstart, useall=useall, onlyabstract=onlyabstract,
                uselexicon=uselexicon)
    tt.msg(f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid")
    tt.msg(f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test")
    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 numbeam=numbeam,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab)
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=10, more_is_better=True, remember_f=lambda: deepcopy(trainm.model))
    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
            d["train_"+name] = loss.get_epoch_error()
        for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
            d["valid_"+name] = loss.get_epoch_error()
        wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device,
                         on_end=[lambda: eyt.on_epoch_end(), lambda: wandb_logger()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=pretrainepochs, check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": ftlr * enclrmul},
                   {"params": otherparams}]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(ftvmetrics[1], patience=1000, min_epochs=10, more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    def wandb_logger_ft():
        d = {}
        for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], ftmetrics):
            d["ft_train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["seq_acc", "tree_acc"], ftvmetrics):
            d["ft_valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=ftdl, optim=ftoptim, losses=ftmetrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=fvdl, losses=ftvmetrics, device=device,
                         on_end=[lambda: eyt.on_epoch_end(), lambda: wandb_logger_ft()])

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm, dataloader=fvdl, losses=ftvmetrics, device=device)
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=ftxmetrics, device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    wandb.config.update(settings)
    # print(settings)
    return settings
Ejemplo n.º 8
0
def run_relations(
    lr=DEFAULT_LR,
    dropout=.3,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskentity=False,
    savep="exp_bilstm_rels_",
    test=False,
    datafrac=1.,
    vanillaemb=False,
    gloveemb=True,
    embdim=300,
    dim=300,
    numlayers=2,
    warmup=0.01,
    cycles=0.5,
    sched="cos",
    evalbatsize=-1,
    classweighted=False,
):
    print(locals())
    settings = locals().copy()
    if evalbatsize < 0:
        evalbatsize = batsize
    if test:
        epochs = 0
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    assert (not gloveemb or not vanillaemb)
    tt = q.ticktock("script")
    tt.msg("running relation classifier with BiLSTM")
    tt.tick("loading data")
    data = load_data(which="rel+borders",
                     retrelD=True,
                     datafrac=datafrac,
                     wordlevel=gloveemb,
                     rettokD=True)
    trainds, devds, testds, relD, tokD = data
    if maskentity:
        trainds, devds, testds = replace_entity_span(trainds, devds, testds)
    else:
        trainds, devds, testds = [
            TensorDataset(ds.tensors[0], ds.tensors[2])
            for ds in [trainds, devds, testds]
        ]
    relcounts = torch.zeros(max(relD.values()) + 1)
    trainrelcounts = torch.bincount(trainds.tensors[1])
    relcounts[:len(trainrelcounts)] += trainrelcounts.float()
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)

    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("making model")
    if vanillaemb:
        bert = BertModel.from_pretrained("bert-base-uncased")
        emb = bert.embeddings.word_embeddings
        tt.msg("using vanilla emb of size {}".format(embdim))
        emb = torch.nn.Embedding(emb.weight.size(0), embdim)
    elif gloveemb:
        emb = q.WordEmb.load_glove("glove.50d", selectD=tokD)
    else:
        bert = BertModel.from_pretrained("bert-base-uncased")
        emb = bert.embeddings.word_embeddings
        embdim = bert.config.hidden_size
    bilstm = q.rnn.LSTMEncoder(embdim,
                               *([dim] * numlayers),
                               bidir=True,
                               dropout_in=dropout)
    # bilstm = torch.nn.LSTM(embdim, dim, batch_first=True, num_layers=numlayers, bidirectional=True, dropout=dropout)
    m = RelationClassifier(emb=emb,
                           bilstm=bilstm,
                           dim=dim * 2,
                           relD=relD,
                           dropout=dropout)
    m.to(device)
    tt.tock("made model")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = m.parameters()
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    # optim = BertAdam(params, lr=lr, weight_decay=wreg, warmup=warmup, t_total=totalsteps, schedule=schedmap[sched])
    optim = BertAdam(params, lr=lr, weight_decay=wreg, schedule=sched)
    losses = [
        q.SmoothedCELoss(smoothing=smoothing,
                         weight=1 /
                         relcounts.clamp_min(1e-6) if classweighted else None),
        q.Accuracy()
    ]
    xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=m,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=m,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=m,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(m, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.test.npy"), testpreds)
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.dev.npy"), testpreds)
        tt.msg("saved in {}".format(savedir))
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
Ejemplo n.º 9
0
def run_seq2seq_(
    lr=0.001,
    batsize=32,
    evalbatsize=256,
    epochs=100,
    warmup=5,
    embdim=50,
    encdim=100,
    numlayers=2,
    dropout=.0,
    wreg=1e-6,
    cuda=False,
    gpu=0,
):
    settings = locals().copy()
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt = q.ticktock("script")
    tt.msg("running seq2seq on LC-QuAD")

    tt.tick("loading data")
    xsm, ysm, teststart, tok2act = load_data()
    _tok2act = {ysm.RD[k]: v for k, v in tok2act.items()}

    print("Some examples:")
    for i in range(5):
        print(
            f"{xsm[i]}\n ->{ysm[i]}\n -> {Node.from_transitions(' '.join(ysm[i].split()[1:]), _tok2act)}"
        )

    print("Non-leaf tokens:")
    print({ysm.RD[k]: v for k, v in tok2act.items() if v > 0})

    devstart = teststart - 500
    trainds = torch.utils.data.TensorDataset(
        torch.tensor(xsm.matrix[:devstart]).long(),
        torch.tensor(ysm.matrix[:devstart, :-1]).long(),
        torch.tensor(ysm.matrix[:devstart, 1:]).long())
    valds = torch.utils.data.TensorDataset(
        torch.tensor(xsm.matrix[devstart:teststart]).long(),
        torch.tensor(ysm.matrix[devstart:teststart, :-1]).long(),
        torch.tensor(ysm.matrix[devstart:teststart, 1:]).long())
    testds = torch.utils.data.TensorDataset(
        torch.tensor(xsm.matrix[teststart:]).long(),
        torch.tensor(ysm.matrix[teststart:, :-1]).long(),
        torch.tensor(ysm.matrix[teststart:, 1:]).long())
    tt.msg(
        f"Data splits: train: {len(trainds)}, valid: {len(valds)}, test: {len(testds)}"
    )

    tloader = torch.utils.data.DataLoader(trainds,
                                          batch_size=batsize,
                                          shuffle=True)
    vloader = torch.utils.data.DataLoader(valds,
                                          batch_size=evalbatsize,
                                          shuffle=False)
    xloader = torch.utils.data.DataLoader(testds,
                                          batch_size=evalbatsize,
                                          shuffle=False)
    tt.tock("data loaded")

    # model
    enclayers, declayers = numlayers, numlayers
    decdim = encdim
    xemb = q.WordEmb(embdim, worddic=xsm.D)
    yemb = q.WordEmb(embdim, worddic=ysm.D)
    encdims = [embdim] + [encdim // 2] * enclayers
    xenc = q.LSTMEncoder(embdim,
                         *encdims[1:],
                         bidir=True,
                         dropout_in_shared=dropout)
    decdims = [embdim] + [decdim] * declayers
    dec_core = torch.nn.Sequential(*[
        q.LSTMCell(decdims[i - 1],
                   decdims[i],
                   dropout_in=dropout,
                   dropout_rec=dropout) for i in range(1, len(decdims))
    ])
    yout = q.WordLinout(encdim + decdim, worddic=ysm.D)
    dec_cell = semparse.rnn.LuongCell(emb=yemb,
                                      core=dec_core,
                                      out=yout,
                                      dropout=dropout)
    decoder = q.TFDecoder(dec_cell)
    testdecoder = q.FreeDecoder(dec_cell, maxtime=100)

    m = Seq2Seq(xemb, xenc, decoder)
    testm = Seq2Seq(xemb, xenc, testdecoder, test=True)

    # test model
    tt.tick("running a batch")
    test_y = m(*iter(tloader).next()[:-1])
    q.batch_reset(m)
    test_y = testm(*iter(vloader).next()[:-1])
    q.batch_reset(m)
    tt.tock(f"ran a batch: {test_y.size()}")

    optim = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wreg)
    tlosses = [
        q.CELoss(mode="logits", ignore_index=0),
        q.Accuracy(ignore_index=0),
        q.SeqAccuracy(ignore_index=0)
    ]
    xlosses = [
        q.CELoss(mode="logits", ignore_index=0),
        q.Accuracy(ignore_index=0),
        q.SeqAccuracy(ignore_index=0)
    ]
    tlosses = [q.LossWrapper(l) for l in tlosses]
    vlosses = [q.LossWrapper(l) for l in xlosses]
    xlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=m,
                        dataloader=tloader,
                        optim=optim,
                        losses=tlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=testm,
                      dataloader=vloader,
                      losses=vlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=testm,
                       dataloader=xloader,
                       losses=xlosses,
                       device=device)

    lrplateau = q.util.ReduceLROnPlateau(optim,
                                         mode="max",
                                         factor=.1,
                                         patience=3,
                                         cooldown=1,
                                         warmup=warmup,
                                         threshold=0.,
                                         verbose=True,
                                         eps=1e-9)
    on_after_valid = [lambda: lrplateau.step(vlosses[1].get_epoch_error())]
    _devloop = partial(devloop, on_end=on_after_valid)
    stoptrain = [lambda: all([pg["lr"] <= 1e-7 for pg in optim.param_groups])]

    tt.tick("training")
    q.run_training(trainloop,
                   _devloop,
                   max_epochs=epochs,
                   check_stop=stoptrain)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    settings["testres"] = testres
    tt.tock("tested")

    devres = devloop()
    print(devres, vlosses[0].get_epoch_error())

    return vlosses[1].get_epoch_error()
Ejemplo n.º 10
0
def run_span_borders(
    lr=DEFAULT_LR,
    dropout=.3,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    evalbatsize=-1,
    epochs=DEFAULT_EPOCHS,
    smoothing=DEFAULT_SMOOTHING,
    dim=200,
    numlayers=1,
    cuda=False,
    gpu=0,
    savep="exp_bilstm_span_borders_",
    datafrac=1.,
    vanillaemb=False,
    embdim=300,
    sched="cos",
    warmup=0.1,
    cycles=0.5,
):
    settings = locals().copy()
    print(locals())
    if evalbatsize < 0:
        evalbatsize = batsize
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span border with BiLSTM")
    tt.tick("loading data")
    data = load_data(which="span/borders", datafrac=datafrac)
    trainds, devds, testds = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:-1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:-1])
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("creating model")
    # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    bert = BertModel.from_pretrained("bert-base-uncased")
    emb = bert.embeddings.word_embeddings
    if vanillaemb:
        tt.msg("using vanilla emb of size {}".format(embdim))
        emb = torch.nn.Embedding(emb.weight.size(0), embdim)
    else:
        embdim = bert.config.hidden_size
    # inpD = tokenizer.vocab
    # q.WordEmb.masktoken = "[PAD]"
    # emb = q.WordEmb(embdim, worddic=inpD)
    bilstm = q.rnn.LSTMEncoder(embdim,
                               *([dim] * numlayers),
                               bidir=True,
                               dropout_in_shared=dropout)
    spandet = BorderSpanDetector(emb, bilstm, dim * 2, dropout=dropout)
    spandet.to(device)
    tt.tock("model created")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = spandet.parameters()
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    optim = BertAdam(params, lr=lr, weight_decay=wreg, schedule=sched)
    # optim = torch.optim.Adam(spandet.parameters(), lr=lr, weight_decay=wreg)
    losses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(),
        q.SeqAccuracy()
    ]
    xlosses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(),
        q.SeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(spandet, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))

        outlen = trainloader.dataset.tensors[0].size(1)
        spandet.outlen = outlen

        # save test predictions
        testpreds = q.eval_loop(spandet, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), testpreds)
        # save dev predictions
        testpreds = q.eval_loop(spandet, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), testpreds)
        tt.msg("saved in {}".format(savedir))
        tt.tock("done")
Ejemplo n.º 11
0
def load_data(p="../../data/buboqa/data/bertified_dataset.npz",
              which="span/io",
              retrelD=False,
              retrelcounts=False,
              rettokD=False,
              datafrac=1.,
              wordlevel=False):
    """
    :param p:       where the stored matrices are
    :param which:   which data to include in output datasets
                        "span/io": O/I annotated spans,
                        "span/borders": begin and end positions of span
                        "rel+io": what relation (also gives "spanio" outputs to give info where entity is supposed to be (to ignore it))
                        "rel+borders": same, but gives "spanborders" instead
                        "all": everything
    :return:
    """
    tt = q.ticktock("dataloader")
    if wordlevel:
        tt.tick("loading original data word-level stringmatrix")
        wordmat, wordD, (word_devstart, word_teststart) = load_word_mat()
        twordmat, vwordmat, xwordmat = wordmat[:word_devstart], wordmat[
            word_devstart:word_teststart], wordmat[word_teststart:]
        tt.tock("loaded stringmatrix")
    tt.tick("loading saved np mats")
    data = np.load(p)
    print(data.keys())
    relD = data["relD"].item()
    revrelD = {v: k for k, v in relD.items()}
    devstart = data["devstart"]
    teststart = data["teststart"]
    if wordlevel:
        assert (devstart == word_devstart)
        assert (teststart == word_teststart)
    tt.tock("mats loaded")
    tt.tick("loading BERT tokenizer")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tt.tock("done")

    if wordlevel:
        tokD = wordD
    else:
        tokD = tokenizer.vocab

    def pp(i):
        tokrow = data["tokmat"][i]
        iorow = [xe - 1 for xe in data["iomat"][i] if xe != 0]
        ioborderrow = data["ioborders"][i]
        rel_i = data["rels"][i]
        tokrow = tokenizer.convert_ids_to_tokens(
            [tok for tok in tokrow if tok != 0])
        # print(" ".join(tokrow))
        print(tabulate([range(len(tokrow)), tokrow, iorow]))
        print(ioborderrow)
        print(revrelD[rel_i])

    # tt.tick("printing some examples")
    # for k in range(10):
    #     print("\nExample {}".format(k))
    #     pp(k)
    # tt.tock("printed some examples")

    # datasets
    tt.tick("making datasets")
    if which == "span/io":
        selection = ["tokmat", "iomat"]
    elif which == "span/borders":
        selection = ["tokmat", "ioborders"]
    elif which == "rel+io":
        selection = ["tokmat", "iomat", "rels"]
    elif which == "rel+borders":
        selection = ["tokmat", "ioborders", "rels"]
    elif which == "all":
        selection = ["tokmat", "iomat", "ioborders", "rels"]
    else:
        raise q.SumTingWongException("unknown which mode: {}".format(which))

    if wordlevel:
        tokmat = wordmat
    else:
        tokmat = data["tokmat"]

    selected = [
        torch.tensor(data[sel] if sel != "tokmat" else tokmat).long()
        for sel in selection
    ]
    tselected = [sel[:devstart] for sel in selected]
    vselected = [sel[devstart:teststart] for sel in selected]
    xselected = [sel[teststart:] for sel in selected]

    if datafrac <= 1.:
        # restrict data such that least relations are unseen
        # get relation counts
        trainrels = data["rels"][:devstart]
        uniquerels, relcounts = np.unique(data["rels"][:devstart],
                                          return_counts=True)
        relcountsD = dict(zip(uniquerels, relcounts))
        relcounter = dict(zip(uniquerels, [0] * len(uniquerels)))
        totalcap = int(datafrac * len(trainrels))
        capperrel = max(relcountsD.values())

        def numberexamplesincluded(capperrel_):
            numberexamplesforcap = np.clip(relcounts, 0, capperrel_).sum()
            return numberexamplesforcap

        while capperrel > 0:  # TODO do binary search
            numexcapped = numberexamplesincluded(capperrel)
            if numexcapped <= totalcap:
                break
            capperrel -= 1

        print("rel count cap is {}".format(capperrel))

        remainids = []
        for i in range(len(trainrels)):
            if len(remainids) >= totalcap:
                break
            if relcounter[trainrels[i]] > capperrel:
                pass
            else:
                relcounter[trainrels[i]] += 1
                remainids.append(i)
        print("{}/{} examples retained".format(len(remainids), len(trainrels)))
        tselected_new = [sel[remainids] for sel in tselected]
        if datafrac == 1.:
            for a, b in zip(tselected_new, tselected):
                assert (np.all(a == b))
        tselected = tselected_new

    traindata = TensorDataset(*tselected)
    devdata = TensorDataset(*vselected)
    testdata = TensorDataset(*xselected)

    ret = (traindata, devdata, testdata)
    if retrelD:
        ret += (relD, )
    if rettokD:
        ret += (tokD, )
    if retrelcounts:
        ret += data["relcounts"]
    tt.tock("made datasets")
    return ret
Ejemplo n.º 12
0
def run_span_io(
        lr=DEFAULT_LR,
        dropout=.5,
        wreg=DEFAULT_WREG,
        batsize=DEFAULT_BATSIZE,
        epochs=DEFAULT_EPOCHS,
        cuda=False,
        gpu=0,
        balanced=False,
        warmup=-1.,
        sched="ang",  # "lin", "cos"
):
    settings = locals().copy()
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span io with BERT")
    tt.tick("loading data")
    data = load_data(which="span/io")
    trainds, devds, testds = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    # compute balancing hyperparam for BCELoss
    trainios = trainds.tensors[1]
    numberpos = (trainios == 2).float().sum()
    numberneg = (trainios == 1).float().sum()
    if balanced:
        pos_weight = (numberneg / numberpos)
    else:
        pos_weight = None
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    spandet = IOSpanDetector(bert, dropout=dropout)
    spandet.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    optim = BertAdam(spandet.parameters(),
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched])
    losses = [
        AutomaskedBCELoss(pos_weight=pos_weight),
        AutomaskedBinarySeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in losses]
    testlosses = [q.LossWrapper(l) for l in losses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")
Ejemplo n.º 13
0
def run_both(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    evalbatsize=-1,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskmention=False,
    warmup=-1.,
    sched="ang",
    cycles=-1.,
    savep="exp_bert_both_",
    test=False,
    freezeemb=False,
    large=False,
    datafrac=1.,
    savemodel=False,
):
    settings = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    if evalbatsize < 0:
        evalbatsize = batsize
    tt.msg("running borders and rel classifier with BERT")
    if test:
        epochs = 0
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    if cycles == -1:
        if sched == "cos":
            cycles = 0.5
        elif sched in ["cosrestart", "coshardrestart"]:
            cycles = 1.0

    # region data
    tt.tick("loading data")
    data = load_data(which="forboth", retrelD=True, datafrac=datafrac)
    trainds, devds, testds, relD = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)
    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    print("number of relations: {}".format(len(relD)))
    # endregion

    # region model
    tt.tick("loading BERT")
    whichbert = "bert-base-uncased"
    if large:
        whichbert = "bert-large-uncased"
    bert = BertModel.from_pretrained(whichbert)
    m = BordersAndRelationClassifier(bert,
                                     relD,
                                     dropout=dropout,
                                     mask_entity_mention=maskmention)
    m.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    assert (initwreg == 0.)
    initl2penalty = InitL2Penalty(bert, factor=q.hyperparam(initwreg))

    params = []
    for paramname, param in m.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    optim = BertAdam(params, lr=lr, weight_decay=wreg, schedule=sched)
    tmodel = BordersAndRelationLosses(m, cesmoothing=smoothing)
    # xmodel = BordersAndRelationLosses(m, cesmoothing=smoothing)
    # losses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    # xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    tlosses = [q.SelectedLinearLoss(i) for i in range(7)]
    xlosses = [q.SelectedLinearLoss(i) for i in range(7)]
    trainlosses = [q.LossWrapper(l) for l in tlosses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=tmodel,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=tmodel,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=tmodel,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    m.clip_len = True
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    settings["testres"] = testres
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        print(savedir)
        # save model
        if savemodel:
            torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        m.clip_len = False
        # TEST data
        testpreds = q.eval_loop(m, evalloader, device=device)
        borderpreds = testpreds[0].cpu().detach().numpy()
        relpreds = testpreds[1].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), borderpreds)
        np.save(os.path.join(savedir, "relpreds.test.npy"), relpreds)
        # DEV data
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        borderpreds = testpreds[0].cpu().detach().numpy()
        relpreds = testpreds[1].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), borderpreds)
        np.save(os.path.join(savedir, "relpreds.dev.npy"), relpreds)
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
Ejemplo n.º 14
0
def run_relations(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskentity=False,
    warmup=-1.,
    sched="ang",
    savep="exp_bert_rels_",
    test=False,
    freezeemb=False,
):
    settings = locals().copy()
    if test:
        epochs = 0
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running relation classifier with BERT")
    tt.tick("loading data")
    data = load_data(which="rel+borders", retrelD=True)
    trainds, devds, testds, relD = data
    if maskentity:
        trainds, devds, testds = replace_entity_span(trainds, devds, testds)
    else:
        trainds, devds, testds = [
            TensorDataset(ds.tensors[0], ds.tensors[2])
            for ds in [trainds, devds, testds]
        ]
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=batsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader_dev = DataLoader(evalds_dev, batch_size=batsize, shuffle=False)
    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    m = RelationClassifier(bert, relD, dropout=dropout)
    m.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs

    params = []
    for paramname, param in m.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    optim = BertAdam(params,
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched],
                     init_weight_decay=initwreg)
    losses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=m,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=m,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=m,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(m, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.test.npy"), testpreds)
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.dev.npy"), testpreds)
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
Ejemplo n.º 15
0
 def __init__(
     self,
     name="",
 ):
     self.tt = q.ticktock(name)
Ejemplo n.º 16
0
def run(
    lr=0.001,
    batsize=20,
    epochs=100,
    embdim=64,
    encdim=128,
    numlayers=1,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    beamsize=1,
    cosine_restarts=1.,
    seed=456789,
):
    # DONE: Porter stemmer
    # DONE: linear attention
    # DONE: grad norm
    # DONE: beam search
    # DONE: lr scheduler
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDatasetFunQL(
        sentence_encoder=SequenceEncoder(tokenizer=tokenizer),
        min_freq=minfreq)

    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(embdim=embdim,
                         hdim=encdim,
                         dropout=dropout,
                         numlayers=numlayers,
                         sentence_encoder=ds.sentence_encoder,
                         query_encoder=ds.query_encoder,
                         feedatt=True)

    # model.apply(initializer)

    tfdecoder = SeqDecoder(
        model,
        tf_ratio=1.,
        eval=[
            CELoss(ignore_index=0, mode="logprobs"),
            SeqAccuracies(),
            TreeAccuracy(
                tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
        ])

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(
            model,
            maxtime=100,
            tf_ratio=0.,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    else:
        print("Doing beam search!")
        freedecoder = BeamDecoder(
            model,
            beamsize=beamsize,
            maxtime=60,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    # 4. define optim
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")
Ejemplo n.º 17
0
def run(
        lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        distill="glove",  # "rnnlm", "glove"
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False,
        repretrain=False,  # retrain base model instead of loading it
        savepath="rnnlm.base.pt",  # where to save after training
        glovepath="../../../data/glove/glove.300d"):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    # region base training
    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    if os.path.exists(savepath) and repretrain is False:
        tt.tick("reloading base model")
        with open(savepath, "rb") as f:
            m = torch.load(f)
            m.to(device)
        tt.tock("reloaded base model")
    else:
        tt.tick("preparing training base")
        dims = [embdim] + ([encdim] * numlayers)

        m = RNNLayer_LM(*dims,
                        worddic=D,
                        dropout=dropout,
                        tieweights=tieweights).to(device)

        if test:
            for i, batch in enumerate(train_batches):
                y = m(batch[0])
                if i > 5:
                    break
            print(y.size())

        optim = torch.optim.SGD(m.parameters(), lr=lr)

        train_batch_f = partial(q.train_batch,
                                on_before_optim_step=[
                                    lambda: torch.nn.utils.clip_grad_norm_(
                                        m.parameters(), gradnorm)
                                ])
        lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                         mode="min",
                                                         factor=1 / 4,
                                                         patience=0,
                                                         verbose=True)
        lrp_f = lambda: lrp.step(validloss.get_epoch_error())

        train_epoch_f = partial(q.train_epoch,
                                model=m,
                                dataloader=train_batches,
                                optim=optim,
                                losses=[loss],
                                device=device,
                                _train_batch=train_batch_f)
        valid_epoch_f = partial(q.test_epoch,
                                model=m,
                                dataloader=valid_batches,
                                losses=validlosses,
                                device=device,
                                on_end=[lrp_f])

        tt.tock("prepared training base")
        tt.tick("training base model")
        q.run_training(train_epoch_f,
                       valid_epoch_f,
                       max_epochs=epochs,
                       validinter=1)
        tt.tock("trained base model")

        with open(savepath, "wb") as f:
            torch.save(m, f)

    tt.tick("testing base model")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested base model")
    # endregion

    # region distillation
    tt.tick("preparing training student")
    dims = [embdim] + ([encdim] * numlayers)
    ms = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                     tieweights=tieweights).to(device)

    loss = q.LossWrapper(q.DistillLoss(temperature=2.))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(ms.parameters(), lr=lr)

    train_batch_f = partial(
        train_batch_distill,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(ms.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    if distill == "rnnlm":
        mbase = m
        goldgetter = None
    elif distill == "glove":
        mbase = None
        tt.tick("creating gold getter based on glove")
        goldgetter = GloveGoldGetter(glovepath, worddic=D)
        goldgetter.to(device)
        tt.tock("created gold getter")
    else:
        raise q.SumTingWongException("unknown distill mode {}".format(distill))

    train_epoch_f = partial(train_epoch_distill,
                            model=ms,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f,
                            mbase=mbase,
                            goldgetter=goldgetter)
    valid_epoch_f = partial(q.test_epoch,
                            model=ms,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("prepared training student")
    tt.tick("training student model")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained student model")

    tt.tick("testing student model")
    testresults = q.test_epoch(model=ms,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested student model")
Ejemplo n.º 18
0
def run_normal(lr=0.001,
               gradclip=5.,
               batsize=20,
               epochs=150,
               embdim=100,
               encdim=200,
               numlayer=1,
               cuda=False,
               gpu=0,
               wreg=1e-8,
               dropout=0.5,
               smoothing=0.,
               goldsmoothing=-0.1,
               selfptr=False,
               which="geo"):
    tt = q.ticktock("script")
    tt.msg("running normal att")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)

    # region data
    tt.tick("generating data")
    # dss, D = gen_sort_data(seqlen=seqlen, numvoc=numvoc, numex=numex, prepend_inp=False)
    dss, nlD, flD, rare_nl, rare_fl = gen_datasets(which=which)
    tloader, vloader, xloader = [torch.utils.data.DataLoader(ds, batch_size=batsize, shuffle=True) for ds in dss]
    seqlen = len(dss[0][0][1])
    # merge nlD into flD and make mapper
    nextflDid = max(flD.values()) + 1
    sourcemap = torch.zeros(len(nlD), dtype=torch.long, device=device)
    for k, v in nlD.items():
        if k not in flD:
            flD[k] = nextflDid
            nextflDid += 1
        sourcemap[v] = flD[k]
    tt.tock("data generated")
    # endregion

    # region model
    tt.tick("building model")
    # source side
    inpemb = q.UnkReplWordEmb(embdim, worddic=nlD, unk_tokens=rare_nl)
    encdims = [encdim] * numlayer
    encoder = q.LSTMEncoder(embdim, *encdims, bidir=True, dropout_in_shared=dropout)

    # target side
    decemb = q.UnkReplWordEmb(embdim, worddic=flD, unk_tokens=rare_fl)
    decinpdim = embdim
    decdims = [decinpdim] + [encdim] * numlayer
    dec_core = torch.nn.Sequential(
        *[q.rnn.LSTMCell(decdims[i-1], decdims[i], dropout_in=dropout) for i in range(1, len(decdims))]
    )
    att = attention.FwdAttention(decdims[-1], encdim * 2, decdims[-1])
    out = torch.nn.Sequential(
        q.UnkReplWordLinout(decdims[-1]+encdim*2, worddic=flD, unk_tokens=rare_fl),
        # torch.nn.Softmax(-1)
    )
    if selfptr:
        outgate = PointerGeneratorOutGate(decdims[-1] + encdim * 2, encdim, 3)
        out = SelfPointerGeneratorOut(out, sourcemap=sourcemap, gate=outgate)
        selfatt = attention.FwdAttention(decdims[-1], decdims[-1], decdims[-1])
        deccell = SelfPointerGeneratorCell(emb=decemb, core=dec_core, att=att, selfatt=selfatt, out=out)
    else:
        outgate = PointerGeneratorOutGate(decdims[-1] + encdim * 2, encdim, 0)
        out = PointerGeneratorOut(out, sourcemap=sourcemap, gate=outgate)
        deccell = PointerGeneratorCell(emb=decemb, core=dec_core, att=att, out=out)
    train_dec = q.TFDecoder(deccell)
    test_dec = q.FreeDecoder(deccell, maxtime=seqlen+10)
    train_encdec = EncDec(inpemb, encoder, train_dec)
    test_encdec = Test_EncDec(inpemb, encoder, test_dec)

    train_encdec.to(device)
    test_encdec.to(device)
    tt.tock("built model")
    # endregion

    # region training
    # losses:
    if smoothing == 0:
        ce = q.loss.CELoss(mode="probs", ignore_index=0)
    elif goldsmoothing < 0.:
        ce = q.loss.SmoothedCELoss(mode="probs", ignore_index=0, smoothing=smoothing)
    else:
        ce = q.loss.DiffSmoothedCELoss(mode="probs", ignore_index=0, alpha=goldsmoothing, beta=smoothing)
    acc = q.loss.SeqAccuracy(ignore_index=0)
    elemacc = q.loss.SeqElemAccuracy(ignore_index=0)
    trainmodel = TrainModel(train_encdec, [ce, elemacc, acc])
    treeacc = TreeAccuracyPrologPar(flD=flD)
    # optim
    optim = torch.optim.Adam(train_encdec.parameters(), lr=lr, weight_decay=wreg)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_value_(train_encdec.parameters(), clip_value=gradclip)
    # lööps
    batchloop = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainloop = partial(q.train_epoch, model=train_encdec, dataloader=tloader, optim=optim, device=device,
                        losses=[q.LossWrapper(ce), q.LossWrapper(elemacc), q.LossWrapper(acc)],
                        print_every_batch=False, _train_batch=batchloop)
    validloop = partial(q.test_epoch, model=test_encdec, dataloader=vloader, device=device,
                        losses=[q.LossWrapper(treeacc)],
                        print_every_batch=False)

    tt.tick("training")
    q.run_training(trainloop, validloop, max_epochs=epochs)
    tt.tock("trained")

    tt.tick("testing")
    test_results = validloop(model=test_encdec, dataloader=xloader)
    print("Test results (freerunning): {}".format(test_results))
    test_results = validloop(model=train_encdec, dataloader=xloader)
    print("Test results (TF): {}".format(test_results))
    tt.tock("tested")
    # endregion
    tt.msg("done")
Ejemplo n.º 19
0
def run(
        sourcelang="en",
        supportlang="en",
        testlang="en",
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        cosinelr=False,
        warmup=0.,
        batsize=20,
        epochs=100,
        dropout=0.1,
        dropoutdec=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="xlm-roberta-base",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        trainonvalid=False,
        statesimweight=0.,
        probsimweight=0.,
        projmode="simple",  # "simple" or "twolayer"
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    # wandb.init(project=f"overnight_pretrain_bert-{domain}",
    #            reinit=True, config=settings)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")

    nltok_name = encoder
    tds, vds, xds, nltok, flenc = load_multilingual_geoquery(
        sourcelang,
        supportlang,
        testlang,
        nltok_name=nltok_name,
        trainonvalid=trainonvalid)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(vds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(xds)/(len(tds) + len(vds) + len(xds)):.2f} ({len(tds)}/{len(vds)}/{len(xds)}) train/valid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(
        encoder_name=encoder,
        dec_vocabsize=flenc.vocab.number_of_ids(),
        dec_layers=numlayers,
        dec_dim=hdim,
        dec_heads=numheads,
        dropout=dropout,
        dropoutdec=dropoutdec,
        smoothing=smoothing,
        maxlen=maxlen,
        numbeam=numbeam,
        tensor2tree=partial(_tensor2tree, D=flenc.vocab),
        statesimweight=statesimweight,
        probsimweight=probsimweight,
        projmode=projmode,
    )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params
        if k.startswith("model.model.encoder.model")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder.model")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[-1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda:
                         (deepcopy(trainm.model), deepcopy(trainm.model2)))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["_train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["_valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()
                                 ])  #, on_end=[lambda: wandb_logger()])

    # validepoch()        # TODO comment out after debugging
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.remembered is not None:
        trainm.model = eyt.remembered[0]
        trainm.model2 = eyt.remembered[1]
        testm.model = eyt.remembered[0]
        testm.model2 = eyt.remembered[1]
    tt.msg("reloaded best")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model2
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
Ejemplo n.º 20
0
def run(domain="restaurants",
        lr=0.001,
        ptlr=0.0001,
        enclrmul=0.1,
        cosinelr=False,
        ptcosinelr=False,
        warmup=0.,
        ptwarmup=0.,
        batsize=20,
        ptbatsize=50,
        epochs=100,
        ptepochs=100,
        dropout=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        dataseed=12345678,
        datatemp=0.33,
        ptN=3000,
        tokenmaskp=0.,
        spanmaskp=0.,
        spanmasklamda=2.2,
        treemaskp=0.,
        encoder="bart-large",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        ):
    settings = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, vds, xds, nltok, flenc = load_ds(domain=domain, nl_mode=encoder)
    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    tt.tock("data loaded")

    tt.tick("creating grammar dataset generator")
    pcfg = build_grammar(tds, vds)
    ptds = PCFGDataset(pcfg, N=ptN, seed=seed, temperature=datatemp, maxlen=100)
    tt.tock("created dataset generator")

    tt.tick("creating model")
    trainm, testm, pretrainm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab)
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        print("generated dataset")
        print(ptds[0])
        print(ptds[0])
        allexamples = []
        for i in tqdm(range(len(ptds))):
            allexamples.append(ptds[i])
        uniqueexamples = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples)/len(allexamples)}% unique examples ({len(uniqueexamples)}/{len(allexamples)})")
        ptds.advance_seed()
        print(ptds[0])
        allexamples = list(ptds.examples)
        uniqueexamples2 = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples2)/len(allexamples)}% unique examples ({len(uniqueexamples2)}/{len(allexamples)})")
        print(f"{len(uniqueexamples & uniqueexamples2)}/{len(uniqueexamples | uniqueexamples2)} overlap")
        print("---")
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretraining
    # setup data perturbation
    tokenmasker = TokenMasker(p=tokenmaskp, seed=dataseed) if tokenmaskp > 0 else lambda x: x
    spanmasker = SpanMasker(p=spanmaskp, lamda=spanmasklamda, seed=dataseed) if spanmaskp > 0 else lambda x: x
    treemasker = SubtreeMasker(p=treemaskp, seed=dataseed) if treemaskp > 0 else lambda x: x

    perturbed_ptds = ptds\
        .map(lambda x: (treemasker(x), x))\
        .map(lambda x: (flenc.convert(x[0], "tokens"),
                        flenc.convert(x[1], "tokens")))\
        .map(lambda x: (spanmasker(tokenmasker(x[0])), x[1]))
    perturbed_ptds_tokens = perturbed_ptds
    perturbed_ptds = perturbed_ptds\
        .map(lambda x: (flenc.convert(x[0], "tensor"),
                        flenc.convert(x[1], "tensor")))

    if localtest:
        allex = []
        allperturbedex = []
        _nepo = 50
        print(f"checking {_nepo}, each {ptN} generated examples")
        for _e in tqdm(range(_nepo)):
            for i in range(len(perturbed_ptds_tokens)):
                ex = str(ptds[i])
                perturbed_ex = perturbed_ptds_tokens[i]
                perturbed_ex = f"{' '.join(perturbed_ex[0])}->{' '.join(perturbed_ex[1])}"
                allex.append(ex)
                allperturbedex.append(perturbed_ex)
            ptds.advance_seed()
        uniqueex = set(allex)
        uniqueperturbedex = set(allperturbedex)
        print(f"{len(uniqueex)}/{len(allex)} unique examples")
        print(f"{len(uniqueperturbedex)}/{len(allperturbedex)} unique perturbed examples")

    ptdl = DataLoader(perturbed_ptds, batch_size=ptbatsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    ptmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    ptparams = pretrainm.parameters()
    ptoptim = torch.optim.Adam(ptparams, lr=ptlr, weight_decay=wreg)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)
    t_max = ptepochs
    print(f"Total number of pretraining updates: {t_max} .")
    if ptcosinelr:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> q.sched.Cosine(steps=t_max-ptwarmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ptoptim, lr_schedule)

    pttrainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    pttrainepoch = partial(q.train_epoch, model=pretrainm, dataloader=ptdl, optim=ptoptim, losses=ptmetrics,
                         _train_batch=pttrainbatch, device=device, on_end=[lambda: lr_schedule.step(),
                                                                           lambda: ptds.advance_seed()])

    tt.tick("pretraining")
    q.run_training(run_train_epoch=pttrainepoch, max_epochs=ptepochs)
    tt.tock("done pretraining")

    # endregion

    # region finetuning
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=10, more_is_better=True, remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step(), lambda: eyt.on_epoch_end()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device)

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    tt.tick("testing")
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=xmetrics, device=device)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
Ejemplo n.º 21
0
def run(trainp="overnight/calendar_train_delex.tsv",
        testp="overnight/calendar_test_delex.tsv",
        batsize=8,
        embdim=50,
        encdim=50,
        maxtime=100,
        lr=.001,
        gpu=0,
        cuda=False,
        epochs=20):
    device = torch.device("cuda", gpu) if cuda else torch.device("cpu")
    tt = q.ticktock("script")
    tt.tick("loading data")

    def tokenizer(x: str, splitter: WordSplitter = None) -> List[str]:
        return [xe.text for xe in splitter.split_words(x)]

    reader = OvernightReader(
        partial(tokenizer, splitter=JustSpacesWordSplitter()),
        partial(tokenizer, splitter=JustSpacesWordSplitter()),
        SingleIdTokenIndexer(namespace="nl_tokens"),
        SingleIdTokenIndexer(namespace="fl_tokens"))
    trainds = reader.read(trainp)
    testds = reader.read(testp)
    tt.tock("data loaded")

    tt.tick("building vocabulary")
    vocab = Vocabulary.from_instances(trainds)
    tt.tock("vocabulary built")

    tt.tick("making iterator")
    iterator = BucketIterator(sorting_keys=[("nl", "num_tokens"),
                                            ("fl", "num_tokens")],
                              batch_size=batsize,
                              biggest_batch_first=True)
    iterator.index_with(vocab)
    batch = next(iter(iterator(trainds)))
    #print(batch["id"])
    #print(batch["nl"])
    tt.tock("made iterator")

    # region model
    nl_emb = Embedding(vocab.get_vocab_size(namespace="nl_tokens"),
                       embdim,
                       padding_index=0)
    fl_emb = Embedding(vocab.get_vocab_size(namespace="fl_tokens"),
                       embdim,
                       padding_index=0)
    nl_field_emb = BasicTextFieldEmbedder({"tokens": nl_emb})
    fl_field_emb = BasicTextFieldEmbedder({"tokens": fl_emb})

    encoder = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(embdim, encdim, bidirectional=True, batch_first=True))
    attention = DotProductAttention()

    smodel = Seq2Seq(vocab,
                     nl_field_emb,
                     encoder,
                     maxtime,
                     target_embedding_dim=embdim,
                     attention=attention,
                     target_namespace='fl_tokens',
                     beam_size=1,
                     use_bleu=True)

    smodel_out = smodel(batch["nl"], batch["fl"])

    smodel.to(device)

    optim = torch.optim.Adam(smodel.parameters(), lr=lr)
    trainer = Trainer(model=smodel,
                      optimizer=optim,
                      iterator=iterator,
                      train_dataset=trainds,
                      validation_dataset=testds,
                      num_epochs=epochs,
                      cuda_device=gpu if cuda else -1)

    metrics = trainer.train()

    sys.exit()

    class MModel(Model):
        def __init__(self, nlemb: Embedding, flemb: Embedding,
                     vocab: Vocabulary, **kwargs):
            super(MModel, self).__init__(vocab, **kwargs)
            self.nlemb, self.flemb = nlemb, flemb

        @overrides
        def forward(self, nl: Dict[str, torch.Tensor],
                    fl: Dict[str, torch.Tensor], id: Any):
            nlemb = self.nlemb(nl["tokens"])
            flemb = self.flemb(fl["tokens"])
            print(nlemb.size())
            pass

    m = MModel(nl_emb, fl_emb, vocab)
    batch = next(iter(iterator(trainds)))
    out = m(**batch)
Ejemplo n.º 22
0
def adv_train_epoch(model=None,
                    dataloader=None,
                    optim=None,
                    losses=None,
                    advmodel=None,
                    advdataloader=None,
                    advoptim=None,
                    advlosses=None,
                    device=torch.device("cpu"),
                    tt=q.ticktock(" -"),
                    current_epoch=0,
                    max_epochs=0,
                    _train_batch=q.train_batch,
                    _adv_train_batch=q.train_batch,
                    on_start=tuple(),
                    on_end=tuple(),
                    print_every_batch=False,
                    advsteps=1):
    """
    Performs an epoch of adversarial training on given model, with data from given dataloader, using given optimizer,
    with loss computed based on given losses.
    :param model:
    :param dataloader:
    :param optim:
    :param losses:  list of loss wrappers
    :param device:  device to put batches on
    :param tt:
    :param current_epoch:
    :param max_epochs:
    :param _train_batch:    train batch function, default is train_batch
    :param on_start:
    :param on_end:
    :return:
    """
    for loss in losses + advlosses:
        loss.push_epoch_to_history(epoch=current_epoch - 1)
        loss.reset_agg()
        loss.loss.to(device)

    model.to(device)
    advmodel.to(device)

    [e() for e in on_start]

    q.epoch_reset(model)
    q.epoch_reset(advmodel)

    for i, _batch in enumerate(dataloader):
        adviter = iter(advdataloader)
        for j in range(advsteps):
            try:
                _advbatch = next(adviter)
            except StopIteration as e:
                adviter = iter(advdataloader)
                _advbatch = next(adviter)
            ttmsg = _adv_train_batch(batch=_advbatch,
                                     model=advmodel,
                                     optim=advoptim,
                                     losses=advlosses,
                                     device=device,
                                     batch_number=j,
                                     max_batches=0,
                                     current_epoch=current_epoch,
                                     max_epochs=0)
            ttmsg = f"adv:  {ttmsg}"
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
        ttmsg = _train_batch(batch=_batch,
                             model=model,
                             optim=optim,
                             losses=losses,
                             device=device,
                             batch_number=i,
                             max_batches=len(dataloader),
                             current_epoch=current_epoch,
                             max_epochs=max_epochs)
        ttmsg = f"main: {ttmsg}"
        if print_every_batch:
            tt.msg(ttmsg)
        else:
            tt.live(ttmsg)

    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    advttmsg = q.pp_epoch_losses(*advlosses)
    ttmsg = f"\n main: {ttmsg}\n adv:  {advttmsg}"
    return ttmsg
Ejemplo n.º 23
0
def run(
    lr=0.001,
    enclrmul=0.1,
    hdim=768,
    numlayers=8,
    numheads=12,
    dropout=0.1,
    wreg=0.,
    batsize=10,
    epochs=100,
    warmup=0,
    sustain=0,
    cosinelr=False,
    gradacc=1,
    gradnorm=100,
    patience=5,
    validinter=3,
    seed=87646464,
    gpu=-1,
    datamode="single",
    decodemode="single",  # "full", "ltr" (left to right), "single", "entropy-single"
    trainonvalid=False,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds, vds, xds, tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        "restaurants", mode=datamode, trainonvalid=trainonvalid)
    tt.tock("loaded")

    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=collate_fn)
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=collate_fn)
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=collate_fn)

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim, flenc.vocab, numlayers, numheads, dropout)
    tagmodel = TreeInsertionTaggerModel(tagger)
    decodermodel = TreeInsertionDecoder(tagger,
                                        seqenc=flenc,
                                        maxsteps=50,
                                        max_tree_size=30,
                                        mode=decodemode)
    decodermodel = TreeInsertionDecoderTrainModel(decodermodel)

    # batch = next(iter(tdl))
    # out = tagmodel(*batch)

    tmetrics = make_array_of_metrics("loss",
                                     "elemrecall",
                                     "allrecall",
                                     "entropyrecall",
                                     reduction="mean")
    vmetrics = make_array_of_metrics("loss",
                                     "elemrecall",
                                     "allrecall",
                                     "entropyrecall",
                                     reduction="mean")
    tseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    vseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vseqmetrics[-1],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul, wreg)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=tagmodel,
                         dataloader=tdl,
                         optim=optim,
                         losses=tmetrics,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainseqepoch = partial(q.test_epoch,
                            model=decodermodel,
                            losses=tseqmetrics,
                            dataloader=tdl_seq,
                            device=device)

    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vseqmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    # validepoch()        # TODO: remove this after debugging

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=[trainseqepoch, validepoch],
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.msg("reloading best")
    if eyt.remembered is not None:
        decodermodel.model.tagger = eyt.remembered
        tagmodel.tagger = eyt.remembered

    tt.tick("running test")
    testepoch = partial(q.test_epoch,
                        model=decodermodel,
                        losses=xmetrics,
                        dataloader=xdl_seq,
                        device=device)
    print(testepoch())
    tt.tock()

    # inspect predictions
    validepoch = partial(q.test_epoch,
                         model=tagmodel,
                         losses=vmetrics,
                         dataloader=vdl,
                         device=device)
    print(validepoch())
    inps, outs = q.eval_loop(tagmodel, vdl, device=device)

    # print(outs)

    doexit = False
    for i in range(len(inps[0])):
        for j in range(len(inps[0][i])):
            ui = input("next? (ENTER for next/anything else to exit)>>>")
            if ui != "":
                doexit = True
                break
            question = " ".join(nltok.convert_ids_to_tokens(inps[0][i][j]))
            out_toks = flenc.vocab.tostr(
                inps[1][i][j].detach().cpu().numpy()).split(" ")

            iscorrect = True

            lines = []
            for k, out_tok in enumerate(out_toks):
                gold_toks_for_k = inps[3][i][j][k].detach().cpu().nonzero()[:,
                                                                            0]
                if len(gold_toks_for_k) > 0:
                    gold_toks_for_k = flenc.vocab.tostr(gold_toks_for_k).split(
                        " ")
                else:
                    gold_toks_for_k = [""]

                isopen = inps[2][i][j][k]
                isopen = isopen.detach().cpu().item()

                pred_tok = outs[1][i][j][k].max(-1)[1].detach().cpu().item()
                pred_tok = flenc.vocab(pred_tok)

                pred_tok_correct = pred_tok in gold_toks_for_k or not isopen
                if not pred_tok_correct:
                    iscorrect = False

                entropy = torch.softmax(outs[1][i][j][k], -1).clamp_min(1e-6)
                entropy = -(entropy * torch.log(entropy)).sum().item()
                lines.append(
                    f"{out_tok:25} [{isopen:1}] >> {f'{pred_tok} ({entropy:.3f})':35} {'!!' if not pred_tok_correct else '  '} [{','.join(gold_toks_for_k) if isopen else ''}]"
                )

            print(f"{question} {'!!WRONG!!' if not iscorrect else ''}")
            for line in lines:
                print(line)

        if doexit:
            break
Ejemplo n.º 24
0
def load_ds(dataset="scan/random",
            validfrac=0.1,
            recompute=False,
            bertname="bert-base-uncased"):
    tt = q.ticktock("data")
    tt.tick(f"loading '{dataset}'")
    if bertname.startswith("none"):
        bertname = "bert" + bertname[4:]
    if dataset.startswith("cfq/") or dataset.startswith("scan/mcd"):
        key = f"{dataset}|bertname={bertname}"
        print(f"validfrac is ineffective with dataset '{dataset}'")
    else:
        key = f"{dataset}|validfrac={validfrac}|bertname={bertname}"

    shelfname = os.path.basename(__file__) + ".cache.shelve"
    if not recompute:
        tt.tick(f"loading from shelf (key '{key}')")
        with shelve.open(shelfname) as shelf:
            if key not in shelf:
                recompute = True
                tt.tock("couldn't load from shelf")
            else:
                shelved = shelf[key]
                trainex, validex, testex, fldic = shelved["trainex"], shelved[
                    "validex"], shelved["testex"], shelved["fldic"]
                inpdic = shelved["inpdic"] if "inpdic" in shelved else None
                trainds, validds, testds = Dataset(trainex), Dataset(
                    validex), Dataset(testex)
                tt.tock("loaded from shelf")

    if recompute:
        tt.tick("loading data")
        splits = dataset.split("/")
        dataset, splits = splits[0], splits[1:]
        split = "/".join(splits)
        if dataset == "scan":
            ds = SCANDatasetLoader().load(split, validfrac=validfrac)
        elif dataset == "cfq":
            ds = CFQDatasetLoader().load(split + "/modent")
        else:
            raise Exception(f"Unknown dataset: '{dataset}'")
        tt.tock("loaded data")

        tt.tick("creating tokenizer")
        tokenizer = Tokenizer(bertname=bertname)
        tt.tock("created tokenizer")

        print(len(ds))

        tt.tick("dictionaries")
        inpdic = Vocab()
        inplens, outlens = [0], []
        fldic = Vocab()
        for x in ds:
            outtoks = tokenizer.get_out_toks(x[1])
            outlens.append(len(outtoks))
            for tok in outtoks:
                fldic.add_token(tok, seen=x[2] == "train")
            inptoks = tokenizer.get_toks(x[0])
            for tok in inptoks:
                inpdic.add_token(tok, seen=x[2] == "train")
        inpdic.finalize(min_freq=0, top_k=np.infty)
        fldic.finalize(min_freq=0, top_k=np.infty)
        print(
            f"input avg/max length is {np.mean(inplens):.1f}/{max(inplens)}, output avg/max length is {np.mean(outlens):.1f}/{max(outlens)}"
        )
        print(
            f"output vocabulary size: {len(fldic.D)} at output, {len(inpdic.D)} at input"
        )
        tt.tock()

        tt.tick("tensorizing")
        tokenizer.inpvocab = inpdic
        tokenizer.outvocab = fldic
        trainds = ds.filter(lambda x: x[-1] == "train").map(
            lambda x: x[:-1]).map(
                lambda x: tokenizer.tokenize(x[0], x[1])).cache(True)
        validds = ds.filter(lambda x: x[-1] == "valid").map(
            lambda x: x[:-1]).map(
                lambda x: tokenizer.tokenize(x[0], x[1])).cache(True)
        testds = ds.filter(lambda x: x[-1] == "test").map(
            lambda x: x[:-1]).map(
                lambda x: tokenizer.tokenize(x[0], x[1])).cache(True)
        # ds = ds.map(lambda x: tokenizer.tokenize(x[0], x[1]) + (x[2],)).cache(True)
        tt.tock("tensorized")

        tt.tick("shelving")
        with shelve.open(shelfname) as shelf:
            shelved = {
                "trainex": trainds.examples,
                "validex": validds.examples,
                "testex": testds.examples,
                "fldic": fldic,
                "inpdic": inpdic,
            }
            shelf[key] = shelved
        tt.tock("shelved")

    tt.tock(f"loaded '{dataset}'")
    tt.msg(
        f"#train={len(trainds)}, #valid={len(validds)}, #test={len(testds)}")
    return trainds, validds, testds, fldic, inpdic
Ejemplo n.º 25
0
def run(
    lr=0.001,
    batsize=20,
    epochs=1,
    embdim=301,
    encdim=200,
    numlayers=1,
    beamsize=5,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    cosine_restarts=1.,
    domain="restaurants",
):
    localargs = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = OvernightDataset(
        domain=domain,
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    sentence_rare_tokens = set(
        [ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0, mode="logprobs"),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"op:and", "SW:concat"})
                           ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(
        model,
        maxtime=50,
        beamsize=beamsize,
        tf_ratio=0.,
        eval_beam=[
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"op:and", "SW:concat"})
        ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3",
                                    "tree_acc_at_last")

    trainable_params = tfdecoder.named_parameters()
    exclude_params = set()
    exclude_params.add("model.model.inp_emb.emb.weight"
                       )  # don't train input embeddings if doing glove
    trainable_params = [
        v for k, v in trainable_params if k not in exclude_params
    ]

    # 4. define optim
    optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("valid", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
Ejemplo n.º 26
0
def run(
        lr=0.0001,
        enclrmul=0.01,
        smoothing=0.,
        gradnorm=3,
        batsize=60,
        epochs=16,
        patience=10,
        validinter=3,
        validfrac=0.1,
        warmup=3,
        cosinelr=False,
        dataset="scan/length",
        mode="normal",  # "normal", "noinp"
        maxsize=50,
        seed=42,
        hdim=768,
        numlayers=6,
        numheads=12,
        dropout=0.1,
        worddropout=0.,
        bertname="bert-base-uncased",
        testcode=False,
        userelpos=False,
        gpu=-1,
        evaltrain=False,
        trainonvalid=False,
        trainonvalidonly=False,
        recomputedata=False,
        smalltrainvalid=False,
        version="v2"):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_baseline", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainvalidds = None
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)
    if smalltrainvalid:
        realtrainds = []
        trainvalidds = []
        splits = [True for _ in range(int(round(len(trainds) * 0.1)))]
        splits = splits + [False for _ in range(len(trainds) - len(splits))]
        random.shuffle(splits)
        for i in range(len(trainds)):
            if splits[i] is True:
                trainvalidds.append(trainds[i])
            else:
                realtrainds.append(trainds[i])
        trainds = Dataset(realtrainds)
        trainvalidds = Dataset(trainvalidds)
        validds = trainvalidds
        tt.msg(
            "split off 10% of training data instead of original validation data"
        )
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    cell = TransformerDecoderCell(hdim,
                                  vocab=fldic,
                                  inpvocab=inpdic,
                                  numlayers=numlayers,
                                  numheads=numheads,
                                  dropout=dropout,
                                  worddropout=worddropout,
                                  mode=mode,
                                  bertname=bertname,
                                  userelpos=userelpos,
                                  useabspos=not userelpos)
    decoder = SeqDecoderBaseline(cell,
                                 vocab=fldic,
                                 max_size=maxsize,
                                 smoothing=smoothing,
                                 mode=mode)
    print(f"one layer of decoder: \n {cell.decoder.block[0]}")
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = decoder(*batch)
        tt.tock()
        decoder.train(False)
        tt.tick("test")
        out = decoder(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "acc", "elemacc", reduction="mean")
    tmetrics = make_array_of_metrics("treeacc",
                                     "bleu",
                                     "lcsf1",
                                     "nll",
                                     "acc",
                                     "elemacc",
                                     "decnll",
                                     "maxmaxnll",
                                     reduction="mean")
    vmetrics = make_array_of_metrics("treeacc",
                                     "bleu",
                                     "lcsf1",
                                     "nll",
                                     "acc",
                                     "elemacc",
                                     "decnll",
                                     "maxmaxnll",
                                     reduction="mean")
    xmetrics = make_array_of_metrics("treeacc",
                                     "bleu",
                                     "lcsf1",
                                     "nll",
                                     "acc",
                                     "elemacc",
                                     "decnll",
                                     "maxmaxnll",
                                     reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(cell))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        if evaltrain:
            for name, loss in zip([
                    "tree_acc", "nll", "acc", "elemacc", "bleu", "decnll",
                    "maxmaxnll"
            ], tmetrics):
                d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip([
                "tree_acc", "nll", "acc", "elemacc", "bleu", "decnll",
                "maxmaxnll"
        ], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        for name, loss in zip([
                "tree_acc", "nll", "acc", "elemacc", "bleu", "decnll",
                "maxmaxnll"
        ], xmetrics):
            d["test_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(cell, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        on_before_optim_step=[lambda: clipgradnorm(_m=cell, _norm=gradnorm)])

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    validfs = validfs + [testepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)
Ejemplo n.º 27
0
def run(
    domain="restaurants",
    mode="baseline",  # "baseline", "ltr", "uniform", "binary"
    probthreshold=0.,  # 0. --> parallel, >1. --> serial, 0.< . <= 1. --> semi-parallel
    lr=0.0001,
    enclrmul=0.1,
    batsize=50,
    epochs=1000,
    hdim=366,
    numlayers=6,
    numheads=6,
    dropout=0.1,
    noreorder=False,
    trainonvalid=False,
    seed=87646464,
    gpu=-1,
    patience=-1,
    gradacc=1,
    cosinelr=False,
    warmup=20,
    gradnorm=3,
    validinter=10,
    maxsteps=20,
    maxsize=75,
    testcode=False,
    numbered=False,
):

    settings = locals().copy()
    q.pp_dict(settings)
    wandb.init(project=f"seqinsert_overnight_v2", config=settings, reinit=True)

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds(
        domain,
        trainonvalid=trainonvalid,
        noreorder=noreorder,
        numbered=numbered)
    tt.tock("loaded")

    tdl_seq = DataLoader(tds_seq,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim,
                               flenc.vocab,
                               numlayers,
                               numheads,
                               dropout,
                               baseline=mode == "baseline")

    if mode == "baseline":
        decoder = SeqDecoderBaseline(tagger,
                                     flenc.vocab,
                                     max_steps=maxsteps,
                                     max_size=maxsize)
    elif mode == "ltr":
        decoder = SeqInsertionDecoderLTR(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize)
    elif mode == "uniform":
        decoder = SeqInsertionDecoderUniform(tagger,
                                             flenc.vocab,
                                             max_steps=maxsteps,
                                             max_size=maxsize,
                                             prob_threshold=probthreshold)
    elif mode == "binary":
        decoder = SeqInsertionDecoderBinary(tagger,
                                            flenc.vocab,
                                            max_steps=maxsteps,
                                            max_size=maxsize,
                                            prob_threshold=probthreshold)
    elif mode == "any":
        decoder = SeqInsertionDecoderAny(tagger,
                                         flenc.vocab,
                                         max_steps=maxsteps,
                                         max_size=maxsize,
                                         prob_threshold=probthreshold)

    # test run
    if testcode:
        batch = next(iter(tdl_seq))
        # out = tagger(batch[1])
        # out = decoder(*batch)
        decoder.train(False)
        out = decoder(*batch)

    tloss = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    vmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", "stepsused", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(tagger))

    def wandb_logger():
        d = {}
        for name, loss in zip(["CE"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["tree_acc", "stepsused"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(
        q.train_batch,
        gradient_accumulation_steps=gradacc,
        on_before_optim_step=[lambda: clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=tdl_seq,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=tdl_seq,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]

    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    q.run_training(
        run_train_epoch=trainepoch,
        # run_valid_epoch=[trainevalepoch, validepoch], #[validepoch],
        run_valid_epoch=[validepoch],
        max_epochs=epochs,
        check_stop=[lambda: eyt.check_stop()],
        validinter=validinter)
    tt.tock("done training")

    if eyt.remembered is not None and not trainonvalid:
        tt.msg("reloading best")
        decoder.tagger = eyt.remembered
        tagger = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        print(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=xdl_seq,
                        device=device)
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_CE": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()})
    settings.update({"final_train_steps_used": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_steps_used": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_steps_used": xmetrics[1].get_epoch_error()})

    # run different prob_thresholds:
    # thresholds = [0., 0.3, 0.5, 0.6, 0.75, 0.85, 0.9, 0.95,  1.]
    thresholds = [
        0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 0.95, 0.975, 0.99, 1.
    ]
    for threshold in thresholds:
        tt.tick("running test for threshold " + str(threshold))
        decoder.prob_threshold = threshold
        testres = testepoch()
        print(f"Test tree acc for threshold {threshold}: testres: {testres}")
        settings.update(
            {f"_thr{threshold}_acc": xmetrics[0].get_epoch_error()})
        settings.update(
            {f"_thr{threshold}_len": xmetrics[1].get_epoch_error()})
        tt.tock("done")

    wandb.config.update(settings)
    q.pp_dict(settings)
Ejemplo n.º 28
0
def test_epoch(model=None,
               dataloader=None,
               losses=None,
               device=torch.device("cpu"),
               current_epoch=0,
               max_epochs=0,
               print_every_batch=False,
               on_start=tuple(),
               on_start_batch=tuple(),
               on_end_batch=tuple(),
               on_end=tuple()):
    """
    Performs a test epoch. If run=True, runs, otherwise returns partially filled function.
    :param model:
    :param dataloader:
    :param losses:
    :param device:
    :param current_epoch:
    :param max_epochs:
    :param on_start:
    :param on_start_batch:
    :param on_end_batch:
    :param on_end:
    :return:
    """
    tt = q.ticktock("-")
    model.eval()
    q.epoch_reset(model)
    [e() for e in on_start]
    with torch.no_grad():
        for loss_obj in losses:
            loss_obj.push_epoch_to_history()
            loss_obj.reset_agg()
            loss_obj.loss.to(device)
        for i, _batch in enumerate(dataloader):
            [e() for e in on_start_batch]

            _batch = (_batch, ) if not q.issequence(_batch) else _batch
            _batch = q.recmap(
                _batch, lambda x: x.to(device)
                if isinstance(x, torch.Tensor) else x)
            batch = _batch
            numex = batch[0].size(0)

            if q.no_gold(losses):
                batch_in = batch
                gold = None
            else:
                batch_in = batch[:-1]
                gold = batch[-1]

            q.batch_reset(model)
            modelouts = model(*batch_in)

            testlosses = []
            for loss_obj in losses:
                loss_val = loss_obj(modelouts, gold, _numex=numex)
                loss_val = [loss_val
                            ] if not q.issequence(loss_val) else loss_val
                testlosses.extend(loss_val)

            ttmsg = "test - Epoch {}/{} - [{}/{}]: {}".format(
                current_epoch + 1, max_epochs, i + 1, len(dataloader),
                q.pp_epoch_losses(*losses))
            if print_every_batch:
                tt.msg(ttmsg)
            else:
                tt.live(ttmsg)
            [e() for e in on_end_batch]
    tt.stoplive()
    [e() for e in on_end]
    ttmsg = q.pp_epoch_losses(*losses)
    return ttmsg
Ejemplo n.º 29
0
def run(
    lr=0.001,
    batsize=20,
    epochs=50,
    embdim=100,
    encdim=200,
    numlayers=1,
    dropout=.2,
    wreg=1e-6,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    beamsize=1,
    smoothing=0.,
    fulltest=False,
    cosine_restarts=-1.,
    feedatt=False,
):
    # DONE: Porter stemmer
    # DONE: linear attention
    # DONE: grad norm
    # DONE: beam search
    # DONE: lr scheduler
    tt = q.ticktock("script")
    ttt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDataset(sentence_encoder=SentenceEncoder(tokenizer=tokenizer),
                         min_freq=minfreq)
    dls = get_dataloaders(ds, batsize=batsize)
    train_dl = dls["train"]
    test_dl = dls["test"]
    tt.tock("data loaded")

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    tfdecoder = create_model(embdim=embdim,
                             hdim=encdim,
                             dropout=dropout,
                             numlayers=numlayers,
                             sentence_encoder=ds.sentence_encoder,
                             query_encoder=ds.query_encoder,
                             smoothing=smoothing,
                             feedatt=feedatt)
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = GreedyActionSeqDecoder(tfdecoder.model, maxsteps=50)
    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = [
        q.LossWrapper(q.SelectedLinearLoss(x, reduction=None), name=x)
        for x in ["loss", "any_acc", "seq_acc"]
    ]
    vlosses = [
        q.LossWrapper(q.SelectedLinearLoss(x, reduction=None), name=x)
        for x in ["seq_acc", "tree_acc"]
    ]

    # 4. define optim
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        t_max = epochs * len(train_dl)
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")
Ejemplo n.º 30
0
def run(lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    tt.tick("creating model")
    dims = [embdim] + ([encdim] * numlayers)

    m = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                    tieweights=tieweights).to(device)

    if test:
        for i, batch in enumerate(train_batches):
            y = m(batch[0])
            if i > 5:
                break
        print(y.size())

    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(m.parameters(), lr=lr)

    train_batch_f = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(m.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    train_epoch_f = partial(q.train_epoch,
                            model=m,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f)
    valid_epoch_f = partial(q.test_epoch,
                            model=m,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("created model")
    tt.tick("training")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained")

    tt.tick("testing")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")