Beispiel #1
0
def run(
    lr=0.001,
    batsize=20,
    epochs=1,
    embdim=301,
    encdim=200,
    numlayers=1,
    beamsize=5,
    dropout=.2,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    cosine_restarts=1.,
    domain="restaurants",
):
    localargs = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = OvernightDataset(
        domain=domain,
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    sentence_rare_tokens = set(
        [ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0, mode="logprobs"),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"op:and", "SW:concat"})
                           ])
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    freedecoder = BeamDecoder(
        model,
        maxtime=50,
        beamsize=beamsize,
        tf_ratio=0.,
        eval_beam=[
            TreeAccuracy(tensor2tree=partial(tensor2tree,
                                             D=ds.query_encoder.vocab),
                         orderless={"op:and", "SW:concat"})
        ])

    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    losses = make_array_of_metrics("loss", "seq_acc", "tree_acc")
    vlosses = make_array_of_metrics("tree_acc", "tree_acc_at3",
                                    "tree_acc_at_last")

    trainable_params = tfdecoder.named_parameters()
    exclude_params = set()
    exclude_params.add("model.model.inp_emb.emb.weight"
                       )  # don't train input embeddings if doing glove
    trainable_params = [
        v for k, v in trainable_params if k not in exclude_params
    ]

    # 4. define optim
    optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader("train", batsize),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader("valid", batsize),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=freedecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=vlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(
            _model,
            maxtime=50,
            beamsize=beamsize,
            eval_beam=[
                TreeAccuracy(tensor2tree=partial(tensor2tree,
                                                 D=ds.query_encoder.vocab),
                             orderless={"op:and", "SW:concat"})
            ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=vlosses,
                                    device=device)
        print(_testresults)
        assert (testresults == _testresults)
        tt.tock("tested")
Beispiel #2
0
def run(
    lr=0.001,
    batsize=20,
    epochs=60,
    embdim=128,
    encdim=256,
    numlayers=1,
    beamsize=5,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    smoothing=0.1,
    cosine_restarts=1.,
    seed=123456,
    numcvfolds=6,
    testfold=-1,  # if non-default, must be within number of splits, the chosen value is used for validation
    reorder_random=False,
):
    localargs = locals().copy()
    print(locals())
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    cvfolds = None if testfold == -1 else numcvfolds
    testfold = None if testfold == -1 else testfold
    ds = GeoDataset(
        sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer),
        min_freq=minfreq,
        cvfolds=cvfolds,
        testfold=testfold,
        reorder_random=reorder_random)
    print(
        f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim,
                          hdim=encdim,
                          dropout=dropout,
                          numlayers=numlayers,
                          sentence_encoder=ds.sentence_encoder,
                          query_encoder=ds.query_encoder,
                          feedatt=True)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model,
                           tf_ratio=1.,
                           eval=[
                               CELoss(ignore_index=0,
                                      mode="logprobs",
                                      smoothing=smoothing),
                               SeqAccuracies(),
                               TreeAccuracy(tensor2tree=partial(
                                   tensor2tree, D=ds.query_encoder.vocab),
                                            orderless={"and"})
                           ])
    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model,
                             maxtime=100,
                             tf_ratio=0.,
                             eval=[
                                 SeqAccuracies(),
                                 TreeAccuracy(tensor2tree=partial(
                                     tensor2tree, D=ds.query_encoder.vocab),
                                              orderless={"and"})
                             ])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model,
                              maxtime=100,
                              beamsize=beamsize,
                              copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[
                                  TreeAccuracy(tensor2tree=partial(
                                      tensor2tree, D=ds.query_encoder.vocab),
                                               orderless={"and"})
                              ])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc",
                                       "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])

    train_on = "train"
    valid_on = "test" if testfold is None else "valid"
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=ds.dataloader(train_on,
                                                  batsize,
                                                  shuffle=True),
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=ds.dataloader(valid_on,
                                                  batsize,
                                                  shuffle=False),
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")

    if testfold is not None:
        return vlosses[1].get_epoch_error()

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder,
                               dataloader=ds.dataloader("test", batsize),
                               losses=beamlosses,
                               device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input(
        "Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>"
    )
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match(
            "\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model,
                       localargs,
                       filepath=__file__,
                       overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model,
                                   maxtime=100,
                                   beamsize=beamsize,
                                   copy_deep=True,
                                   eval=[SeqAccuracies()],
                                   eval_beam=[
                                       TreeAccuracy(tensor2tree=partial(
                                           tensor2tree,
                                           D=ds.query_encoder.vocab),
                                                    orderless={"and"})
                                   ])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder,
                                    dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses,
                                    device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder,
                                   ds.dataloader("test",
                                                 batsize=batsize,
                                                 shuffle=False),
                                   device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder,
                                    ds.dataloader("train",
                                                  batsize=batsize,
                                                  shuffle=False),
                                    device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
