Exemplo n.º 1
0
def example_usage_full():
    # 1. define model
    model = torch.nn.Sequential(torch.nn.Linear(5, 5), torch.nn.Softmax(-1))

    # 2. define data
    x = torch.rand(100, 5)
    y = torch.randint(0, 5, (100, ))
    dataset = torch.utils.data.TensorDataset(x, y)
    traindataset, validdataset, testdataset = torch.utils.data.random_split(
        dataset, [70, 10, 20])
    trainloader = torch.utils.data.DataLoader(traindataset,
                                              batch_size=2,
                                              shuffle=True)
    validloader = torch.utils.data.DataLoader(validdataset,
                                              batch_size=2,
                                              shuffle=False)
    testloader = torch.utils.data.DataLoader(testdataset,
                                             batch_size=2,
                                             shuffle=False)

    # 3. define losses and wrap them
    loss = torch.nn.CrossEntropyLoss(reduction="mean")
    loss2 = torch.nn.CrossEntropyLoss(reduction="sum")
    loss = q.LossWrapper(loss)
    loss2 = q.LossWrapper(loss2)

    # 4. define optim
    optim = torch.optim.SGD(model.parameters(), lr=1.0)

    # 5. other options (device, ...)
    device = torch.device("cpu")

    # 6. define training function (using partial)
    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=trainloader,
                         optim=optim,
                         losses=[loss, loss2],
                         device=device)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=model,
                         dataloader=validloader,
                         losses=[loss, loss2],
                         device=device)

    # 8. run training
    run_training(run_train_epoch=trainepoch,
                 run_valid_epoch=validepoch,
                 max_epochs=50)

    # 9. run test function
    testresults = q.test_epoch(model=model,
                               dataloader=testloader,
                               losses=[loss, loss2],
                               device=device)
    print(testresults)
