def train(self, epochs, start=0):  # 训练
     if not self.init_train:  # 训练是否初始化
         raise Exception('Train graph is not inited!')
     with self.train_graph.as_default():
         if path.isfile(self.model_file +
                        '.meta') and self.restore_model:  # 是否读取上一次的模型
             print("Reloading model file before training.")
             self.train_saver.restore(self.train_session, self.model_file)
         else:
             self.train_session.run(self.train_init)
         total_loss = 0  # 损失
         for step in range(start, epochs):  # 循环epochs次
             # 获取一组训练数据
             data = next(self.train_data)
             in_seq = data['in_seq']
             in_seq_len = data['in_seq_len']
             target_seq = data['target_seq']
             target_seq_len = data['target_seq_len']
             # 计算得到输出,损失
             output, loss, train, summary = self.train_session.run(
                 [
                     self.train_output, self.loss, self.train_op,
                     self.train_summary
                 ],
                 feed_dict={
                     self.train_in_seq: in_seq,
                     self.train_in_seq_len: in_seq_len,
                     self.train_target_seq: target_seq,
                     self.train_target_seq_len: target_seq_len
                 })
             total_loss += loss  # 增加权重
             self.log_writter.add_summary(summary, step)
             if step % self.save_step == 0:  # 每过save_step保存一次
                 self.train_saver.save(self.train_session, self.model_file)
                 print("Saving model. Step: %d, loss: %f" %
                       (step, total_loss / self.save_step))
                 # print sample output
                 sid = random.randint(0, self.batch_size - 1)
                 input_text = reader.decode_text(in_seq[sid],
                                                 self.eval_reader.vocabs)
                 output_text = reader.decode_text(output[sid],
                                                  self.train_reader.vocabs)
                 target_text = reader.decode_text(
                     target_seq[sid],
                     self.train_reader.vocabs).split(' ')[1:]
                 target_text = ' '.join(target_text)
                 print('******************************')
                 print('src: ' + input_text)
                 print('output: ' + output_text)
                 print('target: ' + target_text)
             if step % self.eval_step == 0:  # 每过eval_step评估一次平均损失
                 bleu_score = self.eval(step)
                 print("Evaluate model. Step: %d, score: %f, loss: %f" %
                       (step, bleu_score, total_loss / self.save_step))
                 eval_summary = tf.Summary(value=[
                     tf.Summary.Value(tag='bleu', simple_value=bleu_score)
                 ])
                 self.log_writter.add_summary(eval_summary, step)
             if step % self.save_step == 0:
                 total_loss = 0
示例#2
0
 def eval(self, train_step):
     with self.eval_graph.as_default():
         self.eval_saver.restore(self.eval_session, self.model_file)
         bleu_score = 0
         target_results = []
         output_results = []
         for step in range(0, self.eval_reader.data_size):
             data = next(self.eval_data)
             in_seq = data['in_seq']
             in_seq_len = data['in_seq_len']
             target_seq = data['target_seq']
             target_seq_len = data['target_seq_len']
             outputs = self.eval_session.run(
                     self.eval_output,
                     feed_dict={
                         self.eval_in_seq: in_seq,
                         self.eval_in_seq_len: in_seq_len})
             for i in range(len(outputs)):
                 output = outputs[i]
                 target = target_seq[i]
                 output_text = reader.decode_text(output,
                         self.eval_reader.vocabs).split(' ')
                 target_text = reader.decode_text(target[1:],
                         self.eval_reader.vocabs).split(' ')
                 prob = int(self.eval_reader.data_size * self.batch_size / 10)
                 target_results.append([target_text])
                 output_results.append(output_text)
                 if random.randint(1, prob) == 1:
                     print('====================')
                     input_text = reader.decode_text(in_seq[i],
                             self.eval_reader.vocabs)
                     print('src:' + input_text)
                     print('output: ' + ' '.join(output_text))
                     print('target: ' + ' '.join(target_text))
         return bleu.compute_bleu(target_results, output_results)[0] * 100
