def run(
    lr=0.001,
    batsize=20,
    epochs=60,
    embdim=128,
    encdim=256,
    numlayers=1,
    beamsize=5,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
    numcvfolds=6,
    testfold=-1,  # if non-default, must be within number of splits, the chosen value is used for validation
    bertversion="bert-base-multilingual-uncased",
):
    localargs = locals().copy()
    print(ujson.dumps(localargs, indent=4))
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    cvfolds = None if testfold == -1 else numcvfolds
    testfold = None if testfold == -1 else testfold

    # bertversion = "bert-base-multilingual-uncased"
    # bertversion = "bert-base-uncased"
    berttokenizer = BertTokenizer.from_pretrained(bertversion)
    ds = GeoDataset(bert_tokenizer=berttokenizer,
                    min_freq=minfreq,
                    cvfolds=cvfolds,
                    testfold=testfold)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    # bert = None
    bert = BertModel.from_pretrained(bertversion)
    model = BasicGenModel(bert=bert,
                          embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    orderless = {"_intersection"}
    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0,
                                      mode="logprobs",
                                      smoothing=smoothing),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless=orderless)
                           ])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model,
                             maxtime=100,
                             tf_ratio=0.,
                             eval=[
                                 SeqAccuracies(),
                                 TreeAccuracy(tensor2tree=partial(
                                     tensor2tree, D=ds.query_encoder.vocab),
                                              orderless=orderless)
                             ])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model,
                              maxtime=100,
                              beamsize=beamsize,
                              copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[
                                  TreeAccuracy(tensor2tree=partial(
                                      tensor2tree, D=ds.query_encoder.vocab),
                                               orderless=orderless)
                              ])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc",
                                       "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    bertparams = [
        p for (n, p) in tfdecoder.named_parameters()
        if n.startswith("model.bert.")
    ]
    otherparams = [
        p for (n, p) in tfdecoder.named_parameters()
        if not n.startswith("model.bert.")
    ]
    paramgroups = [{
        "params": bertparams,
        "lr": lr * 0.1
    }, {
        "params": otherparams
    }]
    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    train_on = "train"
    valid_on = "test" if testfold is None else "valid"
    # lr schedule
    if cosine_restarts >= 0:
        t_max = epochs * len(ds.dataloader(train_on))
        # t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 10, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch,
                         on_before_optim_step=[clipgradnorm],
                         on_end=reduce_lr)

    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader(train_on,
                                                  batsize,
                                                  shuffle=True),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader(valid_on,
                                                  batsize,
                                                  shuffle=False),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    if testfold is not None:
        return vlosses[1].get_epoch_error()

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model,
                                   maxtime=100,
                                   beamsize=beamsize,
                                   copy_deep=True,
                                   eval=[SeqAccuracies()],
                                   eval_beam=[
                                       TreeAccuracy(tensor2tree=partial(
                                           tensor2tree,
                                           D=ds.query_encoder.vocab),
                                                    orderless={"and"})
                                   ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
Beispiel #2
0
def run(
    lr=0.001,
    batsize=20,
    epochs=1,
    embdim=301,
    encdim=200,
    numlayers=1,
    beamsize=5,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    cosine_restarts=1.,
    domain="restaurants",
):
    localargs = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = OvernightDataset(
        domain=domain,
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    sentence_rare_tokens = set(
        [ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0, mode="logprobs"),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"op:and", "SW:concat"})
                           ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(
        model,
        maxtime=50,
        beamsize=beamsize,
        tf_ratio=0.,
        eval_beam=[
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"op:and", "SW:concat"})
        ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3",
                                    "tree_acc_at_last")

    trainable_params = tfdecoder.named_parameters()
    exclude_params = set()
    exclude_params.add("model.model.inp_emb.emb.weight"
                       )  # don't train input embeddings if doing glove
    trainable_params = [
        v for k, v in trainable_params if k not in exclude_params
    ]

    # 4. define optim
    optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("valid", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")