Exemplo n.º 1
0
    def test_get_params(self):
        attdist = LinearDistance(110, 110, 100)
        encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                    outvocsize=17,
                                    outconcat=False,
                                    encdim=(110, 100),
                                    decdim=100,
                                    attdist=attdist)
        enclayers = encdec.enc.block.layers
        params = set()
        for layer in enclayers:
            for paramname in "w wm whf u um uhf b bm bhf".split(
            ):  # GRU params
                params.add(getattr(layer, paramname))
        declayers = encdec.dec.block.layers
        for layer in declayers:
            for paramname in "w wm whf u um uhf b bm bhf".split(
            ):  # GRU params
                params.add(getattr(layer, paramname))
        params.update({encdec.dec.lin.W, encdec.dec.lin.b})

        params.update({
            encdec.dec.attention.attentiongenerator.dist.lin.W,
            encdec.dec.attention.attentiongenerator.dist.lin.b,
            encdec.dec.attention.attentiongenerator.dist.lin2.W,
            encdec.dec.attention.attentiongenerator.dist.lin2.b,
            encdec.dec.attention.attentiongenerator.dist.agg
        })
        self.assertEqual(params, encdec.get_params())
Exemplo n.º 2
0
    def do_test_shapes(self, bidir=False):
        inpvocsize = 100
        outvocsize = 13
        inpembdim = 10
        outembdim = 5
        encdim = 9
        decdim = 7
        attdim = 8
        batsize = 11
        inpseqlen = 7
        outseqlen = 5

        m = SimpleSeqEncDecAtt(inpvocsize=inpvocsize,
                               inpembdim=inpembdim,
                               outvocsize=outvocsize,
                               outembdim=outembdim,
                               encdim=encdim,
                               decdim=decdim,
                               attdim=attdim,
                               bidir=bidir)

        inpseq = np.random.randint(0, inpvocsize, (batsize, inpseqlen)).astype("int32")
        outseq = np.random.randint(0, outvocsize, (batsize, outseqlen)).astype("int32")

        predenco, _, _ = m.enc.predict(inpseq)
        self.assertEqual(predenco.shape, (batsize, encdim if not bidir else encdim*2))

        pred = m.predict(inpseq, outseq)
        self.assertEqual(pred.shape, (batsize, outseqlen, outvocsize))
Exemplo n.º 3
0
    def do_test_shapes(self, bidir=False):
        inpvocsize = 100
        outvocsize = 13
        inpembdim = 10
        outembdim = 5
        encdim = 9
        decdim = 7
        attdim = 8
        batsize = 11
        inpseqlen = 7
        outseqlen = 5

        m = SimpleSeqEncDecAtt(inpvocsize=inpvocsize,
                               inpembdim=inpembdim,
                               outvocsize=outvocsize,
                               outembdim=outembdim,
                               encdim=encdim,
                               decdim=decdim,
                               attdim=attdim,
                               bidir=bidir)

        inpseq = np.random.randint(0, inpvocsize,
                                   (batsize, inpseqlen)).astype("int32")
        outseq = np.random.randint(0, outvocsize,
                                   (batsize, outseqlen)).astype("int32")

        predenco, _, _ = m.enc.predict(inpseq)
        self.assertEqual(predenco.shape,
                         (batsize, encdim if not bidir else encdim * 2))

        pred = m.predict(inpseq, outseq)
        self.assertEqual(pred.shape, (batsize, outseqlen, outvocsize))
Exemplo n.º 4
0
    def do_test_shapes(self, bidir=False, sepatt=False, rnu=GRU):
        inpvocsize = 100
        outvocsize = 13
        inpembdim = 10
        outembdim = 7
        encdim = [26, 14]
        decdim = [21, 15]
        batsize = 11
        inpseqlen = 6
        outseqlen = 5

        if bidir:
            encdim = [e / 2 for e in encdim]

        m = SimpleSeqEncDecAtt(inpvocsize=inpvocsize,
                               inpembdim=inpembdim,
                               outvocsize=outvocsize,
                               outembdim=outembdim,
                               encdim=encdim,
                               decdim=decdim,
                               bidir=bidir,
                               statetrans=True,
                               attdist=LinearDistance(15, 14, 17),
                               sepatt=sepatt,
                               rnu=rnu)

        inpseq = np.random.randint(0, inpvocsize,
                                   (batsize, inpseqlen)).astype("int32")
        outseq = np.random.randint(0, outvocsize,
                                   (batsize, outseqlen)).astype("int32")

        predenco, enco, states = m.enc.predict(inpseq)
        self.assertEqual(
            predenco.shape,
            (batsize, encdim[-1] if not bidir else encdim[-1] * 2))
        if rnu == GRU:
            self.assertEqual(len(states), 2)
            for state, encdime in zip(states, encdim):
                self.assertEqual(state.shape,
                                 (batsize, inpseqlen,
                                  encdime if not bidir else encdime * 2))
        elif rnu == LSTM:
            self.assertEqual(len(states), 4)
            for state, encdime in zip(
                    states, [encdim[0], encdim[0], encdim[1], encdim[1]]):
                self.assertEqual(state.shape,
                                 (batsize, inpseqlen,
                                  encdime if not bidir else encdime * 2))

        if sepatt:
            self.assertEqual(enco.shape,
                             (batsize, inpseqlen, 2,
                              encdim[-1] if not bidir else encdim[-1] * 2))

        pred = m.predict(inpseq, outseq)
        self.assertEqual(pred.shape, (batsize, outseqlen, outvocsize))

        _, outvar = m.autobuild(inpseq, outseq)
        for p in sorted(outvar[0].allparams, key=lambda x: str(x)):
            print p
