示例#1
0
def run(
    sourcelang="en",
    testlang="en",
    lr=0.001,
    enclrmul=0.1,
    numbeam=1,
    cosinelr=False,
    warmup=0.,
    batsize=20,
    epochs=100,
    dropout=0.1,
    dropoutdec=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="xlm-roberta-base",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=50,
    localtest=False,
    printtest=False,
    trainonvalid=False,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    # wandb.init(project=f"overnight_pretrain_bert-{domain}",
    #            reinit=True, config=settings)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")

    nltok_name = encoder
    tds, vds, xds, nltok, flenc = load_multilingual_geoquery(
        sourcelang,
        sourcelang,
        testlang,
        nltok_name=nltok_name,
        trainonvalid=trainonvalid)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(vds)/(len(tds) + len(vds) + len(xds)):.2f}/{len(xds)/(len(tds) + len(vds) + len(xds)):.2f} ({len(tds)}/{len(vds)}/{len(xds)}) train/valid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(collate_fn,
                                        pad_value_nl=nltok.pad_token_id))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 dropoutdec=dropoutdec,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 numbeam=numbeam,
                                 tensor2tree=partial(_tensor2tree,
                                                     D=flenc.vocab))
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params
        if k.startswith("model.model.encoder.model")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder.model")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[-1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["_train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["_valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()
                                 ])  #, on_end=[lambda: wandb_logger()])

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.remembered is not None:
        trainm.model = eyt.remembered
        testm.model = eyt.remembered
    tt.msg("reloaded best")

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=vdl,
                                losses=vmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=xmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([metrics, vmetrics, xmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # wandb.config.update(settings)
    # print(settings)
    return settings
示例#2
0
def run_span_io(
        lr=DEFAULT_LR,
        dropout=.5,
        wreg=DEFAULT_WREG,
        batsize=DEFAULT_BATSIZE,
        epochs=DEFAULT_EPOCHS,
        cuda=False,
        gpu=0,
        balanced=False,
        warmup=-1.,
        sched="ang",  # "lin", "cos"
):
    settings = locals().copy()
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span io with BERT")
    tt.tick("loading data")
    data = load_data(which="span/io")
    trainds, devds, testds = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    # compute balancing hyperparam for BCELoss
    trainios = trainds.tensors[1]
    numberpos = (trainios == 2).float().sum()
    numberneg = (trainios == 1).float().sum()
    if balanced:
        pos_weight = (numberneg / numberpos)
    else:
        pos_weight = None
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    spandet = IOSpanDetector(bert, dropout=dropout)
    spandet.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    optim = BertAdam(spandet.parameters(),
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched])
    losses = [
        AutomaskedBCELoss(pos_weight=pos_weight),
        AutomaskedBinarySeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in losses]
    testlosses = [q.LossWrapper(l) for l in losses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")
示例#3
0
def run_relations(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskentity=False,
    warmup=-1.,
    sched="ang",
    savep="exp_bert_rels_",
    test=False,
    freezeemb=False,
):
    settings = locals().copy()
    if test:
        epochs = 0
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running relation classifier with BERT")
    tt.tick("loading data")
    data = load_data(which="rel+borders", retrelD=True)
    trainds, devds, testds, relD = data
    if maskentity:
        trainds, devds, testds = replace_entity_span(trainds, devds, testds)
    else:
        trainds, devds, testds = [
            TensorDataset(ds.tensors[0], ds.tensors[2])
            for ds in [trainds, devds, testds]
        ]
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=batsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader_dev = DataLoader(evalds_dev, batch_size=batsize, shuffle=False)
    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    m = RelationClassifier(bert, relD, dropout=dropout)
    m.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs

    params = []
    for paramname, param in m.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    optim = BertAdam(params,
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched],
                     init_weight_decay=initwreg)
    losses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=m,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=m,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=m,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(m, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.test.npy"), testpreds)
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.dev.npy"), testpreds)
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
示例#4
0
def run_both(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    evalbatsize=-1,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskmention=False,
    warmup=-1.,
    sched="ang",
    cycles=-1.,
    savep="exp_bert_both_",
    test=False,
    freezeemb=False,
    large=False,
    datafrac=1.,
    savemodel=False,
):
    settings = locals().copy()
    print(locals())
    tt = q.ticktock("script")
    if evalbatsize < 0:
        evalbatsize = batsize
    tt.msg("running borders and rel classifier with BERT")
    if test:
        epochs = 0
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    if cycles == -1:
        if sched == "cos":
            cycles = 0.5
        elif sched in ["cosrestart", "coshardrestart"]:
            cycles = 1.0

    # region data
    tt.tick("loading data")
    data = load_data(which="forboth", retrelD=True, datafrac=datafrac)
    trainds, devds, testds, relD = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)
    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    print("number of relations: {}".format(len(relD)))
    # endregion

    # region model
    tt.tick("loading BERT")
    whichbert = "bert-base-uncased"
    if large:
        whichbert = "bert-large-uncased"
    bert = BertModel.from_pretrained(whichbert)
    m = BordersAndRelationClassifier(bert,
                                     relD,
                                     dropout=dropout,
                                     mask_entity_mention=maskmention)
    m.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    assert (initwreg == 0.)
    initl2penalty = InitL2Penalty(bert, factor=q.hyperparam(initwreg))

    params = []
    for paramname, param in m.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    optim = BertAdam(params, lr=lr, weight_decay=wreg, schedule=sched)
    tmodel = BordersAndRelationLosses(m, cesmoothing=smoothing)
    # xmodel = BordersAndRelationLosses(m, cesmoothing=smoothing)
    # losses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    # xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    tlosses = [q.SelectedLinearLoss(i) for i in range(7)]
    xlosses = [q.SelectedLinearLoss(i) for i in range(7)]
    trainlosses = [q.LossWrapper(l) for l in tlosses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=tmodel,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=tmodel,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=tmodel,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    m.clip_len = True
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    settings["testres"] = testres
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        print(savedir)
        # save model
        if savemodel:
            torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        m.clip_len = False
        # TEST data
        testpreds = q.eval_loop(m, evalloader, device=device)
        borderpreds = testpreds[0].cpu().detach().numpy()
        relpreds = testpreds[1].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), borderpreds)
        np.save(os.path.join(savedir, "relpreds.test.npy"), relpreds)
        # DEV data
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        borderpreds = testpreds[0].cpu().detach().numpy()
        relpreds = testpreds[1].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), borderpreds)
        np.save(os.path.join(savedir, "relpreds.dev.npy"), relpreds)
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
示例#5
0
def run(lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        tieweights=False,
        seqlen=35,
        batsize=20,
        eval_batsize=80,
        cuda=False,
        gpu=0,
        test=False):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize,
                  seqlen=VariableSeqlen(minimum=5, maximum_offset=10, mu=seqlen, sigma=0))
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    tt.tick("creating model")
    dims = [embdim] + ([encdim] * numlayers)

    m = RNNLayer_LM(*dims, worddic=D, dropout=dropout,
                    tieweights=tieweights).to(device)

    if test:
        for i, batch in enumerate(train_batches):
            y = m(batch[0])
            if i > 5:
                break
        print(y.size())

    loss = q.LossWrapper(q.CELoss(mode="logits"))
    validloss = q.LossWrapper(q.CELoss(mode="logits"))
    validlosses = [validloss, PPLfromCE(validloss)]
    testloss = q.LossWrapper(q.CELoss(mode="logits"))
    testlosses = [testloss, PPLfromCE(testloss)]

    for l in [loss] + validlosses + testlosses:  # put losses on right device
        l.loss.to(device)

    optim = torch.optim.SGD(m.parameters(), lr=lr)

    train_batch_f = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: torch.nn.utils.clip_grad_norm_(m.parameters(), gradnorm)
        ])
    lrp = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
                                                     mode="min",
                                                     factor=1 / 4,
                                                     patience=0,
                                                     verbose=True)
    lrp_f = lambda: lrp.step(validloss.get_epoch_error())

    train_epoch_f = partial(q.train_epoch,
                            model=m,
                            dataloader=train_batches,
                            optim=optim,
                            losses=[loss],
                            device=device,
                            _train_batch=train_batch_f)
    valid_epoch_f = partial(q.test_epoch,
                            model=m,
                            dataloader=valid_batches,
                            losses=validlosses,
                            device=device,
                            on_end=[lrp_f])

    tt.tock("created model")
    tt.tick("training")
    q.run_training(train_epoch_f,
                   valid_epoch_f,
                   max_epochs=epochs,
                   validinter=1)
    tt.tock("trained")

    tt.tick("testing")
    testresults = q.test_epoch(model=m,
                               dataloader=test_batches,
                               losses=testlosses,
                               device=device)
    print(testresults)
    tt.tock("tested")
