Beispiel #1
0
    def get_distance_from_activations(self,
                                      gen_acts,
                                      real_acts=None,
                                      eps=1e-6):
        mu1, sigma1 = self.get_activation_stats(gen_acts)
        mu2, sigma2 = self.get_activation_stats(
            real_acts) if real_acts is not None else self.real_stats

        tt = q.ticktock("scorer")
        tt.tick("computing fid")
        # from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py
        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        tt.tock("fid computed")
        return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) -
                2 * tr_covmean)
Beispiel #2
0
 def __init__(self,
              generator,
              scorers,
              gendata,
              device=torch.device("cpu"),
              logger=None,
              validinter=1,
              name="main"):
     """
     :param generator:   the generator
     :param scorers:     scorers (FID, IS, Imagesaver)
     :param gendata:     dataloader of data to feed to generator to generate images
     :param device:      device used only for batches (generator/scorers are not set to this device)
     """
     super(GeneratorValidator, self).__init__()
     self.name = name
     self.history = {}
     self.generator = generator
     self.scorers = scorers
     self.gendata = gendata
     self.device = device
     self.tt = q.ticktock("validator-{}".format(name))
     self._iter = 0
     self.logger = logger
     self.validinter = validinter
Beispiel #3
0
def run(
    lr=0.001,
    batsize=10,
    test=False,
    dopreproc=False,
):
    if dopreproc:
        from qelos_core.scripts.convai.preproc import run as run_preproc
        run_preproc()
    else:
        tt = q.ticktock("script")
        tt.tick("loading data")
        train_dataset, valid_dataset = load_datasets()
        tt.tock("loaded data")
        print(
            "{} unique words, {} training examples, {} valid examples".format(
                len(train_dataset.D), len(train_dataset), len(valid_dataset)))
        trainloader = q.dataload(train_dataset,
                                 shuffle=True,
                                 batch_size=batsize)
        validloader = q.dataload(valid_dataset,
                                 shuffle=True,
                                 batch_size=batsize)
        # test
        if test:
            testexample = train_dataset[10]
            trainloader_iter = iter(trainloader)
            tt.tick("getting 1000 batches")
            for i in range(1000):
                batch = next(iter(trainloader))
            tt.tock("got 1000 batches")

        print("done")
def load_data(p="../../../datasets/semparse/", which=None, devfrac=0.1, devfracrandom=False):
    tt = q.ticktock("dataloader")
    tt.tick("loading data")
    assert(which is not None)
    which = {"geo": "geoquery", "atis": "atis", "jobs": "jobs"}[which]
    trainp = os.path.join(p, which, "train.txt")
    testp = os.path.join(p, which, "test.txt")
    devp = os.path.join(p, which, "dev.txt")

    trainlines = open(trainp).readlines()
    testlines = open(testp).readlines()

    if not os.path.exists(devp):
        tt.msg("no dev file, taking {} from training data".format(devfrac))
        splitidx = round(len(trainlines)*devfrac)
        trainlines = trainlines[:-splitidx]
        devlines = trainlines[-splitidx:]
    else:
        devlines = open(devp).readlines()

    tt.msg("{} examples in training set".format(len(trainlines)))
    tt.msg("{} examples in dev set".format(len(devlines)))
    tt.msg("{} examples in test set".format(len(testlines)))

    nlsm = q.StringMatrix(freqcutoff=1)
    nlsm.tokenize = lambda x: x.strip().split()
    qlsm = q.StringMatrix(indicate_start_end=True, freqcutoff=1)
    qlsm.tokenize = lambda x: x.strip().split()

    i = 0
    for line in trainlines:
        nl, ql = line.split("\t")
        nlsm.add(nl)
        qlsm.add(ql)
        i += 1

    nlsm.unseen_mode = True
    qlsm.unseen_mode = True

    devstart = i

    for line in devlines:
        nl, ql = line.split("\t")
        nlsm.add(nl)
        qlsm.add(ql)
        i += 1

    teststart = i

    for line in testlines:
        nl, ql = line.split("\t")
        nlsm.add(nl)
        qlsm.add(ql)

    nlsm.finalize()
    qlsm.finalize()
    tt.tock("data loaded")

    return nlsm, qlsm, (devstart, teststart)
