def run( lr=0.001, batsize=10, test=False, dopreproc=False, ): if dopreproc: from qelos_core.scripts.convai.preproc import run as run_preproc run_preproc() else: tt = q.ticktock("script") tt.tick("loading data") train_dataset, valid_dataset = load_datasets() tt.tock("loaded data") print( "{} unique words, {} training examples, {} valid examples".format( len(train_dataset.D), len(train_dataset), len(valid_dataset))) trainloader = q.dataload(train_dataset, shuffle=True, batch_size=batsize) validloader = q.dataload(valid_dataset, shuffle=True, batch_size=batsize) # test if test: testexample = train_dataset[10] trainloader_iter = iter(trainloader) tt.tick("getting 1000 batches") for i in range(1000): batch = next(iter(trainloader)) tt.tock("got 1000 batches") print("done")
def run(lr=0.001): x = np.random.random((1000, 5)).astype("float32") y = np.random.randint(0, 5, (1000, )).astype("int64") trainloader = q.dataload(x[:800], y[:800], batch_size=100) validloader = q.dataload(x[800:], y[800:], batch_size=100) m = torch.nn.Sequential(torch.nn.Linear(5, 100), torch.nn.Linear(100, 5)) m[1].weight.requires_grad = False losses = q.lossarray(torch.nn.CrossEntropyLoss()) params = m.parameters() for param in params: print(param.requires_grad) init_val = m[1].weight.detach().numpy() optim = torch.optim.Adam(q.params_of(m), lr=lr) trainer = q.trainer(m).on(trainloader).loss(losses).optimizer( optim).epochs(100) # for b, (i, e) in trainer.inf_batches(): # print(i, e) validator = q.tester(m).on(validloader).loss(losses) q.train(trainer, validator).run() new_val = m[1].weight.detach().numpy() print(np.linalg.norm(new_val - init_val))
def run_toy(lr=0.001, seqlen=8, batsize=10, epochs=1000, embdim=32, innerdim=64, z_dim=32, noaccumulate=False, usebase=False, ): # generate some toy data N = 1000 data, vocab = gen_toy_data(N, seqlen=seqlen, mode="copymiddlefixed") datasm = q.StringMatrix() datasm.set_dictionary(vocab) datasm.tokenize = lambda x: list(x) for data_e in data: datasm.add(data_e) datasm.finalize() real_data = q.dataset(datasm.matrix) gen_data_d = q.gan.gauss_dataset(z_dim, len(real_data)) disc_data = q.datacat([real_data, gen_data_d], 1) gen_data = q.gan.gauss_dataset(z_dim) disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True) gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True) discriminator = Discriminator(datasm.D, embdim, innerdim) generator = Decoder(datasm.D, embdim, z_dim, "<START>", innerdim, maxtime=seqlen) SeqGAN = SeqGAN_Base if usebase else SeqGAN_DCL disc_model = SeqGAN(discriminator, generator, gan_mode=q.gan.GAN.DISC_TRAIN, accumulate=not noaccumulate) gen_model = SeqGAN(discriminator, generator, gan_mode=q.gan.GAN.GEN_TRAIN, accumulate=not noaccumulate) disc_optim = torch.optim.Adam(q.params_of(discriminator), lr=lr) gen_optim = torch.optim.Adam(q.params_of(generator), lr=lr) disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(q.no_losses(2)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(q.no_losses(2)) gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer) gan_trainer.run(epochs, disciters=5, geniters=1, burnin=500) # print some predictions: with torch.no_grad(): rvocab = {v: k for k, v in vocab.items()} q.batch_reset(generator) eval_z = torch.randn(50, z_dim) eval_y, _ = generator(eval_z) for i in range(len(eval_y)): prow = "".join([rvocab[mij] for mij in eval_y[i].numpy()]) print(prow) print("done")
def run_cond_toy(lr=0.001, seqlen=8, batsize=10, epochs=1000, embdim=5, innerdim=32, z_dim=5, usebase=False, nrexamples=1000): data, vocab = gen_toy_data(nrexamples, seqlen=seqlen, mode="twointerleaveboth") datasm = q.StringMatrix() datasm.set_dictionary(vocab) datasm.tokenize = lambda x: list(x) for data_e in data: datasm.add(data_e) datasm.finalize() real_data = q.dataset(datasm.matrix) shuffled_datasm_matrix = datasm.matrix + 0 np.random.shuffle(shuffled_datasm_matrix) fake_data = q.dataset(shuffled_datasm_matrix) disc_data = q.datacat([real_data, fake_data], 1) gen_data = q.dataset(datasm.matrix) disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True) gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True) discr = Discriminator(datasm.D, embdim, innerdim) decoder = Decoder_Cond(datasm.D, embdim, z_dim, "<START>", innerdim) disc_model = SeqGAN_Cond(discr, decoder, gan_mode=q.gan.GAN.DISC_TRAIN) gen_model = SeqGAN_Cond(discr, decoder, gan_mode=q.gan.GAN.GEN_TRAIN) disc_optim = torch.optim.Adam(q.params_of(discr), lr=lr) gen_optim = torch.optim.Adam(q.params_of(decoder), lr=lr) disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(q.no_losses(2)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(q.no_losses(2)) gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer) gan_trainer.run(epochs, disciters=5, geniters=1, burnin=500) with torch.no_grad(): rvocab = {v: k for k, v in vocab.items()} q.batch_reset(decoder) eval_z = torch.tensor(datasm.matrix[:50]) eval_y, _, _, _ = decoder(eval_z) for i in range(len(eval_y)): prow = "".join([rvocab[mij] for mij in eval_y[i].numpy()]) print(prow) print("done")
def __call__(self, iter=None): iter = self._iter if iter is None else iter self.generator.eval() with torch.no_grad(): # collect generated images generated = [] self.tt.tick("running generator") for i, batch in enumerate(self.gendata): batch = (batch, ) if not q.issequence(batch) else batch batch = [ torch.tensor(batch_e).to(self.device) for batch_e in batch ] _gen = self.generator(*batch).detach().cpu() _gen = _gen[0] if q.issequence(_gen) else _gen generated.append(_gen) self.tt.live("{}/{}".format(i, len(self.gendata))) batsize = max(map(len, generated)) generated = torch.cat(generated, 0) self.tt.tock("generated data") gen_loaded = q.dataload(generated, batch_size=batsize, shuffle=False) rets = [iter] for scorer in self.scorers: ret = scorer(gen_loaded) if ret is not None: rets.append(ret) if self.logger is not None: self.logger.liner_write("validator-{}.txt".format(self.name), " ".join(map(str, rets))) self._iter += 1 return " ".join(map(str, rets[1:]))
def tst_inception_cifar10(cuda=False, gpu=1, batsize=32): class IgnoreLabelDataset(torch.utils.data.Dataset): def __init__(self, orig): self.orig = orig def __getitem__(self, index): return self.orig[index][0] def __len__(self): return len(self.orig) cifar = dset.CIFAR10(root='../datasets/cifar/', download=True, train=True, transform=transforms.Compose([ transforms.Scale(32), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])) device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) print(device, cuda) cifar = IgnoreLabelDataset(cifar) cifar_loader = q.dataload(cifar, batch_size=batsize) scorer = q.gan.FIDandIS(device=device) print(scorer.inception.training) scorer.set_real_stats_with(cifar_loader) print("Calculating FID and IS ... ") scores = scorer.get_scores(cifar_loader) print(scores)
def run(lr=0.001): # data x = torch.randn(1000, 5, 5) real_data = q.dataset(x) gen_data_d = q.gan.gauss_dataset(10, len(real_data)) disc_data = q.datacat([real_data, gen_data_d], 1) gen_data = q.gan.gauss_dataset(10) disc_data = q.dataload(disc_data, batch_size=20, shuffle=True) gen_data = q.dataload(gen_data, batch_size=20, shuffle=True) next(iter(disc_data)) # models class Generator(torch.nn.Module): def __init__(self): super(Generator, self).__init__() self.lin1 = torch.nn.Linear(10, 20) self.lin2 = torch.nn.Linear(20, 25) def forward(self, z): ret = self.lin1(z) ret = torch.nn.functional.sigmoid(ret) ret = self.lin2(ret) ret = torch.nn.functional.sigmoid(ret) ret = ret.view(z.size(0), 5, 5) return ret class Discriminator(torch.nn.Module): def __init__(self): super(Discriminator, self).__init__() self.lin1 = torch.nn.Linear(25, 20) self.lin2 = torch.nn.Linear(20, 10) self.lin3 = torch.nn.Linear(10, 1) def forward(self, x): x = x.view(x.size(0), -1) ret = self.lin1(x) ret = torch.nn.functional.sigmoid(ret) ret = self.lin2(ret) ret = torch.nn.functional.sigmoid(ret) ret = self.lin3(ret) ret = torch.nn.functional.sigmoid(ret) ret = ret.squeeze(1) return ret discriminator = Discriminator() generator = Generator() disc_model = q.gan.GAN(discriminator, generator, gan_mode=q.gan.GAN.DISC_TRAIN) gen_model = q.gan.GAN(discriminator, generator, gan_mode=q.gan.GAN.GEN_TRAIN) disc_optim = torch.optim.Adam(q.params_of(discriminator), lr=lr) gen_optim = torch.optim.Adam(q.params_of(generator), lr=lr) disc_trainer = q.trainer(disc_model).on(disc_data).optimizer( disc_optim).loss(q.no_losses(1)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss( q.no_losses(1)) gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer) gan_trainer.run(50, disciters=10, geniters=3)
def run_classify(lr=0.001, seqlen=6, numex=500, epochs=25, batsize=10, test=True, cuda=False, gpu=0): device = torch.device("cpu") if cuda: device = torch.device("cuda", gpu) # region construct data colors = "red blue green magenta cyan orange yellow grey salmon pink purple teal".split( ) D = dict(zip(colors, range(len(colors)))) inpseqs = [] targets = [] for i in range(numex): inpseq = list(np.random.choice(colors, seqlen, replace=False)) target = np.random.choice(range(len(inpseq)), 1)[0] target_class = D[inpseq[target]] inpseq[target] = "${}$".format(inpseq[target]) inpseqs.append("".join(inpseq)) targets.append(target_class) sm = q.StringMatrix() sm.tokenize = lambda x: list(x) for inpseq in inpseqs: sm.add(inpseq) sm.finalize() print(sm[0]) print(sm.D) targets = np.asarray(targets) data = q.dataload(sm.matrix[:-100], targets[:-100], batch_size=batsize) valid_data = q.dataload(sm.matrix[-100:], targets[-100:], batch_size=batsize) # endregion # region model embdim = 20 enc2inpdim = 45 encdim = 20 outdim = 20 emb = q.WordEmb(embdim, worddic=sm.D) # sm dictionary (characters) out = q.WordLinout(outdim, worddic=D) # target dictionary # encoders: enc1 = q.RNNEncoder(embdim, encdim, bidir=True) enc2 = q.RNNCellEncoder(enc2inpdim, outdim // 2, bidir=True) # model class Model(torch.nn.Module): def __init__(self, dim, _emb, _out, _enc1, _enc2, **kw): super(Model, self).__init__(**kw) self.dim, self.emb, self.out, self.enc1, self.enc2 = dim, _emb, _out, _enc1, _enc2 self.score = torch.nn.Sequential( torch.nn.Linear(dim, 1, bias=False), torch.nn.Sigmoid()) self.emb_expander = ExpandVecs(embdim, enc2inpdim, 2) self.enc_expander = ExpandVecs(encdim * 2, enc2inpdim, 2) def forward(self, x, with_att=False): # embed and encode xemb, xmask = self.emb(x) xenc = self.enc1(xemb, mask=xmask) # compute attention xatt = self.score(xenc).squeeze( 2) * xmask.float()[:, :xenc.size(1)] # encode again _xemb = self.emb_expander(xemb[:, :xenc.size(1)]) _xenc = self.enc_expander(xenc) _, xenc2 = self.enc2(_xemb, gate=xatt, mask=xmask[:, :xenc.size(1)], ret_states=True) scores = self.out(xenc2.view(xenc.size(0), -1)) if with_att: return scores, xatt else: return scores model = Model(40, emb, out, enc1, enc2) # endregion # region test if test: inps = torch.tensor(sm.matrix[0:2]) outs = model(inps) # endregion # region train optimizer = torch.optim.Adam(q.params_of(model), lr=lr) trainer = q.trainer(model).on(data).loss(torch.nn.CrossEntropyLoss(), q.Accuracy())\ .optimizer(optimizer).hook(q.ClipGradNorm(5.)).device(device) validator = q.tester(model).on(valid_data).loss( q.Accuracy()).device(device) q.train(trainer, validator).run(epochs=epochs) # endregion # region check attention #TODO # feed a batch inpd = torch.tensor(sm.matrix[400:410]) outd, att = model(inpd, with_att=True) outd = torch.max(outd, 1)[1].cpu().detach().numpy() inpd = inpd.cpu().detach().numpy() att = att.cpu().detach().numpy() rD = {v: k for k, v in sm.D.items()} roD = {v: k for k, v in D.items()} for i in range(len(att)): inpdi = " ".join([rD[x] for x in inpd[i]]) outdi = roD[outd[i]] print("input: {}\nattention: {}\nprediction: {}".format( inpdi, " ".join(["{:.1f}".format(x) for x in att[i]]), outdi))
def run_words(lr=0.001, seqlen=8, batsize=50, epochs=1000, embdim=64, innerdim=128, z_dim=64, usebase=True, noaccumulate=False, ): # get some words N = 1000 glove = q.PretrainedWordEmb(50, vocabsize=N+2) words = list(glove.D.keys())[2:] datasm = q.StringMatrix() datasm.tokenize = lambda x: list(x) for word in words: datasm.add(word) datasm.finalize() datamat = datasm.matrix[:, :seqlen] # replace <mask> with <end> datamat = datamat + (datamat == datasm.D["<MASK>"]) * (datasm.D["<END>"] - datasm.D["<MASK>"]) real_data = q.dataset(datamat) gen_data_d = q.gan.gauss_dataset(z_dim, len(real_data)) disc_data = q.datacat([real_data, gen_data_d], 1) gen_data = q.gan.gauss_dataset(z_dim) disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True) gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True) discriminator = Discriminator(datasm.D, embdim, innerdim) generator = Decoder(datasm.D, embdim, z_dim, "<START>", innerdim, maxtime=seqlen) SeqGAN = SeqGAN_Base if usebase else SeqGAN_DCL disc_model = SeqGAN(discriminator, generator, gan_mode=q.gan.GAN.DISC_TRAIN, accumulate=not noaccumulate) gen_model = SeqGAN(discriminator, generator, gan_mode=q.gan.GAN.GEN_TRAIN, accumulate=not noaccumulate) disc_optim = torch.optim.Adam(q.params_of(discriminator), lr=lr) gen_optim = torch.optim.Adam(q.params_of(generator), lr=lr) disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(q.no_losses(2)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(q.no_losses(2)) gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer) gan_trainer.run(epochs, disciters=5, geniters=1, burnin=500) # print some predictions: with torch.no_grad(): rvocab = {v: k for k, v in datasm.D.items()} q.batch_reset(generator) eval_z = torch.randn(50, z_dim) eval_y, _ = generator(eval_z) for i in range(len(eval_y)): prow = "".join([rvocab[mij] for mij in eval_y[i].numpy()]) print(prow) print("done")
def run(lr=0.001, dropout=0.2, batsize=50, embdim=50, encdim=50, decdim=50, numlayers=1, bidir=False, which="geo", # "geo", "atis", "jobs" test=True, ): settings = locals().copy() logger = q.log.Logger(prefix="seq2seq_base") logger.save_settings(**settings) # region data nlsm, qlsm, splitidxs = load_data(which=which) print(nlsm[0], qlsm[0]) print(nlsm._rarewords) trainloader = q.dataload(nlsm.matrix[:splitidxs[0]], qlsm.matrix[:splitidxs[0]], batch_size=batsize, shuffle=True) devloader = q.dataload(nlsm.matrix[splitidxs[0]:splitidxs[1]], qlsm.matrix[splitidxs[0]:splitidxs[1]], batch_size=batsize, shuffle=False) testloader = q.dataload(nlsm.matrix[splitidxs[1]:], qlsm.matrix[splitidxs[1]:], batch_size=batsize, shuffle=False) # endregion # region model encdims = [encdim] * numlayers outdim = (encdim if not bidir else encdim * 2) + decdim nlemb = q.WordEmb(embdim, worddic=nlsm.D) qlemb = q.WordEmb(embdim, worddic=qlsm.D) nlenc = q.LSTMEncoder(embdim, *encdims, bidir=bidir, dropout_in=dropout) att = q.att.DotAtt() if numlayers > 1: qldec_core = torch.nn.Sequential( *[q.LSTMCell(_indim, _outdim, dropout_in=dropout) for _indim, _outdim in [(embdim, decdim)] + [(decdim, decdim)] * (numlayers - 1)] ) else: qldec_core = q.LSTMCell(embdim, decdim, dropout_in=dropout) qlout = q.WordLinout(outdim, worddic=qlsm.D) qldec = q.LuongCell(emb=qlemb, core=qldec_core, att=att, out=qlout) class Model(torch.nn.Module): def __init__(self, _nlemb, _nlenc, _qldec, train=True, **kw): super(Model, self).__init__(**kw) self.nlemb, self.nlenc, self._q_train = _nlemb, _nlenc, train if train: self.qldec = q.TFDecoder(_qldec) else: self.qldec = q.FreeDecoder(_qldec, maxtime=100) def forward(self, x, y): # (batsize, seqlen) int ids xemb, xmask = self.nlemb(x) xenc = self.nlenc(xemb, mask=xmask) if self._q_train is False: assert(y.dim() == 2) dec = self.qldec(y, ctx=xenc, ctxmask=xmask[:, :xenc.size(1)]) return dec m_train = Model(nlemb, nlenc, qldec, train=True) m_test = Model(nlemb, nlenc, qldec, train=False) if test: test_out = m_train(torch.tensor(nlsm.matrix[:5]), torch.tensor(qlsm.matrix[:5])) print("test_out.size() = {}".format(test_out.size()))
def run( lr=0.0001, batsize=64, epochs=100000, lamda=10, disciters=5, burnin=-1, validinter=1000, devinter=100, cuda=False, gpu=0, z_dim=128, test=False, dim_d=128, dim_g=128, ): settings = locals().copy() logger = q.log.Logger(prefix="wgan_resnet_cifar") logger.save_settings(**settings) print("started") burnin = disciters if burnin == -1 else burnin if test: validinter = 10 burnin = 1 batsize = 2 devinter = 1 tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("creating networks") gen = OldGenerator(z_dim, dim_g).to(device) crit = OldDiscriminator(dim_d).to(device) tt.tock("created networks") # test # z = torch.randn(3, z_dim).to(device) # x = gen(z) # s = crit(x) # data # load cifar tt.tick("loading data") traincifar, testcifar = load_cifar_dataset(train=True), load_cifar_dataset( train=False) print(len(traincifar)) gen_data_d = q.gan.gauss_dataset(z_dim, len(traincifar)) disc_data = q.datacat([traincifar, gen_data_d], 1) gen_data = q.gan.gauss_dataset(z_dim) gen_data_valid = q.gan.gauss_dataset(z_dim, 50000) disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True) gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True) gen_data_valid = q.dataload(gen_data_valid, batch_size=batsize, shuffle=False) validcifar_loader = q.dataload(testcifar, batch_size=batsize, shuffle=False) dev_data_gauss = q.gan.gauss_dataset(z_dim, len(testcifar)) dev_disc_data = q.datacat([testcifar, dev_data_gauss], 1) dev_disc_data = q.dataload(dev_disc_data, batch_size=batsize, shuffle=False) # q.embed() tt.tock("loaded data") disc_model = q.gan.WGAN(crit, gen, lamda=lamda).disc_train() gen_model = q.gan.WGAN(crit, gen, lamda=lamda).gen_train() disc_optim = torch.optim.Adam(q.params_of(crit), lr=lr, betas=(0.5, 0.9)) gen_optim = torch.optim.Adam(q.params_of(gen), lr=lr, betas=(0.5, 0.9)) disc_bt = UnquantizeTransform() disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(3).device(device)\ .set_batch_transformer(lambda a, b: (disc_bt(a), b)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss( 1).device(device) fidandis = q.gan.FIDandIS(device=device) if not test: fidandis.set_real_stats_with(validcifar_loader) saver = q.gan.GenDataSaver(logger, "saved.npz") generator_validator = q.gan.GeneratorValidator(gen, [fidandis, saver], gen_data_valid, device=device, logger=logger, validinter=validinter) train_validator = q.tester(disc_model).on(dev_disc_data).loss(3).device(device)\ .set_batch_transformer(lambda a, b: (disc_bt(a), b)) train_validator.validinter = devinter tt.tick("training") gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer, validators=(generator_validator, train_validator), lr_decay=True) gan_trainer.run(epochs, disciters=disciters, geniters=1, burnin=burnin) tt.tock("trained")
def run(lr=OPT_LR, batsize=100, epochs=1000, validinter=20, wreg=0.00000000001, dropout=0.1, embdim=50, encdim=50, numlayers=1, cuda=False, gpu=0, mode="flat", test=False, gendata=False): if gendata: loadret = load_jsons() pickle.dump(loadret, open("loadcache.flat.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) else: settings = locals().copy() logger = q.Logger(prefix="rank_lstm") logger.save_settings(**settings) device = torch.device("cpu") if cuda: device = torch.device("cuda", gpu) tt = q.ticktock("script") # region DATA tt.tick("loading data") qsm, csm, goldchainids, badchainids = pickle.load( open("loadcache.{}.pkl".format(mode))) eids = np.arange(0, len(goldchainids)) data = [qsm.matrix, eids] traindata, validdata = q.datasplit(data, splits=(7, 3), random=False) validdata, testdata = q.datasplit(validdata, splits=(1, 2), random=False) trainloader = q.dataload(*traindata, batch_size=batsize, shuffle=True) input_feeder = FlatInpFeeder(csm.matrix, goldchainids, badchainids) def inp_bt(_qsm_batch, _eids_batch): golds_batch, bads_batch = input_feeder(_eids_batch) return _qsm_batch, golds_batch, bads_batch if test: # test input feeder eids = q.var(torch.arange(0, 10).long()).v _test_golds_batch, _test_bads_batch = input_feeder(eids) tt.tock("data loaded") # endregion # region MODEL dims = [encdim // 2] * numlayers question_encoder = FlatEncoder(embdim, dims, qsm.D, bidir=True) query_encoder = FlatEncoder(embdim, dims, csm.D, bidir=True) similarity = DotDistance() rankmodel = RankModel(question_encoder, query_encoder, similarity) scoremodel = ScoreModel(question_encoder, query_encoder, similarity) # endregion # region VALIDATION rankcomp = RankingComputer(scoremodel, validdata[1], validdata[0], csm.matrix, goldchainids, badchainids) # endregion # region TRAINING optim = torch.optim.Adam(q.params_of(rankmodel), lr=lr, weight_decay=wreg) trainer = q.trainer(rankmodel).on(trainloader).loss(1)\ .set_batch_transformer(inp_bt).optimizer(optim).device(device) def validation_function(): rankmetrics = rankcomp.compute(RecallAt(1, totaltrue=1), RecallAt(5, totaltrue=1), MRR()) ret = [] for rankmetric in rankmetrics: rankmetric = np.asarray(rankmetric) ret_i = rankmetric.mean() ret.append(ret_i) return "valid: " + " - ".join(["{:.4f}".format(x) for x in ret]) q.train(trainer, validation_function).run(epochs, validinter=validinter)
def run(lr=0.0001, batsize=64, epochs=100000, lamda=10, disciters=5, burnin=-1, validinter=1000, devinter=100, cuda=False, gpu=0, z_dim=128, test=False, dim_d=128, dim_g=128, vggversion=13, vgglayer=9, vggvanilla=False, # if True, makes trainable feature transform extralayers=False, # adds a couple extra res blocks to generator to match added VGG pixelpenalty=False, # if True, uses penalty based on pixel-wise interpolate inceptionpath="/data/lukovnik/", normalwgan=False, ): # vggvanilla=True and pixelpenalty=True makes a normal WGAN settings = locals().copy() logger = q.log.Logger(prefix="wgan_resnet_cifar_feat") logger.save_settings(**settings) burnin = disciters if burnin == -1 else burnin if test: validinter=10 burnin=1 batsize=2 devinter = 1 tt = q.ticktock("script") device = torch.device("cpu") if not cuda else torch.device("cuda", gpu) tt.tick("creating networks") if not normalwgan: print("doing wgan-feat") gen = OldGenerator(z_dim, dim_g, extra_layers=extralayers).to(device) inpd = get_vgg_outdim(vggversion, vgglayer) crit = ReducedDiscriminator(inpd, dim_d).to(device) subvgg = SubVGG(vggversion, vgglayer, pretrained=not vggvanilla) else: print("doing normal wgan") gen = OldGenerator(z_dim, dim_g, extra_layers=False).to(device) crit = OldDiscriminator(dim_d).to(device) subvgg = None tt.tock("created networks") # test # z = torch.randn(3, z_dim).to(device) # x = gen(z) # s = crit(x) # data # load cifar tt.tick("loading data") traincifar, testcifar = load_cifar_dataset(train=True), load_cifar_dataset(train=False) print(len(traincifar), len(testcifar)) gen_data_d = q.gan.gauss_dataset(z_dim, len(traincifar)) disc_data = q.datacat([traincifar, gen_data_d], 1) gen_data = q.gan.gauss_dataset(z_dim) gen_data_valid = q.gan.gauss_dataset(z_dim, 50000) swd_gen_data = q.gan.gauss_dataset(z_dim, 10000) swd_real_data = [] swd_shape = traincifar[0].size() for i in range(10000): swd_real_data.append(testcifar[i]) swd_real_data = torch.stack(swd_real_data, 0) disc_data = q.dataload(disc_data, batch_size=batsize, shuffle=True) gen_data = q.dataload(gen_data, batch_size=batsize, shuffle=True) gen_data_valid = q.dataload(gen_data_valid, batch_size=batsize, shuffle=False) validcifar_loader = q.dataload(testcifar, batch_size=batsize, shuffle=False) swd_batsize = 64 swd_gen_data = q.dataload(swd_gen_data, batch_size=swd_batsize, shuffle=False) swd_real_data = q.dataload(swd_real_data, batch_size=swd_batsize, shuffle=False) dev_data_gauss = q.gan.gauss_dataset(z_dim, len(testcifar)) dev_disc_data = q.datacat([testcifar, dev_data_gauss], 1) dev_disc_data = q.dataload(dev_disc_data, batch_size=batsize, shuffle=False) # q.embed() tt.tock("loaded data") if not normalwgan: disc_model = q.gan.WGAN_F(crit, gen, subvgg, lamda=lamda, pixel_penalty=pixelpenalty).disc_train() gen_model = q.gan.WGAN_F(crit, gen, subvgg, lamda=lamda, pixel_penalty=pixelpenalty).gen_train() else: disc_model = q.gan.WGAN(crit, gen, lamda=lamda).disc_train() gen_model = q.gan.WGAN(crit, gen, lamda=lamda).gen_train() disc_params = q.params_of(crit) if vggvanilla and not normalwgan: disc_params += q.params_of(subvgg) disc_optim = torch.optim.Adam(disc_params, lr=lr, betas=(0.5, 0.9)) gen_optim = torch.optim.Adam(q.params_of(gen), lr=lr, betas=(0.5, 0.9)) disc_bt = UnquantizeTransform() disc_trainer = q.trainer(disc_model).on(disc_data).optimizer(disc_optim).loss(3).device(device)\ .set_batch_transformer(lambda a, b: (disc_bt(a), b)) gen_trainer = q.trainer(gen_model).on(gen_data).optimizer(gen_optim).loss(1).device(device) # fidandis = q.gan.FIDandIS(device=device) tfis = q.gan.tfIS(inception_path=inceptionpath, gpu=gpu) # if not test: # fidandis.set_real_stats_with(validcifar_loader) saver = q.gan.GenDataSaver(logger, "saved.npz") generator_validator = q.gan.GeneratorValidator(gen, [tfis, saver], gen_data_valid, device=device, logger=logger, validinter=validinter) train_validator = q.tester(disc_model).on(dev_disc_data).loss(3).device(device)\ .set_batch_transformer(lambda a, b: (disc_bt(a), b)) train_validator.validinter = devinter tt.tick("initializing SWD") swd = q.gan.SlicedWassersteinDistance(swd_shape) swd.prepare_reals(swd_real_data) tt.tock("SWD initialized") swd_validator = q.gan.GeneratorValidator(gen, [swd], swd_gen_data, device=device, logger=logger, validinter=validinter, name="swd") tt.tick("training") gan_trainer = q.gan.GANTrainer(disc_trainer, gen_trainer, validators=(generator_validator, train_validator, swd_validator), lr_decay=True) gan_trainer.run(epochs, disciters=disciters, geniters=1, burnin=burnin) tt.tock("trained")
def run_normal_seqvae_toy( lr=0.001, embdim=64, encdim=100, zdim=64, batsize=50, epochs=100, ): # test vocsize = 100 seqlen = 12 wD = dict((chr(xi), xi) for xi in range(vocsize)) # region encoder encoder_emb = q.WordEmb(embdim, worddic=wD) encoder_lstm = q.FastestLSTMEncoder(embdim, encdim) class EncoderNet(torch.nn.Module): def __init__(self, emb, core): super(EncoderNet, self).__init__() self.emb, self.core = emb, core def forward(self, x): embs, mask = self.emb(x) out, states = self.core(embs, mask, ret_states=True) top_state = states[-1][0][:, 0] # top_state = top_state.unsqueeze(1).repeat(1, out.size(1), 1) return top_state # (batsize, encdim) encoder_net = EncoderNet(encoder_emb, encoder_lstm) encoder = Posterior(encoder_net, encdim, zdim) # endregion # region decoder decoder_emb = q.WordEmb(embdim, worddic=wD) decoder_lstm = q.LSTMCell(embdim + zdim, encdim) decoder_outlin = q.WordLinout(encdim, worddic=wD) class DecoderCell(torch.nn.Module): def __init__(self, emb, core, out, **kw): super(DecoderCell, self).__init__() self.emb, self.core, self.out = emb, core, out def forward(self, xs, z=None): embs, mask = self.emb(xs) core_inp = torch.cat([embs, z], 1) core_out = self.core(core_inp) out = self.out(core_out) return out decoder_cell = DecoderCell(decoder_emb, decoder_lstm, decoder_outlin) decoder = q.TFDecoder(decoder_cell) # endregion likelihood = Likelihood() vae = SeqVAE(encoder, decoder, likelihood) x = torch.randint(0, vocsize, (batsize, seqlen), dtype=torch.int64) ys = vae(x) optim = torch.optim.Adam(q.params_of(vae), lr=lr) x = torch.randint(0, vocsize, (batsize * 100, seqlen), dtype=torch.int64) dataloader = q.dataload(x, batch_size=batsize, shuffle=True) trainer = q.trainer(vae).on(dataloader).optimizer(optim).loss(4).epochs( epochs) trainer.run() print("done \n\n")