示例#3
0
    def infer(self, texts):
        if not self.init_infer:
            raise Exception('Infer graph is not inited!')
        with self.infer_graph.as_default():
            if not isinstance(texts, list):
                texts = [texts]
            in_seq_list = []
            in_seq_len_list = []
            for text in texts:
                in_seq = reader.encode_text(
                    list(text) + [
                        '</s>',
                    ], self.infer_vocab_indices)
                in_seq_len = len(in_seq)
                in_seq_list.append(in_seq)
                in_seq_len_list.append(in_seq_len)
            import time
            pre_time = time.time()
            outputs = self.infer_session.run(self.infer_output,
                                             feed_dict={
                                                 self.infer_in_seq:
                                                 in_seq_list,
                                                 self.infer_in_seq_len:
                                                 in_seq_len_list
                                             })
            pre_time = time.time()
            outputs_final = []
            for output in outputs:
                output_text = []
                if self.decode_method == 'beam':
                    for i in range(len(output[0])):
                        output_x = [x[i] for x in output]
                        _text = reader.decode_text(output_x, self.infer_vocabs)
                        output_text.append(_text)
                else:
                    _text = reader.decode_text(output, self.infer_vocabs)
                    output_text.append(_text)

                outputs_final.append(output_text)

            return outputs_final
示例#4
0
 def head_infer(self, text):
     # input_words = reader.encode_text(text.split(' '), self.infer_vocab_indices)
     input_words = []
     if not self.init_infer:
         print("=====infer 289===== 模型未初始化")
     with self.infer_graph.as_default():
         print("===text==", text.split(' ') + ['</s>'])
         for i in range(5):
             input_words.append(random.randint(4, self.batch_size * 10))
         random.shuffle(input_words)
         print("===input_words297==", input_words)
         words_in_seq = input_words
         words_in_seq_len = len(input_words)
         print("===infer_in_seq_len 300====", self.infer_in_seq_len)
         outputs = self.infer_session.run(self.infer_output,
                                          feed_dict={
                                              self.infer_in_seq:
                                              [words_in_seq],
                                              self.infer_in_seq_len:
                                              [words_in_seq_len]
                                          })
         print("=====outputs 306====", outputs)
         input_words.append(3)
         input_words.extend(outputs[0])
         in_seq_len = len(input_words)
         output = self.infer_session.run(self.infer_output,
                                         feed_dict={
                                             self.infer_in_seq:
                                             [input_words],
                                             self.infer_in_seq_len:
                                             [in_seq_len]
                                         })
         head = reader.encode_text(text.split(' '),
                                   self.infer_vocab_indices)
         output_one, output_two = reader.headrandom(input_words, output[0],
                                                    head)
         output_top = reader.decode_text(output_one, self.infer_vocabs)
         output_text = reader.decode_text(output_two, self.infer_vocabs)
         return output_top, output_text
示例#5
0
 def infer(self, text):
     if not self.init_infer:
         raise Exception('Infer graph is not inited!')
     with self.infer_graph.as_default():
         in_seq = reader.encode_text(text.split(' ') + ['</s>',],
                 self.infer_vocab_indices)
         in_seq_len = len(in_seq)
         outputs = self.infer_session.run(self.infer_output,
                 feed_dict={
                     self.infer_in_seq: [in_seq],
                     self.infer_in_seq_len: [in_seq_len]})
         output = outputs[0]
         output_text = reader.decode_text(output, self.infer_vocabs)
         return output_text