Beispiel #5
0
def run(lr=20.,
        dropout=0.2,
        dropconnect=0.2,
        gradnorm=0.25,
        epochs=25,
        embdim=200,
        encdim=200,
        numlayers=2,
        seqlen=35,
        batsize=20,
        eval_batsize=10,
        cuda=False,
        gpu=0,
        test=False):
    tt = q.ticktock("script")
    device = torch.device("cpu")
    if cuda:
        device = torch.device("cuda", gpu)
    tt.tick("loading data")
    train_batches, valid_batches, test_batches, D = \
        load_data(batsize=batsize, eval_batsize=eval_batsize, seqlen=seqlen)
    tt.tock("data loaded")
    print("{} batches in train".format(len(train_batches)))

    tt.tick("creating model")
    dims = [embdim] + ([encdim] * numlayers)
    m = RNNLayer_LM(*dims, worddic=D, dropout=dropout)

    if test:
        for i, batch in enumerate(train_batches):
            y = m(batch[0])
            if i > 5:
                break
        print(y.size())

    loss = q.SeqKLLoss(time_average=True, size_average=True, mode="logits")
    ppl_loss = q.SeqPPL_Loss(time_average=True,
                             size_average=True,
                             mode="logits")

    optim = torch.optim.SGD(q.params_of(m), lr=lr)
    gradclip = q.ClipGradNorm(gradnorm)

    trainer = q.trainer(m).on(train_batches).loss(loss).optimizer(
        optim).device(device).hook(m).hook(gradclip)
    tester = q.tester(m).on(valid_batches).loss(
        loss, ppl_loss).device(device).hook(m)

    tt.tock("created model")
    tt.tick("training")
    q.train(trainer, tester).run(epochs=epochs)
    tt.tock("trained")
Beispiel #6
0
 def __init__(self,
              img_shape,
              nhood_size=7,
              nhoods_per_image=128,
              dir_repeats=4,
              dirs_per_repeat=128):
     super(SlicedWassersteinDistance, self).__init__()
     self.impl = SWD_np(img_shape,
                        nhood_size=nhood_size,
                        nhoods_per_image=nhoods_per_image,
                        dir_repeats=dir_repeats,
                        dirs_per_repeat=dirs_per_repeat)
     self.tt = q.ticktock("SWD")
Beispiel #7
0
 def __init__(self,
              image_shape,
              nhood_size=7,
              nhoods_per_image=128,
              dir_repeats=4,
              dirs_per_repeat=128):
     self.nhood_size = nhood_size
     self.nhoods_per_image = nhoods_per_image
     self.dir_repeats = dir_repeats
     self.dirs_per_repeat = dirs_per_repeat
     self.resolutions = []
     res = image_shape[1]
     while res >= 16:
         self.resolutions.append(res)
         res //= 2
     self.tt = q.ticktock()
Beispiel #8
0
    def get_scores_from_probs(self, allprobs):
        tt = q.ticktock("scorer")

        tt.tick("calculating scores")

        scores = []
        splits = self.splits
        for i in range(splits):
            part = allprobs[(i * allprobs.shape[0] //
                             splits):((i + 1) * allprobs.shape[0] //
                                      splits), :]
            part_means = np.expand_dims(np.mean(part, 0), 0)
            kl = part * (np.log(part) - np.log(part_means))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))

        tt.tock("calculated scores")
        return np.mean(scores), np.std(scores)
Beispiel #9
0
def load_data(p="../../datasets/simplequestions/"):
    tt = q.ticktock("dataloader")
    tt.tick("loading")
    questions, subjects, subject_names, relations, spans, (start_valid, start_test) \
        = load_questions(p)
    generate_candidates(p)
    tt.tock("{} questions loaded".format(len(questions)))

    tt.tick("generating matrices")
    qsm = q.StringMatrix(freqcutoff=2)
    qsm.tokenize = lambda x: x.split()
    for question in tqdm.tqdm(questions[:start_valid]):
        qsm.add(question)
    qsm.unseen_mode = True
    for question in tqdm.tqdm(questions[start_valid:]):
        qsm.add(question)
    tt.msg("finalizing")
    qsm.finalize()
    print(qsm[0])
    q.embed()
    tt.tock("matrices generated")
