def run( epochs=10, numbats=100, numsam=10000, lr=0.1, datap="../../../data/simplequestions/datamat.char.pkl", innerdim=200, wreg=0.00005, bidir=False, keepmincount=5, mem=False, sameenc=False, memaddr="dot", memattdim=100, membidir=False, memlayers=1, memmaxwords=5, memmaxchars=20, layers=1, ): (traindata, traingold), (validdata, validgold), (testdata, testgold), chardic, entdic\ = readdata(datap) if mem: memdata = getcharmemdata(entdic, chardic, maxwords=memmaxwords, maxchar=memmaxchars) print traindata.shape, testdata.shape numchars = max(chardic.values()) + 1 numrels = max(entdic.values()) + 1 print numchars, numrels if bidir: encinnerdim = [innerdim / 2] * layers else: encinnerdim = [innerdim] * layers enc = SimpleSeq2Vec(indim=numchars, inpembdim=None, innerdim=encinnerdim, maskid=-1, bidir=bidir) if mem: if membidir: innerdim = [innerdim / 2] * memlayers else: innerdim = [innerdim] * memlayers memindim = numchars memenc = SimpleSeq2Vec(indim=memindim, inpembdim=None, innerdim=innerdim, maskid=-1, bidir=membidir) if memaddr is None or memaddr == "dot": memaddr = DotMemAddr elif memaddr == "lin": memaddr = LinearGateMemAddr dec = MemVec2Idx(memenc, memdata, memdim=innerdim, memaddr=memaddr, memattdim=memattdim) else: dec = SimpleVec2Idx(indim=innerdim, outdim=numrels) m = Seq2Idx(enc, dec) m = m.train([traindata], traingold).adagrad(lr=lr).l2(wreg).grad_total_norm(1.0).cross_entropy()\ .validate_on([validdata], validgold).accuracy().cross_entropy().takebest()\ .train(numbats=numbats, epochs=epochs) pred = m.predict(testdata) print pred.shape evalres = evaluate(np.argmax(pred, axis=1), testgold) print str(evalres) + "%"
def run(epochs=10, numbats=100, numsam=10000, lr=0.1, datap="../../../data/simplequestions/datamat.wordchar.pkl", embdim=50, encdim=50, innerdim=200, wreg=0.00005, bidir=False, keepmincount=5, sameenc=False, memaddr="dot", memattdim=100, layers=1, embtrainfrac=0.0, mem=False, membidir=False, memlayers=1, sharedwordenc=False): """ Memory match-based glove-based word-level relation classification """ (traindata, traingold), (validdata, validgold), (testdata, testgold), worddic, chardic, entdic\ = readdata(datap) # get words from relation names, update word dic memdata = getmemdata(entdic, worddic, chardic) # get glove and transform word mats to glove index space d2g, newdic, glove = getdic2glove(worddic, dim=embdim, trainfrac=embtrainfrac) traindata, validdata, testdata, memdata = \ [np.concatenate([np.vectorize(d2g)(x[..., 0]).reshape(x.shape[:2] + (1,)), x[..., 1:]], axis=2) for x in [traindata, validdata, testdata, memdata]] print traindata.shape, testdata.shape #embed() numwords = max(worddic.values()) + 1 # don't use this, use glove numchars = max(chardic.values()) + 1 numrels = max(entdic.values()) + 1 if bidir: encinnerdim = [innerdim / 2] * layers else: encinnerdim = [innerdim] * layers wordemb = WordEncoderPlusGlove(numchars=numchars, encdim=encdim, embdim=embdim, maskid=-1, embtrainfrac=embtrainfrac) rnn, lastdim = SimpleSeq2Vec.makernu(embdim + encdim, encinnerdim, bidir=bidir) enc = Seq2Vec(wordemb, rnn, maskid=-1) if mem: memembdim = embdim memencdim = encdim if membidir: innerdim = [innerdim / 2] * memlayers else: innerdim = [innerdim] * memlayers if not sharedwordenc: memwordemb = WordEncoderPlusGlove(numchars=numchars, encdim=encdim, embdim=embdim, maskid=-1, embtrainfrac=embtrainfrac) else: memwordemb = wordemb memrnn, memlastdim = SimpleSeq2Vec.makernu(memembdim + memencdim, innerdim, bidir=membidir) memenc = Seq2Vec(memwordemb, memrnn, maskid=-1) if memaddr is None or memaddr == "dot": memaddr = DotMemAddr elif memaddr == "lin": memaddr = LinearGateMemAddr dec = MemVec2Idx(memenc, memdata, memdim=innerdim, memaddr=memaddr, memattdim=memattdim) else: dec = SimpleVec2Idx(indim=innerdim, outdim=numrels) m = Seq2Idx(enc, dec) m = m.train([traindata], traingold).adagrad(lr=lr).l2(wreg).grad_total_norm(1.0).cross_entropy()\ .validate_on([validdata], validgold).accuracy().cross_entropy().takebest()\ .train(numbats=numbats, epochs=epochs) pred = m.predict(testdata) print pred.shape evalres = evaluate(np.argmax(pred, axis=1), testgold) print str(evalres) + "%"