Exemplo n.º 5
0
def run(p="../../../data/atis/atis.pkl", wordembdim=70, lablembdim=70, innerdim=300, lr=0.01, numbats=100, epochs=20, validinter=1, wreg=0.0001, depth=1, attdim=300):
    train, test, dics = pickle.load(open(p))
    word2idx = dics["words2idx"]
    table2idx = dics["tables2idx"]
    label2idx = dics["labels2idx"]
    label2idxrev = {v: k for k, v in label2idx.items()}
    train = zip(*train)
    test = zip(*test)
    print "%d training examples, %d test examples" % (len(train), len(test))
    #tup2text(train[0], word2idx, table2idx, label2idx)
    maxlen = 0
    for tup in train + test:
        maxlen = max(len(tup[0]), maxlen)

    numwords = max(word2idx.values()) + 2
    numlabels = max(label2idx.values()) + 2

    # get training data
    traindata = getdatamatrix(train, maxlen, 0).astype("int32")
    traingold = getdatamatrix(train, maxlen, 2).astype("int32")
    trainmask = (traindata > 0).astype("float32")

    # test data
    testdata = getdatamatrix(test, maxlen, 0).astype("int32")
    testgold = getdatamatrix(test, maxlen, 2).astype("int32")
    testmask = (testdata > 0).astype("float32")

    res = atiseval(testgold-1, testgold-1, label2idxrev); print res#; exit()

    # define model
    innerdim = [innerdim] * depth
    m = SimpleSeqEncDecAtt(
        inpvocsize=numwords,
        inpembdim=wordembdim,
        outvocsize=numlabels,
        outembdim=lablembdim,
        encdim=innerdim,
        decdim=innerdim,
        attdim=attdim,
        inconcat=False
    )

    # training
    m.train([traindata, shiftdata(traingold), trainmask], traingold).adagrad(lr=lr).grad_total_norm(1.).seq_cross_entropy().l2(wreg)\
        .validate_on([testdata, shiftdata(testgold), testmask], testgold).seq_cross_entropy().seq_accuracy().validinter(validinter)\
        .train(numbats, epochs)

    # predict after training
    s = SeqEncDecAttSearch(m)
    testpred = s.decode(testdata)
    testpred = testpred * testmask
    #testpredprobs = m.predict(testdata, shiftdata(testgold), testmask)
    #testpred = np.argmax(testpredprobs, axis=2)-1
    #testpred = testpred * testmask
    #print np.vectorize(lambda x: label2idxrev[x] if x > -1 else " ")(testpred)

    evalres = atiseval(testpred-1, testgold-1, label2idxrev); print evalres
Exemplo n.º 6
0
 def test_continued_training_encdec(self):
     m = SimpleSeqEncDecAtt(inpvocsize=20,
                            outvocsize=20,
                            inpembdim=5,
                            outembdim=5,
                            encdim=10,
                            decdim=10)
     data = np.random.randint(0, 20, (50, 7))
     r = m.train([data, data[:, :-1]],
                 data[:, 1:]).cross_entropy().adadelta().train(
                     5, 10, returnerrors=True)
     a = r[1]
     print "\n".join(map(str, a))
     for i in range(0, len(a) - 1):
         self.assertTrue(a[i + 1] < a[i])
     m.get_params()
     m.save("/tmp/testmodelsave")
     m = m.load("/tmp/testmodelsave")
     r = m.train([data, data[:, :-1]],
                 data[:, 1:]).cross_entropy().adadelta().train(
                     5, 10, returnerrors=True)
     b = r[1]
     print "\n".join(map(str, b))
     for i in range(0, len(b) - 1):
         self.assertTrue(b[i + 1] < b[i])
     self.assertTrue(b[0] < a[-1])
    def test_vector_out(self):
        encdec = SimpleSeqEncDecAtt(inpvocsize=19, outvocsize=17, outconcat=False, decdim=110)
        encdata = np.random.randint(0, 19, (1, 5))
        decdata = np.random.randint(0, 17, (1, 5))
        pred = encdec.predict(encdata, decdata)
        self.assertEqual(pred.shape, (1, 5, 17))

        encdec = SimpleSeqEncDecAtt(inpvocsize=19, outvocsize=17, vecout=True, outconcat=False, decdim=110)
        pred = encdec.predict(encdata, decdata)
        print pred.shape
        self.assertEqual(pred.shape, (1, 5, 110))
Exemplo n.º 8
0
def run_seqdecatt(  # seems to work
    wreg=0.00001,
    epochs=50,
    numbats=50,
    lr=0.1,
    statedim=50,
    encdim=50,
    attdim=50,
    numwords=5000,
):
    # get words
    vocsize = 28
    lm = Glove(50, numwords)
    allwords = filter(lambda x: re.match("^[a-z]+$", x), lm.D.keys())
    #embed()
    invwords = [word[::-1] for word in allwords]
    data = words2ints(allwords)
    idata = words2ints(invwords)
    startsym = 0

    golddata = data

    #golddata = idata

    print data[:10]
    print shiftdata(data, startsym)[:10]

    testwords = [
        "the", "alias", "mock", "test", "stalin", "allahuakbar", "python",
        "pythonista"
    ]
    testpred = words2ints(testwords)

    block = SimpleSeqEncDecAtt(inpvocsize=vocsize,
                               outvocsize=vocsize,
                               encdim=encdim,
                               decdim=statedim,
                               attdim=attdim,
                               inconcat=False,
                               bidir=False,
                               statetrans=None)
    block.train([data, shiftdata(golddata, startsym)], golddata).seq_cross_entropy().grad_total_norm(1.0).adagrad(lr=lr).l2(wreg) \
        .split_validate(splits=5, random=True).seq_cross_entropy().seq_accuracy().validinter(2) \
        .train(numbats=numbats, epochs=epochs)

    s = SeqEncDecSearch(block)
    pred, probs = s.decode(testpred, startsym, testpred.shape[1])
    print ints2words(pred), probs
Exemplo n.º 9
0
    def test_seqdecatt(self,
            statedim=50,
            encdim=50,
            attdim=50,
            startsym=0,
    ):
        # get words
        vocsize = 27

        testpred = ["the", "alias", "mock", "test", "stalin", "allahuakbar", "python", "pythonista",
                    " "]
        testpred = words2ints(testpred)
        print testpred

        block = SimpleSeqEncDecAtt(inpvocsize=vocsize, outvocsize=vocsize,
                                   encdim=encdim, decdim=statedim,
                                   attdim=attdim, inconcat=False,
                                   maskid=0)

        s = GreedySearch(block, startsymbol=startsym, maxlen=testpred.shape[1])
        s.init(testpred, testpred.shape[0])
        ctxmask, ctx = s.wrapped.recpred.nonseqvals
        print ctxmask
        self.assertTrue(np.all(ctxmask == (testpred > 0)))
        pred, probs = s.search(testpred.shape[0])
        print ints2words(pred), probs