示例#6
0
def run_span_borders(
    lr=DEFAULT_LR,
    dropout=.5,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=DEFAULT_EPOCHS,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    warmup=-1.,
    sched="ang",
    savep="exp_bert_span_borders_",
    freezeemb=False,
):
    settings = locals().copy()
    print(locals())
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span border with BERT")
    tt.tick("loading data")
    data = load_data(which="span/borders")
    trainds, devds, testds = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=batsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=batsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:-1])
    evalloader = DataLoader(evalds, batch_size=batsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:-1])
    evalloader_dev = DataLoader(evalds_dev, batch_size=batsize, shuffle=False)
    # endregion

    # region model
    tt.tick("loading BERT")
    bert = BertModel.from_pretrained("bert-base-uncased")
    spandet = BorderSpanDetector(bert, dropout=dropout)
    spandet.to(device)
    tt.tock("loaded BERT")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = []
    for paramname, param in spandet.named_parameters():
        if paramname.startswith("bert.embeddings.word_embeddings"):
            if not freezeemb:
                params.append(param)
        else:
            params.append(param)
    optim = BertAdam(params,
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=schedmap[sched])
    losses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(reduction="none"),
        q.SeqAccuracy()
    ]
    xlosses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(reduction="none"),
        q.SeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(spandet, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(spandet, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), testpreds)
        # save dev predictions
        testpreds = q.eval_loop(spandet, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), testpreds)
        tt.tock("done")