Beispiel #10
0
 def get_inception_outs(self, data):  # dataloader
     tt = q.ticktock("inception")
     tt.tick("running data through network")
     probses = []
     activationses = []
     for i, batch in enumerate(data):
         batch = (batch, ) if not q.issequence(batch) else batch
         batch = [
             torch.tensor(batch_e).to(self.device) for batch_e in batch
         ]
         probs, activations = self.inception(*batch)
         probs = torch.nn.functional.softmax(probs)
         probses.append(probs.detach())
         activationses.append(activations.detach())
         tt.live("{}/{}".format(i, len(data)))
     tt.stoplive()
     tt.tock("done")
     probses = torch.cat(probses, 0)
     activationses = torch.cat(activationses, 0)
     return probses.cpu().detach().numpy(), activationses.cpu().detach(
     ).numpy()
Beispiel #11
0
    def do_epoch(self, tt=q.ticktock("-")):
        self.stop_training = self.current_epoch + 1 == self.max_epochs
        self.losses.push_and_reset(epoch=self.current_epoch - 1)
        # tt.tick()
        self.do_callbacks(self.START_EPOCH)
        self.do_callbacks(self.START_TRAIN)

        for i, _batch in enumerate(self.dataloader):
            ttmsg = self.do_batch(_batch, i=i)
            tt.live(ttmsg)

        tt.stoplive()
        self.do_callbacks(self.END_TRAIN)
        ttmsg = "Epoch {}/{} -- train: {}"\
            .format(
                self.current_epoch+1,
                self.max_epochs,
                self.losses.pp()
            )
        # tt.tock(ttmsg)
        self.do_callbacks(self.END_EPOCH)
        self.current_epoch += 1
        return ttmsg
Beispiel #12
0
 def runloop(self, validinter=1, print_on_valid_only=False):
     tt = q.ticktock("runner")
     self.do_callbacks(self.START)
     validinter_count = 0
     while self.trainer.stop_training is not True:
         tt.tick()
         self.do_callbacks(self.START_EPOCH)
         self.do_callbacks(self.START_TRAIN)
         self.trainer.do_epoch()
         ttmsg = "Epoch {}/{} -- train: {}" \
             .format(
             self.trainer.current_epoch,
             self.trainer.max_epochs,
             self.trainer.losses.pp()
         )
         self.do_callbacks(self.END_TRAIN)
         validepoch = False
         if self.validator is not None and validinter_count % validinter == 0:
             self.do_callbacks(self.START_VALID)
             if isinstance(self.validator, tester):
                 self.validator.do_epoch(self.trainer.current_epoch,
                                         self.trainer.max_epochs)
                 ttmsg += " -- {}" \
                     .format(self.validator.losses.pp())
             else:
                 toprint = self.validator()
                 ttmsg += " -- {}".format(toprint)
             self.do_callbacks(self.END_VALID)
             validepoch = True
         self.do_callbacks(self.END_EPOCH)
         validinter_count += 1
         if not print_on_valid_only or validepoch:
             tt.tock(ttmsg)
             if self._logger is not None:
                 self._logger.liner_write("losses.txt", ttmsg)
     self.do_callbacks(self.END)