Beispiel #3
0
def run(
    lr=0.001,
    batsize=20,
    epochs=100,
    embdim=64,
    encdim=128,
    numlayers=1,
    dropout=.25,
    wreg=1e-10,
    cuda=False,
    gpu=0,
    minfreq=2,
    gradnorm=3.,
    beamsize=1,
    cosine_restarts=1.,
    seed=456789,
):
    # DONE: Porter stemmer
    # DONE: linear attention
    # DONE: grad norm
    # DONE: beam search
    # DONE: lr scheduler
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    stemmer = PorterStemmer()
    tokenizer = lambda x: [stemmer.stem(xe) for xe in x.split()]
    ds = GeoQueryDatasetFunQL(
        sentence_encoder=SequenceEncoder(tokenizer=tokenizer),
        min_freq=minfreq)

    train_dl = ds.dataloader("train", batsize=batsize)
    test_dl = ds.dataloader("test", batsize=batsize)
    tt.tock("data loaded")

    do_rare_stats(ds)

    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = create_model(embdim=embdim,
                         hdim=encdim,
                         dropout=dropout,
                         numlayers=numlayers,
                         sentence_encoder=ds.sentence_encoder,
                         query_encoder=ds.query_encoder,
                         feedatt=True)

    # model.apply(initializer)

    tfdecoder = SeqDecoder(
        model,
        tf_ratio=1.,
        eval=[
            CELoss(ignore_index=0, mode="logprobs"),
            SeqAccuracies(),
            TreeAccuracy(
                tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
        ])

    losses = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    # beamdecoder = BeamActionSeqDecoder(tfdecoder.model, beamsize=beamsize, maxsteps=50)
    if beamsize == 1:
        freedecoder = SeqDecoder(
            model,
            maxtime=100,
            tf_ratio=0.,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    else:
        print("Doing beam search!")
        freedecoder = BeamDecoder(
            model,
            beamsize=beamsize,
            maxtime=60,
            eval=[
                SeqAccuracies(),
                TreeAccuracy(
                    tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab))
            ])

        vlosses = make_array_of_metrics("seq_acc", "tree_acc")
    # # test
    # tt.tick("doing one epoch")
    # for batch in iter(train_dl):
    #     batch = batch.to(device)
    #     ttt.tick("start batch")
    #     # with torch.no_grad():
    #     out = tfdecoder(batch)
    #     ttt.tock("end batch")
    # tt.tock("done one epoch")
    # print(out)
    # sys.exit()

    # beamdecoder(next(iter(train_dl)))

    # print(dict(tfdecoder.named_parameters()).keys())

    # 4. define optim
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)
    # optim = torch.optim.SGD(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max} ({epochs} * {len(train_dl)})")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(
            optim, 0, t_max, cycles=cosine_restarts)
        reduce_lr = [lambda: lr_schedule.step()]
    else:
        reduce_lr = []

    # 6. define training function (using partial)
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        tfdecoder.parameters(), gradnorm)
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=tfdecoder,
                         dataloader=train_dl,
                         optim=optim,
                         losses=losses,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=reduce_lr)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch,
                         model=freedecoder,
                         dataloader=test_dl,
                         losses=vlosses,
                         device=device)
    # validepoch = partial(q.test_epoch, model=tfdecoder, dataloader=test_dl, losses=vlosses, device=device)

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs)
    tt.tock("done training")