Exemplo n.º 10
0
 def test_set_lr(self):
     attdist = LinearDistance(110, 110, 100)
     encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                 outvocsize=17,
                                 outconcat=False,
                                 encdim=110,
                                 decdim=110,
                                 attdist=attdist)
     encdec.dec.set_lr(0.1)
     encdec.dec.attention.set_lr(0.5)  # TODO
     encdata = np.random.randint(0, 19, (2, 5))
     decdata = np.random.randint(0, 17, (2, 5))
     o = encdec(Val(encdata), Val(decdata))
     #print "\n".join(["{}: {}".format(x, x.lrmul) for x in o.allparams])
     #print "\n".join(["{}: {}".format(x, x.lrmul) for x in o.allparams])
     encparams = encdec.enc.get_params()
     decparams = encdec.dec.get_params()
     attparams = encdec.dec.attention.get_params()
     print "\n".join(["{}: {}".format(x, x.lrmul)
                      for x in encparams]) + "\n"
     print "\n".join(["{}: {}".format(x, x.lrmul)
                      for x in decparams]) + "\n"
     for x in encparams:
         self.assertEqual(x.lrmul, 1.0)
     for x in decparams:
         if x not in attparams:
             self.assertEqual(x.lrmul, 0.1)
         else:
             self.assertEqual(x.lrmul, 0.5)
Exemplo n.º 11
0
def run_seqdecatt(  # seems to work
        wreg=0.00001,
        epochs=50,
        numbats=50,
        lr=0.1,
        statedim=50,
        encdim=50,
        attdim=50,
        numwords=5000,
    ):
    # get words
    vocsize = 28
    lm = Glove(50, numwords)
    allwords = filter(lambda x: re.match("^[a-z]+$", x), lm.D.keys())
    #embed()
    invwords = [word[::-1] for word in allwords]
    data = words2ints(allwords)
    idata = words2ints(invwords)
    startsym = 0

    golddata = data

    #golddata = idata

    print data[:10]
    print shiftdata(data, startsym)[:10]

    testwords = ["the", "alias", "mock", "test", "stalin", "allahuakbar", "python", "pythonista"]
    testpred = words2ints(testwords)

    block = SimpleSeqEncDecAtt(inpvocsize=vocsize, outvocsize=vocsize, encdim=encdim, decdim=statedim, attdim=attdim, inconcat=False, bidir=False, statetrans=None)
    block.train([data, shiftdata(golddata, startsym)], golddata).seq_cross_entropy().grad_total_norm(1.0).adagrad(lr=lr).l2(wreg) \
        .split_validate(splits=5, random=True).seq_cross_entropy().seq_accuracy().validinter(2) \
        .train(numbats=numbats, epochs=epochs)

    s = SeqEncDecSearch(block)
    pred, probs = s.decode(testpred, startsym, testpred.shape[1])
    print ints2words(pred), probs
Exemplo n.º 12
0
    def test_encdec_attention_output_extra(self):
        m = SimpleSeqEncDecAtt()
        xdata = np.random.randint(0, 400, (50, 13))
        ydata = np.random.randint(0, 100, (50, 7))
        pred, extra = m.predict.return_extra_outs(["attention_weights",
                                                   "i_t"])(xdata, ydata)
        self.assertEqual(pred.shape, (50, 7, 100))
        self.assertEqual(extra["attention_weights"].shape, (7, 50, 13))

        #AttentionPlotter.plot(extra["attention_weights"][:, 0, :].T)
        attw = extra["attention_weights"][:, 0, :].T
        print np.sum(attw, axis=0)
        print np.sum(attw, axis=1)
        self.assertTrue(
            np.allclose(np.ones_like(np.sum(attw, axis=0)), np.sum(attw,
                                                                   axis=0)))
Exemplo n.º 13
0
    def test_seqdecatt(  # seems to work
        wreg=0.00001,  # TODO: regularization other than 0.0001 first stagnates, then goes down
        epochs=50,
        numbats=20,
        lr=0.1,
        statedim=50,
        encdim=50,
        attdim=50,
        startsym=0,
    ):
        # get words
        vocsize = 27
        embdim = 50
        lm = Glove(embdim, 2000)
        allwords = filter(lambda x: re.match("^[a-z]+$", x), lm.D.keys())
        words = allwords[1000:]
        vwords = allwords[:1000]
        data = words2ints(words)
        sdata = shiftdata(data)
        vdata = words2ints(vwords)
        svdata = shiftdata(vdata)
        testneglogprob = 17

        testpred = [
            "the", "alias", "mock", "test", "stalin", "allahuakbar", "python",
            "pythonista", " " * (data.shape[1])
        ]
        testpred = words2ints(testpred)
        print testpred

        block = SimpleSeqEncDecAtt(inpvocsize=vocsize,
                                   outvocsize=vocsize,
                                   encdim=encdim,
                                   decdim=statedim,
                                   attdim=attdim,
                                   inconcat=False)

        s = SeqEncDecSearch(block)
        pred, probs = s.decode(testpred, startsym, testpred.shape[1])
        print ints2words(pred), probs
    def test_vector_out(self):
        encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                    outvocsize=17,
                                    outconcat=False,
                                    decdim=110)
        encdata = np.random.randint(0, 19, (1, 5))
        decdata = np.random.randint(0, 17, (1, 5))
        pred = encdec.predict(encdata, decdata)
        self.assertEqual(pred.shape, (1, 5, 17))

        encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                    outvocsize=17,
                                    vecout=True,
                                    outconcat=False,
                                    decdim=110)
        pred = encdec.predict(encdata, decdata)
        print pred.shape
        self.assertEqual(pred.shape, (1, 5, 110))