def run(lr=0.0001,
        batsize=64,
        epochs=100000,
        lamda=10,
        disciters=5,
        burnin=-1,
        validinter=1000,
        devinter=100,
        cuda=False,
        gpu=0,
        z_dim=128,
        test=False,
        dim_d=128,
        dim_g=128,
        vggversion=13,
        vgglayer=9,
        vggvanilla=False,           # if True, makes trainable feature transform
        extralayers=False,          # adds a couple extra res blocks to generator to match added VGG
        pixelpenalty=False,         # if True, uses penalty based on pixel-wise interpolate
        inceptionpath="/data/lukovnik/",
        normalwgan=False,
        ):
        # vggvanilla=True and pixelpenalty=True makes a normal WGAN

    settings = locals().copy()
    logger = q.log.Logger(prefix="wgan_resnet_cifar_feat")
    logger.save_settings(**settings)

    burnin = disciters if burnin == -1 else burnin

    if test:
        validinter=10
        burnin=1
        batsize=2
        devinter = 1

    tt = q.ticktock("script")

    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)

    tt.tick("creating networks")
    if not normalwgan:
        print("doing wgan-feat")
        gen = OldGenerator(z_dim, dim_g, extra_layers=extralayers).to(device)
        inpd = get_vgg_outdim(vggversion, vgglayer)
        crit = ReducedDiscriminator(inpd, dim_d).to(device)
        subvgg = SubVGG(vggversion, vgglayer, pretrained=not vggvanilla)
    else:
        print("doing normal wgan")
        gen = OldGenerator(z_dim, dim_g, extra_layers=False).to(device)
        crit = OldDiscriminator(dim_d).to(device)
        subvgg = None
    tt.tock("created networks")

    # test
    # z = torch.randn(3, z_dim).to(device)
    # x = gen(z)
    # s = crit(x)

    # data
    # load cifar
    tt.tick("loading data")
    traincifar, testcifar = load_cifar_dataset(train=True), load_cifar_dataset(train=False)
    print(len(traincifar), len(testcifar))

    gen_data_d = q.gan.gauss_dataset(z_dim, len(traincifar))
    disc_data = q.datacat([traincifar, gen_data_d], 1)

    gen_data = q.gan.gauss_dataset(z_dim)
    gen_data_valid = q.gan.gauss_dataset(z_dim, 50000)

    swd_gen_data = q.gan.gauss_dataset(z_dim, 10000)
    swd_real_data = []
    swd_shape = traincifar[0].size()
    for i in range(10000):
        swd_real_data.append(testcifar[i])
    swd_real_data = torch.stack(swd_real_data, 0)

    disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True)
    gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True)
    gen_data_valid = q.dataload(gen_data_valid, batch_size=batsize, shuffle=False)
    validcifar_loader = q.dataload(testcifar, batch_size=batsize, shuffle=False)

    swd_batsize = 64
    swd_gen_data = q.dataload(swd_gen_data, batch_size=swd_batsize, shuffle=False)
    swd_real_data = q.dataload(swd_real_data, batch_size=swd_batsize, shuffle=False)

    dev_data_gauss = q.gan.gauss_dataset(z_dim, len(testcifar))
    dev_disc_data = q.datacat([testcifar, dev_data_gauss], 1)
    dev_disc_data = q.dataload(dev_disc_data, batch_size=batsize, shuffle=False)
    # q.embed()
    tt.tock("loaded data")

    if not normalwgan:
        disc_model = q.gan.WGAN_F(crit, gen, subvgg, lamda=lamda, pixel_penalty=pixelpenalty).disc_train()
        gen_model = q.gan.WGAN_F(crit, gen, subvgg, lamda=lamda, pixel_penalty=pixelpenalty).gen_train()
    else:
        disc_model = q.gan.WGAN(crit, gen, lamda=lamda).disc_train()
        gen_model = q.gan.WGAN(crit, gen, lamda=lamda).gen_train()

    disc_params = q.params_of(crit)
    if vggvanilla and not normalwgan:
        disc_params += q.params_of(subvgg)
    disc_optim = torch.optim.Adam(disc_params, lr=lr, betas=(0.5, 0.9))
    gen_optim = torch.optim.Adam(q.params_of(gen), lr=lr, betas=(0.5, 0.9))

    disc_bt = UnquantizeTransform()

    disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(3).device(device)\
        .set_batch_transformer(lambda a, b: (disc_bt(a), b))
    gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(1).device(device)

    # fidandis = q.gan.FIDandIS(device=device)
    tfis = q.gan.tfIS(inception_path=inceptionpath, gpu=gpu)
    # if not test:
    #     fidandis.set_real_stats_with(validcifar_loader)
    saver = q.gan.GenDataSaver(logger, "saved.npz")
    generator_validator = q.gan.GeneratorValidator(gen, [tfis, saver], gen_data_valid, device=device,
                                         logger=logger, validinter=validinter)

    train_validator = q.tester(disc_model).on(dev_disc_data).loss(3).device(device)\
        .set_batch_transformer(lambda a, b: (disc_bt(a), b))

    train_validator.validinter = devinter

    tt.tick("initializing SWD")
    swd = q.gan.SlicedWassersteinDistance(swd_shape)
    swd.prepare_reals(swd_real_data)
    tt.tock("SWD initialized")

    swd_validator = q.gan.GeneratorValidator(gen, [swd], swd_gen_data, device=device,
                                             logger=logger, validinter=validinter, name="swd")

    tt.tick("training")
    gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer,
                                   validators=(generator_validator, train_validator, swd_validator),
                                   lr_decay=True)

    gan_trainer.run(epochs, disciters=disciters, geniters=1, burnin=burnin)
    tt.tock("trained")