def run(lr=0.001,
        enclrmul=0.1,
        hdim=768,
        numlayers=8,
        numheads=12,
        dropout=0.1,
        wreg=0.,
        batsize=10,
        epochs=100,
        warmup=0,
        sustain=0,
        cosinelr=False,
        gradacc=1,
        gradnorm=100,
        patience=5,
        validinter=3,
        seed=87646464,
        gpu=-1,
        datamode="full",    # "full", "ltr" (left to right)
        ):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))

    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt = q.ticktock("script")
    tt.tick("loading")
    tds, vds, xds, tds_seq, vds_seq, xds_seq, nltok, flenc, orderless = load_ds("restaurants", mode=datamode)
    tt.tock("loaded")

    tdl = DataLoader(tds, batch_size=batsize, shuffle=True, collate_fn=collate_fn)
    vdl = DataLoader(vds, batch_size=batsize, shuffle=False, collate_fn=collate_fn)
    xdl = DataLoader(xds, batch_size=batsize, shuffle=False, collate_fn=collate_fn)

    tdl_seq = DataLoader(tds_seq, batch_size=batsize, shuffle=True, collate_fn=autocollate)
    vdl_seq = DataLoader(vds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate)
    xdl_seq = DataLoader(xds_seq, batch_size=batsize, shuffle=False, collate_fn=autocollate)

    # model
    tagger = TransformerTagger(hdim, flenc.vocab, numlayers, numheads, dropout)
    tagmodel = TreeInsertionTaggerModel(tagger)
    decodermodel = TreeInsertionDecoder(tagger, seqenc=flenc, maxsteps=50, max_tree_size=30,
                                        mode=datamode)
    decodermodel = TreeInsertionDecoderTrainModel(decodermodel)

    # batch = next(iter(tdl))
    # out = tagmodel(*batch)

    tmetrics = make_array_of_metrics("loss", "elemrecall", "seqrecall", reduction="mean")
    tseqmetrics = make_array_of_metrics("treeacc", reduction="mean")
    vmetrics = make_array_of_metrics("treeacc", reduction="mean")
    xmetrics = make_array_of_metrics("treeacc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "bert_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{"params": bertparams, "lr": _lr * _enclrmul},
                       {"params": otherparams}]
        return paramgroups
    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[0], patience=patience, min_epochs=30, more_is_better=True, remember_f=lambda: deepcopy(tagger))
    # def wandb_logger():
    #     d = {}
    #     for name, loss in zip(["loss", "elem_acc", "seq_acc", "tree_acc"], metrics):
    #         d["train_"+name] = loss.get_epoch_error()
    #     for name, loss in zip(["seq_acc", "tree_acc"], vmetrics):
    #         d["valid_"+name] = loss.get_epoch_error()
    #     wandb.log(d)
    t_max = epochs
    optim = get_optim(tagger, lr, enclrmul, wreg)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(steps=t_max-warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, gradient_accumulation_steps=gradacc,
                                        on_before_optim_step=[lambda : clipgradnorm(_m=tagger, _norm=gradnorm)])

    trainepoch = partial(q.train_epoch, model=tagmodel,
                                        dataloader=tdl,
                                        optim=optim,
                                        losses=tmetrics,
                                        device=device,
                                        _train_batch=trainbatch,
                                        on_end=[lambda: lr_schedule.step()])

    trainseqepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=tseqmetrics,
                         dataloader=tdl_seq,
                         device=device)

    validepoch = partial(q.test_epoch,
                         model=decodermodel,
                         losses=vmetrics,
                         dataloader=vdl_seq,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    validepoch()        # TODO: remove this after debugging

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=[trainseqepoch, validepoch],
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")
示例#8
0
def run(
    traindomains="ALL",
    domain="recipes",
    mincoverage=2,
    lr=0.001,
    enclrmul=0.1,
    numbeam=1,
    ftlr=0.0001,
    cosinelr=False,
    warmup=0.,
    batsize=30,
    epochs=100,
    pretrainepochs=100,
    dropout=0.1,
    wreg=1e-9,
    gradnorm=3,
    smoothing=0.,
    patience=5,
    gpu=-1,
    seed=123456789,
    encoder="bert-base-uncased",
    numlayers=6,
    hdim=600,
    numheads=8,
    maxlen=30,
    localtest=False,
    printtest=False,
    fullsimplify=True,
    domainstart=False,
    useall=False,
    nopretrain=False,
    onlyabstract=False,
):
    settings = locals().copy()
    print(json.dumps(settings, indent=4))
    if traindomains == "ALL":
        alldomains = {
            "recipes", "restaurants", "blocks", "calendar", "housing",
            "publications"
        }
        traindomains = alldomains - {
            domain,
        }
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if gpu < 0 else torch.device(gpu)

    tt.tick("loading data")
    tds, ftds, vds, fvds, xds, nltok, flenc = \
        load_ds(traindomains=traindomains, testdomain=domain, nl_mode=encoder, mincoverage=mincoverage,
                fullsimplify=fullsimplify, add_domain_start=domainstart, useall=useall, onlyabstract=onlyabstract)
    tt.msg(
        f"{len(tds)/(len(tds) + len(vds)):.2f}/{len(vds)/(len(tds) + len(vds)):.2f} ({len(tds)}/{len(vds)}) train/valid"
    )
    tt.msg(
        f"{len(ftds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(fvds)/(len(ftds) + len(fvds) + len(xds)):.2f}/{len(xds)/(len(ftds) + len(fvds) + len(xds)):.2f} ({len(ftds)}/{len(fvds)}/{len(xds)}) fttrain/ftvalid/test"
    )
    tdl = DataLoader(tds,
                     batch_size=batsize,
                     shuffle=True,
                     collate_fn=partial(autocollate, pad_value=0))
    ftdl = DataLoader(ftds,
                      batch_size=batsize,
                      shuffle=True,
                      collate_fn=partial(autocollate, pad_value=0))
    vdl = DataLoader(vds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    fvdl = DataLoader(fvds,
                      batch_size=batsize,
                      shuffle=False,
                      collate_fn=partial(autocollate, pad_value=0))
    xdl = DataLoader(xds,
                     batch_size=batsize,
                     shuffle=False,
                     collate_fn=partial(autocollate, pad_value=0))
    tt.tock("data loaded")

    tt.tick("creating model")
    trainm, testm = create_model(encoder_name=encoder,
                                 dec_vocabsize=flenc.vocab.number_of_ids(),
                                 dec_layers=numlayers,
                                 dec_dim=hdim,
                                 dec_heads=numheads,
                                 dropout=dropout,
                                 smoothing=smoothing,
                                 maxlen=maxlen,
                                 numbeam=numbeam,
                                 tensor2tree=partial(_tensor2tree,
                                                     D=flenc.vocab))
    tt.tock("model created")

    # run a batch of data through the model
    if localtest:
        batch = next(iter(tdl))
        out = trainm(*batch)
        print(out)
        out = testm(*batch)
        print(out)

    # region pretrain on all domains
    metrics = make_array_of_metrics("loss", "elem_acc", "seq_acc", "tree_acc")
    vmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    xmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": lr * enclrmul
    }, {
        "params": otherparams
    }]

    optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(vmetrics[1],
                         patience=patience,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=tdl,
                         optim=optim,
                         losses=metrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=vdl,
                         losses=vmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    if not nopretrain:
        tt.tick("pretraining")
        q.run_training(run_train_epoch=trainepoch,
                       run_valid_epoch=validepoch,
                       max_epochs=pretrainepochs,
                       check_stop=[lambda: eyt.check_stop()])
        tt.tock("done pretraining")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    # region finetune
    ftmetrics = make_array_of_metrics("loss", "elem_acc", "seq_acc",
                                      "tree_acc")
    ftvmetrics = make_array_of_metrics("seq_acc", "tree_acc")
    ftxmetrics = make_array_of_metrics("seq_acc", "tree_acc")

    trainable_params = list(trainm.named_parameters())
    exclude_params = set()
    # exclude_params.add("model.model.inp_emb.emb.weight")  # don't train input embeddings if doing glove
    if len(exclude_params) > 0:
        trainable_params = [(k, v) for k, v in trainable_params
                            if k not in exclude_params]

    tt.msg("different param groups")
    encparams = [
        v for k, v in trainable_params if k.startswith("model.model.encoder")
    ]
    otherparams = [
        v for k, v in trainable_params
        if not k.startswith("model.model.encoder")
    ]
    if len(encparams) == 0:
        raise Exception("No encoder parameters found!")
    paramgroups = [{
        "params": encparams,
        "lr": ftlr * enclrmul
    }, {
        "params": otherparams
    }]

    ftoptim = torch.optim.Adam(paramgroups, lr=ftlr, weight_decay=wreg)

    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(
        trainm.parameters(), gradnorm)

    eyt = q.EarlyStopper(ftvmetrics[1],
                         patience=1000,
                         min_epochs=10,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(trainm.model))

    t_max = epochs
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            steps=t_max - warmup) >> 0.
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(ftoptim, lr_schedule)

    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch,
                         model=trainm,
                         dataloader=ftdl,
                         optim=ftoptim,
                         losses=ftmetrics,
                         _train_batch=trainbatch,
                         device=device,
                         on_end=[lambda: lr_schedule.step()])
    validepoch = partial(q.test_epoch,
                         model=testm,
                         dataloader=fvdl,
                         losses=ftvmetrics,
                         device=device,
                         on_end=[lambda: eyt.on_epoch_end()])

    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validepoch,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()])
    tt.tock("done training")

    if eyt.get_remembered() is not None:
        tt.msg("reloaded")
        trainm.model = eyt.get_remembered()
        testm.model = eyt.get_remembered()

    # endregion

    tt.tick("testing")
    validresults = q.test_epoch(model=testm,
                                dataloader=fvdl,
                                losses=ftvmetrics,
                                device=device)
    testresults = q.test_epoch(model=testm,
                               dataloader=xdl,
                               losses=ftxmetrics,
                               device=device)
    print(validresults)
    print(testresults)
    tt.tock("tested")

    if printtest:
        predm = testm.model
        predm.to(device)
        c, t = 0, 0
        for testbatch in iter(xdl):
            input_ids = testbatch[0]
            output_ids = testbatch[1]
            input_ids = input_ids.to(device)
            ret = predm.generate(
                input_ids,
                attention_mask=input_ids != predm.config.pad_token_id,
                max_length=maxlen)
            inp_strs = [
                nltok.decode(input_idse,
                             skip_special_tokens=True,
                             clean_up_tokenization_spaces=False)
                for input_idse in input_ids
            ]
            out_strs = [
                flenc.vocab.tostr(rete.to(torch.device("cpu"))) for rete in ret
            ]
            gold_strs = [
                flenc.vocab.tostr(output_idse.to(torch.device("cpu")))
                for output_idse in output_ids
            ]

            for x, y, g in zip(inp_strs, out_strs, gold_strs):
                print(" ")
                print(f"'{x}'\n--> {y}\n <=> {g}")
                if y == g:
                    c += 1
                else:
                    print("NOT SAME")
                t += 1
        print(f"seq acc: {c/t}")
        # testout = q.eval_loop(model=testm, dataloader=xdl, device=device)
        # print(testout)

    print("done")
    # settings.update({"train_seqacc": losses[]})

    for metricarray, datasplit in zip([ftmetrics, ftvmetrics, ftxmetrics],
                                      ["train", "valid", "test"]):
        for metric in metricarray:
            settings[f"{datasplit}_{metric.name}"] = metric.get_epoch_error()

    # print(settings)
    return settings
