def __init__(self):
     self.model = NMT_Model.NMT_Model()
     self.srcVocab = Corpus.Vocabulary()
     self.trgVocab = Corpus.Vocabulary()
     self.srcVocab.loadDict(Config.srcVocabF)
     self.trgVocab.loadDict(Config.trgVocabF)
     self.trainData = Corpus.BiCorpus(self.srcVocab, self.trgVocab,
                                      Config.trainSrcF, Config.trainTrgF)
     self.valData = Corpus.BiCorpus(self.srcVocab, self.trgVocab,
                                    Config.valSrcF, Config.valTrgF)
     self.buckets = self.trainData.getBuckets()
     self.networkBucket = {}
     self.bestValCE = 999999
     self.bestBleu = 0
     self.badValCount = 0
     self.maxBadVal = 5
     self.learningRate = Config.LearningRate
     self.inputSrc = tf.placeholder(
         tf.int32,
         shape=[Config.MaxLength, Config.BatchSize],
         name='srcInput')
     self.maskSrc = tf.placeholder(
         tf.float32,
         shape=[Config.MaxLength, Config.BatchSize],
         name='srcMask')
     self.inputTrg = tf.placeholder(
         tf.int32,
         shape=[Config.MaxLength, Config.BatchSize],
         name='trgInput')
     self.maskTrg = tf.placeholder(
         tf.float32,
         shape=[Config.MaxLength, Config.BatchSize],
         name='trgMask')
     self.optimizer = tf.train.AdamOptimizer()
     self.createBucketNetworks()
示例#2
0
 def __init__(self):
     self.model = NMT_Model.NMT_Model()
     self.srcVocab = Corpus.Vocabulary()
     self.trgVocab = Corpus.Vocabulary()
     self.srcVocab.loadDict(Config.srcVocabF)
     self.trgVocab.loadDict(Config.trgVocabF)
     self.trainData = Corpus.BiCorpus(self.srcVocab, self.trgVocab,
                                      Config.trainSrcF, Config.trainTrgF)
     self.valData = Corpus.BiCorpus(self.srcVocab, self.trgVocab,
                                    Config.valSrcF, Config.valTrgF)
     self.valBleuData = Corpus.ValCorpus(self.srcVocab, self.trgVocab,
                                         Config.valFile, Config.refCount)
     self.decoder = NMT_Decoder.NMT_Decoder(self.model, self.srcVocab,
                                            self.trgVocab)
     self.networkBucket = {}
     self.exampleNetwork = self.getNetwork(1, 1)
     self.bestValCE = 999999
     self.bestBleu = 0
     self.badValCount = 0
     self.maxBadVal = 5
     self.learningRate = Config.LearningRate
     if os.path.isfile(Config.initModelF):
         self.model.loadModel(Config.initModelF)