Beispiel #14
0
def load_jsons(datap="../../../datasets/lcquad/newdata.json",
               relp="../../../datasets/lcquad/nrelations.json",
               mode="flat"):
    tt = q.ticktock("data loader")
    tt.tick("loading jsons")

    data = json.load(open(datap))
    rels = json.load(open(relp))

    tt.tock("jsons loaded")

    tt.tick("extracting data")
    questions = []
    goldchains = []
    badchains = []
    for dataitem in data:
        questions.append(dataitem["parsed-data"]["corrected_question"])
        goldchain = []
        for x in dataitem["parsed-data"]["path_id"]:
            goldchain += [x[0], int(x[1:])]
        goldchains.append(goldchain)
        badchainses = []
        goldfound = False
        for badchain in dataitem["uri"]["hop-1-properties"] + dataitem["uri"][
                "hop-2-properties"]:
            if goldchain == badchain:
                goldfound = True
            else:
                if len(badchain) == 2:
                    badchain += [-1, -1]
                badchainses.append(badchain)
        badchains.append(badchainses)

    tt.tock("extracted data")

    tt.msg("mode: {}".format(mode))

    if mode == "flat":
        tt.tick("flattening")

        def flatten_chain(chainspec):
            flatchainspec = []
            for x in chainspec:
                if x in ("+", "-"):
                    flatchainspec.append(x)
                elif x > -1:
                    relwords = rels[str(x)]
                    flatchainspec += relwords
                elif x == -1:
                    pass
                else:
                    raise q.SumTingWongException("unexpected symbol in chain")
            return " ".join(flatchainspec)

        goldchainids = []
        badchainsids = []

        uniquechainids = {}

        qsm = q.StringMatrix()
        csm = q.StringMatrix()
        csm.tokenize = lambda x: x.lower().strip().split()

        def get_ensure_chainid(flatchain):
            if flatchain not in uniquechainids:
                uniquechainids[flatchain] = len(uniquechainids)
                csm.add(flatchain)
                assert (len(csm) == len(uniquechainids))
            return uniquechainids[flatchain]

        eid = 0
        numchains = 0
        for question, goldchain, badchainses in zip(questions, goldchains,
                                                    badchains):
            qsm.add(question)
            # flatten gold chain
            flatgoldchain = flatten_chain(goldchain)
            chainid = get_ensure_chainid(flatgoldchain)
            goldchainids.append(chainid)
            badchainsids.append([])
            numchains += 1
            for badchain in badchainses:
                flatbadchain = flatten_chain(badchain)
                chainid = get_ensure_chainid(flatbadchain)
                badchainsids[eid].append(chainid)
                numchains += 1
            eid += 1
            tt.live("{}".format(eid))

        assert (len(badchainsids) == len(questions))
        tt.stoplive()
        tt.msg("{} unique chains from {} total".format(len(csm), numchains))
        qsm.finalize()
        csm.finalize()
        tt.tock("flattened")
        csm.tokenize = None
        return qsm, csm, goldchainids, badchainsids
    else:
        raise q.SumTingWongException("unsupported mode: {}".format(mode))