Exemplo n.º 2
0
def run(
    lr=0.001,
    batsize=20,
    epochs=70,
    embdim=128,
    encdim=400,
    numlayers=1,
    beamsize=5,
    dropout=.5,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
):
    localargs = locals().copy()
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = GeoDatasetRank()
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    # do_rare_stats(ds)

    # model = TreeRankModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
    #                          sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder)
    #
    model = ParikhRankModel(embdim=encdim,
                            dropout=dropout,
                            sentence_encoder=ds.sentence_encoder,
                            query_encoder=ds.query_encoder)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    ranker = Ranker(model,
                    eval=[BCELoss(mode="logits", smoothing=smoothing)],
                    evalseq=[
                        SeqAccuracies(),
                        TreeAccuracy(tensor2tree=partial(
                            tensor2tree, D=ds.query_encoder.vocab),
                                     orderless={"and", "or"})
                    ])

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        model.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=ranker,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=ranker,
                         dataloader=ds.dataloader("test", batsize),
                         losses=vlosses,
                         device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=ranker,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=ranker,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=100,
            beamsize=beamsize,
            copy_deep=True,
            eval=[SeqAccuracies()],
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
Exemplo n.º 3
0
def run(
    lr=0.001,
    batsize=50,
    epochs=50,
    embdim=100,
    encdim=100,
    numlayers=1,
    beamsize=1,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=3,
    gradnorm=3.,
    cosine_restarts=1.,
    beta=0.001,
    vib_init=True,
    vib_enc=True,
):
    localargs = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = LCQuaDnoENTDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel_VIB(embdim=embdim,
                              hdim=encdim,
                              dropout=dropout,
                              numlayers=numlayers,
                              sentence_encoder=ds.sentence_encoder,
                              query_encoder=ds.query_encoder,
                              feedatt=True,
                              vib_init=vib_init,
                              vib_enc=vib_enc)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    losses = [CELoss(ignore_index=0, mode="logprobs")]
    if vib_init:
        losses.append(
            StatePenalty(lambda state: sum(state.mstate.vib.init),
                         weight=beta))
    if vib_enc:
        losses.append(StatePenalty("mstate.vib.enc", weight=beta))

    tfdecoder = SeqDecoder(
        model,
        tf_ratio=1.,
        eval=losses + [
            SeqAccuracies(),
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"select", "count", "ask"})
        ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(
            model,
            maxtime=40,
            tf_ratio=0.,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"select", "count", "ask"})
            ])
    else:

        freedecoder = BeamDecoder(
            model,
            maxtime=30,
            beamsize=beamsize,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"select", "count", "ask"})
            ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    # if beamsize >= 3:
    #     vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at3", "tree_acc_at_last")
    # else:
    #     vlosses = make_loss_array("seq_acc", "tree_acc", "tree_acc_at_last")

    # trainable_params = tfdecoder.named_parameters()
    # exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")   # don't train input embeddings if doing glove
    # trainable_params = [v for k, v in trainable_params if k not in exclude_params]

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("test", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("valid", batsize),
                               losses=vlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
Exemplo n.º 4
0
def run(traindomains="ALL",
        domain="recipes",
        mincoverage=2,
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        ftlr=0.0001,
        cosinelr=False,
        warmup=0.,
        batsize=30,
        pretrainbatsize=100,
        epochs=100,
        resetmode="none",
        pretrainepochs=100,
        minpretrainepochs=10,
        dropout=0.1,
        decoderdropout=0.5,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="bert-base-uncased",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=30,
        localtest=False,
        printtest=False,
        fullsimplify=True,
        nopretrain=False,
        onlyabstract=False,
        pretrainsetting="all",  # "all", "all+lex", "lex"
        finetunesetting="min",      # "lex", "all", "min"
        ):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    numresets, resetafter, resetevery = 0, 0, 0
    if resetmode == "none":
        pass
    elif resetmode == "once":
        resetafter = 15
        resetevery = 5
        numresets = 1
    elif resetmode == "more":
        resetafter = 15
        resetevery = 5
        numresets = 3
    elif resetmode == "forever":
        resetafter = 15
        resetevery = 5
        numresets = 1000

    print(f'Resetting: "{resetmode}": {numresets} times, first after {resetafter} epochs, then every {resetevery} epochs')

    # wandb.init(project=f"overnight_joint_pretrain_fewshot_{pretrainsetting}-{finetunesetting}-{domain}",
    #            reinit=True, config=settings)
    if traindomains == "ALL":
        alldomains = {"recipes", "restaurants", "blocks", "calendar", "housing", "publications"}
        traindomains = alldomains - {domain, }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc, generaltokenmask = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, onlyabstract=onlyabstract,
                pretrainsetting=pretrainsetting, finetunesetting=finetunesetting)
    tt.msg(f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid")
    tt.msg(f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test")
    tdl = DataLoader(tds, batch_size=pretrainbatsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds, batch_size=pretrainbatsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 decoderdropout=decoderdropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 numbeam=numbeam,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab),
                                 generaltokenmask=generaltokenmask,
                                 resetmode=resetmode
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    if resetmode != "none":
        minpretrainepochs = resetafter + (numresets - 1) * resetevery
    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=minpretrainepochs,
                         more_is_better=True, remember_f=lambda: deepcopy(trainm.model))

    reinit = Reinitializer(trainm.model, resetafter=resetafter, resetevery=resetevery, numresets=numresets, resetothers=[eyt])

    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device,
                         on_end=[lambda: eyt.on_epoch_end(), lambda: reinit()])#, lambda: wandb_logger()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=pretrainepochs,
                       check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": ftlr * enclrmul},
                   {"params": otherparams}]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    # def wandb_logger_ft():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], ftmetrics):
    #         d["ft_train_" + name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], ftvmetrics):
    #         d["ft_valid_" + name] = loss.get_epoch_error()
    #     wandb.log(d)

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=ftdl, optim=ftoptim, losses=ftmetrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=fvdl, losses=ftvmetrics, device=device,
                         on_end=[])#, lambda: wandb_logger_ft()])

    tt.tick("finetuning")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs)
    tt.tock("done finetuning")

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm, dataloader=fvdl, losses=ftvmetrics, device=device)
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=ftxmetrics, device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
Exemplo n.º 5
0
def run(lr=0.001,
        batsize=20,
        epochs=60,
        embdim=128,
        encdim=256,
        numlayers=1,
        beamsize=1,
        dropout=.25,
        wreg=1e-10,
        cuda=False,
        gpu=0,
        minfreq=2,
        gradnorm=3.,
        smoothing=0.,
        cosine_restarts=1.,
        seed=456789,
        p_step=.2,
        p_min=.3,
        ):
    localargs = locals().copy()
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = GeoDataset(sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq)
    print(f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)
    model = BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
                             sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True,
                          p_step=p_step, p_min=p_min)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)
    losses = [CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing)]

    tfdecoder = SeqDecoder(model, tf_ratio=1.,
                           eval=losses + [SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"and", "or"})])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(model, maxtime=100, tf_ratio=0.,
                                 eval=[SeqAccuracies(),
                                       TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                    orderless={"and", "or"})])
        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    else:

        freedecoder = BeamDecoder(model, maxtime=100, beamsize=beamsize,
                                  eval=[SeqAccuracies()],
                                  eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                    orderless={"and", "or"})])
        vlosses = make_array_of_metrics("seq_acc", "tree_acc", "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses,
                         _train_batch=trainbatch, device=device, on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input("Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>")
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match("\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model, maxtime=50, beamsize=beamsize,
                                  eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"op:and", "SW:concat"})])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses, device=device)
        print(_testresults)
        assert(testresults == _testresults)
        tt.tock("tested")