示例#6
0
    def infer(self, text):
        input_words = reader.encode_text(text.split(' '),
                                         self.infer_vocab_indices)
        if not self.init_infer:
            print("模型未初始化")
        with self.infer_graph.as_default():
            # print("===text==", text.split(' ') + ['</s>'])
            len_text = len(text.split(' '))
            input_words = reader.handleData(input_words, len_text,
                                            self.batch_size)
            words_in_seq = input_words
            words_in_seq_len = len(input_words)
            print("===infer_in_seq_len====", self.infer_in_seq_len)

            outputs = self.infer_session.run(self.infer_output,
                                             feed_dict={
                                                 self.infer_in_seq:
                                                 [words_in_seq],
                                                 self.infer_in_seq_len:
                                                 [words_in_seq_len]
                                             })
            input_words.append(3)
            input_words.extend(outputs[0])
            in_seq_len = len(input_words)

            output = self.infer_session.run(self.infer_output,
                                            feed_dict={
                                                self.infer_in_seq:
                                                [input_words],
                                                self.infer_in_seq_len:
                                                [in_seq_len]
                                            })
            output_top = reader.decode_text(input_words, self.infer_vocabs)
            output_text = reader.decode_text(output[0], self.infer_vocabs)
            # print("return=====",output[0])
            return output_top, output_text
示例#7
0
    def train(self, epochs, start=0):
        if not self.init_train:
            raise Exception('Train graph is not inited!')
        with self.train_graph.as_default():
            if path.isfile(self.model_file + '.meta') and self.restore_model:
                print("训练之前重新加载模型文件。")
                # 重载模型的参数,继续训练或者用于测试数据
                # print("Reloading model file before training.")
                self.train_saver.restore(self.train_session, self.model_file)
            else:
                self.train_session.run(self.train_init)
            total_loss = 0
            for step in range(start, epochs):
                data = next(self.train_data)
                in_seq = data['in_seq']
                in_seq_len = data['in_seq_len']
                target_seq = data['target_seq']
                target_seq_len = data['target_seq_len']
                output, loss, train, summary = self.train_session.run(
                    [
                        self.train_output, self.loss, self.train_op,
                        self.train_summary
                    ],
                    feed_dict={
                        self.train_in_seq: in_seq,
                        self.train_in_seq_len: in_seq_len,
                        self.train_target_seq: target_seq,
                        self.train_target_seq_len: target_seq_len
                    })
                total_loss += loss
                # 将训练步骤和数据写入文件
                self.log_writter.add_summary(summary, step)
                if step % 10 == 0:
                    print("=======step===", step)

                if step % self.save_step == 0:
                    self.train_saver.save(self.train_session, self.model_file)
                    print("保存模型  步骤: %d, loss: %f" %
                          (step, total_loss / self.save_step))
                    # print sample output
                    sid = random.randint(0, self.batch_size - 1)
                    input_text = reader.decode_text(in_seq[sid],
                                                    self.eval_reader.vocabs)
                    output_text = reader.decode_text(output[sid],
                                                     self.train_reader.vocabs)
                    print("=====train  166 output[%d]========", sid,
                          output[sid])
                    target_text = reader.decode_text(
                        target_seq[sid],
                        self.train_reader.vocabs).split(' ')[1:]
                    target_text = ' '.join(target_text)
                    print('******************************')
                    print('src: ' + input_text)
                    print('output: ' + output_text)
                    print('target: ' + target_text)
                if step % self.eval_step == 0:
                    bleu_score = self.eval(step)
                    print("评估模型  步骤: %d, score: %f, loss: %f" %
                          (step, bleu_score, total_loss / self.save_step))
                    eval_summary = tf.Summary(value=[
                        tf.Summary.Value(tag='bleu', simple_value=bleu_score)
                    ])
                    self.log_writter.add_summary(eval_summary, step)
                if step % self.save_step == 0:
                    total_loss = 0