Exemplo n.º 15
0
def run(
        numbats=50,
        epochs=10,
        lr=0.5,
        embdim=50,
        encdim=400,
        dropout=0.2,
        layers=1,
        inconcat=True,
        outconcat=True,
        posemb=False,
        customemb=False,
        preproc="none",  # "none" or "generate" or "abstract" or "gensample"
        bidir=False,
        corruptnoise=0.0,
        inspectdata=False,
        relinearize="none",
        pretrain=False,
        pretrainepochs=-1,
        pretrainnumbats=-1,
        pretrainlr=-0.1,
        loadpretrained="none",
        pretrainfixdecoder=False,
        wreg=0.0,
        testmode=False):

    #TODO: bi-encoder and other beasts
    #TODO: make sure gensample results NOT IN test data

    if pretrain == True:
        assert (preproc == "none" or preproc == "gensample")
        pretrainepochs = epochs if pretrainepochs == -1 else pretrainepochs

    ######### DATA LOADING AND TRANSFORMATIONS ###########
    srctransformer = None
    if relinearize != "none":
        lambdaparser = LambdaParser()
        if relinearize == "greedy":

            def srctransformer(x):
                return lambdaparser.parse(x).greedy_linearize(deeppref=True)
        elif relinearize == "deep":

            def srctransformer(x):
                return lambdaparser.parse(x).deep_linearize()
        else:
            raise Exception("unknown linearization")

    adic = {}
    if pretrain or loadpretrained != "none":  ### PRETRAIN DATA LOAD ###
        qmat_auto, amat_auto, qdic_auto, adic, qwc_auto, awc_auto = \
            loadgeoauto(reverse=True, transformer=srctransformer)

        def pp(i):
            print wordids2string(qmat_auto[i],
                                 {v: k
                                  for k, v in qdic_auto.items()}, 0)
            print wordids2string(amat_auto[i], {v: k
                                                for k, v in adic.items()}, 0)

        if inspectdata:
            print "pretrain inspect"
            embed()
    qmat, amat, qdic, adic, qwc, awc = loadgeo(customemb=customemb,
                                               reverse=True,
                                               transformer=srctransformer,
                                               adic=adic)

    maskid = 0
    typdic = None
    oqmat = qmat.copy()
    oamat = amat.copy()
    print "{} is preproc".format(preproc)
    if preproc != "none":
        qmat, amat, qdic, adic, qwc, awc = preprocess(
            qmat,
            amat,
            qdic,
            adic,
            qwc,
            awc,
            maskid,
            qreversed=True,
            dorare=preproc != "generate")
        if preproc == "generate":  # alters size
            print "generating"
            qmat, amat = generate(qmat,
                                  amat,
                                  qdic,
                                  adic,
                                  oqmat,
                                  oamat,
                                  reversed=True)
            #embed()
        elif preproc == "gensample":
            typdic = gentypdic(qdic, adic)

    ######### train/test split from here #########
    qmat_t, qmat_x = split_train_test(qmat)
    amat_t, amat_x = split_train_test(amat)
    oqmat_t, oqmat_x = split_train_test(oqmat)
    oamat_t, oamat_x = split_train_test(oamat)

    qoverlap, aoverlap, overlap = compute_overlap(qmat_t, amat_t, qmat_x,
                                                  amat_x)
    print "overlaps: {}, {}: {}".format(len(qoverlap), len(aoverlap),
                                        len(overlap))

    if inspectdata:
        embed()

    np.random.seed(12345)

    encdimi = [encdim / 2 if bidir else encdim] * layers
    decdimi = [encdim] * layers

    amati_t, amati_x = amat_t, amat_x
    oamati_t, oamati_x = oamat_t, oamat_x
    if pretrain:
        amati_auto = amat_auto

    if posemb:  # add positional indexes to datamatrices
        qmat_t, oqmat_t, amat_t, oamat_t = add_pos_indexes(
            qmat_t, oqmat_t, amat_t, oamat_t)
        qmat_x, oqmat_x, amat_x, oamat_x = add_pos_indexes(
            qmat_x, oqmat_x, amat_x, oamat_x)

    if preproc == "gensample":
        qmat_x, amat_x, amati_x = oqmat_x, oamat_x, oamati_x

    rqdic = {v: k for k, v in qdic.items()}
    radic = {v: k for k, v in adic.items()}

    def tpp(i):
        print wordids2string(qmat_t[i], rqdic, 0)
        print wordids2string(amat_t[i], radic, 0)

    def xpp(i):
        print wordids2string(qmat_x[i], rqdic, 0)
        print wordids2string(amat_x[i], radic, 0)

    if inspectdata:
        embed()
    print "{} training examples".format(qmat_t.shape[0])

    ################## MODEL DEFINITION ##################
    # encdec prerequisites
    inpemb = WordEmb(worddic=qdic, maskid=maskid, dim=embdim)
    outemb = WordEmb(worddic=adic, maskid=maskid, dim=embdim)

    if pretrain == True:
        inpemb_auto = WordEmb(worddic=qdic_auto, maskid=maskid, dim=embdim)
        #outemb = WordEmb(worddic=adic, maskid=maskid, dim=embdim)

    if customemb:
        inpemb, outemb = do_custom_emb(inpemb, outemb, awc, embdim)
        if pretrain:
            inpemb_auto, outemb = do_custom_emb(inpemb_auto, outemb, awc_auto,
                                                embdim)

    if posemb:  # use custom emb layers, with positional embeddings
        posembdim = 50
        inpemb = VectorPosEmb(inpemb, qmat_t.shape[1], posembdim)
        outemb = VectorPosEmb(outemb, amat_t.shape[1], posembdim)
        if pretrain:
            inpemb_auto = VectorPosEmb(inpemb_auto, qmat_auto.shape[1],
                                       posembdim)
            outemb = VectorPosEmb(outemb,
                                  max(amat_auto.shape[1], amat_t.shape[1]),
                                  posembdim)

    smodim = embdim
    smo = SoftMaxOut(indim=encdim + encdim,
                     innerdim=smodim,
                     outvocsize=len(adic) + 1,
                     dropout=dropout)

    if customemb:
        smo.setlin2(outemb.baseemb.W.T)

    # encdec model
    encdec = SimpleSeqEncDecAtt(
        inpvocsize=max(qdic.values()) + 1,
        inpembdim=embdim,
        inpemb=inpemb,
        outvocsize=max(adic.values()) + 1,
        outembdim=embdim,
        outemb=outemb,
        encdim=encdimi,
        decdim=decdimi,
        maskid=maskid,
        statetrans=True,
        dropout=dropout,
        inconcat=inconcat,
        outconcat=outconcat,
        rnu=GRU,
        vecout=smo,
        bidir=bidir,
    )

    ################## TRAINING ##################
    if pretrain == True or loadpretrained != "none":
        if pretrain == True and loadpretrained == "none":
            if pretrainfixdecoder:
                encdec.remake_encoder(inpvocsize=max(qdic_auto.values()) + 1,
                                      inpembdim=embdim,
                                      inpemb=inpemb_auto,
                                      maskid=maskid,
                                      bidir=bidir,
                                      dropout_h=dropout,
                                      dropout_in=dropout)
            else:
                encdec.enc.embedder = inpemb_auto
        if loadpretrained != "none":
            encdec = encdec.load(loadpretrained + ".pre.sp.model")
            print "MODEL LOADED: {}".format(loadpretrained)
        if pretrain == True:
            if pretrainnumbats < 0:
                import math
                batsize = int(math.ceil(qmat_t.shape[0] * 1.0 / numbats))
                pretrainnumbats = int(
                    math.ceil(qmat_auto.shape[0] * 1.0 / batsize))
                print "{} batches".format(pretrainnumbats)
            if pretrainlr < 0:
                pretrainlr = lr
            if testmode:
                oldparamvals = {p: p.v for p in encdec.get_params()}
                qmat_auto = qmat_auto[:100]
                amat_auto = amat_auto[:100]
                amati_auto = amati_auto[:100]
                pretrainnumbats = 10
            #embed()
            encdec.train([qmat_auto, amat_auto[:, :-1]], amati_auto[:, 1:])\
                .cross_entropy().adadelta(lr=pretrainlr).grad_total_norm(5.) \
                .l2(wreg).exp_mov_avg(0.95) \
                .split_validate(splits=10, random=True).cross_entropy().seq_accuracy() \
                .train(pretrainnumbats, pretrainepochs)

            if testmode:
                for p in encdec.get_params():
                    print np.linalg.norm(p.v - oldparamvals[p], ord=1)
            savepath = "{}.pre.sp.model".format(random.randint(1000, 9999))
            print "PRETRAIN SAVEPATH: {}".format(savepath)
            encdec.save(savepath)

        # NaN somewhere at 75% in training, in one of RNU's? --> with rmsprop
        if pretrainfixdecoder:
            encdec.remake_encoder(inpvocsize=max(qdic.values()) + 1,
                                  inpembdim=embdim,
                                  inpemb=inpemb,
                                  bidir=bidir,
                                  maskid=maskid,
                                  dropout_h=dropout,
                                  dropout_in=dropout)
            encdec.dec.set_lr(0.0)
        else:
            encdec.dec.embedder.set_lr(0.0)
            encdec.enc.embedder = inpemb

    encdec.train([qmat_t, amat_t[:, :-1]], amati_t[:, 1:])\
        .sampletransform(GenSample(typdic),
                         RandomCorrupt(corruptdecoder=(2, max(adic.values()) + 1),
                                       corruptencoder=(2, max(qdic.values()) + 1),
                                       maskid=maskid, p=corruptnoise))\
        .cross_entropy().adadelta(lr=lr).grad_total_norm(5.) \
        .l2(wreg).exp_mov_avg(0.8) \
        .validate_on([qmat_x, amati_x[:, :-1]], amat_x[:, 1:]) \
        .cross_entropy().seq_accuracy()\
        .train(numbats, epochs)
    #.split_validate(splits=10, random=True)\

    qrwd = {v: k for k, v in qdic.items()}
    arwd = {v: k for k, v in adic.items()}

    def play(*x, **kw):
        hidecorrect = False
        if "hidecorrect" in kw:
            hidecorrect = kw["hidecorrect"]
        if len(x) == 1:
            x = x[0]
            q = wordids2string(qmat_x[x],
                               rwd=qrwd,
                               maskid=maskid,
                               reverse=True)
            ga = wordids2string(amat_x[x, 1:], rwd=arwd, maskid=maskid)
            pred = encdec.predict(qmat_x[x:x + 1], amati_x[x:x + 1, :-1])
            pa = wordids2string(np.argmax(pred[0], axis=1),
                                rwd=arwd,
                                maskid=maskid)
            if hidecorrect and ga == pa[:len(ga)]:  # correct
                return False
            else:
                print "{}: {}".format(x, q)
                print ga
                print pa
                return True
        elif len(x) == 0:
            for i in range(0, qmat_x.shape[0]):
                r = play(i)
                if r:
                    raw_input()
        else:
            raise Exception("invalid argument to play")

    embed()
