def __init__(self): self.__use_pretrain = True # Construct hyper-parameter self.hps = hps self.enc_len = hps.bucket[0] self.dec_len = hps.bucket[1] self.sens_num = hps.sens_num self.key_slots = hps.key_slots self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) # If there isn't a pre-trained word embedding, just # set it to None, then the word embedding # will be initialized with a norm distribution. if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print ("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 self.PAD_ID = self.tool.get_PAD_ID() assert self.PAD_ID == 0 self.hps = self.hps._replace(vocab_size=vocab_size) print("Params sets: ") print (self.hps) print("___________________")
def __init__(self): # Construct hyper-parameter self.hps = hps self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) # If there isn't a pre-trained word embedding, just # set it to None, then the word embedding # will be initialized with a norm distribution. if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 PAD_ID = self.tool.get_PAD_ID() print(PAD_ID) assert PAD_ID > 0 self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID) print("Params sets: ") print(self.hps) print("___________________") raw_input("Please check the parameters and press enter to continue>")
def __init__(self, beam_size, model_file=None): # Construct hyper-parameter self.hps = hps self.dtool = data_tool self.beam_size = beam_size self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 PAD_ID = self.tool.get_PAD_ID() assert PAD_ID > 0 self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID, batch_size=beam_size, mode='decode') self.model = PoemModel(self.hps) self.EOS_ID, self.PAD_ID, self.GO_ID, self.UNK_ID \ = self.tool.get_special_IDs() self.enc_len = self.hps.bucket[0] self.dec_len = self.hps.bucket[1] self.topic_trace_size = self.hps.topic_trace_size self.key_slots = self.hps.key_slots self.his_mem_slots = self.hps.his_mem_slots self.his_mem_size = self.hps.his_mem_size self.global_trace_size = self.hps.global_trace_size self.hidden_size = self.hps.hidden_size self.sess = tf.InteractiveSession() if model_file is None: self.load_model(self.sess, self.model) else: self.load_model_by_path(self.sess, self.model, model_file) self.__buildPH()
def __init__(self, beam_size, model_path=None): # Construct hyper-parameter self.hps = hps self.dtool = data_tool self.beam_size = beam_size if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 self.hps = self.hps._replace(vocab_size=vocab_size, batch_size=beam_size) f_idxes = self.tool.build_fvec() self.model = graphs.WorkingMemoryModel(self.hps, f_idxes) self.model.build_eval_graph() self.PAD_ID, self.UNK_ID, self.B_ID, self.E_ID, _ \ = self.tool.get_special_IDs() self.enc_len = self.hps.bucket[0] self.dec_len = self.hps.bucket[1] self.topic_trace_size = self.hps.topic_trace_size self.key_slots = self.hps.key_slots self.his_mem_slots = self.hps.his_mem_slots self.mem_size = self.hps.mem_size self.global_trace_size = self.hps.global_trace_size self.hidden_size = self.hps.hidden_size self.sess = tf.InteractiveSession() self.load_model(model_path) self.__buildPH()
class PoemTrainer(object): def __init__(self): self.__use_pretrain = True # Construct hyper-parameter self.hps = hps self.enc_len = hps.bucket[0] self.dec_len = hps.bucket[1] self.sens_num = hps.sens_num self.key_slots = hps.key_slots self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) # If there isn't a pre-trained word embedding, just # set it to None, then the word embedding # will be initialized with a norm distribution. if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print ("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 self.PAD_ID = self.tool.get_PAD_ID() assert self.PAD_ID == 0 self.hps = self.hps._replace(vocab_size=vocab_size) print("Params sets: ") print (self.hps) print("___________________") def create_model(self, sess, path): ckpt = tf.train.get_checkpoint_state(path) saver = tf.train.Saver(tf.global_variables() , write_version=tf.train.SaverDef.V1) #print (ckpt.model_checkpoint_path) if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) saver.restore(sess, ckpt.model_checkpoint_path) else: print("Created model with fresh parameters.") sess.run(tf.global_variables_initializer()) def restore_pretrained_model(self, sess, saver_for_restore, pre_model_dir, model_dir): """Restores pretrained model if there is no ckpt model.""" ckpt = tf.train.get_checkpoint_state(model_dir) checkpoint_exists = ckpt and ckpt.model_checkpoint_path if checkpoint_exists: print('Checkpoint exists in FLAGS.train_dir; skipping ' 'pretraining restore') return pretrain_ckpt = tf.train.get_checkpoint_state(pre_model_dir) if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path): raise ValueError('Asked to restore model from %s but no checkpoint found.' % model_dir) print ("restore from %s" % (pretrain_ckpt.model_checkpoint_path)) saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path) print ("restor OK!") def step(self, sess, data_dic, model, valid): '''For training one batch''' keep_prob = 1.0 if valid else self.hps.keep_prob input_feed = {} input_feed[model.keep_prob] = keep_prob input_feed[model.gama] = [self.gama] for step in range(self.key_slots): # NOTE: Each topic word must consist of no more than 2 characters for j in range(2): input_feed[model.key_inps[step][j].name] = data_dic['key_inps'][step][j] input_feed[model.key_mask.name] = data_dic['key_mask'] for step in range(0, self.sens_num): if len(data_dic['enc_inps'][step]) != self.enc_len: raise ValueError("Encoder length must be equal %d != %d." % (len(data_dic['enc_inps'][step]), self.enc_len)) if len(data_dic['dec_inps'][step]) != self.dec_len: raise ValueError("Decoder length must be equal %d != %d. " % (len(data_dic['dec_inps'][step]), self.dec_len)) if len(data_dic['trg_weights'][step]) != self.dec_len: raise ValueError("Weights length must be equal %d != %d." % (len(data_dic['trg_weights'][step]), self.dec_len)) for l in range(self.enc_len): input_feed[model.enc_inps[step][l].name] = data_dic['enc_inps'][step][l] input_feed[model.write_masks[step][l].name] = data_dic['write_masks'][step][l] for l in range(self.dec_len): input_feed[model.dec_inps[step][l].name] = data_dic['dec_inps'][step][l] input_feed[model.trg_weights[step][l].name] = data_dic['trg_weights'][step][l] input_feed[model.len_inps[step][l].name] = data_dic['len_inps'][step][l] input_feed[model.ph_inps[step][l].name] = data_dic['ph_inps'][step][l] last_target = model.dec_inps[step][self.dec_len].name input_feed[last_target] = np.ones([self.hps.batch_size], dtype=np.int32) * self.PAD_ID input_feed[model.enc_mask[step].name] =data_dic['enc_mask'][step] output_feed = [] for step in range(0, self.sens_num): for l in range(self.dec_len): # Output logits. output_feed.append(self.outs_op[step][l]) output_feed += [self.gen_loss_op, self.l2_loss_op, self.global_step_op] if not valid: output_feed += [self.train_op] outputs = sess.run(output_feed, input_feed) logits = [] for step in range(0, self.sens_num): logits.append(outputs[step*self.dec_len:(step+1)*self.dec_len]) n = self.dec_len * self.sens_num return logits, outputs[n], outputs[n+1], outputs[n+2] def sample(self, enc_inps, dec_inps, key_inps, outputs): sample_num = min(self.hps.sample_num, self.hps.batch_size) # Random select some examples idxes = random.sample(list(range(0, self.hps.batch_size)), sample_num) # for idx in idxes: keys = [] # NOTE: Each keyword must consist of no more than 2 characters for i in range (0, self.key_slots): key_idx = [key_inps[i][0][idx], key_inps[i][1][idx]] keys.append("".join(self.tool.idxes2chars(key_idx))) key_str = " ".join(keys) # Build lines print ("%s" % (key_str)) for step in range(0, self.sens_num): inputs = [c[idx] for c in enc_inps[step]] sline = "".join(self.tool.idxes2chars(inputs)) target = [c[idx] for c in dec_inps[step]] tline = "".join(self.tool.idxes2chars(target)) outline = [c[idx] for c in outputs[step]] outline = self.tool.greedy_search(outline) if step == 0: print(sline.ljust(35) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ") else: print(sline.ljust(30) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ") def run_validation(self, sess, model, valid_batches, valid_batch_num, epoch): print("run validation...") total_gen_loss = 0.0 total_l2_loss = 0.0 for step in range(0, valid_batch_num): batch = valid_batches[step] _, gen_loss, l2_loss, _ = self.step(sess, batch, model, True) total_gen_loss += gen_loss total_l2_loss += l2_loss total_gen_loss /= valid_batch_num total_l2_loss /= valid_batch_num info = "validation epoch: %d loss: %.3f ppl: %.2f, l2 loss: %.3f" % \ (epoch, total_gen_loss, np.exp(total_gen_loss), total_l2_loss) print (info) fout = open("validlog.txt", 'a') fout.write(info + "\n") fout.close() def train(self): print ("building data...") train_batches, train_batch_num = self.tool.build_data(self.hps.train_data, self.hps.batch_size, 'train') valid_batches, valid_batch_num = self.tool.build_data(self.hps.valid_data, self.hps.batch_size, 'train') print ("train batch num: %d" % (train_batch_num)) print ("valid batch num: %d" % (valid_batch_num)) f_idxes = self.tool.build_fvec() input("Please check the parameters and press enter to continue>") model = graphs.WorkingMemoryModel(self.hps, f_idxes=f_idxes) self.train_op, self.outs_op, self.gen_loss_op, self.l2_loss_op, \ self.global_step_op = model.training() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.98) gpu_options.allow_growth = True with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess: self.create_model(sess, self.hps.model_path) # check the initialization print ("uninitialized_variables") uni = sess.run(tf.report_uninitialized_variables(tf.global_variables(), name='report_uninitialized_variables')) print (len(uni)) if self.__use_pretrain: variables_to_restore = model.pretrained_variables print ("parameters num of restored pre-trained model: %d" % (len(variables_to_restore))) saver_for_restore = tf.train.Saver(variables_to_restore) self.restore_pretrained_model(sess, saver_for_restore, self.hps.pre_model_path, self.hps.model_path) burn_down = min(self.hps.burn_down, self.hps.max_epoch) total_annealing_steps = self.hps.annealing_epoches * train_batch_num print ("total annealing steps: %d" % (total_annealing_steps)) total_gen_loss = 0.0 total_l2_loss = 0.0 total_steps = 0 time1 = time.time() for epoch in range(1, self.hps.max_epoch+1): for step in range(0, train_batch_num): batch = train_batches[step] total_steps += 1 self.gama = 1.0 - min(total_steps / total_annealing_steps, 1.0) * 0.9 logits, gen_loss, l2_loss, global_step = self.step(sess, batch, model, False) total_gen_loss += gen_loss total_l2_loss += l2_loss if step % self.hps.steps_per_train_log == 0: time2 = time.time() time_cost = float(time2-time1) / self.hps.steps_per_train_log cur_gen_loss = total_gen_loss / (total_steps+1) cur_ppl = math.exp(cur_gen_loss) cur_l2_loss = total_l2_loss / (total_steps+1) self.sample(batch['enc_inps'], batch['dec_inps'], batch['key_inps'], logits) process_info = "epoch: %d, %d/%d %.3f%%, %.3f s per iter" % (epoch, step, train_batch_num, float(step+1) /train_batch_num * 100, time_cost) train_info = "train loss: %.3f ppl:%.2f, l2 loss: %.3f, lr:%.4f. gama:%.4f" \ % (cur_gen_loss, cur_ppl, cur_l2_loss, model.learning_rate.eval(), self.gama) print (process_info) print(train_info) print("______________________") info = process_info + " " + train_info fout = open("trainlog.txt", 'a') fout.write(info + "\n") fout.close() time1 = time.time() current_epoch = int(global_step // train_batch_num) lr0 = model.learning_rate.eval() if epoch > burn_down: print ("lr decay...") sess.run(model.learning_rate_decay_op) lr1 = model.learning_rate.eval() print ("%.4f to %.4f" % (lr0, lr1)) if epoch % self.hps.epoches_per_validate == 0: self.run_validation(sess, model, valid_batches, valid_batch_num, current_epoch) if epoch % self.hps.epoches_per_checkpoint == 0: # Save checkpoint print("saving model...") checkpoint_path = os.path.join(self.hps.model_path, "poem.ckpt" + "_" + str(current_epoch)) saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1 ) saver.save(sess, checkpoint_path, global_step=global_step) print("shuffle data...") random.shuffle(train_batches)
class PoemTrainer(object): def __init__(self): # Construct hyper-parameter self.hps = hps self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) # If there isn't a pre-trained word embedding, just # set it to None, then the word embedding # will be initialized with a norm distribution. if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 PAD_ID = self.tool.get_PAD_ID() print(PAD_ID) assert PAD_ID > 0 self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID) print("Params sets: ") print(self.hps) print("___________________") raw_input("Please check the parameters and press enter to continue>") def create_model(self, session, model): """Create the model and initialize or load parameters in session.""" ckpt = tf.train.get_checkpoint_state(self.hps.model_path) if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created model with fresh parameters.") session.run(tf.global_variables_initializer()) return model def sample(self, enc_inps, dec_inps, key_inps, outputs): sample_num = self.hps.sample_num if sample_num > self.hps.batch_size: sample_num = self.hps.batch_size # Random select some examples idxes = random.sample(range(0, self.hps.batch_size), sample_num) # for idx in idxes: keys = [] for i in xrange(0, self.hps.key_slots): key_idx = [key_inps[i][0][idx], key_inps[i][1][idx]] keys.append("".join(self.tool.idxes2chars(key_idx))) key_str = " ".join(keys) # Build lines print("%s" % (key_str)) for step in xrange(0, self.hps.sens_num): inputs = [c[idx] for c in enc_inps[step]] sline = "".join(self.tool.idxes2chars(inputs)) target = [c[idx] for c in dec_inps[step]] tline = "".join(self.tool.idxes2chars(target)) outline = [c[idx] for c in outputs[step]] outline = self.tool.greedy_search(outline) if step == 0: print( sline.ljust(25) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ") else: print( sline.ljust(30) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ") def run_validation(self, sess, model, valid_batches, valid_batch_num, epoch): print("run validation...") total_gen_loss = 0.0 for step in xrange(0, valid_batch_num): batch = valid_batches[step] outputs, gen_loss = model.step(sess, batch, True) total_gen_loss += gen_loss total_gen_loss /= valid_batch_num info = "validation epoch: %d loss: %.3f ppl: %.2f" % \ (epoch, total_gen_loss, np.exp(total_gen_loss)) print(info) fout = open("validlog.txt", 'a') fout.write(info + "\n") fout.close() def train(self): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.98) gpu_options.allow_growth = True with tf.Session(config=tf.ConfigProto( gpu_options=gpu_options)) as sess: # Create model. model = PoemModel(self.hps, self.init_emb) self.create_model(sess, model) # Build batched data train_batch_num, valid_batch_num, \ train_batches, valid_batches = self.tool.build_data( self.hps.batch_size, self.hps.train_data, self.hps.valid_data) print("train_batch_num: %d" % (train_batch_num)) print("valid_batch_num: %d" % (valid_batch_num)) for epoch in xrange(1, self.hps.max_epoch + 1): total_gen_loss = 0.0 time1 = time.time() for step in xrange(0, train_batch_num): batch = train_batches[step] outputs, gen_loss = model.step(sess, batch, False) total_gen_loss += gen_loss if step % self.hps.steps_per_train_log == 0: time2 = time.time() time_cost = float(time2 - time1) / self.hps.steps_per_train_log time1 = time2 process_info = "epoch: %d, %d/%d %.3f%%, %.3f s per iter" % ( epoch, step, train_batch_num, float(step + 1) / train_batch_num * 100, time_cost) self.sample(batch['enc_inps'], batch['dec_inps'], batch['key_inps'], outputs) current_gen_loss = total_gen_loss / (step + 1) ppl = math.exp(current_gen_loss ) if current_gen_loss < 300 else float( 'inf') train_info = "train loss: %.3f ppl:%.2f" % ( current_gen_loss, ppl) print(process_info) print(train_info) print("______________________") info = process_info + " " + train_info fout = open("trainlog.txt", 'a') fout.write(info + "\n") fout.close() current_epoch = int(model.global_step.eval() // train_batch_num) if epoch > self.hps.burn_down: lr0 = model.learning_rate.eval() print("lr decay...") sess.run(model.learning_rate_decay_op) lr1 = model.learning_rate.eval() print("%.4f to %.4f" % (lr0, lr1)) if epoch % self.hps.epoches_per_validate == 0: self.run_validation(sess, model, valid_batches, valid_batch_num, epoch) if epoch % self.hps.epoches_per_checkpoint == 0: # Save checkpoint and zero timer and loss. print("saving model...") checkpoint_path = os.path.join( self.hps.model_path, "poem.ckpt" + "_" + str(current_epoch)) model.saver.save(sess, checkpoint_path, global_step=model.global_step) print("shuffle data...") random.shuffle(train_batches)
class Generator(object): def __init__(self, beam_size, model_file=None): # Construct hyper-parameter self.hps = hps self.dtool = data_tool self.beam_size = beam_size self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 PAD_ID = self.tool.get_PAD_ID() assert PAD_ID > 0 self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID, batch_size=beam_size, mode='decode') self.model = PoemModel(self.hps) self.EOS_ID, self.PAD_ID, self.GO_ID, self.UNK_ID \ = self.tool.get_special_IDs() self.enc_len = self.hps.bucket[0] self.dec_len = self.hps.bucket[1] self.topic_trace_size = self.hps.topic_trace_size self.key_slots = self.hps.key_slots self.his_mem_slots = self.hps.his_mem_slots self.his_mem_size = self.hps.his_mem_size self.global_trace_size = self.hps.global_trace_size self.hidden_size = self.hps.hidden_size self.sess = tf.InteractiveSession() if model_file is None: self.load_model(self.sess, self.model) else: self.load_model_by_path(self.sess, self.model, model_file) self.__buildPH() def load_model(self, session, model): """load parameters in session.""" ckpt = tf.train.get_checkpoint_state(self.hps.model_path) if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: raise ValueError("%s not found! " % ckpt.model_checkpoint_path) def load_model_by_path(self, session, model, modefile): """load parameters in session.""" if tf.gfile.Exists(modefile): print("Reading model parameters from %s" % modefile) model.saver.restore(session, modefile) else: raise ValueError("%s not found! " % modefile) def __buildPH(self): self.__PHDic = self.dtool.buildPHDicForIdx( copy.deepcopy( self.tool.get_vocab())) #把__GLDic改一下,由原来的 韵编号-字列表,变为,韵编号-字id列表 def addtionFilter(self, trans, pos): #4 trans batch_size,3 pos -= 1 preidx = range(0, pos) #0 1 2 batch_size = len(trans) forbidden_list = [[] for _ in xrange(0, batch_size)] for i in range(0, batch_size): prechar = [trans[i][c] for c in preidx] forbidden_list[i] = prechar return forbidden_list def beam_select( self, probs, trans, k, trg_len, beam_size, repeatidxvec, ph ): #trans是已有的候选,probs是这次step的。trans [candidate_num,already_len] probs [candidate_num,vocab_size] V = np.shape(probs)[1] # vocabulary size n_samples = np.shape(probs)[0] if k == 1: n_samples = beam_size # trans_indices, word_indices, costs hypothesis = [] # (char_idx, which beam, prob) cost_eps = float(1e5) # Control inner repeat forbidden_list = self.addtionFilter(trans, k) for i in range(0, np.shape(probs)[0]): probs[i, forbidden_list[i]] = cost_eps # Control global repeat probs[:, repeatidxvec] = cost_eps # hard control for genre if ph != 0: #print (k, gl) probs *= cost_eps probs[:, self.__PHDic[ph]] /= float(cost_eps) flat_next_costs = probs.flatten() #全部展平,变为1维列表 best_costs_indices = np.argpartition( #若第一个参数是数字列表,第二个参数是k,则返回一个与数字列表相同长度的列表,这个列表里的每个元素是下标,其中下标是k的元素是从小到大排序的正确元素,它左边的元素是小于它的,右边的元素是大于它的 flat_next_costs.flatten(), n_samples)[:n_samples] trans_indices = [int(idx) for idx in best_costs_indices / V] # which beam line word_indices = best_costs_indices % V costs = flat_next_costs[best_costs_indices] for i in range(0, n_samples): hypothesis.append((word_indices[i], trans_indices[i], costs[i])) return hypothesis def beam_search(self, sess, sen, len_inps, ori_key_states, key_initial_state, ori_topic_trace, ori_his_mem, ori_his_mem_mask, ori_global_trace, enc_mask, ori_key_mask, repeatidxvec, phs): trg_length = len(phs) beam_size = self.beam_size n_samples = beam_size enc_state, attn_states = self.model.encoder_computer( sess, sen, enc_mask) #为啥input_feed是0??? enc_states = copy.deepcopy(attn_states) enc_mask = np.array(enc_mask) key_states = copy.deepcopy(ori_key_states) topic_trace = copy.deepcopy(ori_topic_trace) global_trace = copy.deepcopy(ori_global_trace) his_mem = copy.deepcopy(ori_his_mem) his_mem_mask = copy.deepcopy(ori_his_mem_mask) key_mask = copy.deepcopy(ori_key_mask) fin_trans, fin_costs, fin_align = [], [], [] trans = [[] for i in xrange(0, beam_size)] costs = [0.0] key_align = [] for i in range(beam_size): key_align.append(np.zeros([1, self.key_slots], dtype=np.float32)) state = enc_state if not (key_initial_state is None): state = key_initial_state inp = np.array([self.GO_ID] * beam_size) ph_inp = [phs[0]] * n_samples output, state, alignments = self.model.decoder_state_computer( sess, inp, len_inps[0], ph_inp, state, attn_states, key_states, his_mem, global_trace, enc_mask, key_mask, his_mem_mask, topic_trace) #output [batch_size,vocab_size] for k in range(1, 2 * trg_length): if n_samples == 0: break if k == 1: output = output[0, :] #[vocab_size] log_probs = np.log(output) next_costs = np.array( costs )[:, None] - log_probs #np.array(costs) array([0.0]) np.array(costs)[:,None] array([[0.0]]) 增加一个维度,shape(1,1) # Form a beam for the next iteration new_trans = [[] for i in xrange(0, n_samples)] new_costs = np.zeros(n_samples, dtype="float32") new_states = np.zeros((n_samples, self.hidden_size), dtype="float32") new_align = [[] for i in xrange(0, n_samples)] inputs = np.zeros(n_samples, dtype="int64") # Note that here k < len(gls) means that we don't put hard constraint on yun ph_require = phs[k - 1] if k <= len(phs) else 0 #print (gl_require) hypothesis = self.beam_select(next_costs, trans, k, trg_length, n_samples, repeatidxvec, ph_require) for i, (next_word, orig_idx, next_cost) in enumerate(hypothesis): #print("%d %d %d %f %s" % (i, next_word, orig_idx, next_cost)) new_trans[i] = trans[orig_idx] + [next_word] new_costs[i] = next_cost align_start = self.his_mem_slots align_end = self.his_mem_slots + self.key_slots current_align = alignments[orig_idx, align_start:align_end] new_align[i] = np.concatenate( (key_align[orig_idx], [current_align]), axis=0) new_states[i] = state[orig_idx, :] inputs[i] = next_word # Filter the sequences that end with end-of-sequence character trans, costs, indices, key_align = [], [], [], [] for i in range(n_samples): if new_trans[i][-1] != self.EOS_ID: trans.append(new_trans[i]) costs.append(new_costs[i]) indices.append(i) key_align.append(new_align[i]) else: n_samples -= 1 fin_trans.append(new_trans[i]) fin_costs.append(new_costs[i]) fin_align.append(new_align[i]) if n_samples == 0: break inputs = inputs[indices] new_states = new_states[indices] attn_states = attn_states[indices, :, :] global_trace = global_trace[indices, :] enc_mask = enc_mask[indices, :, :] key_states = key_states[indices, :, :] his_mem = his_mem[indices, :, :] key_mask = key_mask[indices, :, :] his_mem_mask = his_mem_mask[indices, :] topic_trace = topic_trace[indices, :] if k >= np.shape(len_inps)[0]: specify_len = len_inps[np.shape(len_inps)[0] - 1, indices] else: specify_len = len_inps[k, indices] if k >= len(phs): ph_inp = [0] * n_samples else: ph_inp = [phs[k]] * n_samples output, state, alignments = self.model.decoder_state_computer( sess, inputs, specify_len, ph_inp, new_states, attn_states, key_states, his_mem, global_trace, enc_mask, key_mask, his_mem_mask, topic_trace) #print (np.shape(fin_align)) for i in range(len(fin_align)): fin_align[i] = fin_align[i][1:, :] index = np.argsort(fin_costs) #从小到大排序,返回下标列表 fin_align = np.array(fin_align)[index] fin_trans = np.array(fin_trans)[index] fin_costs = np.array(sorted(fin_costs)) if len(fin_trans) == 0: index = np.argsort(costs) fin_align = np.array(key_align)[index] fin_trans = np.array(trans)[index] fin_costs = np.array(sorted(costs)) return fin_trans, fin_costs, fin_align, enc_states def get_new_global_trace(self, sess, history, ori_enc_states, beam_size): enc_states = np.expand_dims(ori_enc_states, axis=0) #[1,enc_len,2*hidden_size] prev_history = np.expand_dims(history[0, :], 0) #[1,global_trace_size] #print (np.shape(prev_encoder_state)) #tt = input(">") new_history = self.model.global_trace_computer(sess, prev_history, enc_states) new_history = np.tile(new_history, [beam_size, 1]) #[beam_size,global_trace_size] return new_history def get_new_his_mem(self, sess, ori_his_mem, enc_states, ori_global_trace, beam_size, src_len): his_mem = np.expand_dims(ori_his_mem, axis=0) #[1,his_mem_slots,his_mem_size] fin_states = [] for i in xrange(0, np.shape(enc_states)[0]): fin_states.append(np.expand_dims(enc_states[i], 0)) #enc_len [1,2*hidden_size] mask = [np.ones((1, 1))] * src_len + [np.zeros( (1, 1))] * (np.shape(enc_states)[0] - src_len) global_trace = np.expand_dims(ori_global_trace[0, :], 0) #[1,global_trace_size] new_his_mem = self.model.his_mem_computer(sess, his_mem, fin_states, mask, global_trace) new_his_mem = np.tile( new_his_mem, [beam_size, 1, 1]) #[beam_size,his_mem_slots,his_mem_size] return new_his_mem def get_new_topic_trace(self, sess, ori_topic_trace, key_align, ori_key_states, beam_size): key_states = np.expand_dims(ori_key_states[0, :, :], 0) #[1,key_slots,2*hidden_size] topic_trace = np.expand_dims(ori_topic_trace, axis=0) #[1,topic_trace_size+key_slots] key_align = np.mean(key_align, axis=0) #[trg_len,key_slots]变为[key_slots] key_align = np.expand_dims(key_align, axis=0) #[1,key_slots] new_topic_trace = self.model.topic_trace_computer( sess, key_states, topic_trace, key_align) new_topic_trace = np.tile( new_topic_trace, [beam_size, 1]) #[beam_size,topic_trace_size+key_slots] return new_topic_trace def generate_one(self, keystr, pattern): #pattern 4,5或4,7 beam_size = self.beam_size sens_num = len(pattern) keys = keystr.strip() ans, repeatidxes = [], [] print("using keywords: %s" % (keystr)) keys = keystr.split(" ") keys_idxes = [ self.tool.chars2idxes(self.tool.line2chars(key)) for key in keys ] #print (keys_idxes) key_inps, key_mask = self.tool.gen_batch_key_beam( keys_idxes, beam_size ) #key_inps:key_slots,2 [batch_size] key_mask:batch_size [key_slots,1] # Calculate initial_key state and key_states key_initial_state, key_states = self.model.key_memory_computer( self.sess, key_inps, key_mask) his_mem_mask = np.zeros([beam_size, self.his_mem_slots], dtype=np.float32) global_trace = np.zeros([beam_size, self.global_trace_size], dtype='float32') his_mem = np.zeros([beam_size, self.his_mem_slots, self.his_mem_size], dtype='float32') topic_trace = np.zeros( [beam_size, self.topic_trace_size + self.key_slots], dtype='float32') # Generate the first line, line0 is an empty list sen = [] for step in xrange(0, sens_num): print("generating %d line..." % (step + 1)) phs = pattern[step] trg_len = len(phs) if step > 0: key_initial_state = None src_len = len(sen) batch_sen, enc_mask, len_inps = self.tool.gen_batch_beam( sen, trg_len, beam_size) #len_inps [dec_len,batch_size] trans, costs, align, enc_states = self.beam_search( self.sess, batch_sen, len_inps, key_states, key_initial_state, topic_trace, his_mem, his_mem_mask, global_trace, enc_mask, key_mask, repeatidxes, phs) trans, costs, align, enc_states = self.pFilter( trans, costs, align, enc_states, trg_len) if len(trans) == 0: return [], ("line %d generation failed!" % (step + 1)) which = 0 his_mem = self.get_new_his_mem(self.sess, his_mem[which, :, :], enc_states[which], global_trace, beam_size, src_len) if step >= 1: #更新his_mem_mask one_his_mem = his_mem[ which, :, :] #[his_mem_slots,his_mem_size] his_mem_mask = np.sum(np.abs(one_his_mem), axis=1) #[his_mem_slots] his_mem_mask = his_mem_mask != 0 his_mem_mask = np.tile(his_mem_mask.astype( np.float32), [beam_size, 1]) #[beam_size,his_mem_slots] sentence = self.tool.beam_get_sentence(trans[which]) sentence = sentence.strip() ans.append(sentence) attn_aligns = align[which][0:trg_len, :] #trg_len,key_slots topic_trace = self.get_new_topic_trace(self.sess, topic_trace[which, :], attn_aligns, key_states, beam_size) global_trace = self.get_new_global_trace(self.sess, global_trace, enc_states[which], beam_size) sentence = self.tool.line2chars(sentence) sen = self.tool.chars2idxes(sentence) repeatidxes = list(set(repeatidxes).union(set(sen))) return ans, "ok" def pFilter(self, trans, costs, align, states, trg_len): new_trans, new_costs, new_align, new_states = [], [], [], [] for i in range(len(trans)): if len(trans[i]) < trg_len: continue tran = trans[i][0:trg_len] sen = self.tool.idxes2chars(tran) sen = "".join(sen) if trg_len > 4 and self.dtool.checkIfInLib(sen): continue new_trans.append(tran) new_align.append(align[i]) new_states.append(states[i]) new_costs.append(costs[i]) return new_trans, new_costs, new_align, new_states
class Generator(object): def __init__(self, beam_size, model_path=None): # Construct hyper-parameter self.hps = hps self.dtool = data_tool self.beam_size = beam_size if hps.init_emb == '': self.init_emb = None else: self.init_emb = np.load(self.hps.init_emb) print("init_emb_size: %s" % str(np.shape(self.init_emb))) self.tool = PoetryTool(sens_num=hps.sens_num, key_slots=hps.key_slots, enc_len=hps.bucket[0], dec_len=hps.bucket[1]) self.tool.load_dic(hps.vocab_path, hps.ivocab_path) vocab_size = self.tool.get_vocab_size() assert vocab_size > 0 self.hps = self.hps._replace(vocab_size=vocab_size, batch_size=beam_size) f_idxes = self.tool.build_fvec() self.model = graphs.WorkingMemoryModel(self.hps, f_idxes) self.model.build_eval_graph() self.PAD_ID, self.UNK_ID, self.B_ID, self.E_ID, _ \ = self.tool.get_special_IDs() self.enc_len = self.hps.bucket[0] self.dec_len = self.hps.bucket[1] self.topic_trace_size = self.hps.topic_trace_size self.key_slots = self.hps.key_slots self.his_mem_slots = self.hps.his_mem_slots self.mem_size = self.hps.mem_size self.global_trace_size = self.hps.global_trace_size self.hidden_size = self.hps.hidden_size self.sess = tf.InteractiveSession() self.load_model(model_path) self.__buildPH() def load_model(self, model_path): """load parameters in session.""" saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) if model_path is None: ckpt = tf.train.get_checkpoint_state(self.hps.model_path) if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) saver.restore(self.sess, ckpt.model_checkpoint_path) else: raise ValueError("%s not found! " % ckpt.model_checkpoint_path) else: print("Reading model parameters from %s" % model_path) saver.restore(self.sess, model_path) def __buildPH(self): self.__PHDic = self.dtool.buildPHDicForIdx( copy.deepcopy(self.tool.get_vocab())) def addtionFilter(self, trans, pos): pos -= 1 preidx = range(0, pos) batch_size = len(trans) forbidden_list = [[] for _ in range(0, batch_size)] for i in range(0, batch_size): prechar = [trans[i][c] for c in preidx] forbidden_list[i] = prechar return forbidden_list def beam_select(self, probs, trans, k, trg_len, beam_size, repeatidxvec, ph): V = np.shape(probs)[1] # vocabulary size n_samples = np.shape(probs)[0] if k == 1: n_samples = beam_size # trans_indices, word_indices, costs hypothesis = [] # (char_idx, which beam, prob) cost_eps = float(1e5) # Control inner repeat forbidden_list = self.addtionFilter(trans, k) for i in range(0, np.shape(probs)[0]): probs[i, forbidden_list[i]] = cost_eps # Control global repeat probs[:, repeatidxvec] = cost_eps # hard control for genre if ph != 0: #print (k, gl) probs *= cost_eps probs[:, self.__PHDic[ph]] /= float(cost_eps) flat_next_costs = probs.flatten() best_costs_indices = np.argpartition(flat_next_costs.flatten(), n_samples)[:n_samples] trans_indices = [int(idx) for idx in best_costs_indices / V] # which beam line word_indices = best_costs_indices % V costs = flat_next_costs[best_costs_indices] for i in range(0, n_samples): hypothesis.append((word_indices[i], trans_indices[i], costs[i])) return hypothesis def beam_search(self, sen, len_inps, ori_key_states, key_initial_state, ori_topic_trace, ori_his_mem, ori_his_mem_mask, ori_global_trace, enc_mask, ori_key_mask, repeatidxvec, phs): trg_length = len(phs) beam_size = self.beam_size enc_state, ori_attn_states = self.encoder_computer(sen, enc_mask) enc_mask = np.tile(enc_mask, [beam_size, 1, 1]) key_mask = np.tile(ori_key_mask, [beam_size, 1, 1]) attn_states = copy.deepcopy(ori_attn_states) key_states = copy.deepcopy(ori_key_states) topic_trace = copy.deepcopy(ori_topic_trace) global_trace = copy.deepcopy(ori_global_trace) his_mem = copy.deepcopy(ori_his_mem) his_mem_mask = copy.deepcopy(ori_his_mem_mask) fin_trans, fin_costs, fin_align = [], [], [] trans = [[] for i in range(0, beam_size)] costs = [0.0] key_align = [] for i in range(beam_size): key_align.append(np.zeros([1, self.key_slots], dtype=np.float32)) state = enc_state if not (key_initial_state is None): state = key_initial_state inp = np.array([self.B_ID] * beam_size) ph_inp = [phs[0]] * beam_size specify_len = len_inps[0] output, state, alignments = self.decoder_computer( inp, specify_len, ph_inp, state, attn_states, key_states, his_mem, global_trace, enc_mask, key_mask, his_mem_mask, topic_trace) n_samples = beam_size for k in range(1, trg_length + 4): if n_samples == 0: break if k == 1: output = output[0, :] log_probs = np.log(output) next_costs = np.array(costs)[:, None] - log_probs # Form a beam for the next iteration new_trans = [[] for i in range(0, n_samples)] new_costs = np.zeros(n_samples, dtype="float32") new_states = np.zeros((n_samples, self.hidden_size), dtype="float32") new_align = [[] for i in range(0, n_samples)] inputs = np.zeros(n_samples, dtype=np.int32) ph_require = phs[k - 1] if k <= len(phs) else 0 #print (gl_require) hypothesis = self.beam_select(next_costs, trans, k, trg_length, n_samples, repeatidxvec, ph_require) for i, (next_word, orig_idx, next_cost) in enumerate(hypothesis): #print("%d %d %d %f %s" % (i, next_word, orig_idx, next_cost)) new_trans[i] = trans[orig_idx] + [next_word] new_costs[i] = next_cost align_start = self.his_mem_slots align_end = self.his_mem_slots + self.key_slots current_align = alignments[orig_idx, :, align_start:align_end] new_align[i] = np.concatenate( (key_align[orig_idx], current_align), axis=0) new_states[i] = state[orig_idx, :] inputs[i] = next_word # Filter the sequences that end with end-of-sequence character trans, costs, indices, key_align = [], [], [], [] for i in range(n_samples): if new_trans[i][-1] != self.E_ID: trans.append(new_trans[i]) costs.append(new_costs[i]) indices.append(i) key_align.append(new_align[i]) else: n_samples -= 1 fin_trans.append(new_trans[i]) fin_costs.append(new_costs[i]) fin_align.append(new_align[i]) if n_samples == 0: break inputs = inputs[indices] new_states = new_states[indices] attn_states = attn_states[indices, :, :] global_trace = global_trace[indices, :] enc_mask = enc_mask[indices, :, :] key_states = key_states[indices, :, :] his_mem = his_mem[indices, :, :] key_mask = key_mask[indices, :, :] his_mem_mask = his_mem_mask[indices, :] topic_trace = topic_trace[indices, :] if k >= np.shape(len_inps)[0]: specify_len = len_inps[np.shape(len_inps)[0] - 1, indices] else: specify_len = len_inps[k, indices] if k >= len(phs): ph_inp = [0] * n_samples else: ph_inp = [phs[k]] * n_samples output, state, alignments = self.decoder_computer( inputs, specify_len, ph_inp, new_states, attn_states, key_states, his_mem, global_trace, enc_mask, key_mask, his_mem_mask, topic_trace) for i in range(len(fin_align)): fin_align[i] = fin_align[i][1:, :] index = np.argsort(fin_costs) fin_align = np.array(fin_align)[index] fin_trans = np.array(fin_trans)[index] fin_costs = np.array(sorted(fin_costs)) if len(fin_trans) == 0: index = np.argsort(costs) fin_align = np.array(key_align)[index] fin_trans = np.array(trans)[index] fin_costs = np.array(sorted(costs)) return fin_trans, fin_costs, fin_align, ori_attn_states def get_new_global_trace(self, global_trace, ori_enc_states, beam_size): enc_states = np.expand_dims(ori_enc_states, axis=0) prev_global_trace = np.expand_dims(global_trace[0, :], 0) new_global_trace = self.global_trace_computer(prev_global_trace, enc_states) new_global_trace = np.tile(new_global_trace, [beam_size, 1]) return new_global_trace def get_new_his_mem(self, ori_his_mem, attn_states, ori_global_trace, beam_size, src_len): #ori_his_mem [4, mem_size] his_mem = np.expand_dims(ori_his_mem, axis=0) enc_outs = [] assert np.shape(attn_states)[0] == self.enc_len for i in range(0, self.enc_len): enc_outs.append(np.expand_dims(attn_states[i, :], 0)) mask = [np.ones((1, 1))] * src_len + [np.zeros( (1, 1))] * (self.enc_len - src_len) global_trace = np.expand_dims(ori_global_trace[0, :], 0) new_his_mem = self.his_mem_computer(his_mem, enc_outs, mask, global_trace) new_his_mem = np.tile(new_his_mem, [beam_size, 1, 1]) return new_his_mem def get_new_topic_trace(self, ori_topic_trace, key_align, ori_key_states, beam_size): key_states = np.expand_dims(ori_key_states[0, :, :], 0) topic_trace = np.expand_dims(ori_topic_trace, axis=0) key_align = np.mean(key_align, axis=0) key_align = np.expand_dims(key_align, axis=0) new_topic_trace = self.topic_trace_computer(key_states, topic_trace, key_align) new_topic_trace = np.tile(new_topic_trace, [beam_size, 1]) return new_topic_trace def generate_one(self, keystr, pattern): beam_size = self.beam_size ans, repeatidxes = [], [] sens_num = len(pattern) keys = keystr.strip() print("using keywords: %s" % (keystr)) keys = keystr.split(" ") keys_idxes = [ self.tool.chars2idxes(self.tool.line2chars(key)) for key in keys ] # Here we set batch size=1 and then tile the results key_inps, key_mask = self.tool.gen_batch_key_beam(keys_idxes, batch_size=1) # Calculate initial_key state and key_states key_initial_state, key_states = self.key_memory_computer( key_inps, key_mask, beam_size) his_mem_mask = np.zeros([beam_size, self.his_mem_slots], dtype=np.float32) global_trace = np.zeros([beam_size, self.global_trace_size], dtype=np.float32) his_mem = np.zeros([beam_size, self.his_mem_slots, self.mem_size], dtype=np.float32) topic_trace = np.zeros( [beam_size, self.topic_trace_size + self.key_slots], dtype=np.float32) # Generate the lines, line0 is an empty list sen = [] for step in range(0, sens_num): print("generating %d line..." % (step + 1)) if step > 0: key_initial_state = None phs = pattern[step] trg_len = len(phs) src_len = len(sen) batch_sen, enc_mask, len_inps = self.tool.gen_enc_beam( sen, trg_len, batch_size=1) len_inps = np.tile(len_inps, [1, beam_size]) trans, costs, align, attn_states = self.beam_search( batch_sen, len_inps, key_states, key_initial_state, topic_trace, his_mem, his_mem_mask, global_trace, enc_mask, key_mask, repeatidxes, phs) trans, costs, align, attn_states = self.pFilter( trans, costs, align, attn_states, trg_len) if len(trans) == 0: return [], ("line %d generation failed!" % (step + 1)) which = 0 one_his_mem = his_mem[which, :, :] his_mem = self.get_new_his_mem(one_his_mem, attn_states[which], global_trace, beam_size, src_len) if step >= 1: his_mem_mask = np.sum(np.abs(one_his_mem), axis=1) his_mem_mask = his_mem_mask != 0 his_mem_mask = np.tile(his_mem_mask.astype(np.float32), [beam_size, 1]) sentence = self.tool.beam_get_sentence(trans[which]) sentence = sentence.strip() ans.append(sentence) attn_aligns = align[which][0:trg_len, :] topic_trace = self.get_new_topic_trace(topic_trace[which, :], attn_aligns, key_states, beam_size) global_trace = self.get_new_global_trace(global_trace, attn_states[which], beam_size) sentence = self.tool.line2chars(sentence) sen = self.tool.chars2idxes(sentence) repeatidxes = list(set(repeatidxes).union(set(sen))) return ans, "ok" def pFilter(self, trans, costs, align, states, trg_len): new_trans, new_costs, new_align, new_states = [], [], [], [] for i in range(len(trans)): if len(trans[i]) < trg_len: continue tran = trans[i][0:trg_len] sen = self.tool.idxes2chars(tran) sen = "".join(sen) if trg_len > 4 and self.dtool.checkIfInLib(sen): continue new_trans.append(tran) new_align.append(align[i]) new_states.append(states[i]) new_costs.append(costs[i]) return new_trans, new_costs, new_align, new_states # ----------------------------------------------------------------- # Some apis for beam search def key_memory_computer(self, key_inps, key_mask, beam_size): input_feed = {} input_feed[self.model.keep_prob] = 1.0 for step in range(self.key_slots): for j in range(2): input_feed[self.model.key_inps[step] [j].name] = key_inps[step][j] input_feed[self.model.key_mask.name] = key_mask output_feed = [self.model.key_initial_state, self.model.key_states] [key_initial_state, key_states] = self.sess.run(output_feed, input_feed) key_initial_state = np.tile(key_initial_state, [beam_size, 1]) key_states = np.tile(key_states, [beam_size, 1, 1]) return key_initial_state, key_states def encoder_computer(self, enc_inps, enc_mask): assert self.enc_len == len(enc_inps) input_feed = {} input_feed[self.model.keep_prob] = 1.0 for l in range(self.enc_len): input_feed[self.model.enc_inps[0][l].name] = enc_inps[l] input_feed[self.model.enc_mask[0].name] = enc_mask output_feed = [self.model.enc_state, self.model.attn_states] [enc_state, attn_states] = self.sess.run(output_feed, input_feed) enc_state = np.tile(enc_state, [self.beam_size, 1]) attn_states = np.tile(attn_states, [self.beam_size, 1, 1]) return enc_state, attn_states def decoder_computer(self, dec_inp, len_inp, ph_inp, prev_state, attn_states, key_states, his_mem, global_trace, enc_mask, key_mask, his_mem_mask, topic_trace): input_feed = {} input_feed[self.model.keep_prob] = 1.0 input_feed[self.model.dec_inps[0][0].name] = dec_inp input_feed[self.model.len_inps[0][0].name] = len_inp input_feed[self.model.ph_inps[0][0].name] = ph_inp input_feed[self.model.beam_attn_states.name] = attn_states input_feed[self.model.enc_mask[0].name] = enc_mask input_feed[self.model.beam_initial_state.name] = prev_state input_feed[self.model.beam_key_states.name] = key_states input_feed[self.model.key_mask.name] = key_mask input_feed[self.model.beam_global_trace.name] = global_trace input_feed[self.model.beam_topic_trace.name] = topic_trace input_feed[self.model.beam_his_mem.name] = his_mem input_feed[self.model.beam_his_mem_mask.name] = his_mem_mask output_feed = [ self.model.next_out, self.model.next_state, self.model.next_align ] [next_output, next_state, next_align] = self.sess.run(output_feed, input_feed) return next_output, next_state, next_align def his_mem_computer(self, his_mem, enc_outs, write_mask, global_trace): input_feed = {} input_feed[self.model.keep_prob] = 1.0 input_feed[self.model.gama] = [0.05] input_feed[self.model.beam_his_mem.name] = his_mem for l in range(self.enc_len): input_feed[self.model.beam_enc_outs[l]] = enc_outs[l] input_feed[self.model.write_masks[0][l]] = write_mask[l] input_feed[self.model.beam_global_trace.name] = global_trace output_feed = [self.model.new_his_mem] [new_his_mem] = self.sess.run(output_feed, input_feed) return new_his_mem def topic_trace_computer(self, key_states, prev_topic_trace, key_align): input_feed = {} input_feed[self.model.keep_prob] = 1.0 input_feed[self.model.beam_key_states.name] = key_states input_feed[self.model.beam_topic_trace.name] = prev_topic_trace input_feed[self.model.beam_key_align.name] = key_align output_feed = [self.model.new_topic_trace] [new_topic_trace] = self.sess.run(output_feed, input_feed) return new_topic_trace def global_trace_computer(self, prev_global_trace, prev_attn_states): input_feed = {} input_feed[self.model.keep_prob] = 1.0 input_feed[self.model.beam_global_trace] = prev_global_trace input_feed[self.model.beam_attn_states] = prev_attn_states output_feed = [self.model.new_global_trace] [new_global_trace] = self.sess.run(output_feed, input_feed) return new_global_trace