示例#8
0
    def train(self, epochs, start=0):
        if not self.init_train:
            raise Exception('Train graph is not inited!')
        with self.train_graph.as_default():
            #Judge whether additional training is needed
            if path.isfile(self.model_file + '.meta') and self.restore_model:
                print("Reloading model file before training.")
                self.train_saver.restore(self.train_session, self.model_file)
            else:
                #The training session starts with the initializer
                self.train_session.run(self.train_init)
            total_loss = 0
            for step in range(start, epochs):
                data = next(self.train_data)
                in_seq = data['in_seq']
                in_seq_len = data['in_seq_len']
                target_seq = data['target_seq']
                target_seq_len = data['target_seq_len']
                #Training session uses four variables that can be passed in to run four indicators
                #Add a new indicator about target to reconstruct loss
                # self.train_target_seq_new_weight=tf.placeholder(tf.int32, shape=[self.batch_size,None,len(self.train_vocabs)])
                #target_seq_new_weight=getReconstruct(target_seq)
                A = copy.deepcopy(target_seq)
                #print(A)
                for i in range(len(A)):

                    for j in range(len(A[i])):

                        c = A[i][j]
                        #For example, C = 1, then list [0-1,1-1,2-1,3-1...]
                        new_v = [help.edge[c][i] for i in range(178)]
                        #print(new_v)
                        A[i][j] = new_v
                target_seq_new_weight = A
                #print (A)
                #print ()
                #help.edge[i][j]
                #print("target_seq:")
                #print(target_seq)
                #print("target_seq_new_weight[0][1]")
                #print(target_seq_new_weight[0][1])

                #######################End here###############################
                output, loss, train, summary = self.train_session.run(
                    [
                        self.train_output, self.loss, self.train_op,
                        self.train_summary
                    ],
                    feed_dict={
                        self.train_in_seq: in_seq,
                        self.train_in_seq_len: in_seq_len,
                        self.train_target_seq: target_seq,
                        self.train_target_seq_len: target_seq_len,
                        self.train_target_seq_new_weight: target_seq_new_weight
                    })
                total_loss += loss
                self.log_writter.add_summary(summary, step)
                if step % self.save_step == 0:
                    self.train_saver.save(self.train_session, self.model_file)

                    # Save the training model in Android studio for easy use
                    #output_graph_def = tf.convert_variables_to_constants(sess, sess.graph_def,
                    #                                                            output_node_names=['predict'])
                    #with tf.gfile.FastGFile('model_500_200_c3//digital_gesture.pb',
                    #                     mode='wb') as f:  # ’In WB ', w stands for writing the file, and B stands for writing the data to the file in binary mode.
                    #f.write(output_graph_def.SerializeToString())

                    print("Saving model. Step: %d, loss: %f" %
                          (step, total_loss / self.save_step))
                    # print sample output
                    sid = random.randint(0, self.batch_size - 1)
                    input_text = reader.decode_text(in_seq[sid],
                                                    self.eval_reader.vocabs)
                    output_text = reader.decode_text(output[sid],
                                                     self.train_reader.vocabs)
                    target_text = reader.decode_text(
                        target_seq[sid],
                        self.train_reader.vocabs).split(' ')[1:]
                    target_text = ' '.join(target_text)
                    print('******************************')
                    print('src: ' + input_text)
                    print('output: ' + output_text)
                    print(output[sid])
                    print('target: ' + target_text)
                if step % self.eval_step == 0:
                    bleu_score = self.eval(step)
                    print("Evaluate model. Step: %d, score: %f, loss: %f" %
                          (step, bleu_score, total_loss / self.save_step))
                    eval_summary = tf.Summary(value=[
                        tf.Summary.Value(tag='bleu', simple_value=bleu_score)
                    ])
                    self.log_writter.add_summary(eval_summary, step)
                if step % self.save_step == 0:
                    total_loss = 0