Exemplo n.º 6
0
def run(
    lr=2.5e-4,
    edropout=0.1,
    wdropout=0.1,
    rdropout=0.1,
    adropout=0.1,
    dropout=-1.,
    numlayers=2,
    numheads=8,
    abspos=False,
    tie_wordvecs=False,
    gradnorm=0.5,
    epochs=200,
    dim=256,
    seqlen=50,
    batsize=32,
    eval_batsize=64,
    cuda=False,
    gpu=0,
    test=True,
    subsampleeval=10,
    wreg=1e-6,
    lrcycle=5,
    lrwarmup=3,
):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize, seqlen=seqlen, subsample_eval=subsampleeval)
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))
    if dropout >= 0.:
        edropout, adropout, rdropout, wdropout = dropout, dropout, dropout, dropout
    relpos = not abspos

    tt.tick("creating model")

    m = TransformerLM(dim=dim,
                      worddic=D,
                      numlayers=numlayers,
                      numheads=numheads,
                      activation=q.GeLU,
                      embedding_dropout=edropout,
                      attention_dropout=adropout,
                      word_dropout=wdropout,
                      residual_dropout=rdropout,
                      relpos=relpos,
                      tie_wordvecs=tie_wordvecs,
                      maxlen=2 * seqlen).to(device)
    valid_m = TransformerLMCell(m)

    if test:
        for i, batch in enumerate(valid_batches):
            batch = [batch_e.to(device) for batch_e in batch]
            y = valid_m(batch[0])
            if i > 5:
                break
        for i, batch in enumerate(valid_batches):
            pass
        print(i, batsize, seqlen, valid_batches.data.size(0))
        print(y.size())
        # return
    # return

    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]
    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    # optim = torch.optim.SGD(m.parameters(), lr=lr)
    numbats = len(train_batches)
    print("{} batches in training".format(numbats))
    optim = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wreg)
    # lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode="min", factor=1/4, patience=0, verbose=True)
    # lrp_f = lambda: lrp.step(validloss.get_epoch_error())
    sched = q.CosineLRwithWarmup(optim,
                                 lrcycle * numbats,
                                 warmup=lrwarmup * numbats)

    train_batch_f = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(m.parameters(), gradnorm),
            lambda: sched.step()
        ])
    train_epoch_f = partial(q.train_epoch,
                            model=m,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f)
    valid_epoch_f = partial(q.test_epoch,
                            model=valid_m,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device)
    tt.tock("created model")
    tt.tick("training")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained")

    tt.tick("testing")
    testresults = q.test_epoch(model=valid_m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")
Exemplo n.º 7
0
def run(
    lr=0.001,
    batsize=20,
    epochs=60,
    embdim=128,
    encdim=256,
    numlayers=1,
    beamsize=5,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
    numcvfolds=6,
    testfold=-1,  # if non-default, must be within number of splits, the chosen value is used for validation
    reorder_random=False,
):
    localargs = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    cvfolds = None if testfold == -1 else numcvfolds
    testfold = None if testfold == -1 else testfold
    ds = GeoDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq,
        cvfolds=cvfolds,
        testfold=testfold,
        reorder_random=reorder_random)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0,
                                      mode="logprobs",
                                      smoothing=smoothing),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"and"})
                           ])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model,
                             maxtime=100,
                             tf_ratio=0.,
                             eval=[
                                 SeqAccuracies(),
                                 TreeAccuracy(tensor2tree=partial(
                                     tensor2tree, D=ds.query_encoder.vocab),
                                              orderless={"and"})
                             ])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model,
                              maxtime=100,
                              beamsize=beamsize,
                              copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[
                                  TreeAccuracy(tensor2tree=partial(
                                      tensor2tree, D=ds.query_encoder.vocab),
                                               orderless={"and"})
                              ])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc",
                                       "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])

    train_on = "train"
    valid_on = "test" if testfold is None else "valid"
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader(train_on,
                                                  batsize,
                                                  shuffle=True),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader(valid_on,
                                                  batsize,
                                                  shuffle=False),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    if testfold is not None:
        return vlosses[1].get_epoch_error()

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model,
                                   maxtime=100,
                                   beamsize=beamsize,
                                   copy_deep=True,
                                   eval=[SeqAccuracies()],
                                   eval_beam=[
                                       TreeAccuracy(tensor2tree=partial(
                                           tensor2tree,
                                           D=ds.query_encoder.vocab),
                                                    orderless={"and"})
                                   ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
Exemplo n.º 8
0
def run(
    domain="restaurants",
    lr=0.001,
    enclrmul=0.1,
    cosinelr=False,
    warmup=0.,
    batsize=20,
    epochs=100,
    dropout=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="bart-large",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=50,
    localtest=False,
    printtest=False,
    trainonvalid=False,
):
    settings = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, vds, xds, nltok, flenc = load_ds(domain=domain,
                                          nl_mode=encoder,
                                          trainonvalid=trainonvalid)
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(autocollate, pad_value=1))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=1))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=1))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 tensor2tree=partial(_tensor2tree,
                                                     D=flenc.vocab))
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device)

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
Exemplo n.º 9
0
def run_rerank(
    lr=0.001,
    batsize=20,
    epochs=1,
    embdim=301,  # not used
    encdim=200,
    numlayers=1,
    beamsize=5,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    cosine_restarts=1.,
    domain="restaurants",
    gensavedp="overnight_basic/run{}",
    genrunid=1,
):
    localargs = locals().copy()
    print(locals())
    gensavedrunp = gensavedp.format(genrunid)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = q.load_dataset(gensavedrunp)
    # ds = OvernightDataset(domain=domain, sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    genmodel, genargs = q.load_run(gensavedrunp)
    # BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
    #                          sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder, feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    inpenc = q.LSTMEncoder(embdim,
                           *([encdim // 2] * numlayers),
                           bidir=True,
                           dropout_in=dropout)
    outenc = q.LSTMEncoder(embdim,
                           *([encdim // 2] * numlayers),
                           bidir=True,
                           dropout_in=dropout)
    scoremodel = SimpleScoreModel(genmodel.inp_emb, genmodel.out_emb,
                                  LSTMEncoderWrapper(inpenc),
                                  LSTMEncoderWrapper(outenc), DotSimilarity())

    model = BeamReranker(genmodel, scoremodel, beamsize=beamsize, maxtime=50)

    # todo: run over whole dataset to populate beam cache
    testbatch = next(iter(ds.dataloader("train", batsize=2)))
    model(testbatch)

    sys.exit()

    tfdecoder = SeqDecoder(TFTransition(model), [
        CELoss(ignore_index=0, mode="logprobs"),
        SeqAccuracies(),
        TreeAccuracy(tensor2tree=partial(tensor2tree,
                                         D=ds.query_encoder.vocab),
                     orderless={"op:and", "SW:concat"})
    ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(
        model,
        maxtime=50,
        beamsize=beamsize,
        eval_beam=[
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"op:and", "SW:concat"})
        ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3",
                                    "tree_acc_at_last")

    trainable_params = tfdecoder.named_parameters()
    exclude_params = {"model.model.inp_emb.emb.weight"
                      }  # don't train input embeddings if doing glove
    trainable_params = [
        v for k, v in trainable_params if k not in exclude_params
    ]

    # 4. define optim
    optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("valid", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
Exemplo n.º 10
0
def run(
        sourcelang="en",
        supportlang="en",
        testlang="en",
        lr=0.001,
        enclrmul=0.1,
        numbeam=1,
        cosinelr=False,
        warmup=0.,
        batsize=20,
        epochs=100,
        dropout=0.1,
        dropoutdec=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        encoder="xlm-roberta-base",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        trainonvalid=False,
        statesimweight=0.,
        probsimweight=0.,
        projmode="simple",  # "simple" or "twolayer"
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    # wandb.init(project=f"overnight_pretrain_bert-{domain}",
    #            reinit=True, config=settings)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")

    nltok_name = encoder
    tds, vds, xds, nltok, flenc = load_multilingual_geoquery(
        sourcelang,
        supportlang,
        testlang,
        nltok_name=nltok_name,
        trainonvalid=trainonvalid)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(vds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(xds)/(len(tds) + len(vds) + len(xds)):.2f} ({len(tds)}/{len(vds)}/{len(xds)}) train/valid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(
        encoder_name=encoder,
        dec_vocabsize=flenc.vocab.number_of_ids(),
        dec_layers=numlayers,
        dec_dim=hdim,
        dec_heads=numheads,
        dropout=dropout,
        dropoutdec=dropoutdec,
        smoothing=smoothing,
        maxlen=maxlen,
        numbeam=numbeam,
        tensor2tree=partial(_tensor2tree, D=flenc.vocab),
        statesimweight=statesimweight,
        probsimweight=probsimweight,
        projmode=projmode,
    )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params
        if k.startswith("model.model.encoder.model")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder.model")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[-1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda:
                         (deepcopy(trainm.model), deepcopy(trainm.model2)))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["_train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["_valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()
                                 ])  #, on_end=[lambda: wandb_logger()])

    # validepoch()        # TODO comment out after debugging
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.remembered is not None:
        trainm.model = eyt.remembered[0]
        trainm.model2 = eyt.remembered[1]
        testm.model = eyt.remembered[0]
        testm.model2 = eyt.remembered[1]
    tt.msg("reloaded best")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model2
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
Exemplo n.º 11
0
def run(
        lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        distill="glove",  # "rnnlm", "glove"
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False,
        repretrain=False,  # retrain base model instead of loading it
        savepath="rnnlm.base.pt",  # where to save after training
        glovepath="../../../data/glove/glove.300d"):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    # region base training
    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    if os.path.exists(savepath) and repretrain is False:
        tt.tick("reloading base model")
        with open(savepath, "rb") as f:
            m = torch.load(f)
            m.to(device)
        tt.tock("reloaded base model")
    else:
        tt.tick("preparing training base")
        dims = [embdim] + ([encdim] * numlayers)

        m = RNNLayer_LM(*dims,
                        worddic=D,
                        dropout=dropout,
                        tieweights=tieweights).to(device)

        if test:
            for i, batch in enumerate(train_batches):
                y = m(batch[0])
                if i > 5:
                    break
            print(y.size())

        optim = torch.optim.SGD(m.parameters(), lr=lr)

        train_batch_f = partial(q.train_batch,
                                on_before_optim_step=[
                                    lambda: torch.nn.utils.clip_grad_norm_(
                                        m.parameters(), gradnorm)
                                ])
        lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                         mode="min",
                                                         factor=1 / 4,
                                                         patience=0,
                                                         verbose=True)
        lrp_f = lambda: lrp.step(validloss.get_epoch_error())

        train_epoch_f = partial(q.train_epoch,
                                model=m,
                                dataloader=train_batches,
                                optim=optim,
                                losses=[loss],
                                device=device,
                                _train_batch=train_batch_f)
        valid_epoch_f = partial(q.test_epoch,
                                model=m,
                                dataloader=valid_batches,
                                losses=validlosses,
                                device=device,
                                on_end=[lrp_f])

        tt.tock("prepared training base")
        tt.tick("training base model")
        q.run_training(train_epoch_f,
                       valid_epoch_f,
                       max_epochs=epochs,
                       validinter=1)
        tt.tock("trained base model")

        with open(savepath, "wb") as f:
            torch.save(m, f)

    tt.tick("testing base model")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested base model")
    # endregion

    # region distillation
    tt.tick("preparing training student")
    dims = [embdim] + ([encdim] * numlayers)
    ms = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                     tieweights=tieweights).to(device)

    loss = q.LossWrapper(q.DistillLoss(temperature=2.))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(ms.parameters(), lr=lr)

    train_batch_f = partial(
        train_batch_distill,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(ms.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    if distill == "rnnlm":
        mbase = m
        goldgetter = None
    elif distill == "glove":
        mbase = None
        tt.tick("creating gold getter based on glove")
        goldgetter = GloveGoldGetter(glovepath, worddic=D)
        goldgetter.to(device)
        tt.tock("created gold getter")
    else:
        raise q.SumTingWongException("unknown distill mode {}".format(distill))

    train_epoch_f = partial(train_epoch_distill,
                            model=ms,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f,
                            mbase=mbase,
                            goldgetter=goldgetter)
    valid_epoch_f = partial(q.test_epoch,
                            model=ms,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("prepared training student")
    tt.tick("training student model")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained student model")

    tt.tick("testing student model")
    testresults = q.test_epoch(model=ms,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested student model")
Exemplo n.º 12
0
def example_usage_full_with_penalty_and_hyperparam():
    # 1. define model
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.lin = torch.nn.Linear(5, 5)
            self.sm = torch.nn.Softmax(-1)
            self._pen = 0

        def batch_reset(self):  # called before every batch
            self._pen = 0  # resets penalty

        def get_penalty(
                self):  # must be specified to be called by PenaltyGetter
            return self._pen

        def forward(self, _x):
            _y = self.lin(_x)
            self._pen = torch.sum(_y, dim=1)
            return self.sm(_y)

    model = Model()

    # 2. define data
    x = torch.rand(100, 5)
    y = torch.randint(0, 5, (100, ))
    dataset = torch.utils.data.TensorDataset(x, y)
    traindataset, validdataset, testdataset = torch.utils.data.random_split(
        dataset, [70, 10, 20])
    trainloader = torch.utils.data.DataLoader(traindataset,
                                              batch_size=2,
                                              shuffle=True)
    validloader = torch.utils.data.DataLoader(validdataset,
                                              batch_size=2,
                                              shuffle=False)
    testloader = torch.utils.data.DataLoader(testdataset,
                                             batch_size=2,
                                             shuffle=False)

    # 3. define losses and penalties and wrap them
    loss = torch.nn.CrossEntropyLoss(reduction="mean")
    loss2 = torch.nn.CrossEntropyLoss(reduction="sum")
    penweight = q.hyperparam(1.)
    pen = q.PenaltyGetter(model, "get_penalty", factor=penweight)
    loss = q.LossWrapper(loss)
    loss2 = q.LossWrapper(loss2)
    pen = q.LossWrapper(pen)

    # 4. define optim
    optim = torch.optim.SGD(model.parameters(), lr=1.)

    # 5. other options (device, ...)
    device = torch.device("cpu")

    def on_start_train_epoch():
        penweight.v /= 1.2
        print(q.v(penweight))

    # 6. define training function (using partial)
    trainepoch = partial(q.train_epoch,
                         model=model,
                         dataloader=trainloader,
                         optim=optim,
                         losses=[loss, loss2, pen],
                         device=device,
                         on_start=[on_start_train_epoch])

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=model,
                         dataloader=validloader,
                         losses=[loss, loss2],
                         device=device)

    # 8. run training
    run_training(run_train_epoch=trainepoch,
                 run_valid_epoch=validepoch,
                 max_epochs=50)

    # 9. run test function
    testresults = q.test_epoch(model=model,
                               dataloader=testloader,
                               losses=[loss, loss2],
                               device=device)
    print(testresults)
Exemplo n.º 13
0
def run(
    traindomains="ALL",
    domain="recipes",
    mincoverage=2,
    lr=0.001,
    advlr=-1,
    enclrmul=0.1,
    numbeam=1,
    ftlr=0.0001,
    cosinelr=False,
    warmup=0.,
    batsize=30,
    epochs=100,
    pretrainepochs=100,
    dropout=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="bert-base-uncased",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=30,
    localtest=False,
    printtest=False,
    fullsimplify=True,
    domainstart=False,
    useall=False,
    nopretrain=False,
    entropycontrib=1.,
    advsteps=5,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    if advlr < 0:
        advlr = lr
    if traindomains == "ALL":
        alldomains = {
            "recipes", "restaurants", "blocks", "calendar", "housing",
            "publications"
        }
        traindomains = alldomains - {
            domain,
        }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc, absflenc = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, add_domain_start=domainstart, useall=useall)
    advds = Dataset(tds.examples)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid"
    )
    tt.msg(
        f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(autocollate, pad_value=0))
    advdl = DataLoader(advds,
                       batch_size=batsize,
                       shuffle=True,
                       collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds,
                      batch_size=batsize,
                      shuffle=True,
                      collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds,
                      batch_size=batsize,
                      shuffle=False,
                      collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, advtrainm, testm = create_model(
        encoder_name=encoder,
        fl_vocab=flenc.vocab,
        abs_fl_vocab=absflenc.vocab,
        dec_layers=numlayers,
        dec_dim=hdim,
        dec_heads=numheads,
        dropout=dropout,
        smoothing=smoothing,
        maxlen=maxlen,
        numbeam=numbeam,
        abs_id=absflenc.vocab["@ABS@"],
        entropycontrib=entropycontrib,
    )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "ce", "elem_acc", "tree_acc")
    advmetrics = make_array_of_metrics("adv_loss", "adv_elem_acc",
                                       "adv_tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    advoptim = torch.optim.Adam(advtrainm.parameters(),
                                lr=advlr,
                                weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)
    advclipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        advtrainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
        advlr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
        advlr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)
    advlr_schedule = q.sched.LRSchedule(advoptim, advlr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    advtrainbatch = partial(q.train_batch,
                            on_before_optim_step=[advclipgradnorm])
    trainepoch = partial(
        adv_train_epoch,
        model=trainm,
        dataloader=tdl,
        optim=optim,
        losses=metrics,
        advmodel=advtrainm,
        advdataloader=advdl,
        advoptim=advoptim,
        advlosses=advmetrics,
        _train_batch=trainbatch,
        _adv_train_batch=advtrainbatch,
        device=device,
        on_end=[lambda: lr_schedule.step(), lambda: advlr_schedule.step()],
        advsteps=advsteps)
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch,
                       run_valid_epoch=validepoch,
                       max_epochs=pretrainepochs,
                       check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "ce", "elem_acc", "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": ftlr * enclrmul
    }, {
        "params": otherparams
    }]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(ftvmetrics[1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=ftdl,
                         optim=ftoptim,
                         losses=ftmetrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=fvdl,
                         losses=ftvmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=fvdl,
                                losses=ftvmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=ftxmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
Exemplo n.º 14
0
def run(domain="restaurants",
        lr=0.001,
        ptlr=0.0001,
        enclrmul=0.1,
        cosinelr=False,
        ptcosinelr=False,
        warmup=0.,
        ptwarmup=0.,
        batsize=20,
        ptbatsize=50,
        epochs=100,
        ptepochs=100,
        dropout=0.1,
        wreg=1e-9,
        gradnorm=3,
        smoothing=0.,
        patience=5,
        gpu=-1,
        seed=123456789,
        dataseed=12345678,
        datatemp=0.33,
        ptN=3000,
        tokenmaskp=0.,
        spanmaskp=0.,
        spanmasklamda=2.2,
        treemaskp=0.,
        encoder="bart-large",
        numlayers=6,
        hdim=600,
        numheads=8,
        maxlen=50,
        localtest=False,
        printtest=False,
        ):
    settings = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, vds, xds, nltok, flenc = load_ds(domain=domain, nl_mode=encoder)
    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=partial(autocollate, pad_value=1))
    tt.tock("data loaded")

    tt.tick("creating grammar dataset generator")
    pcfg = build_grammar(tds, vds)
    ptds = PCFGDataset(pcfg, N=ptN, seed=seed, temperature=datatemp, maxlen=100)
    tt.tock("created dataset generator")

    tt.tick("creating model")
    trainm, testm, pretrainm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 tensor2tree=partial(_tensor2tree, D=flenc.vocab)
                                 )
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        print("generated dataset")
        print(ptds[0])
        print(ptds[0])
        allexamples = []
        for i in tqdm(range(len(ptds))):
            allexamples.append(ptds[i])
        uniqueexamples = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples)/len(allexamples)}% unique examples ({len(uniqueexamples)}/{len(allexamples)})")
        ptds.advance_seed()
        print(ptds[0])
        allexamples = list(ptds.examples)
        uniqueexamples2 = set([str(x) for x in allexamples])
        print(f"{100*len(uniqueexamples2)/len(allexamples)}% unique examples ({len(uniqueexamples2)}/{len(allexamples)})")
        print(f"{len(uniqueexamples & uniqueexamples2)}/{len(uniqueexamples | uniqueexamples2)} overlap")
        print("---")
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretraining
    # setup data perturbation
    tokenmasker = TokenMasker(p=tokenmaskp, seed=dataseed) if tokenmaskp > 0 else lambda x: x
    spanmasker = SpanMasker(p=spanmaskp, lamda=spanmasklamda, seed=dataseed) if spanmaskp > 0 else lambda x: x
    treemasker = SubtreeMasker(p=treemaskp, seed=dataseed) if treemaskp > 0 else lambda x: x

    perturbed_ptds = ptds\
        .map(lambda x: (treemasker(x), x))\
        .map(lambda x: (flenc.convert(x[0], "tokens"),
                        flenc.convert(x[1], "tokens")))\
        .map(lambda x: (spanmasker(tokenmasker(x[0])), x[1]))
    perturbed_ptds_tokens = perturbed_ptds
    perturbed_ptds = perturbed_ptds\
        .map(lambda x: (flenc.convert(x[0], "tensor"),
                        flenc.convert(x[1], "tensor")))

    if localtest:
        allex = []
        allperturbedex = []
        _nepo = 50
        print(f"checking {_nepo}, each {ptN} generated examples")
        for _e in tqdm(range(_nepo)):
            for i in range(len(perturbed_ptds_tokens)):
                ex = str(ptds[i])
                perturbed_ex = perturbed_ptds_tokens[i]
                perturbed_ex = f"{' '.join(perturbed_ex[0])}->{' '.join(perturbed_ex[1])}"
                allex.append(ex)
                allperturbedex.append(perturbed_ex)
            ptds.advance_seed()
        uniqueex = set(allex)
        uniqueperturbedex = set(allperturbedex)
        print(f"{len(uniqueex)}/{len(allex)} unique examples")
        print(f"{len(uniqueperturbedex)}/{len(allperturbedex)} unique perturbed examples")

    ptdl = DataLoader(perturbed_ptds, batch_size=ptbatsize, shuffle=True, collate_fn=partial(autocollate, pad_value=1))
    ptmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    ptparams = pretrainm.parameters()
    ptoptim = torch.optim.Adam(ptparams, lr=ptlr, weight_decay=wreg)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)
    t_max = ptepochs
    print(f"Total number of pretraining updates: {t_max} .")
    if ptcosinelr:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> q.sched.Cosine(steps=t_max-ptwarmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=ptwarmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ptoptim, lr_schedule)

    pttrainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    pttrainepoch = partial(q.train_epoch, model=pretrainm, dataloader=ptdl, optim=ptoptim, losses=ptmetrics,
                         _train_batch=pttrainbatch, device=device, on_end=[lambda: lr_schedule.step(),
                                                                           lambda: ptds.advance_seed()])

    tt.tick("pretraining")
    q.run_training(run_train_epoch=pttrainepoch, max_epochs=ptepochs)
    tt.tock("done pretraining")

    # endregion

    # region finetuning
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [v for k, v in trainable_params if k.startswith("model.model.encoder")]
    otherparams = [v for k, v in trainable_params if not k.startswith("model.model.encoder")]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{"params": encparams, "lr": lr * enclrmul},
                   {"params": otherparams}]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1], patience=patience, min_epochs=10, more_is_better=True, remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=trainm, dataloader=tdl, optim=optim, losses=metrics,
                         _train_batch=trainbatch, device=device, on_end=[lambda: lr_schedule.step(), lambda: eyt.on_epoch_end()])
    validepoch = partial(q.test_epoch, model=testm, dataloader=vdl, losses=vmetrics, device=device)

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs, check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    tt.tick("testing")
    testresults = q.test_epoch(model=testm, dataloader=xdl, losses=xmetrics, device=device)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(input_ids, attention_mask=input_ids != predm.config.pad_token_id,
                                      max_length=maxlen)
            inp_strs = [nltok.decode(input_idse, skip_special_tokens=True, clean_up_tokenization_spaces=False) for input_idse in input_ids]
            out_strs = [flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret]
            gold_strs = [flenc.vocab.tostr(output_idse.to(torch.device("cpu"))) for output_idse in output_ids]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics], ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
Exemplo n.º 15
0
def run(lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    tt.tick("creating model")
    dims = [embdim] + ([encdim] * numlayers)

    m = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                    tieweights=tieweights).to(device)

    if test:
        for i, batch in enumerate(train_batches):
            y = m(batch[0])
            if i > 5:
                break
        print(y.size())

    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(m.parameters(), lr=lr)

    train_batch_f = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(m.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    train_epoch_f = partial(q.train_epoch,
                            model=m,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f)
    valid_epoch_f = partial(q.test_epoch,
                            model=m,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("created model")
    tt.tick("training")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained")

    tt.tick("testing")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")