Beispiel #15
0
def run(
    lr=0.0001,
    batsize=64,
    epochs=100000,
    lamda=10,
    disciters=5,
    burnin=-1,
    validinter=1000,
    devinter=100,
    cuda=False,
    gpu=0,
    z_dim=128,
    test=False,
    dim_d=128,
    dim_g=128,
):

    settings = locals().copy()
    logger = q.log.Logger(prefix="wgan_resnet_cifar")
    logger.save_settings(**settings)
    print("started")

    burnin = disciters if burnin == -1 else burnin

    if test:
        validinter = 10
        burnin = 1
        batsize = 2
        devinter = 1

    tt = q.ticktock("script")

    device = torch.device("cpu") if not cuda else torch.device("cuda", gpu)

    tt.tick("creating networks")
    gen = OldGenerator(z_dim, dim_g).to(device)
    crit = OldDiscriminator(dim_d).to(device)
    tt.tock("created networks")

    # test
    # z = torch.randn(3, z_dim).to(device)
    # x = gen(z)
    # s = crit(x)

    # data
    # load cifar
    tt.tick("loading data")
    traincifar, testcifar = load_cifar_dataset(train=True), load_cifar_dataset(
        train=False)
    print(len(traincifar))

    gen_data_d = q.gan.gauss_dataset(z_dim, len(traincifar))
    disc_data = q.datacat([traincifar, gen_data_d], 1)

    gen_data = q.gan.gauss_dataset(z_dim)
    gen_data_valid = q.gan.gauss_dataset(z_dim, 50000)

    disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True)
    gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True)
    gen_data_valid = q.dataload(gen_data_valid,
                                batch_size=batsize,
                                shuffle=False)
    validcifar_loader = q.dataload(testcifar,
                                   batch_size=batsize,
                                   shuffle=False)

    dev_data_gauss = q.gan.gauss_dataset(z_dim, len(testcifar))
    dev_disc_data = q.datacat([testcifar, dev_data_gauss], 1)
    dev_disc_data = q.dataload(dev_disc_data,
                               batch_size=batsize,
                               shuffle=False)
    # q.embed()
    tt.tock("loaded data")

    disc_model = q.gan.WGAN(crit, gen, lamda=lamda).disc_train()
    gen_model = q.gan.WGAN(crit, gen, lamda=lamda).gen_train()

    disc_optim = torch.optim.Adam(q.params_of(crit), lr=lr, betas=(0.5, 0.9))
    gen_optim = torch.optim.Adam(q.params_of(gen), lr=lr, betas=(0.5, 0.9))

    disc_bt = UnquantizeTransform()

    disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(3).device(device)\
        .set_batch_transformer(lambda a, b: (disc_bt(a), b))
    gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(
        1).device(device)

    fidandis = q.gan.FIDandIS(device=device)
    if not test:
        fidandis.set_real_stats_with(validcifar_loader)
    saver = q.gan.GenDataSaver(logger, "saved.npz")
    generator_validator = q.gan.GeneratorValidator(gen, [fidandis, saver],
                                                   gen_data_valid,
                                                   device=device,
                                                   logger=logger,
                                                   validinter=validinter)

    train_validator = q.tester(disc_model).on(dev_disc_data).loss(3).device(device)\
        .set_batch_transformer(lambda a, b: (disc_bt(a), b))

    train_validator.validinter = devinter

    tt.tick("training")
    gan_trainer = q.gan.GANTrainer(disc_trainer,
                                   gen_trainer,
                                   validators=(generator_validator,
                                               train_validator),
                                   lr_decay=True)

    gan_trainer.run(epochs, disciters=disciters, geniters=1, burnin=burnin)
    tt.tock("trained")
