def test_continued_training_encdec(self): m = SimpleSeqEncDecAtt(inpvocsize=20, outvocsize=20, inpembdim=5, outembdim=5, encdim=10, decdim=10) data = np.random.randint(0, 20, (50, 7)) r = m.train([data, data[:, :-1]], data[:, 1:]).cross_entropy().adadelta().train( 5, 10, returnerrors=True) a = r[1] print "\n".join(map(str, a)) for i in range(0, len(a) - 1): self.assertTrue(a[i + 1] < a[i]) m.get_params() m.save("/tmp/testmodelsave") m = m.load("/tmp/testmodelsave") r = m.train([data, data[:, :-1]], data[:, 1:]).cross_entropy().adadelta().train( 5, 10, returnerrors=True) b = r[1] print "\n".join(map(str, b)) for i in range(0, len(b) - 1): self.assertTrue(b[i + 1] < b[i]) self.assertTrue(b[0] < a[-1])
def run( numbats=50, epochs=10, lr=0.5, embdim=50, encdim=400, dropout=0.2, layers=1, inconcat=True, outconcat=True, posemb=False, customemb=False, preproc="none", # "none" or "generate" or "abstract" or "gensample" bidir=False, corruptnoise=0.0, inspectdata=False, relinearize="none", pretrain=False, pretrainepochs=-1, pretrainnumbats=-1, pretrainlr=-0.1, loadpretrained="none", pretrainfixdecoder=False, wreg=0.0, testmode=False): #TODO: bi-encoder and other beasts #TODO: make sure gensample results NOT IN test data if pretrain == True: assert (preproc == "none" or preproc == "gensample") pretrainepochs = epochs if pretrainepochs == -1 else pretrainepochs ######### DATA LOADING AND TRANSFORMATIONS ########### srctransformer = None if relinearize != "none": lambdaparser = LambdaParser() if relinearize == "greedy": def srctransformer(x): return lambdaparser.parse(x).greedy_linearize(deeppref=True) elif relinearize == "deep": def srctransformer(x): return lambdaparser.parse(x).deep_linearize() else: raise Exception("unknown linearization") adic = {} if pretrain or loadpretrained != "none": ### PRETRAIN DATA LOAD ### qmat_auto, amat_auto, qdic_auto, adic, qwc_auto, awc_auto = \ loadgeoauto(reverse=True, transformer=srctransformer) def pp(i): print wordids2string(qmat_auto[i], {v: k for k, v in qdic_auto.items()}, 0) print wordids2string(amat_auto[i], {v: k for k, v in adic.items()}, 0) if inspectdata: print "pretrain inspect" embed() qmat, amat, qdic, adic, qwc, awc = loadgeo(customemb=customemb, reverse=True, transformer=srctransformer, adic=adic) maskid = 0 typdic = None oqmat = qmat.copy() oamat = amat.copy() print "{} is preproc".format(preproc) if preproc != "none": qmat, amat, qdic, adic, qwc, awc = preprocess( qmat, amat, qdic, adic, qwc, awc, maskid, qreversed=True, dorare=preproc != "generate") if preproc == "generate": # alters size print "generating" qmat, amat = generate(qmat, amat, qdic, adic, oqmat, oamat, reversed=True) #embed() elif preproc == "gensample": typdic = gentypdic(qdic, adic) ######### train/test split from here ######### qmat_t, qmat_x = split_train_test(qmat) amat_t, amat_x = split_train_test(amat) oqmat_t, oqmat_x = split_train_test(oqmat) oamat_t, oamat_x = split_train_test(oamat) qoverlap, aoverlap, overlap = compute_overlap(qmat_t, amat_t, qmat_x, amat_x) print "overlaps: {}, {}: {}".format(len(qoverlap), len(aoverlap), len(overlap)) if inspectdata: embed() np.random.seed(12345) encdimi = [encdim / 2 if bidir else encdim] * layers decdimi = [encdim] * layers amati_t, amati_x = amat_t, amat_x oamati_t, oamati_x = oamat_t, oamat_x if pretrain: amati_auto = amat_auto if posemb: # add positional indexes to datamatrices qmat_t, oqmat_t, amat_t, oamat_t = add_pos_indexes( qmat_t, oqmat_t, amat_t, oamat_t) qmat_x, oqmat_x, amat_x, oamat_x = add_pos_indexes( qmat_x, oqmat_x, amat_x, oamat_x) if preproc == "gensample": qmat_x, amat_x, amati_x = oqmat_x, oamat_x, oamati_x rqdic = {v: k for k, v in qdic.items()} radic = {v: k for k, v in adic.items()} def tpp(i): print wordids2string(qmat_t[i], rqdic, 0) print wordids2string(amat_t[i], radic, 0) def xpp(i): print wordids2string(qmat_x[i], rqdic, 0) print wordids2string(amat_x[i], radic, 0) if inspectdata: embed() print "{} training examples".format(qmat_t.shape[0]) ################## MODEL DEFINITION ################## # encdec prerequisites inpemb = WordEmb(worddic=qdic, maskid=maskid, dim=embdim) outemb = WordEmb(worddic=adic, maskid=maskid, dim=embdim) if pretrain == True: inpemb_auto = WordEmb(worddic=qdic_auto, maskid=maskid, dim=embdim) #outemb = WordEmb(worddic=adic, maskid=maskid, dim=embdim) if customemb: inpemb, outemb = do_custom_emb(inpemb, outemb, awc, embdim) if pretrain: inpemb_auto, outemb = do_custom_emb(inpemb_auto, outemb, awc_auto, embdim) if posemb: # use custom emb layers, with positional embeddings posembdim = 50 inpemb = VectorPosEmb(inpemb, qmat_t.shape[1], posembdim) outemb = VectorPosEmb(outemb, amat_t.shape[1], posembdim) if pretrain: inpemb_auto = VectorPosEmb(inpemb_auto, qmat_auto.shape[1], posembdim) outemb = VectorPosEmb(outemb, max(amat_auto.shape[1], amat_t.shape[1]), posembdim) smodim = embdim smo = SoftMaxOut(indim=encdim + encdim, innerdim=smodim, outvocsize=len(adic) + 1, dropout=dropout) if customemb: smo.setlin2(outemb.baseemb.W.T) # encdec model encdec = SimpleSeqEncDecAtt( inpvocsize=max(qdic.values()) + 1, inpembdim=embdim, inpemb=inpemb, outvocsize=max(adic.values()) + 1, outembdim=embdim, outemb=outemb, encdim=encdimi, decdim=decdimi, maskid=maskid, statetrans=True, dropout=dropout, inconcat=inconcat, outconcat=outconcat, rnu=GRU, vecout=smo, bidir=bidir, ) ################## TRAINING ################## if pretrain == True or loadpretrained != "none": if pretrain == True and loadpretrained == "none": if pretrainfixdecoder: encdec.remake_encoder(inpvocsize=max(qdic_auto.values()) + 1, inpembdim=embdim, inpemb=inpemb_auto, maskid=maskid, bidir=bidir, dropout_h=dropout, dropout_in=dropout) else: encdec.enc.embedder = inpemb_auto if loadpretrained != "none": encdec = encdec.load(loadpretrained + ".pre.sp.model") print "MODEL LOADED: {}".format(loadpretrained) if pretrain == True: if pretrainnumbats < 0: import math batsize = int(math.ceil(qmat_t.shape[0] * 1.0 / numbats)) pretrainnumbats = int( math.ceil(qmat_auto.shape[0] * 1.0 / batsize)) print "{} batches".format(pretrainnumbats) if pretrainlr < 0: pretrainlr = lr if testmode: oldparamvals = {p: p.v for p in encdec.get_params()} qmat_auto = qmat_auto[:100] amat_auto = amat_auto[:100] amati_auto = amati_auto[:100] pretrainnumbats = 10 #embed() encdec.train([qmat_auto, amat_auto[:, :-1]], amati_auto[:, 1:])\ .cross_entropy().adadelta(lr=pretrainlr).grad_total_norm(5.) \ .l2(wreg).exp_mov_avg(0.95) \ .split_validate(splits=10, random=True).cross_entropy().seq_accuracy() \ .train(pretrainnumbats, pretrainepochs) if testmode: for p in encdec.get_params(): print np.linalg.norm(p.v - oldparamvals[p], ord=1) savepath = "{}.pre.sp.model".format(random.randint(1000, 9999)) print "PRETRAIN SAVEPATH: {}".format(savepath) encdec.save(savepath) # NaN somewhere at 75% in training, in one of RNU's? --> with rmsprop if pretrainfixdecoder: encdec.remake_encoder(inpvocsize=max(qdic.values()) + 1, inpembdim=embdim, inpemb=inpemb, bidir=bidir, maskid=maskid, dropout_h=dropout, dropout_in=dropout) encdec.dec.set_lr(0.0) else: encdec.dec.embedder.set_lr(0.0) encdec.enc.embedder = inpemb encdec.train([qmat_t, amat_t[:, :-1]], amati_t[:, 1:])\ .sampletransform(GenSample(typdic), RandomCorrupt(corruptdecoder=(2, max(adic.values()) + 1), corruptencoder=(2, max(qdic.values()) + 1), maskid=maskid, p=corruptnoise))\ .cross_entropy().adadelta(lr=lr).grad_total_norm(5.) \ .l2(wreg).exp_mov_avg(0.8) \ .validate_on([qmat_x, amati_x[:, :-1]], amat_x[:, 1:]) \ .cross_entropy().seq_accuracy()\ .train(numbats, epochs) #.split_validate(splits=10, random=True)\ qrwd = {v: k for k, v in qdic.items()} arwd = {v: k for k, v in adic.items()} def play(*x, **kw): hidecorrect = False if "hidecorrect" in kw: hidecorrect = kw["hidecorrect"] if len(x) == 1: x = x[0] q = wordids2string(qmat_x[x], rwd=qrwd, maskid=maskid, reverse=True) ga = wordids2string(amat_x[x, 1:], rwd=arwd, maskid=maskid) pred = encdec.predict(qmat_x[x:x + 1], amati_x[x:x + 1, :-1]) pa = wordids2string(np.argmax(pred[0], axis=1), rwd=arwd, maskid=maskid) if hidecorrect and ga == pa[:len(ga)]: # correct return False else: print "{}: {}".format(x, q) print ga print pa return True elif len(x) == 0: for i in range(0, qmat_x.shape[0]): r = play(i) if r: raw_input() else: raise Exception("invalid argument to play") embed()