Exemplo n.º 16
0
    def test_two_phase_training(self):
        encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                    inpembdim=50,
                                    outvocsize=17,
                                    outembdim=40,
                                    outconcat=False,
                                    encdim=110,
                                    decdim=110,
                                    statetrans=True)
        originaldecparams = encdec.dec.get_params()
        originalencparams = encdec.enc.get_params()
        originaldecparamvals = dict(
            zip(originaldecparams, [x.v for x in originaldecparams]))
        for x in originaldecparams:
            self.assertEqual(x.lrmul, 1.0)
        for x in originalencparams:
            self.assertEqual(x.lrmul, 1.0)

        inpseq = np.random.randint(0, 19, (10, 20))
        outseq = np.random.randint(0, 17, (10, 15))
        encdec.train([inpseq, outseq[:, :-1]],
                     outseq[:,
                            1:]).cross_entropy().rmsprop(lr=0.001).train(1, 5)

        traineddecparamvals = dict(
            zip(originaldecparams, [x.v for x in originaldecparams]))
        for k in originaldecparamvals:
            self.assertTrue(not np.allclose(originaldecparamvals[k],
                                            traineddecparamvals[k]))
            print "{} {}".format(
                k,
                np.linalg.norm(originaldecparamvals[k] -
                               traineddecparamvals[k]))

        encdec.dec.set_lr(0.0)
        encdec.remake_encoder(inpvocsize=21, inpembdim=60, innerdim=110)
        for x in originaldecparams:
            self.assertEqual(x.lrmul, 0.0)
        newencparams = encdec.enc.get_params()
        self.assertEqual(newencparams.difference(originalencparams),
                         newencparams)
        originalnewencparamvals = dict(
            zip(newencparams, [x.v for x in newencparams]))

        inpseq = np.random.randint(0, 21, (10, 16))
        outseq = np.random.randint(0, 17, (10, 14))
        encdec.train([inpseq, outseq[:, :-1]],
                     outseq[:,
                            1:]).cross_entropy().rmsprop(lr=0.001).train(1, 5)

        trainednewencparamvals = dict(
            zip(newencparams, [x.v for x in newencparams]))
        newdecparamvals = dict(
            zip(originaldecparams, [x.v for x in originaldecparams]))
        print "\n"
        for k in originaldecparams:
            self.assertTrue(
                np.allclose(traineddecparamvals[k], newdecparamvals[k]))
            print "{} {}".format(
                k, np.linalg.norm(newdecparamvals[k] - traineddecparamvals[k]))
        print "\n"
        for k in newencparams:
            self.assertTrue(not np.allclose(trainednewencparamvals[k],
                                            originalnewencparamvals[k]))
            print "{} {}".format(
                k,
                np.linalg.norm(trainednewencparamvals[k] -
                               originalnewencparamvals[k]))