Beispiel #16
0
def run(lr=OPT_LR,
        batsize=100,
        epochs=1000,
        validinter=20,
        wreg=0.00000000001,
        dropout=0.1,
        embdim=50,
        encdim=50,
        numlayers=1,
        cuda=False,
        gpu=0,
        mode="flat",
        test=False,
        gendata=False):
    if gendata:
        loadret = load_jsons()
        pickle.dump(loadret,
                    open("loadcache.flat.pkl", "w"),
                    protocol=pickle.HIGHEST_PROTOCOL)
    else:
        settings = locals().copy()
        logger = q.Logger(prefix="rank_lstm")
        logger.save_settings(**settings)

        device = torch.device("cpu")
        if cuda:
            device = torch.device("cuda", gpu)

        tt = q.ticktock("script")

        # region DATA
        tt.tick("loading data")
        qsm, csm, goldchainids, badchainids = pickle.load(
            open("loadcache.{}.pkl".format(mode)))
        eids = np.arange(0, len(goldchainids))

        data = [qsm.matrix, eids]
        traindata, validdata = q.datasplit(data, splits=(7, 3), random=False)
        validdata, testdata = q.datasplit(validdata,
                                          splits=(1, 2),
                                          random=False)

        trainloader = q.dataload(*traindata, batch_size=batsize, shuffle=True)

        input_feeder = FlatInpFeeder(csm.matrix, goldchainids, badchainids)

        def inp_bt(_qsm_batch, _eids_batch):
            golds_batch, bads_batch = input_feeder(_eids_batch)
            return _qsm_batch, golds_batch, bads_batch

        if test:
            # test input feeder
            eids = q.var(torch.arange(0, 10).long()).v
            _test_golds_batch, _test_bads_batch = input_feeder(eids)
        tt.tock("data loaded")
        # endregion

        # region MODEL
        dims = [encdim // 2] * numlayers

        question_encoder = FlatEncoder(embdim, dims, qsm.D, bidir=True)
        query_encoder = FlatEncoder(embdim, dims, csm.D, bidir=True)
        similarity = DotDistance()

        rankmodel = RankModel(question_encoder, query_encoder, similarity)
        scoremodel = ScoreModel(question_encoder, query_encoder, similarity)
        # endregion

        # region VALIDATION
        rankcomp = RankingComputer(scoremodel, validdata[1], validdata[0],
                                   csm.matrix, goldchainids, badchainids)
        # endregion

        # region TRAINING
        optim = torch.optim.Adam(q.params_of(rankmodel),
                                 lr=lr,
                                 weight_decay=wreg)
        trainer = q.trainer(rankmodel).on(trainloader).loss(1)\
                   .set_batch_transformer(inp_bt).optimizer(optim).device(device)

        def validation_function():
            rankmetrics = rankcomp.compute(RecallAt(1, totaltrue=1),
                                           RecallAt(5, totaltrue=1), MRR())
            ret = []
            for rankmetric in rankmetrics:
                rankmetric = np.asarray(rankmetric)
                ret_i = rankmetric.mean()
                ret.append(ret_i)
            return "valid: " + " - ".join(["{:.4f}".format(x) for x in ret])

        q.train(trainer, validation_function).run(epochs,
                                                  validinter=validinter)
Beispiel #17
0
    def loadvalue(self,
                  path,
                  dim,
                  indim=None,
                  worddic=None,
                  maskid=True,
                  rareid=True):
        # TODO: nonstandard mask and rareid?
        tt = q.ticktock(self.__class__.__name__)
        tt.tick()
        # load weights
        if path not in self.loadcache:
            W = np.load(path + ".npy")
        else:
            W = self.loadcache[path][0]
        tt.tock("vectors loaded")

        # load words
        tt.tick()
        if path not in self.loadcache:
            words = json.load(open(path + ".words"))
        else:
            words = self.loadcache[path][1]
        tt.tock("words loaded")

        # cache
        if self.useloadcache:
            self.loadcache[path] = (W, words)

        # select
        if indim is not None:
            W = W[:indim, :]

        if rareid:
            W = np.concatenate([np.zeros_like(W[0, :])[np.newaxis, :], W],
                               axis=0)
        if maskid:
            W = np.concatenate([np.zeros_like(W[0, :])[np.newaxis, :], W],
                               axis=0)

        tt.tick()

        # dictionary
        D = OrderedDict()
        i = 0
        if maskid is not None:
            D[self.masktoken] = i
            i += 1
        if rareid is not None:
            D[self.raretoken] = i
            i += 1
        wordset = set(words)
        for j, word in enumerate(words):
            if indim is not None and j >= indim:
                break
            if word.lower() not in wordset and self.trylowercase:
                word = word.lower()
            D[word] = i
            i += 1
        tt.tock("dictionary created")

        if worddic is not None:
            vocsize = max(worddic.values()) + 1
            new_weight = np.zeros((vocsize, W.shape[1]), dtype=W.dtype)
            new_dic = {}
            for k, v in worddic.items():
                if k in D:
                    new_weight[v, :] = W[D[k], :]
                    new_dic[k] = v

            W = new_weight
            D = new_dic

        return W, D
Beispiel #18
0
    def runloop(self, iters, disciters=1, geniters=1, burnin=10):
        tt = q.ticktock("gan runner")
        self.do_callbacks(self.START)
        current_iter = 0
        disc_batch_iter = self.disc_trainer.inf_batches(with_info=False)
        gen_batch_iter = self.gen_trainer.inf_batches(with_info=False)

        lr_decay_disc, lr_decay_gen = None, None
        if self.lr_decay:
            lr_decay_disc = torch.optim.lr_scheduler.LambdaLR(
                self.disc_trainer.optim,
                lr_lambda=lambda ep: max(0, 1. - ep * 1. / iters))
            lr_decay_gen = torch.optim.lr_scheduler.LambdaLR(
                self.gen_trainer.optim,
                lr_lambda=lambda ep: max(0, 1. - ep * 1. / iters))

        while self.stop_training is not True:
            tt.tick()
            self.do_callbacks(self.START_EPOCH)
            self.do_callbacks(self.START_TRAIN)
            self.do_callbacks(self.START_DISC)

            if lr_decay_disc is not None:
                lr_decay_disc.step()
            if lr_decay_gen is not None:
                lr_decay_gen.step()

            _disciters = burnin if current_iter == 0 else disciters

            for disc_iter in range(_disciters):  # discriminator iterations
                batch = next(disc_batch_iter)
                self.disc_trainer.do_batch(batch)
                ttmsg = "iter {}/{} - disc: {}/{} :: {}".format(
                    current_iter, iters, disc_iter + 1, _disciters,
                    self.disc_trainer.losses.pp())
                tt.live(ttmsg)
            tt.stoplive()
            self.do_callbacks(self.END_DISC)
            self.do_callbacks(self.START_GEN)
            for gen_iter in range(geniters):  # generator iterations
                batch = next(gen_batch_iter)
                self.gen_trainer.do_batch(batch)
                ttmsg = "iter {}/{} - gen: {}/{} :: {}".format(
                    current_iter, iters, gen_iter + 1, geniters,
                    self.gen_trainer.losses.pp())
                tt.live(ttmsg)
            tt.stoplive()
            self.do_callbacks(self.END_GEN)
            ttmsg = "iter {}/{} - disc: {} - gen: {}".format(
                current_iter, iters, self.disc_trainer.losses.pp(),
                self.gen_trainer.losses.pp())
            self.disc_trainer.losses.push_and_reset()
            self.gen_trainer.losses.push_and_reset()
            self.do_callbacks(self.END_TRAIN)

            if self.validators is not None:
                for validator in self.validators:
                    if current_iter % validator.validinter == 0:
                        self.do_callbacks(self.START_VALID)
                        if isinstance(validator, q.tester):
                            validator.do_epoch()
                            ttmsg += " -- {}".format(validator.losses.pp())
                        else:
                            toprint = validator(iter=current_iter)
                            ttmsg += " -- {}".format(toprint)
                        self.do_callbacks(self.END_VALID)
            self.do_callbacks(self.END_EPOCH)
            tt.tock(ttmsg)
            current_iter += 1
            self.stop_training = current_iter >= iters
        self.do_callbacks(self.END)