示例#9
0
文件: model.py 项目: yf1291/nlp3
                    output_results.append(output_text)
                    if random.randint(1, prob) == 1:

                        input_text = reader.decode_text(in_seq[i],
                                self.eval_reader.vocabs)



            return bleu.compute_bleu(target_results, output_results)[0] * 100


    def reload_infer_model(self):
        with self.infer_graph.as_default():
            self.infer_saver.restore(self.infer_session, self.model_file)


    def infer(self, text):
        if not self.init_infer:
            raise Exception('Infer graph is not inited!')
        with self.infer_graph.as_default():
            in_seq = reader.encode_text(text.split(' ') + ['</s>',],
                    self.infer_vocab_indices)
            in_seq_len = len(in_seq)
            outputs = self.infer_session.run(self.infer_output,
                    feed_dict={
                        self.infer_in_seq: [in_seq],
                        self.infer_in_seq_len: [in_seq_len]})
            output = outputs[0]
            output_text = reader.decode_text(output, self.infer_vocabs)
            return output_text
示例#10
0
    def train(self, epochs, start=0):
        print('model file', self.model_file)
        if not self.init_train:
            raise Exception('Train graph is not inited!')
        with self.train_graph.as_default():
            if path.isfile(self.model_file + '.meta') and self.restore_model:
                print("Reloading model file before training.")
                self.train_saver.restore(self.train_session, self.model_file)
            else:
                self.train_session.run(self.train_init)
            total_loss = 0
            max_bleu = 0
            t = path.join('data/dl-data/models/tf-lib/max_output',
                          'model.ckpl')
            for step in range(start, epochs + 1):
                data = next(self.train_data)
                in_seq = data['in_seq']
                in_seq_len = data['in_seq_len']
                target_seq = data['target_seq']
                target_seq_len = data['target_seq_len']
                output, loss, train, summary = self.train_session.run(
                    [
                        self.train_output, self.loss, self.train_op,
                        self.train_summary
                    ],
                    feed_dict={
                        self.train_in_seq: in_seq,
                        self.train_in_seq_len: in_seq_len,
                        self.train_target_seq: target_seq,
                        self.train_target_seq_len: target_seq_len
                    })
                total_loss += loss
                self.log_writter.add_summary(summary, step)
                if step % self.save_step == 0:  #每迭代多少次 保存模型
                    self.train_saver.save(self.train_session, self.model_file)
                    print("Saving model. Step: %d, loss: %f" %
                          (step, total_loss / self.save_step))
                    print(self.model_file)
                    current_bleu = self.eval(step)
                    print('当前分数为', current_bleu)
                    if current_bleu > max_bleu:
                        max_bleu = current_bleu
                        self.train_saver.save(self.train_session, t)
                        print('已保存最大分数模型')
                    # print sample output
                    sid = random.randint(0, self.batch_size - 1)
                    input_text = reader.decode_text(in_seq[sid],
                                                    self.eval_reader.vocabs)
                    output_text = reader.decode_text(output[sid],
                                                     self.train_reader.vocabs)
                    target_text = reader.decode_text(
                        target_seq[sid],
                        self.train_reader.vocabs).split(' ')[1:]
                    target_text = ' '.join(target_text)
                    print('******************************')
                    print('src: ' + input_text)
                    print('output: ' + output_text)
                    print('target: ' + target_text)
                if step % self.eval_step == 0:  #评估模型
                    bleu_score = self.eval(step)
                    print("Evaluate model. Step: %d, score: %f, loss: %f" %
                          (step, bleu_score, total_loss / self.save_step))
                    # 保存模型
                    if bleu_score > max_bleu:
                        print('分数为%f,大于当前最大%f,保存模型' % (bleu_score, max_bleu))
                        max_bleu = bleu_score
                        self.train_saver.save(self.train_session, t)
                        print("save model. Step: %d, score: %f, loss: %f" %
                              (step, bleu_score, total_loss / self.save_step))

                    eval_summary = tf.Summary(value=[
                        tf.Summary.Value(tag='bleu', simple_value=bleu_score)
                    ])
                    self.log_writter.add_summary(eval_summary, step)
                if step % self.save_step == 0:
                    total_loss = 0