Exemplo n.º 17
0
def run(p="../../../data/atis/atis.pkl",
        wordembdim=70,
        lablembdim=70,
        innerdim=300,
        lr=0.01,
        numbats=100,
        epochs=20,
        validinter=1,
        wreg=0.0001,
        depth=1,
        attdim=300):
    train, test, dics = pickle.load(open(p))
    word2idx = dics["words2idx"]
    table2idx = dics["tables2idx"]
    label2idx = dics["labels2idx"]
    label2idxrev = {v: k for k, v in label2idx.items()}
    train = zip(*train)
    test = zip(*test)
    print "%d training examples, %d test examples" % (len(train), len(test))
    #tup2text(train[0], word2idx, table2idx, label2idx)
    maxlen = 0
    for tup in train + test:
        maxlen = max(len(tup[0]), maxlen)

    numwords = max(word2idx.values()) + 2
    numlabels = max(label2idx.values()) + 2

    # get training data
    traindata = getdatamatrix(train, maxlen, 0).astype("int32")
    traingold = getdatamatrix(train, maxlen, 2).astype("int32")
    trainmask = (traindata > 0).astype("float32")

    # test data
    testdata = getdatamatrix(test, maxlen, 0).astype("int32")
    testgold = getdatamatrix(test, maxlen, 2).astype("int32")
    testmask = (testdata > 0).astype("float32")

    res = atiseval(testgold - 1, testgold - 1, label2idxrev)
    print res  #; exit()

    # define model
    innerdim = [innerdim] * depth
    m = SimpleSeqEncDecAtt(inpvocsize=numwords,
                           inpembdim=wordembdim,
                           outvocsize=numlabels,
                           outembdim=lablembdim,
                           encdim=innerdim,
                           decdim=innerdim,
                           attdim=attdim,
                           inconcat=False)

    # training
    m.train([traindata, shiftdata(traingold), trainmask], traingold).adagrad(lr=lr).grad_total_norm(1.).seq_cross_entropy().l2(wreg)\
        .validate_on([testdata, shiftdata(testgold), testmask], testgold).seq_cross_entropy().seq_accuracy().validinter(validinter)\
        .train(numbats, epochs)

    # predict after training
    s = SeqEncDecAttSearch(m)
    testpred = s.decode(testdata)
    testpred = testpred * testmask
    #testpredprobs = m.predict(testdata, shiftdata(testgold), testmask)
    #testpred = np.argmax(testpredprobs, axis=2)-1
    #testpred = testpred * testmask
    #print np.vectorize(lambda x: label2idxrev[x] if x > -1 else " ")(testpred)

    evalres = atiseval(testpred - 1, testgold - 1, label2idxrev)
    print evalres
Exemplo n.º 18
0
def run(
    epochs=50,
    mode="char",  # "char" or "word" or "charword"
    numbats=1000,
    lr=0.1,
    wreg=0.000001,
    bidir=False,
    layers=1,
    encdim=200,
    decdim=200,
    embdim=100,
    negrate=1,
    margin=1.,
    hingeloss=False,
    debug=False,
    preeval=False,
    sumhingeloss=False,
    checkdata=False,  # starts interactive shell for data inspection
    printpreds=False,
    subjpred=False,
    predpred=False,
    specemb=-1,
    balancednegidx=False,
    usetypes=False,
    evalsplits=50,
    relembrep=False,
):
    if debug:  # debug settings
        sumhingeloss = True
        numbats = 10
        lr = 0.02
        epochs = 10
        printpreds = True
        whatpred = "all"
        if whatpred == "pred":
            predpred = True
        elif whatpred == "subj":
            subjpred = True
        #preeval = True
        specemb = 100
        margin = 1.
        balancednegidx = True
        evalsplits = 1
        relembrep = True
        #usetypes=True
        #mode = "charword"
        #checkdata = True
    # load the right file
    maskid = -1
    tt = ticktock("script")
    specids = specemb > 0
    tt.tick()
    (traindata, traingold), (validdata, validgold), (testdata, testgold), \
    worddic, entdic, entmat, relstarts, canids, wordmat, chardic\
        = readdata(mode, testcans="testcans.pkl", debug=debug, specids=True,
                   usetypes=usetypes, maskid=maskid)
    entmat = entmat.astype("int32")

    #embed()

    if subjpred is True and predpred is False:
        traingold = traingold[:, [0]]
        validgold = validgold[:, [0]]
        testgold = testgold[:, [0]]
    if predpred is True and subjpred is False:
        traingold = traingold[:, [1]]
        validgold = validgold[:, [1]]
        testgold = testgold[:, [1]]

    if checkdata:
        rwd = {v: k for k, v in worddic.items()}
        red = {v: k for k, v in entdic.items()}

        def p(xids):
            return (" " if mode == "word" else "").join(
                [rwd[xid] if xid > -1 else "" for xid in xids])

        embed()

    print traindata.shape, traingold.shape, testdata.shape, testgold.shape

    tt.tock("data loaded")

    # *data: matrix of word ids (-1 filler), example per row
    # *gold: vector of true entity ids
    # entmat: matrix of word ids (-1 filler), entity label per row, indexes according to *gold
    # *dic: from word/ent-fbid to integer id, as used in data

    numwords = max(worddic.values()) + 1
    numents = max(entdic.values()) + 1
    print "%d words, %d entities" % (numwords, numents)

    if bidir:
        encinnerdim = [encdim / 2] * layers
    else:
        encinnerdim = [encdim] * layers

    memembdim = embdim
    memlayers = layers
    membidir = bidir
    if membidir:
        decinnerdim = [decdim / 2] * memlayers
    else:
        decinnerdim = [decdim] * memlayers

    entenc = EntEnc(
        SimpleSeq2Vec(indim=numwords,
                      inpembdim=memembdim,
                      innerdim=decinnerdim,
                      maskid=maskid,
                      bidir=membidir))

    numentembs = len(np.unique(entmat[:, 0]))
    if specids:  # include vectorembedder
        entenc = EntEmbEnc(entenc, numentembs, specemb)
    if relembrep:
        repsplit = entmat[relstarts, 0]
        entenc = EntEncRep(entenc, numentembs, repsplit)

        # adjust params for enc/dec construction
        #encinnerdim[-1] += specemb
        #innerdim[-1] += specemb

    encdec = SimpleSeqEncDecAtt(inpvocsize=numwords,
                                inpembdim=embdim,
                                encdim=encinnerdim,
                                bidir=bidir,
                                outembdim=entenc,
                                decdim=decinnerdim,
                                vecout=True,
                                statetrans="matdot")

    scorerargs = ([encdec, SeqUnroll(entenc)], {
        "argproc": lambda x, y, z: ((x, y), (z, )),
        "scorer": GenDotDistance(decinnerdim[-1], entenc.outdim)
    })
    if sumhingeloss:
        scorerargs[1]["aggregator"] = lambda x: x  # no aggregation of scores
    scorer = SeqMatchScore(*scorerargs[0], **scorerargs[1])

    #scorer.save("scorer.test.save")

    # TODO: below this line, check and test
    class PreProc(object):
        def __init__(self, entmat):
            self.f = PreProcE(entmat)

        def __call__(self, encdata, decsg,
                     decgold):  # gold: idx^(batsize, seqlen)
            return (encdata, self.f(decsg), self.f(decgold)), {}

    class PreProcE(object):
        def __init__(self, entmat):
            self.em = Val(entmat)

        def __call__(self, x):
            return self.em[x]

    transf = PreProc(entmat)

    class NegIdxGen(object):
        def __init__(self, rng, midsplit=None):
            self.min = 0
            self.max = rng
            self.midsplit = midsplit

        def __call__(
            self, datas, sgold, gold
        ):  # the whole target sequence is corrupted, corruption targets the whole set of entities and relations together
            if self.midsplit is None or not balancednegidx:
                return datas, sgold, np.random.randint(
                    self.min, self.max, gold.shape).astype("int32")
            else:
                entrand = np.random.randint(self.min, self.midsplit,
                                            gold.shape)
                relrand = np.random.randint(self.midsplit, self.max,
                                            gold.shape)
                mask = np.random.randint(0, 2, gold.shape)
                ret = entrand * mask + relrand * (1 - mask)
                return datas, sgold, ret.astype("int32")

    obj = lambda p, n: n - p
    if hingeloss:
        obj = lambda p, n: (n - p + margin).clip(0, np.infty)
    if sumhingeloss:  #
        obj = lambda p, n: T.sum((n - p + margin).clip(0, np.infty), axis=1)

    traingoldshifted = shiftdata(traingold)
    validgoldshifted = shiftdata(validgold)

    #embed()
    # eval
    if preeval:
        tt.tick("pre-evaluating")
        s = SeqEncDecRankSearch(encdec, entenc, scorer.s, scorer.agg)
        eval = FullRankEval()
        pred, scores = s.decode(testdata,
                                testgold.shape[1],
                                candata=entmat,
                                canids=canids,
                                split=evalsplits,
                                transform=transf.f,
                                debug=printpreds)
        evalres = eval.eval(pred, testgold, debug=debug)
        for k, evalre in evalres.items():
            print("{}:\t{}".format(k, evalre))
        tt.tock("pre-evaluated")

    negidxgenargs = ([numents], {"midsplit": relstarts})
    if debug:
        pass
        #negidxgenargs = ([numents], {})

    tt.tick("training")
    nscorer = scorer.nstrain([traindata, traingoldshifted, traingold]).transform(transf) \
        .negsamplegen(NegIdxGen(*negidxgenargs[0], **negidxgenargs[1])).negrate(negrate).objective(obj) \
        .adagrad(lr=lr).l2(wreg).grad_total_norm(1.0) \
        .validate_on([validdata, validgoldshifted, validgold]) \
        .train(numbats=numbats, epochs=epochs)
    tt.tock("trained")

    #scorer.save("scorer.test.save")

    # eval
    tt.tick("evaluating")
    s = SeqEncDecRankSearch(encdec, entenc, scorer.s, scorer.agg)
    eval = FullRankEval()
    pred, scores = s.decode(testdata,
                            testgold.shape[1],
                            candata=entmat,
                            canids=canids,
                            split=evalsplits,
                            transform=transf.f,
                            debug=printpreds)
    if printpreds:
        print pred
    debugarg = "subj" if subjpred else "pred" if predpred else False
    evalres = eval.eval(pred, testgold, debug=debugarg)
    for k, evalre in evalres.items():
        print("{}:\t{}".format(k, evalre))
    tt.tock("evaluated")

    # save
    basename = os.path.splitext(os.path.basename(__file__))[0]
    dirname = basename + ".results"
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    savenamegen = lambda i: "{}/{}.res".format(dirname, i)
    savename = None
    for i in xrange(1000):
        savename = savenamegen(i)
        if not os.path.exists(savename):
            break
        savename = None
    if savename is None:
        raise Exception("exceeded number of saved results")
    with open(savename, "w") as f:
        f.write("{}\n".format(" ".join(sys.argv)))
        for k, evalre in evalres.items():
            f.write("{}:\t{}\n".format(k, evalre))