示例#9
0
def run(
    lr=0.0001,
    enclrmul=0.1,
    smoothing=0.1,
    gradnorm=3,
    batsize=60,
    epochs=16,
    patience=-1,
    validinter=1,
    validfrac=0.1,
    warmup=3,
    cosinelr=False,
    dataset="scan/length",
    maxsize=50,
    seed=42,
    hdim=768,
    numlayers=6,
    numheads=12,
    dropout=0.1,
    bertname="bert-base-uncased",
    testcode=False,
    userelpos=False,
    gpu=-1,
    evaltrain=False,
    trainonvalid=False,
    trainonvalidonly=False,
    recomputedata=False,
    adviters=3,  # adversary updates per main update
    advreset=10,  # reset adversary after this number of epochs
    advcontrib=1.,
    advmaskfrac=0.2,
    advnumsamples=3,
):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    wandb.init(project=f"compgen_set_aib", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)
    if trainonvalid:
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl_main = DataLoader(trainds,
                              batch_size=batsize,
                              shuffle=True,
                              collate_fn=autocollate)
    traindl_adv = DataLoader(trainds,
                             batch_size=batsize,
                             shuffle=True,
                             collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    if trainonvalidonly:
        traindl_main = DataLoader(validds,
                                  batch_size=batsize,
                                  shuffle=True,
                                  collate_fn=autocollate)
        traindl_adv = DataLoader(validds,
                                 batch_size=batsize,
                                 shuffle=True,
                                 collate_fn=autocollate)
        validdl = testdl
    # print(json.dumps(next(iter(trainds)), indent=3))
    # print(next(iter(traindl)))
    # print(next(iter(validdl)))
    tt.tock()
    tt.tock()

    tt.tick("model")
    encoder = TransformerEncoder(hdim,
                                 vocab=inpdic,
                                 numlayers=numlayers,
                                 numheads=numheads,
                                 dropout=dropout,
                                 weightmode=bertname,
                                 userelpos=userelpos,
                                 useabspos=not userelpos)
    advencoder = TransformerEncoder(hdim,
                                    vocab=inpdic,
                                    numlayers=numlayers,
                                    numheads=numheads,
                                    dropout=dropout,
                                    weightmode="vanilla",
                                    userelpos=userelpos,
                                    useabspos=not userelpos)
    setdecoder = SetDecoder(hdim, vocab=fldic, encoder=encoder)
    adv = AdvTagger(advencoder, maskfrac=advmaskfrac, vocab=inpdic)
    model = AdvModel(setdecoder,
                     adv,
                     numsamples=advnumsamples,
                     advcontrib=advcontrib)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl_main))
        # out = tagger(batch[1])
        tt.tick("train")
        out = model(*batch)
        tt.tock()
        model.train(False)
        tt.tick("test")
        out = model(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss_main = make_array_of_metrics("loss",
                                       "mainloss",
                                       "advloss",
                                       "acc",
                                       reduction="mean")
    tloss_adv = make_array_of_metrics("loss", reduction="mean")
    tmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    vmetrics = make_array_of_metrics("loss", "acc", reduction="mean")
    xmetrics = make_array_of_metrics("loss", "acc", reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    if patience < 0:
        patience = epochs
    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(model))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "mainloss", "advloss", "acc"],
                              tloss_main):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["advloss"], tloss_adv):
            d["train_adv_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], tmetrics):
            d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(["acc"], vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim_main = get_optim(model.core, lr, enclrmul)
    optim_adv = get_optim(model.adv, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule_main = q.sched.LRSchedule(optim_main, lr_schedule)
    lr_schedule_adv = q.sched.LRSchedule(optim_adv, lr_schedule)

    trainbatch_main = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.core, _norm=gradnorm)
        ])
    trainbatch_adv = partial(
        q.train_batch,
        on_before_optim_step=[
            lambda: clipgradnorm(_m=model.adv, _norm=gradnorm)
        ])

    print("using test data for validation")
    validdl = testdl

    trainepoch = partial(adv_train_epoch,
                         main_model=model.main_trainmodel,
                         adv_model=model.adv_trainmodel,
                         main_dataloader=traindl_main,
                         adv_dataloader=traindl_adv,
                         main_optim=optim_main,
                         adv_optim=optim_adv,
                         main_losses=tloss_main,
                         adv_losses=tloss_adv,
                         adviters=adviters,
                         device=device,
                         print_every_batch=True,
                         _main_train_batch=trainbatch_main,
                         _adv_train_batch=trainbatch_adv,
                         on_end=[
                             lambda: lr_schedule_main.step(),
                             lambda: lr_schedule_adv.step()
                         ])

    # eval epochs
    trainevalepoch = partial(q.test_epoch,
                             model=model,
                             losses=tmetrics,
                             dataloader=traindl_main,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=model,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testepoch = partial(q.test_epoch,
                        model=model,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None:
        assert False
        tt.msg("reloading best")
        model = eyt.remembered

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_acc": tmetrics[1].get_epoch_error()})
    settings.update({"final_valid_acc": vmetrics[1].get_epoch_error()})
    settings.update({"final_test_acc": xmetrics[1].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)
示例#10
0
def run_relations(
    lr=DEFAULT_LR,
    dropout=.3,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    epochs=10,
    smoothing=DEFAULT_SMOOTHING,
    cuda=False,
    gpu=0,
    balanced=False,
    maskentity=False,
    savep="exp_bilstm_rels_",
    test=False,
    datafrac=1.,
    glove=False,
    embdim=50,
    dim=300,
    numlayers=2,
    warmup=0.0,
    cycles=0.5,
    sched="cos",
    evalbatsize=-1,
    classweighted=False,
    fixembed=False,
):
    print(locals())
    settings = locals().copy()
    if evalbatsize < 0:
        evalbatsize = batsize
    if test:
        epochs = 0
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running relation classifier with BiLSTM")
    tt.tick("loading data")
    data = load_data(which="wordmat,wordborders,rels",
                     datafrac=datafrac,
                     retrelD=True)
    trainds, devds, testds, wD, relD = data
    rev_wD = {v: k for k, v in wD.items()}

    def pp(ids):
        ret = " ".join(
            [rev_wD[idse.item()] for idse in ids if idse.item() != 0])
        return ret

    print(pp(trainds.tensors[0][0]))
    print(trainds.tensors[1][0])
    if maskentity:
        trainds, devds, testds = replace_entity_span(trainds,
                                                     devds,
                                                     testds,
                                                     D=wD)
    else:
        trainds, devds, testds = [
            TensorDataset(ds.tensors[0], ds.tensors[2])
            for ds in [trainds, devds, testds]
        ]

    for i in range(10):
        question = trainds.tensors[0][i]
        print(pp(question))
    print()
    for i in range(10):
        question = devds.tensors[0][i]
        print(pp(question))
    print()
    for i in range(10):
        question = testds.tensors[0][i]
        print(pp(question))

    relcounts = torch.zeros(max(relD.values()) + 1)
    trainrelcounts = torch.tensor(
        np.bincount(trainds.tensors[1].detach().cpu().numpy()))
    relcounts[:len(trainrelcounts)] += trainrelcounts.float()
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)

    if test:
        evalloader = DataLoader(TensorDataset(*evalloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
        testloader = DataLoader(TensorDataset(*testloader.dataset[:10]),
                                batch_size=batsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("making model")
    emb = q.WordEmb(embdim, worddic=wD)
    if glove:
        print("using glove")
        stoi_, vectors_, dim = torch.load(
            "../../data/buboqa/data/sq_glove300d.pt")
        # map vectors from custom glove ids to wD ids
        vectors = torch.zeros(max(wD.values()) + 1,
                              embdim,
                              device=vectors_.device,
                              dtype=vectors_.dtype)
        stoi = {}
        for k, v in stoi_.items():
            if k in wD:
                vectors[wD[k]] = vectors_[v]
                stoi[k] = wD[k]
        print("{} words in stoi that are in wD".format(len(stoi)))
        gloveemb = q.WordEmb(embdim, worddic=stoi, _weight=vectors)
        # gloveemb = q.WordEmb.load_glove("glove.{}d".format(embdim), selectD=wD)
        if fixembed:
            gloveemb.freeze()
            emb.freeze()
        emb = q.SwitchedWordEmb(emb).override(gloveemb)

    bilstm = q.rnn.LSTMEncoder(embdim,
                               *([dim] * numlayers),
                               bidir=True,
                               dropout_in=dropout)
    # bilstm = torch.nn.LSTM(embdim, dim, batch_first=True, num_layers=numlayers, bidirectional=True, dropout=dropout)
    m = RelationClassifier(emb=emb,
                           bilstm=bilstm,
                           dim=dim,
                           relD=relD,
                           dropout=dropout)
    m.to(device)

    # model = RelationPrediction(config)
    tt.tock("made model")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = m.parameters()
    params = [param for param in params if param.requires_grad == True]
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    optim = BertAdam(params,
                     lr=lr,
                     weight_decay=wreg,
                     warmup=warmup,
                     t_total=totalsteps,
                     schedule=sched)
    # optim = torch.optim.Adam(params, lr=lr, weight_decay=wreg)
    # losses = [
    #     torch.nn.CrossEntropyLoss(size_average=True),
    #     q.Accuracy()
    # ]
    losses = [
        q.SmoothedCELoss(smoothing=smoothing,
                         weight=1 /
                         relcounts.clamp_min(1e-6) if classweighted else None),
        q.Accuracy()
    ]
    # xlosses = [
    #     torch.nn.CrossEntropyLoss(size_average=True),
    #     q.Accuracy()
    # ]
    xlosses = [q.SmoothedCELoss(smoothing=smoothing), q.Accuracy()]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=m,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=m,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=m,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(m, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))
        # save relation dictionary
        # json.dump(relD, open(os.path.join(savedir, "relD.json"), "w"))
        # save test predictions
        testpreds = q.eval_loop(m, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.test.npy"), testpreds)
        testpreds = q.eval_loop(m, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "relpreds.dev.npy"), testpreds)
        tt.msg("saved in {}".format(savedir))
        # save bert-tokenized questions
        # tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        # with open(os.path.join(savedir, "testquestions.txt"), "w") as f:
        #     for batch in evalloader:
        #         ques, io = batch
        #         ques = ques.numpy()
        #         for question in ques:
        #             qstr = " ".join([x for x in tokenizer.convert_ids_to_tokens(question) if x != "[PAD]"])
        #             f.write(qstr + "\n")

        tt.tock("done")
示例#11
0
def run_span_borders(
    lr=DEFAULT_LR,
    dropout=.3,
    wreg=DEFAULT_WREG,
    initwreg=DEFAULT_INITWREG,
    batsize=DEFAULT_BATSIZE,
    evalbatsize=-1,
    epochs=DEFAULT_EPOCHS,
    smoothing=DEFAULT_SMOOTHING,
    dim=200,
    numlayers=1,
    cuda=False,
    gpu=0,
    savep="exp_bilstm_span_borders_",
    datafrac=1.,
    glove=False,
    fixembed=False,
    embdim=50,
    sched="cos",
    warmup=0.1,
    cycles=0.5,
):
    settings = locals().copy()
    print(locals())
    if evalbatsize < 0:
        evalbatsize = batsize
    if cuda:
        device = torch.device("cuda", gpu)
    else:
        device = torch.device("cpu")
    # region data
    tt = q.ticktock("script")
    tt.msg("running span border with BiLSTM")
    tt.tick("loading data")
    data = load_data(which="wordmat,wordborders", datafrac=datafrac)
    trainds, devds, testds, wD = data
    tt.tock("data loaded")
    tt.msg("Train/Dev/Test sizes: {} {} {}".format(len(trainds), len(devds),
                                                   len(testds)))
    trainloader = DataLoader(trainds, batch_size=batsize, shuffle=True)
    devloader = DataLoader(devds, batch_size=evalbatsize, shuffle=False)
    testloader = DataLoader(testds, batch_size=evalbatsize, shuffle=False)
    evalds = TensorDataset(*testloader.dataset.tensors[:1])
    evalloader = DataLoader(evalds, batch_size=evalbatsize, shuffle=False)
    evalds_dev = TensorDataset(*devloader.dataset.tensors[:1])
    evalloader_dev = DataLoader(evalds_dev,
                                batch_size=evalbatsize,
                                shuffle=False)
    # endregion

    # region model
    tt.tick("creating model")
    emb = q.WordEmb(embdim, worddic=wD)
    if glove:
        print("using glove")
        stoi_, vectors_, dim = torch.load(
            "../../data/buboqa/data/sq_glove300d.pt")
        # map vectors from custom glove ids to wD ids
        vectors = torch.zeros(max(wD.values()) + 1,
                              embdim,
                              device=vectors_.device,
                              dtype=vectors_.dtype)
        stoi = {}
        for k, v in stoi_.items():
            if k in wD:
                vectors[wD[k]] = vectors_[v]
                stoi[k] = wD[k]
        print("{} words in stoi that are in wD".format(len(stoi)))
        gloveemb = q.WordEmb(embdim, worddic=stoi, _weight=vectors)
        # gloveemb = q.WordEmb.load_glove("glove.{}d".format(embdim), selectD=wD)
        if fixembed:
            gloveemb.freeze()
        emb = q.SwitchedWordEmb(emb).override(gloveemb)
    # inpD = tokenizer.vocab
    # q.WordEmb.masktoken = "[PAD]"
    # emb = q.WordEmb(embdim, worddic=inpD)
    bilstm = q.rnn.LSTMEncoder(embdim,
                               *([dim] * numlayers),
                               bidir=True,
                               dropout_in_shared=dropout)
    spandet = BorderSpanDetector(emb, bilstm, dim * 2, dropout=dropout)
    spandet.to(device)
    tt.tock("model created")
    # endregion

    # region training
    totalsteps = len(trainloader) * epochs
    params = spandet.parameters()
    sched = get_schedule(sched,
                         warmup=warmup,
                         t_total=totalsteps,
                         cycles=cycles)
    optim = BertAdam(params, lr=lr, weight_decay=wreg, schedule=sched)
    # optim = torch.optim.Adam(spandet.parameters(), lr=lr, weight_decay=wreg)
    losses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(),
        q.SeqAccuracy()
    ]
    xlosses = [
        q.SmoothedCELoss(smoothing=smoothing),
        SpanF1Borders(),
        q.SeqAccuracy()
    ]
    trainlosses = [q.LossWrapper(l) for l in losses]
    devlosses = [q.LossWrapper(l) for l in xlosses]
    testlosses = [q.LossWrapper(l) for l in xlosses]
    trainloop = partial(q.train_epoch,
                        model=spandet,
                        dataloader=trainloader,
                        optim=optim,
                        losses=trainlosses,
                        device=device)
    devloop = partial(q.test_epoch,
                      model=spandet,
                      dataloader=devloader,
                      losses=devlosses,
                      device=device)
    testloop = partial(q.test_epoch,
                       model=spandet,
                       dataloader=testloader,
                       losses=testlosses,
                       device=device)

    tt.tick("training")
    q.run_training(trainloop, devloop, max_epochs=epochs)
    tt.tock("done training")

    tt.tick("testing")
    testres = testloop()
    print(testres)
    tt.tock("tested")

    if len(savep) > 0:
        tt.tick("making predictions and saving")
        i = 0
        while os.path.exists(savep + str(i)):
            i += 1
        os.mkdir(savep + str(i))
        savedir = savep + str(i)
        # save model
        # torch.save(spandet, open(os.path.join(savedir, "model.pt"), "wb"))
        # save settings
        json.dump(settings, open(os.path.join(savedir, "settings.json"), "w"))

        outlen = trainloader.dataset.tensors[0].size(1)
        spandet.outlen = outlen

        # save test predictions
        testpreds = q.eval_loop(spandet, evalloader, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.test.npy"), testpreds)
        # save dev predictions
        testpreds = q.eval_loop(spandet, evalloader_dev, device=device)
        testpreds = testpreds[0].cpu().detach().numpy()
        np.save(os.path.join(savedir, "borderpreds.dev.npy"), testpreds)
        tt.msg("saved in {}".format(savedir))
        tt.tock("done")
示例#12
0
def run(lr=0.001,
        batsize=20,
        epochs=70,
        embdim=128,
        encdim=256,
        numlayers=1,
        beamsize=5,
        dropout=.25,
        wreg=1e-10,
        cuda=False,
        gpu=0,
        minfreq=2,
        gradnorm=3.,
        smoothing=0.1,
        cosine_restarts=1.,
        seed=123456,
        beta_spec="none",
        minkl=0.05,
        ):
    localargs = locals().copy()
    print(locals())
    torch.manual_seed(seed)
    np.random.seed(seed)
    tt = q.ticktock("script")
    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)
    tt.tick("loading data")
    ds = GeoDataset(sentence_encoder=SequenceEncoder(tokenizer=split_tokenizer), min_freq=minfreq)
    print(f"max lens: {ds.maxlen_input} (input) and {ds.maxlen_output} (output)")
    tt.tock("data loaded")

    beta_ = q.hyperparam(0)
    if beta_spec == "none":
        beta_spec = "0:0"
    beta_sched = q.EnvelopeSchedule(beta_, beta_spec, numsteps=epochs)

    do_rare_stats(ds)
    # batch = next(iter(train_dl))
    # print(batch)
    # print("input graph")
    # print(batch.batched_states)

    model = BasicGenModel(embdim=embdim, hdim=encdim, dropout=dropout, numlayers=numlayers,
                             sentence_encoder=ds.sentence_encoder, query_encoder=ds.query_encoder,
                          feedatt=True, minkl=minkl)

    # sentence_rare_tokens = set([ds.sentence_encoder.vocab(i) for i in model.inp_emb.rare_token_ids])
    # do_rare_stats(ds, sentence_rare_tokens=sentence_rare_tokens)

    tfdecoder = SeqDecoder(model, tf_ratio=1.,
                           eval=[CELoss(ignore_index=0, mode="logprobs", smoothing=smoothing),
                                 StatePenalty(lambda x: x.mstate.kld, weight=beta_),
                            SeqAccuracies(), TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"and", "or"})])
    losses = make_array_of_metrics("loss", "penalty", "elem_acc", "seq_acc", "tree_acc")

    freedecoder = SeqDecoder(model, maxtime=100, tf_ratio=0.,
                             eval=[SeqAccuracies(),
                                   TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                orderless={"and", "or"})])
    vlosses = make_array_of_metrics("seq_acc", "tree_acc")

    beamdecoder = BeamDecoder(model, maxtime=100, beamsize=beamsize, copy_deep=True,
                              eval=[SeqAccuracies()],
                              eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                orderless={"and", "or"})])
    beamlosses = make_array_of_metrics("seq_acc", "tree_acc", "tree_acc_at_last")

    # 4. define optim
    # optim = torch.optim.Adam(trainable_params, lr=lr, weight_decay=wreg)
    optim = torch.optim.Adam(tfdecoder.parameters(), lr=lr, weight_decay=wreg)

    # lr schedule
    if cosine_restarts >= 0:
        # t_max = epochs * len(train_dl)
        t_max = epochs
        print(f"Total number of updates: {t_max}")
        lr_schedule = q.WarmupCosineWithHardRestartsSchedule(optim, 0, t_max, cycles=cosine_restarts)
        on_epoch_end = [lambda: lr_schedule.step()]
    else:
        on_epoch_end = []

    on_epoch_end.append(lambda: beta_sched.step())

    # 6. define training function
    clipgradnorm = lambda: torch.nn.utils.clip_grad_norm_(tfdecoder.parameters(), gradnorm)
    # clipgradnorm = lambda: None
    trainbatch = partial(q.train_batch, on_before_optim_step=[clipgradnorm])
    trainepoch = partial(q.train_epoch, model=tfdecoder, dataloader=ds.dataloader("train", batsize), optim=optim, losses=losses,
                         _train_batch=trainbatch, device=device, on_end=on_epoch_end)

    # 7. define validation function (using partial)
    validepoch = partial(q.test_epoch, model=freedecoder, dataloader=ds.dataloader("test", batsize), losses=vlosses, device=device)
    # validepoch = partial(q.test_epoch, model=freedecoder, dataloader=valid_dl, losses=vlosses, device=device)

    # p = q.save_run(freedecoder, localargs, filepath=__file__)
    # q.save_dataset(ds, p)
    # _freedecoder, _localargs = q.load_run(p)
    # _ds = q.load_dataset(p)
    # sys.exit()

    # 7. run training
    tt.tick("training")
    q.run_training(run_train_epoch=trainepoch, run_valid_epoch=validepoch, max_epochs=epochs)
    tt.tock("done training")

    # testing
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder, dataloader=ds.dataloader("test", batsize), losses=beamlosses, device=device)
    print("validation test results: ", testresults)
    tt.tock("tested")
    tt.tick("testing")
    testresults = q.test_epoch(model=beamdecoder, dataloader=ds.dataloader("test", batsize), losses=beamlosses, device=device)
    print("test results: ", testresults)
    tt.tock("tested")

    # save model?
    tosave = input("Save this model? 'y(es)'=Yes, <int>=overwrite previous, otherwise=No) \n>")
    # if True:
    #     overwrite = None
    if tosave.lower() == "y" or tosave.lower() == "yes" or re.match("\d+", tosave.lower()):
        overwrite = int(tosave) if re.match("\d+", tosave) else None
        p = q.save_run(model, localargs, filepath=__file__, overwrite=overwrite)
        q.save_dataset(ds, p)
        _model, _localargs = q.load_run(p)
        _ds = q.load_dataset(p)

        _freedecoder = BeamDecoder(_model, maxtime=100, beamsize=beamsize, copy_deep=True,
                                  eval=[SeqAccuracies()],
                                  eval_beam=[TreeAccuracy(tensor2tree=partial(tensor2tree, D=ds.query_encoder.vocab),
                                                          orderless={"and", "or"})])

        # testing
        tt.tick("testing reloaded")
        _testresults = q.test_epoch(model=_freedecoder, dataloader=_ds.dataloader("test", batsize),
                                    losses=beamlosses, device=device)
        print(_testresults)
        tt.tock("tested")

        # save predictions
        _, testpreds = q.eval_loop(_freedecoder, ds.dataloader("test", batsize=batsize, shuffle=False), device=device)
        testout = get_outputs_for_save(testpreds)
        _, trainpreds = q.eval_loop(_freedecoder, ds.dataloader("train", batsize=batsize, shuffle=False), device=device)
        trainout = get_outputs_for_save(trainpreds)

        with open(os.path.join(p, "trainpreds.json"), "w") as f:
            ujson.dump(trainout, f)

        with open(os.path.join(p, "testpreds.json"), "w") as f:
            ujson.dump(testout, f)
示例#13
0
def run(
        lr=0.0001,
        enclrmul=0.01,
        smoothing=0.,
        gradnorm=3,
        batsize=60,
        epochs=16,
        patience=10,
        validinter=3,
        validfrac=0.1,
        warmup=3,
        cosinelr=False,
        dataset="scan/length",
        mode="normal",  # "normal", "noinp"
        maxsize=50,
        seed=42,
        hdim=768,
        numlayers=2,
        numtmlayers=6,
        numheads=6,
        dropout=0.1,
        worddropout=0.,
        bertname="bert-base-uncased",
        testcode=False,
        userelpos=False,
        gpu=-1,
        evaltrain=False,
        trainonvalid=False,
        trainonvalidonly=False,
        recomputedata=False,
        version="v1"):

    settings = locals().copy()
    q.pp_dict(settings, indent=3)
    # wandb.init()

    # torch.backends.cudnn.enabled = False

    wandb.init(project=f"compgen_butterfly", config=settings, reinit=True)
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device("cpu") if gpu < 0 else torch.device("cuda", gpu)

    if maxsize < 0:
        if dataset.startswith("cfq"):
            maxsize = 155
        elif dataset.startswith("scan"):
            maxsize = 55
        print(f"maxsize: {maxsize}")

    tt = q.ticktock("script")
    tt.tick("data")
    trainds, validds, testds, fldic, inpdic = load_ds(dataset=dataset,
                                                      validfrac=validfrac,
                                                      bertname=bertname,
                                                      recompute=recomputedata)

    if "mcd" in dataset.split("/")[1]:
        print(f"Setting patience to -1 because MCD (was {patience})")
        patience = -1

    tt.msg(f"TRAIN DATA: {len(trainds)}")
    tt.msg(f"DEV DATA: {len(validds)}")
    tt.msg(f"TEST DATA: {len(testds)}")
    if trainonvalid:
        assert False
        trainds = trainds + validds
        validds = testds

    tt.tick("dataloaders")
    traindl = DataLoader(trainds,
                         batch_size=batsize,
                         shuffle=True,
                         collate_fn=autocollate)
    validdl = DataLoader(validds,
                         batch_size=batsize,
                         shuffle=False,
                         collate_fn=autocollate)
    testdl = DataLoader(testds,
                        batch_size=batsize,
                        shuffle=False,
                        collate_fn=autocollate)
    tt.tock()
    tt.tock()

    tt.tick("model")
    # cell = GRUDecoderCell(hdim, vocab=fldic, inpvocab=inpdic, numlayers=numlayers, dropout=dropout, worddropout=worddropout)
    # decoder = S2S(cell, vocab=fldic, max_size=maxsize, smoothing=smoothing)
    cell1 = DecoderCell(hdim,
                        vocab=fldic,
                        inpvocab=inpdic,
                        numlayers=numlayers,
                        numtmlayers=numtmlayers,
                        dropout=dropout,
                        worddropout=worddropout,
                        mode="cont",
                        numheads=numheads)
    cell2 = DecoderCell(hdim,
                        vocab=fldic,
                        inpvocab=inpdic,
                        numlayers=numlayers,
                        numtmlayers=numtmlayers,
                        dropout=dropout,
                        worddropout=worddropout,
                        mode="normal",
                        noencoder=True,
                        numheads=numheads)
    decoder = S2Z2S(cell1,
                    cell2,
                    vocab=fldic,
                    max_size=maxsize,
                    smoothing=smoothing)
    # print(f"one layer of decoder: \n {cell.decoder.block[0]}")
    print(decoder)
    tt.tock()

    if testcode:
        tt.tick("testcode")
        batch = next(iter(traindl))
        # out = tagger(batch[1])
        tt.tick("train")
        out = decoder(*batch)
        tt.tock()
        decoder.train(False)
        tt.tick("test")
        out = decoder(*batch)
        tt.tock()
        tt.tock("testcode")

    tloss = make_array_of_metrics("loss", "elemacc", "acc", reduction="mean")
    metricnames = ["treeacc"]
    tmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    vmetrics = make_array_of_metrics(*metricnames, reduction="mean")
    xmetrics = make_array_of_metrics(*metricnames, reduction="mean")

    # region parameters
    def get_parameters(m, _lr, _enclrmul):
        bertparams = []
        otherparams = []
        for k, v in m.named_parameters():
            if "encoder_model." in k:
                bertparams.append(v)
            else:
                otherparams.append(v)
        if len(bertparams) == 0:
            raise Exception("No encoder parameters found!")
        paramgroups = [{
            "params": bertparams,
            "lr": _lr * _enclrmul
        }, {
            "params": otherparams
        }]
        return paramgroups

    # endregion

    def get_optim(_m, _lr, _enclrmul, _wreg=0):
        paramgroups = get_parameters(_m, _lr=lr, _enclrmul=_enclrmul)
        optim = torch.optim.Adam(paramgroups, lr=lr, weight_decay=_wreg)
        return optim

    def clipgradnorm(_m=None, _norm=None):
        torch.nn.utils.clip_grad_norm_(_m.parameters(), _norm)

    eyt = q.EarlyStopper(vmetrics[0],
                         patience=patience,
                         min_epochs=30,
                         more_is_better=True,
                         remember_f=lambda: deepcopy(decoder.state_dict()))

    def wandb_logger():
        d = {}
        for name, loss in zip(["loss", "acc"], tloss):
            d["train_" + name] = loss.get_epoch_error()
        if evaltrain:
            for name, loss in zip(metricnames, tmetrics):
                d["train_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, vmetrics):
            d["valid_" + name] = loss.get_epoch_error()
        for name, loss in zip(metricnames, xmetrics):
            d["test_" + name] = loss.get_epoch_error()
        wandb.log(d)

    t_max = epochs
    optim = get_optim(decoder, lr, enclrmul)
    print(f"Total number of updates: {t_max} .")
    if cosinelr:
        assert t_max > (warmup + 10)
        lr_schedule = q.sched.Linear(steps=warmup) >> q.sched.Cosine(
            low=0., high=1.0, steps=t_max - warmup) >> (0. * lr)
    else:
        lr_schedule = q.sched.Linear(steps=warmup) >> 1.
    lr_schedule = q.sched.LRSchedule(optim, lr_schedule)

    trainbatch = partial(q.train_batch,
                         on_before_optim_step=[
                             lambda: clipgradnorm(_m=decoder, _norm=gradnorm)
                         ])

    if trainonvalidonly:
        traindl = validdl
        validdl = testdl

    trainepoch = partial(q.train_epoch,
                         model=decoder,
                         dataloader=traindl,
                         optim=optim,
                         losses=tloss,
                         device=device,
                         _train_batch=trainbatch,
                         on_end=[lambda: lr_schedule.step()])

    trainevalepoch = partial(q.test_epoch,
                             model=decoder,
                             losses=tmetrics,
                             dataloader=traindl,
                             device=device)

    on_end_v = [lambda: eyt.on_epoch_end(), lambda: wandb_logger()]
    validepoch = partial(q.test_epoch,
                         model=decoder,
                         losses=vmetrics,
                         dataloader=validdl,
                         device=device,
                         on_end=on_end_v)
    testepoch = partial(q.test_epoch,
                        model=decoder,
                        losses=xmetrics,
                        dataloader=testdl,
                        device=device)

    tt.tick("training")
    if evaltrain:
        validfs = [trainevalepoch, validepoch]
    else:
        validfs = [validepoch]
    validfs = validfs + [testepoch]

    q.run_training(run_train_epoch=trainepoch,
                   run_valid_epoch=validfs,
                   max_epochs=epochs,
                   check_stop=[lambda: eyt.check_stop()],
                   validinter=validinter)
    tt.tock("done training")

    tt.tick("running test before reloading")
    testres = testepoch()
    print(f"Test tree acc: {testres}")
    tt.tock("ran test")

    if eyt.remembered is not None and patience >= 0:
        tt.msg("reloading best")
        decoder.load_state_dict(eyt.remembered)

        tt.tick("rerunning validation")
        validres = validepoch()
        tt.tock(f"Validation results: {validres}")

    tt.tick("running train")
    trainres = trainevalepoch()
    print(f"Train tree acc: {trainres}")
    tt.tock()

    tt.tick("running test")
    testres = testepoch()
    print(f"test tree acc: {testres}")
    tt.tock()

    settings.update({"final_train_loss": tloss[0].get_epoch_error()})
    settings.update({"final_train_tree_acc": tmetrics[0].get_epoch_error()})
    settings.update({"final_valid_tree_acc": vmetrics[0].get_epoch_error()})
    settings.update({"final_test_tree_acc": xmetrics[0].get_epoch_error()})

    wandb.config.update(settings)
    q.pp_dict(settings)

    return decoder, testds