Exemplo n.º 19
0
def run(
        numbats=50,
        epochs=10,
        lr=0.5,
        embdim=50,
        encdim=400,
        dropout=0.2,
        layers=1,
        inconcat=True,
        outconcat=True,
        posemb=False,
        customemb=False,
        preproc="none",  # "none" or "generate" or "abstract" or "gensample"
        bidir=False,
        corruptnoise=0.0,
        inspectdata=False,
        relinearize="none",
        wreg=0.0,
        testmode=False,
        autolr=0.5,
        autonumbats=500,
        **kw):

    ######### DATA LOADING AND TRANSFORMATIONS ###########
    srctransformer = None
    if relinearize != "none":
        lambdaparser = LambdaParser()
        if relinearize == "greedy":

            def srctransformer(x):
                return lambdaparser.parse(x).greedy_linearize(deeppref=True)
        elif relinearize == "deep":

            def srctransformer(x):
                return lambdaparser.parse(x).deep_linearize()
        else:
            raise Exception("unknown linearization")
    adic = {}
    ### AUTO DATA LOAD ###
    qmat_auto, amat_auto, qdic_auto, adic, qwc_auto, awc_auto = \
        loadgeoauto(reverse=True, transformer=srctransformer)

    def pp(i):
        print wordids2string(qmat_auto[i],
                             {v: k
                              for k, v in qdic_auto.items()}, 0)
        print wordids2string(amat_auto[i], {v: k for k, v in adic.items()}, 0)

    if inspectdata:
        print "auto data inspect"
        #embed()

    ### TRAIN DATA LOAD ###
    qmat, amat, qdic, adic, qwc, awc = loadgeo(customemb=customemb,
                                               reverse=True,
                                               transformer=srctransformer,
                                               adic=adic)

    maskid = 0
    typdic = None
    oqmat = qmat.copy()
    oamat = amat.copy()
    print "{} is preproc".format(preproc)
    if preproc != "none":
        qmat, amat, qdic, adic, qwc, awc = preprocess(
            qmat,
            amat,
            qdic,
            adic,
            qwc,
            awc,
            maskid,
            qreversed=True,
            dorare=preproc != "generate")
        if preproc == "generate":  # alters size
            print "generating"
            qmat, amat = generate(qmat,
                                  amat,
                                  qdic,
                                  adic,
                                  oqmat,
                                  oamat,
                                  reversed=True)
            #embed()
        elif preproc == "gensample":
            typdic = gentypdic(qdic, adic)

    ######### train/test split from here #########
    qmat_t, qmat_x = split_train_test(qmat)
    amat_t, amat_x = split_train_test(amat)
    oqmat_t, oqmat_x = split_train_test(oqmat)
    oamat_t, oamat_x = split_train_test(oamat)

    qoverlap, aoverlap, overlap = compute_overlap(qmat_t, amat_t, qmat_x,
                                                  amat_x)
    print "overlaps: {}, {}: {}".format(len(qoverlap), len(aoverlap),
                                        len(overlap))

    if inspectdata:
        embed()

    np.random.seed(12345)

    encdimi = [encdim / 2 if bidir else encdim] * layers
    decdimi = [encdim] * layers

    amati_t, amati_x = amat_t, amat_x
    oamati_t, oamati_x = oamat_t, oamat_x
    amati_auto = amat_auto

    if posemb:  # add positional indexes to datamatrices
        qmat_t, oqmat_t, amat_t, oamat_t = add_pos_indexes(
            qmat_t, oqmat_t, amat_t, oamat_t)
        qmat_x, oqmat_x, amat_x, oamat_x = add_pos_indexes(
            qmat_x, oqmat_x, amat_x, oamat_x)

    if preproc == "gensample":
        qmat_x, amat_x, amati_x = oqmat_x, oamat_x, oamati_x

    rqdic = {v: k for k, v in qdic.items()}
    radic = {v: k for k, v in adic.items()}

    def tpp(i):
        print wordids2string(qmat_t[i], rqdic, 0)
        print wordids2string(amat_t[i], radic, 0)

    def xpp(i):
        print wordids2string(qmat_x[i], rqdic, 0)
        print wordids2string(amat_x[i], radic, 0)

    if inspectdata:
        embed()
    print "{} training examples".format(qmat_t.shape[0])

    ################## MODEL DEFINITION ##################
    inpemb = WordEmb(worddic=qdic, maskid=maskid, dim=embdim)
    outemb = WordEmb(worddic=adic, maskid=maskid, dim=embdim)

    inpemb_auto = WordEmb(worddic=qdic_auto, maskid=maskid, dim=embdim)

    if customemb:
        inpemb, outemb = do_custom_emb(inpemb, outemb, awc, embdim)
        inpemb_auto, outemb = do_custom_emb(inpemb_auto, outemb, awc_auto,
                                            embdim)

    if posemb:  # use custom emb layers, with positional embeddings
        posembdim = 50
        inpemb = VectorPosEmb(inpemb, qmat_t.shape[1], posembdim)
        outemb = VectorPosEmb(outemb, amat_t.shape[1], posembdim)

        inpemb_auto = VectorPosEmb(inpemb_auto, qmat_auto.shape[1], posembdim)
        outemb = VectorPosEmb(outemb, max(amat_auto.shape[1], amat_t.shape[1]),
                              posembdim)

    smodim = embdim
    smo = SoftMaxOut(indim=encdim + encdim,
                     innerdim=smodim,
                     outvocsize=len(adic) + 1,
                     dropout=dropout)

    if customemb:
        smo.setlin2(outemb.baseemb.W.T)

    # main encdec model
    encdec = SimpleSeqEncDecAtt(
        inpvocsize=max(qdic.values()) + 1,
        inpembdim=embdim,
        inpemb=inpemb,
        outvocsize=max(adic.values()) + 1,
        outembdim=embdim,
        outemb=outemb,
        encdim=encdimi,
        decdim=decdimi,
        maskid=maskid,
        statetrans=True,
        dropout=dropout,
        inconcat=inconcat,
        outconcat=outconcat,
        rnu=GRU,
        vecout=smo,
        bidir=bidir,
    )

    encdec_auto = SimpleSeqEncDecAtt(inpvocsize=max(qdic_auto.values()) + 1,
                                     inpembdim=embdim,
                                     inpemb=inpemb_auto,
                                     encdim=encdimi,
                                     decdim=decdimi,
                                     maskid=maskid,
                                     statetrans=True,
                                     dropout=dropout,
                                     inconcat=inconcat,
                                     outconcat=outconcat,
                                     rnu=GRU,
                                     bidir=bidir,
                                     decoder=encdec.dec)

    encdec_params = encdec.get_params()
    encdec_auto_params = encdec_auto.get_params()
    dec_params = encdec.dec.get_params()
    overlapping_params = encdec_params.intersection(encdec_auto_params)
    print "\n".join(map(str, overlapping_params))
    assert (len(overlapping_params.difference(dec_params)) == 0)

    ################## INTERLEAVED TRAINING ##################

    main_trainer = encdec.train([qmat_t, amat_t[:, :-1]], amati_t[:, 1:])\
        .sampletransform(GenSample(typdic),
                         RandomCorrupt(corruptdecoder=(2, max(adic.values()) + 1),
                                       corruptencoder=(2, max(qdic.values()) + 1),
                                       maskid=maskid, p=corruptnoise))\
        .cross_entropy().adadelta(lr=lr).grad_total_norm(5.) \
        .l2(wreg).exp_mov_avg(0.8) \
        .validate_on([qmat_x, amati_x[:, :-1]], amat_x[:, 1:]) \
        .cross_entropy().seq_accuracy()\
        .train_lambda(numbats, 1)

    auto_trainer = encdec_auto.train([qmat_auto, amat_auto[:, :-1]], amati_auto[:, 1:]) \
        .cross_entropy().adadelta(lr=autolr).grad_total_norm(5.) \
        .l2(wreg).exp_mov_avg(0.95) \
        .split_validate(splits=50, random=True).cross_entropy().seq_accuracy()\
        .train_lambda(autonumbats, 1)

    #embed()

    main_trainer.interleave(auto_trainer).train(epochs=epochs)

    qrwd = {v: k for k, v in qdic.items()}
    arwd = {v: k for k, v in adic.items()}

    def play(*x, **kw):
        hidecorrect = False
        if "hidecorrect" in kw:
            hidecorrect = kw["hidecorrect"]
        if len(x) == 1:
            x = x[0]
            q = wordids2string(qmat_x[x],
                               rwd=qrwd,
                               maskid=maskid,
                               reverse=True)
            ga = wordids2string(amat_x[x, 1:], rwd=arwd, maskid=maskid)
            pred = encdec.predict(qmat_x[x:x + 1], amati_x[x:x + 1, :-1])
            pa = wordids2string(np.argmax(pred[0], axis=1),
                                rwd=arwd,
                                maskid=maskid)
            if hidecorrect and ga == pa[:len(ga)]:  # correct
                return False
            else:
                print "{}: {}".format(x, q)
                print ga
                print pa
                return True
        elif len(x) == 0:
            for i in range(0, qmat_x.shape[0]):
                r = play(i)
                if r:
                    raw_input()
        else:
            raise Exception("invalid argument to play")

    embed()