Beispiel #1
0
    def __init__(self):
        self.__use_pretrain = True
        # Construct hyper-parameter
        self.hps = hps

        self.enc_len = hps.bucket[0]
        self.dec_len = hps.bucket[1]
        self.sens_num = hps.sens_num
        self.key_slots = hps.key_slots

        self.tool = PoetryTool(sens_num=hps.sens_num,
            key_slots=hps.key_slots, enc_len=hps.bucket[0],
            dec_len=hps.bucket[1])
        # If there isn't a pre-trained word embedding, just
        # set it to None, then the word embedding
        # will be initialized with a norm distribution.
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print ("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        self.PAD_ID = self.tool.get_PAD_ID()
        assert self.PAD_ID == 0

        self.hps = self.hps._replace(vocab_size=vocab_size)

        print("Params  sets: ")
        print (self.hps)
        print("___________________")
Beispiel #2
0
    def __init__(self):
        # Construct hyper-parameter
        self.hps = hps
        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        # If there isn't a pre-trained word embedding, just
        # set it to None, then the word embedding
        # will be initialized with a norm distribution.
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        PAD_ID = self.tool.get_PAD_ID()
        print(PAD_ID)
        assert PAD_ID > 0

        self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID)

        print("Params  sets: ")
        print(self.hps)
        print("___________________")
        raw_input("Please check the parameters and press enter to continue>")
Beispiel #3
0
    def __init__(self, beam_size, model_file=None):
        # Construct hyper-parameter
        self.hps = hps
        self.dtool = data_tool
        self.beam_size = beam_size
        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        PAD_ID = self.tool.get_PAD_ID()
        assert PAD_ID > 0

        self.hps = self.hps._replace(vocab_size=vocab_size,
                                     PAD_ID=PAD_ID,
                                     batch_size=beam_size,
                                     mode='decode')
        self.model = PoemModel(self.hps)

        self.EOS_ID, self.PAD_ID, self.GO_ID, self.UNK_ID \
            = self.tool.get_special_IDs()

        self.enc_len = self.hps.bucket[0]
        self.dec_len = self.hps.bucket[1]
        self.topic_trace_size = self.hps.topic_trace_size
        self.key_slots = self.hps.key_slots
        self.his_mem_slots = self.hps.his_mem_slots
        self.his_mem_size = self.hps.his_mem_size
        self.global_trace_size = self.hps.global_trace_size
        self.hidden_size = self.hps.hidden_size

        self.sess = tf.InteractiveSession()

        if model_file is None:
            self.load_model(self.sess, self.model)
        else:
            self.load_model_by_path(self.sess, self.model, model_file)

        self.__buildPH()
Beispiel #4
0
    def __init__(self, beam_size, model_path=None):
        # Construct hyper-parameter
        self.hps = hps
        self.dtool = data_tool
        self.beam_size = beam_size
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))

        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        self.hps = self.hps._replace(vocab_size=vocab_size,
                                     batch_size=beam_size)

        f_idxes = self.tool.build_fvec()
        self.model = graphs.WorkingMemoryModel(self.hps, f_idxes)
        self.model.build_eval_graph()

        self.PAD_ID, self.UNK_ID, self.B_ID, self.E_ID, _ \
            = self.tool.get_special_IDs()

        self.enc_len = self.hps.bucket[0]
        self.dec_len = self.hps.bucket[1]
        self.topic_trace_size = self.hps.topic_trace_size
        self.key_slots = self.hps.key_slots
        self.his_mem_slots = self.hps.his_mem_slots
        self.mem_size = self.hps.mem_size
        self.global_trace_size = self.hps.global_trace_size
        self.hidden_size = self.hps.hidden_size

        self.sess = tf.InteractiveSession()

        self.load_model(model_path)
        self.__buildPH()
Beispiel #5
0
class PoemTrainer(object):

    def __init__(self):
        self.__use_pretrain = True
        # Construct hyper-parameter
        self.hps = hps

        self.enc_len = hps.bucket[0]
        self.dec_len = hps.bucket[1]
        self.sens_num = hps.sens_num
        self.key_slots = hps.key_slots

        self.tool = PoetryTool(sens_num=hps.sens_num,
            key_slots=hps.key_slots, enc_len=hps.bucket[0],
            dec_len=hps.bucket[1])
        # If there isn't a pre-trained word embedding, just
        # set it to None, then the word embedding
        # will be initialized with a norm distribution.
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print ("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        self.PAD_ID = self.tool.get_PAD_ID()
        assert self.PAD_ID == 0

        self.hps = self.hps._replace(vocab_size=vocab_size)

        print("Params  sets: ")
        print (self.hps)
        print("___________________")


    def create_model(self, sess, path):
        ckpt = tf.train.get_checkpoint_state(path)
        saver = tf.train.Saver(tf.global_variables() , write_version=tf.train.SaverDef.V1)
        #print (ckpt.model_checkpoint_path)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" %
                  ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("Created model with fresh parameters.")
            sess.run(tf.global_variables_initializer())

    def restore_pretrained_model(self, sess, saver_for_restore, pre_model_dir, model_dir):
        """Restores pretrained model if there is no ckpt model."""
        ckpt = tf.train.get_checkpoint_state(model_dir)
        checkpoint_exists = ckpt and ckpt.model_checkpoint_path
        if checkpoint_exists:
            print('Checkpoint exists in FLAGS.train_dir; skipping '
                'pretraining restore')
            return

        pretrain_ckpt = tf.train.get_checkpoint_state(pre_model_dir)
        if not (pretrain_ckpt and pretrain_ckpt.model_checkpoint_path):
            raise ValueError('Asked to restore model from %s but no checkpoint found.' % model_dir)
        print ("restore from %s" % (pretrain_ckpt.model_checkpoint_path))
        saver_for_restore.restore(sess, pretrain_ckpt.model_checkpoint_path)
        print ("restor OK!") 

    def step(self, sess, data_dic, model, valid):
        '''For training one batch'''
        keep_prob = 1.0 if valid else self.hps.keep_prob

        input_feed = {}
        input_feed[model.keep_prob] = keep_prob
        input_feed[model.gama] = [self.gama]
        
        for step in range(self.key_slots):
            # NOTE: Each topic word must consist of no more than 2 characters
            for j in range(2):
                input_feed[model.key_inps[step][j].name] = data_dic['key_inps'][step][j]
        input_feed[model.key_mask.name] = data_dic['key_mask']

        for step in range(0, self.sens_num):
            if len(data_dic['enc_inps'][step]) != self.enc_len:
                raise ValueError("Encoder length must be equal %d != %d." % 
                    (len(data_dic['enc_inps'][step]), self.enc_len))

            if len(data_dic['dec_inps'][step]) != self.dec_len:
                raise ValueError("Decoder length must be equal %d != %d. " %
                 (len(data_dic['dec_inps'][step]), self.dec_len))
            
            if len(data_dic['trg_weights'][step]) != self.dec_len:
                raise ValueError("Weights length must be equal %d != %d." %
                    (len(data_dic['trg_weights'][step]), self.dec_len))
        
            for l in range(self.enc_len):
                input_feed[model.enc_inps[step][l].name] = data_dic['enc_inps'][step][l]
                input_feed[model.write_masks[step][l].name] = data_dic['write_masks'][step][l]
            for l in range(self.dec_len):
                input_feed[model.dec_inps[step][l].name] = data_dic['dec_inps'][step][l]
                input_feed[model.trg_weights[step][l].name] = data_dic['trg_weights'][step][l]
                input_feed[model.len_inps[step][l].name] = data_dic['len_inps'][step][l]
                input_feed[model.ph_inps[step][l].name] = data_dic['ph_inps'][step][l]

            last_target = model.dec_inps[step][self.dec_len].name
            input_feed[last_target] = np.ones([self.hps.batch_size], dtype=np.int32) * self.PAD_ID
            input_feed[model.enc_mask[step].name] =data_dic['enc_mask'][step]

        output_feed = []       
        for step in range(0, self.sens_num):
            for l in range(self.dec_len):  # Output logits.
                output_feed.append(self.outs_op[step][l])

        output_feed += [self.gen_loss_op, self.l2_loss_op, self.global_step_op]

        if not valid:
            output_feed += [self.train_op] 

        outputs = sess.run(output_feed, input_feed)

        logits = []
        for step in range(0, self.sens_num):
            logits.append(outputs[step*self.dec_len:(step+1)*self.dec_len])

        n = self.dec_len * self.sens_num
        return logits, outputs[n], outputs[n+1], outputs[n+2]

    def sample(self, enc_inps, dec_inps, key_inps, outputs):

        sample_num = min(self.hps.sample_num, self.hps.batch_size)

        # Random select some examples
        idxes = random.sample(list(range(0, self.hps.batch_size)), 
            sample_num)

        #
        for idx in idxes:
            keys = []
            # NOTE: Each keyword must consist of no more than 2 characters
            for i in range (0, self.key_slots):
                key_idx = [key_inps[i][0][idx], key_inps[i][1][idx]]
                keys.append("".join(self.tool.idxes2chars(key_idx)))
            key_str = " ".join(keys)

            # Build lines
            print ("%s" % (key_str))
            for step in range(0, self.sens_num):
                inputs = [c[idx] for c in enc_inps[step]]
                sline = "".join(self.tool.idxes2chars(inputs))

                target = [c[idx] for c in dec_inps[step]]
                tline = "".join(self.tool.idxes2chars(target))

                outline = [c[idx] for c in outputs[step]]
                outline = self.tool.greedy_search(outline)

                if step == 0:
                    print(sline.ljust(35) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ")
                else:
                    print(sline.ljust(30) + " # " + tline.ljust(30) + " # " + outline.ljust(30) + " # ")


    def run_validation(self, sess, model, valid_batches, valid_batch_num, epoch):
        print("run validation...")
        total_gen_loss = 0.0
        total_l2_loss = 0.0
        for step in range(0, valid_batch_num):
            batch = valid_batches[step]
            _, gen_loss, l2_loss, _ = self.step(sess, batch, model, True)
            total_gen_loss += gen_loss
            total_l2_loss += l2_loss
        total_gen_loss /= valid_batch_num
        total_l2_loss /= valid_batch_num
        info = "validation epoch: %d  loss: %.3f  ppl: %.2f, l2 loss: %.3f" % \
            (epoch, total_gen_loss, np.exp(total_gen_loss), total_l2_loss)
        print (info)
        fout = open("validlog.txt", 'a')
        fout.write(info + "\n")
        fout.close()

    def train(self):
        print ("building data...")
        train_batches, train_batch_num = self.tool.build_data(self.hps.train_data,
            self.hps.batch_size, 'train')
        valid_batches, valid_batch_num = self.tool.build_data(self.hps.valid_data,
            self.hps.batch_size, 'train')

        print ("train batch num: %d" % (train_batch_num))
        print ("valid batch num: %d" % (valid_batch_num))

        f_idxes = self.tool.build_fvec()

        input("Please check the parameters and press enter to continue>")

        model = graphs.WorkingMemoryModel(self.hps, f_idxes=f_idxes)
        self.train_op, self.outs_op, self.gen_loss_op, self.l2_loss_op, \
            self.global_step_op = model.training()

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.98)
        gpu_options.allow_growth = True

        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess:

            self.create_model(sess, self.hps.model_path)

            # check the initialization
            print ("uninitialized_variables")
            uni = sess.run(tf.report_uninitialized_variables(tf.global_variables(),
                name='report_uninitialized_variables'))
            print (len(uni))

            if self.__use_pretrain:
                variables_to_restore = model.pretrained_variables
                print ("parameters num of restored pre-trained model: %d" % (len(variables_to_restore)))
                saver_for_restore = tf.train.Saver(variables_to_restore)
                self.restore_pretrained_model(sess,  saver_for_restore, self.hps.pre_model_path, self.hps.model_path)


            burn_down = min(self.hps.burn_down, self.hps.max_epoch)
            total_annealing_steps = self.hps.annealing_epoches * train_batch_num
            print ("total annealing steps: %d" % (total_annealing_steps))

            total_gen_loss = 0.0
            total_l2_loss = 0.0
            total_steps = 0
            time1 = time.time()

            for epoch in range(1, self.hps.max_epoch+1):

                for step in range(0, train_batch_num):
                    batch = train_batches[step]

                    total_steps += 1
                    self.gama = 1.0 - min(total_steps / total_annealing_steps, 1.0) * 0.9

                    logits, gen_loss, l2_loss, global_step = self.step(sess, batch, model, False)

                    total_gen_loss += gen_loss
                    total_l2_loss += l2_loss

                    if step % self.hps.steps_per_train_log == 0:
                        time2 = time.time()
                        time_cost = float(time2-time1) / self.hps.steps_per_train_log

                        cur_gen_loss = total_gen_loss / (total_steps+1)
                        cur_ppl = math.exp(cur_gen_loss)
                        cur_l2_loss = total_l2_loss / (total_steps+1)

                        self.sample(batch['enc_inps'], batch['dec_inps'], batch['key_inps'], logits)

                        process_info = "epoch: %d, %d/%d %.3f%%, %.3f s per iter" % (epoch, step, train_batch_num,
                            float(step+1) /train_batch_num * 100, time_cost)
                        
                        train_info = "train loss: %.3f  ppl:%.2f, l2 loss: %.3f, lr:%.4f. gama:%.4f" \
                            % (cur_gen_loss, cur_ppl, cur_l2_loss, model.learning_rate.eval(), self.gama)
                        print (process_info)
                        print(train_info)
                        print("______________________")
                        
                        info = process_info + " " + train_info
                        fout = open("trainlog.txt", 'a')
                        fout.write(info + "\n")
                        fout.close()

                        time1 = time.time()


                current_epoch = int(global_step // train_batch_num)
                
                lr0 = model.learning_rate.eval()
                if epoch > burn_down:
                    print ("lr decay...")
                    sess.run(model.learning_rate_decay_op)
                lr1 = model.learning_rate.eval()
                print ("%.4f to %.4f" % (lr0, lr1))

                if epoch % self.hps.epoches_per_validate == 0:
                    self.run_validation(sess, model, valid_batches, valid_batch_num, current_epoch)

                if epoch % self.hps.epoches_per_checkpoint == 0:
                    # Save checkpoint
                    print("saving model...")
                    checkpoint_path = os.path.join(self.hps.model_path, "poem.ckpt" + "_" + str(current_epoch))
                    saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1 )
                    saver.save(sess, checkpoint_path, global_step=global_step)
                
                print("shuffle data...")
                random.shuffle(train_batches)
Beispiel #6
0
class PoemTrainer(object):
    def __init__(self):
        # Construct hyper-parameter
        self.hps = hps
        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        # If there isn't a pre-trained word embedding, just
        # set it to None, then the word embedding
        # will be initialized with a norm distribution.
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        PAD_ID = self.tool.get_PAD_ID()
        print(PAD_ID)
        assert PAD_ID > 0

        self.hps = self.hps._replace(vocab_size=vocab_size, PAD_ID=PAD_ID)

        print("Params  sets: ")
        print(self.hps)
        print("___________________")
        raw_input("Please check the parameters and press enter to continue>")

    def create_model(self, session, model):
        """Create the model and initialize or load parameters in session."""
        ckpt = tf.train.get_checkpoint_state(self.hps.model_path)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(session, ckpt.model_checkpoint_path)
        else:
            print("Created model with fresh parameters.")
            session.run(tf.global_variables_initializer())

        return model

    def sample(self, enc_inps, dec_inps, key_inps, outputs):

        sample_num = self.hps.sample_num
        if sample_num > self.hps.batch_size:
            sample_num = self.hps.batch_size

        # Random select some examples
        idxes = random.sample(range(0, self.hps.batch_size), sample_num)

        #
        for idx in idxes:
            keys = []
            for i in xrange(0, self.hps.key_slots):
                key_idx = [key_inps[i][0][idx], key_inps[i][1][idx]]
                keys.append("".join(self.tool.idxes2chars(key_idx)))
            key_str = " ".join(keys)

            # Build lines
            print("%s" % (key_str))
            for step in xrange(0, self.hps.sens_num):
                inputs = [c[idx] for c in enc_inps[step]]
                sline = "".join(self.tool.idxes2chars(inputs))

                target = [c[idx] for c in dec_inps[step]]
                tline = "".join(self.tool.idxes2chars(target))

                outline = [c[idx] for c in outputs[step]]
                outline = self.tool.greedy_search(outline)

                if step == 0:
                    print(
                        sline.ljust(25) + " # " + tline.ljust(30) + " # " +
                        outline.ljust(30) + " # ")
                else:
                    print(
                        sline.ljust(30) + " # " + tline.ljust(30) + " # " +
                        outline.ljust(30) + " # ")

    def run_validation(self, sess, model, valid_batches, valid_batch_num,
                       epoch):
        print("run validation...")
        total_gen_loss = 0.0
        for step in xrange(0, valid_batch_num):
            batch = valid_batches[step]
            outputs, gen_loss = model.step(sess, batch, True)
            total_gen_loss += gen_loss
        total_gen_loss /= valid_batch_num
        info = "validation epoch: %d  loss: %.3f  ppl: %.2f" % \
            (epoch, total_gen_loss, np.exp(total_gen_loss))
        print(info)
        fout = open("validlog.txt", 'a')
        fout.write(info + "\n")
        fout.close()

    def train(self):
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.98)
        gpu_options.allow_growth = True

        with tf.Session(config=tf.ConfigProto(
                gpu_options=gpu_options)) as sess:

            # Create model.
            model = PoemModel(self.hps, self.init_emb)
            self.create_model(sess, model)

            # Build batched data
            train_batch_num, valid_batch_num, \
            train_batches, valid_batches = self.tool.build_data(
                self.hps.batch_size, self.hps.train_data, self.hps.valid_data)

            print("train_batch_num: %d" % (train_batch_num))
            print("valid_batch_num: %d" % (valid_batch_num))

            for epoch in xrange(1, self.hps.max_epoch + 1):
                total_gen_loss = 0.0
                time1 = time.time()

                for step in xrange(0, train_batch_num):
                    batch = train_batches[step]
                    outputs, gen_loss = model.step(sess, batch, False)
                    total_gen_loss += gen_loss

                    if step % self.hps.steps_per_train_log == 0:
                        time2 = time.time()
                        time_cost = float(time2 -
                                          time1) / self.hps.steps_per_train_log
                        time1 = time2
                        process_info = "epoch: %d, %d/%d %.3f%%, %.3f s per iter" % (
                            epoch, step, train_batch_num,
                            float(step + 1) / train_batch_num * 100, time_cost)

                        self.sample(batch['enc_inps'], batch['dec_inps'],
                                    batch['key_inps'], outputs)
                        current_gen_loss = total_gen_loss / (step + 1)
                        ppl = math.exp(current_gen_loss
                                       ) if current_gen_loss < 300 else float(
                                           'inf')
                        train_info = "train loss: %.3f  ppl:%.2f" % (
                            current_gen_loss, ppl)
                        print(process_info)
                        print(train_info)
                        print("______________________")

                        info = process_info + " " + train_info
                        fout = open("trainlog.txt", 'a')
                        fout.write(info + "\n")
                        fout.close()

                current_epoch = int(model.global_step.eval() //
                                    train_batch_num)

                if epoch > self.hps.burn_down:
                    lr0 = model.learning_rate.eval()
                    print("lr decay...")
                    sess.run(model.learning_rate_decay_op)
                    lr1 = model.learning_rate.eval()
                    print("%.4f to %.4f" % (lr0, lr1))

                if epoch % self.hps.epoches_per_validate == 0:
                    self.run_validation(sess, model, valid_batches,
                                        valid_batch_num, epoch)

                if epoch % self.hps.epoches_per_checkpoint == 0:
                    # Save checkpoint and zero timer and loss.
                    print("saving model...")
                    checkpoint_path = os.path.join(
                        self.hps.model_path,
                        "poem.ckpt" + "_" + str(current_epoch))
                    model.saver.save(sess,
                                     checkpoint_path,
                                     global_step=model.global_step)

                print("shuffle data...")
                random.shuffle(train_batches)
Beispiel #7
0
class Generator(object):
    def __init__(self, beam_size, model_file=None):
        # Construct hyper-parameter
        self.hps = hps
        self.dtool = data_tool
        self.beam_size = beam_size
        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        PAD_ID = self.tool.get_PAD_ID()
        assert PAD_ID > 0

        self.hps = self.hps._replace(vocab_size=vocab_size,
                                     PAD_ID=PAD_ID,
                                     batch_size=beam_size,
                                     mode='decode')
        self.model = PoemModel(self.hps)

        self.EOS_ID, self.PAD_ID, self.GO_ID, self.UNK_ID \
            = self.tool.get_special_IDs()

        self.enc_len = self.hps.bucket[0]
        self.dec_len = self.hps.bucket[1]
        self.topic_trace_size = self.hps.topic_trace_size
        self.key_slots = self.hps.key_slots
        self.his_mem_slots = self.hps.his_mem_slots
        self.his_mem_size = self.hps.his_mem_size
        self.global_trace_size = self.hps.global_trace_size
        self.hidden_size = self.hps.hidden_size

        self.sess = tf.InteractiveSession()

        if model_file is None:
            self.load_model(self.sess, self.model)
        else:
            self.load_model_by_path(self.sess, self.model, model_file)

        self.__buildPH()

    def load_model(self, session, model):
        """load parameters in session."""
        ckpt = tf.train.get_checkpoint_state(self.hps.model_path)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
            print("Reading model parameters from %s" %
                  ckpt.model_checkpoint_path)
            model.saver.restore(session, ckpt.model_checkpoint_path)
        else:
            raise ValueError("%s not found! " % ckpt.model_checkpoint_path)

    def load_model_by_path(self, session, model, modefile):
        """load parameters in session."""
        if tf.gfile.Exists(modefile):
            print("Reading model parameters from %s" % modefile)
            model.saver.restore(session, modefile)
        else:
            raise ValueError("%s not found! " % modefile)

    def __buildPH(self):
        self.__PHDic = self.dtool.buildPHDicForIdx(
            copy.deepcopy(
                self.tool.get_vocab()))  #把__GLDic改一下,由原来的 韵编号-字列表,变为,韵编号-字id列表

    def addtionFilter(self, trans, pos):  #4 trans batch_size,3
        pos -= 1
        preidx = range(0, pos)  #0 1 2
        batch_size = len(trans)
        forbidden_list = [[] for _ in xrange(0, batch_size)]

        for i in range(0, batch_size):
            prechar = [trans[i][c] for c in preidx]
            forbidden_list[i] = prechar

        return forbidden_list

    def beam_select(
        self, probs, trans, k, trg_len, beam_size, repeatidxvec, ph
    ):  #trans是已有的候选,probs是这次step的。trans [candidate_num,already_len] probs [candidate_num,vocab_size]
        V = np.shape(probs)[1]  # vocabulary size
        n_samples = np.shape(probs)[0]
        if k == 1:
            n_samples = beam_size

        # trans_indices, word_indices, costs
        hypothesis = []  # (char_idx, which beam, prob)
        cost_eps = float(1e5)

        # Control inner repeat
        forbidden_list = self.addtionFilter(trans, k)
        for i in range(0, np.shape(probs)[0]):
            probs[i, forbidden_list[i]] = cost_eps

        # Control global repeat
        probs[:, repeatidxvec] = cost_eps

        # hard control for genre
        if ph != 0:
            #print (k, gl)
            probs *= cost_eps
            probs[:, self.__PHDic[ph]] /= float(cost_eps)

        flat_next_costs = probs.flatten()  #全部展平,变为1维列表
        best_costs_indices = np.argpartition(  #若第一个参数是数字列表,第二个参数是k,则返回一个与数字列表相同长度的列表,这个列表里的每个元素是下标,其中下标是k的元素是从小到大排序的正确元素,它左边的元素是小于它的,右边的元素是大于它的
            flat_next_costs.flatten(), n_samples)[:n_samples]

        trans_indices = [int(idx)
                         for idx in best_costs_indices / V]  # which beam line
        word_indices = best_costs_indices % V
        costs = flat_next_costs[best_costs_indices]

        for i in range(0, n_samples):
            hypothesis.append((word_indices[i], trans_indices[i], costs[i]))

        return hypothesis

    def beam_search(self, sess, sen, len_inps, ori_key_states,
                    key_initial_state, ori_topic_trace, ori_his_mem,
                    ori_his_mem_mask, ori_global_trace, enc_mask, ori_key_mask,
                    repeatidxvec, phs):
        trg_length = len(phs)
        beam_size = self.beam_size
        n_samples = beam_size

        enc_state, attn_states = self.model.encoder_computer(
            sess, sen, enc_mask)  #为啥input_feed是0???
        enc_states = copy.deepcopy(attn_states)
        enc_mask = np.array(enc_mask)

        key_states = copy.deepcopy(ori_key_states)
        topic_trace = copy.deepcopy(ori_topic_trace)
        global_trace = copy.deepcopy(ori_global_trace)
        his_mem = copy.deepcopy(ori_his_mem)
        his_mem_mask = copy.deepcopy(ori_his_mem_mask)
        key_mask = copy.deepcopy(ori_key_mask)

        fin_trans, fin_costs, fin_align = [], [], []

        trans = [[] for i in xrange(0, beam_size)]
        costs = [0.0]

        key_align = []
        for i in range(beam_size):
            key_align.append(np.zeros([1, self.key_slots], dtype=np.float32))

        state = enc_state
        if not (key_initial_state is None):
            state = key_initial_state
        inp = np.array([self.GO_ID] * beam_size)

        ph_inp = [phs[0]] * n_samples

        output, state, alignments = self.model.decoder_state_computer(
            sess, inp, len_inps[0], ph_inp, state, attn_states, key_states,
            his_mem, global_trace, enc_mask, key_mask, his_mem_mask,
            topic_trace)  #output [batch_size,vocab_size]

        for k in range(1, 2 * trg_length):
            if n_samples == 0:
                break

            if k == 1:
                output = output[0, :]  #[vocab_size]

            log_probs = np.log(output)
            next_costs = np.array(
                costs
            )[:,
              None] - log_probs  #np.array(costs) array([0.0]) np.array(costs)[:,None] array([[0.0]]) 增加一个维度,shape(1,1)

            # Form a beam for the next iteration
            new_trans = [[] for i in xrange(0, n_samples)]
            new_costs = np.zeros(n_samples, dtype="float32")
            new_states = np.zeros((n_samples, self.hidden_size),
                                  dtype="float32")
            new_align = [[] for i in xrange(0, n_samples)]

            inputs = np.zeros(n_samples, dtype="int64")

            # Note that here k < len(gls) means that we don't put hard constraint on yun
            ph_require = phs[k - 1] if k <= len(phs) else 0

            #print (gl_require)
            hypothesis = self.beam_select(next_costs, trans, k, trg_length,
                                          n_samples, repeatidxvec, ph_require)

            for i, (next_word, orig_idx, next_cost) in enumerate(hypothesis):
                #print("%d %d %d %f %s" % (i, next_word, orig_idx, next_cost))
                new_trans[i] = trans[orig_idx] + [next_word]
                new_costs[i] = next_cost
                align_start = self.his_mem_slots
                align_end = self.his_mem_slots + self.key_slots
                current_align = alignments[orig_idx, align_start:align_end]
                new_align[i] = np.concatenate(
                    (key_align[orig_idx], [current_align]), axis=0)
                new_states[i] = state[orig_idx, :]
                inputs[i] = next_word

            # Filter the sequences that end with end-of-sequence character
            trans, costs, indices, key_align = [], [], [], []

            for i in range(n_samples):
                if new_trans[i][-1] != self.EOS_ID:
                    trans.append(new_trans[i])
                    costs.append(new_costs[i])
                    indices.append(i)
                    key_align.append(new_align[i])
                else:
                    n_samples -= 1
                    fin_trans.append(new_trans[i])
                    fin_costs.append(new_costs[i])
                    fin_align.append(new_align[i])

            if n_samples == 0:
                break

            inputs = inputs[indices]
            new_states = new_states[indices]
            attn_states = attn_states[indices, :, :]

            global_trace = global_trace[indices, :]
            enc_mask = enc_mask[indices, :, :]
            key_states = key_states[indices, :, :]
            his_mem = his_mem[indices, :, :]
            key_mask = key_mask[indices, :, :]
            his_mem_mask = his_mem_mask[indices, :]
            topic_trace = topic_trace[indices, :]

            if k >= np.shape(len_inps)[0]:
                specify_len = len_inps[np.shape(len_inps)[0] - 1, indices]
            else:
                specify_len = len_inps[k, indices]

            if k >= len(phs):
                ph_inp = [0] * n_samples
            else:
                ph_inp = [phs[k]] * n_samples

            output, state, alignments = self.model.decoder_state_computer(
                sess, inputs, specify_len, ph_inp, new_states, attn_states,
                key_states, his_mem, global_trace, enc_mask, key_mask,
                his_mem_mask, topic_trace)

        #print (np.shape(fin_align))
        for i in range(len(fin_align)):
            fin_align[i] = fin_align[i][1:, :]

        index = np.argsort(fin_costs)  #从小到大排序,返回下标列表
        fin_align = np.array(fin_align)[index]
        fin_trans = np.array(fin_trans)[index]
        fin_costs = np.array(sorted(fin_costs))

        if len(fin_trans) == 0:
            index = np.argsort(costs)
            fin_align = np.array(key_align)[index]
            fin_trans = np.array(trans)[index]
            fin_costs = np.array(sorted(costs))

        return fin_trans, fin_costs, fin_align, enc_states

    def get_new_global_trace(self, sess, history, ori_enc_states, beam_size):

        enc_states = np.expand_dims(ori_enc_states,
                                    axis=0)  #[1,enc_len,2*hidden_size]
        prev_history = np.expand_dims(history[0, :], 0)  #[1,global_trace_size]
        #print (np.shape(prev_encoder_state))
        #tt = input(">")
        new_history = self.model.global_trace_computer(sess, prev_history,
                                                       enc_states)
        new_history = np.tile(new_history,
                              [beam_size, 1])  #[beam_size,global_trace_size]
        return new_history

    def get_new_his_mem(self, sess, ori_his_mem, enc_states, ori_global_trace,
                        beam_size, src_len):
        his_mem = np.expand_dims(ori_his_mem,
                                 axis=0)  #[1,his_mem_slots,his_mem_size]
        fin_states = []
        for i in xrange(0, np.shape(enc_states)[0]):
            fin_states.append(np.expand_dims(enc_states[i],
                                             0))  #enc_len [1,2*hidden_size]

        mask = [np.ones((1, 1))] * src_len + [np.zeros(
            (1, 1))] * (np.shape(enc_states)[0] - src_len)
        global_trace = np.expand_dims(ori_global_trace[0, :],
                                      0)  #[1,global_trace_size]
        new_his_mem = self.model.his_mem_computer(sess, his_mem, fin_states,
                                                  mask, global_trace)

        new_his_mem = np.tile(
            new_his_mem,
            [beam_size, 1, 1])  #[beam_size,his_mem_slots,his_mem_size]
        return new_his_mem

    def get_new_topic_trace(self, sess, ori_topic_trace, key_align,
                            ori_key_states, beam_size):
        key_states = np.expand_dims(ori_key_states[0, :, :],
                                    0)  #[1,key_slots,2*hidden_size]
        topic_trace = np.expand_dims(ori_topic_trace,
                                     axis=0)  #[1,topic_trace_size+key_slots]
        key_align = np.mean(key_align,
                            axis=0)  #[trg_len,key_slots]变为[key_slots]
        key_align = np.expand_dims(key_align, axis=0)  #[1,key_slots]
        new_topic_trace = self.model.topic_trace_computer(
            sess, key_states, topic_trace, key_align)
        new_topic_trace = np.tile(
            new_topic_trace,
            [beam_size, 1])  #[beam_size,topic_trace_size+key_slots]
        return new_topic_trace

    def generate_one(self, keystr, pattern):  #pattern 4,5或4,7
        beam_size = self.beam_size
        sens_num = len(pattern)
        keys = keystr.strip()
        ans, repeatidxes = [], []
        print("using keywords: %s" % (keystr))
        keys = keystr.split(" ")
        keys_idxes = [
            self.tool.chars2idxes(self.tool.line2chars(key)) for key in keys
        ]
        #print (keys_idxes)
        key_inps, key_mask = self.tool.gen_batch_key_beam(
            keys_idxes, beam_size
        )  #key_inps:key_slots,2 [batch_size] key_mask:batch_size [key_slots,1]

        # Calculate initial_key state and key_states
        key_initial_state, key_states = self.model.key_memory_computer(
            self.sess, key_inps, key_mask)

        his_mem_mask = np.zeros([beam_size, self.his_mem_slots],
                                dtype=np.float32)
        global_trace = np.zeros([beam_size, self.global_trace_size],
                                dtype='float32')
        his_mem = np.zeros([beam_size, self.his_mem_slots, self.his_mem_size],
                           dtype='float32')
        topic_trace = np.zeros(
            [beam_size, self.topic_trace_size + self.key_slots],
            dtype='float32')

        # Generate the first line, line0 is an empty list
        sen = []
        for step in xrange(0, sens_num):
            print("generating %d line..." % (step + 1))
            phs = pattern[step]
            trg_len = len(phs)
            if step > 0:
                key_initial_state = None
            src_len = len(sen)
            batch_sen, enc_mask, len_inps = self.tool.gen_batch_beam(
                sen, trg_len, beam_size)  #len_inps [dec_len,batch_size]
            trans, costs, align, enc_states = self.beam_search(
                self.sess, batch_sen, len_inps, key_states, key_initial_state,
                topic_trace, his_mem, his_mem_mask, global_trace, enc_mask,
                key_mask, repeatidxes, phs)

            trans, costs, align, enc_states = self.pFilter(
                trans, costs, align, enc_states, trg_len)

            if len(trans) == 0:
                return [], ("line %d generation failed!" % (step + 1))

            which = 0

            his_mem = self.get_new_his_mem(self.sess, his_mem[which, :, :],
                                           enc_states[which], global_trace,
                                           beam_size, src_len)

            if step >= 1:  #更新his_mem_mask
                one_his_mem = his_mem[
                    which, :, :]  #[his_mem_slots,his_mem_size]
                his_mem_mask = np.sum(np.abs(one_his_mem),
                                      axis=1)  #[his_mem_slots]
                his_mem_mask = his_mem_mask != 0
                his_mem_mask = np.tile(his_mem_mask.astype(
                    np.float32), [beam_size, 1])  #[beam_size,his_mem_slots]

            sentence = self.tool.beam_get_sentence(trans[which])
            sentence = sentence.strip()
            ans.append(sentence)
            attn_aligns = align[which][0:trg_len, :]  #trg_len,key_slots
            topic_trace = self.get_new_topic_trace(self.sess,
                                                   topic_trace[which, :],
                                                   attn_aligns, key_states,
                                                   beam_size)
            global_trace = self.get_new_global_trace(self.sess, global_trace,
                                                     enc_states[which],
                                                     beam_size)

            sentence = self.tool.line2chars(sentence)
            sen = self.tool.chars2idxes(sentence)
            repeatidxes = list(set(repeatidxes).union(set(sen)))

        return ans, "ok"

    def pFilter(self, trans, costs, align, states, trg_len):
        new_trans, new_costs, new_align, new_states = [], [], [], []

        for i in range(len(trans)):
            if len(trans[i]) < trg_len:
                continue
            tran = trans[i][0:trg_len]
            sen = self.tool.idxes2chars(tran)
            sen = "".join(sen)
            if trg_len > 4 and self.dtool.checkIfInLib(sen):
                continue
            new_trans.append(tran)
            new_align.append(align[i])
            new_states.append(states[i])
            new_costs.append(costs[i])

        return new_trans, new_costs, new_align, new_states
Beispiel #8
0
class Generator(object):
    def __init__(self, beam_size, model_path=None):
        # Construct hyper-parameter
        self.hps = hps
        self.dtool = data_tool
        self.beam_size = beam_size
        if hps.init_emb == '':
            self.init_emb = None
        else:
            self.init_emb = np.load(self.hps.init_emb)
            print("init_emb_size: %s" % str(np.shape(self.init_emb)))

        self.tool = PoetryTool(sens_num=hps.sens_num,
                               key_slots=hps.key_slots,
                               enc_len=hps.bucket[0],
                               dec_len=hps.bucket[1])
        self.tool.load_dic(hps.vocab_path, hps.ivocab_path)

        vocab_size = self.tool.get_vocab_size()
        assert vocab_size > 0
        self.hps = self.hps._replace(vocab_size=vocab_size,
                                     batch_size=beam_size)

        f_idxes = self.tool.build_fvec()
        self.model = graphs.WorkingMemoryModel(self.hps, f_idxes)
        self.model.build_eval_graph()

        self.PAD_ID, self.UNK_ID, self.B_ID, self.E_ID, _ \
            = self.tool.get_special_IDs()

        self.enc_len = self.hps.bucket[0]
        self.dec_len = self.hps.bucket[1]
        self.topic_trace_size = self.hps.topic_trace_size
        self.key_slots = self.hps.key_slots
        self.his_mem_slots = self.hps.his_mem_slots
        self.mem_size = self.hps.mem_size
        self.global_trace_size = self.hps.global_trace_size
        self.hidden_size = self.hps.hidden_size

        self.sess = tf.InteractiveSession()

        self.load_model(model_path)
        self.__buildPH()

    def load_model(self, model_path):
        """load parameters in session."""
        saver = tf.train.Saver(tf.global_variables(),
                               write_version=tf.train.SaverDef.V1)

        if model_path is None:
            ckpt = tf.train.get_checkpoint_state(self.hps.model_path)
            if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
                print("Reading model parameters from %s" %
                      ckpt.model_checkpoint_path)
                saver.restore(self.sess, ckpt.model_checkpoint_path)
            else:
                raise ValueError("%s not found! " % ckpt.model_checkpoint_path)
        else:
            print("Reading model parameters from %s" % model_path)
            saver.restore(self.sess, model_path)

    def __buildPH(self):
        self.__PHDic = self.dtool.buildPHDicForIdx(
            copy.deepcopy(self.tool.get_vocab()))

    def addtionFilter(self, trans, pos):
        pos -= 1
        preidx = range(0, pos)
        batch_size = len(trans)
        forbidden_list = [[] for _ in range(0, batch_size)]

        for i in range(0, batch_size):
            prechar = [trans[i][c] for c in preidx]
            forbidden_list[i] = prechar

        return forbidden_list

    def beam_select(self, probs, trans, k, trg_len, beam_size, repeatidxvec,
                    ph):
        V = np.shape(probs)[1]  # vocabulary size
        n_samples = np.shape(probs)[0]
        if k == 1:
            n_samples = beam_size

        # trans_indices, word_indices, costs
        hypothesis = []  # (char_idx, which beam, prob)
        cost_eps = float(1e5)

        # Control inner repeat
        forbidden_list = self.addtionFilter(trans, k)
        for i in range(0, np.shape(probs)[0]):
            probs[i, forbidden_list[i]] = cost_eps

        # Control global repeat
        probs[:, repeatidxvec] = cost_eps

        # hard control for genre
        if ph != 0:
            #print (k, gl)
            probs *= cost_eps
            probs[:, self.__PHDic[ph]] /= float(cost_eps)

        flat_next_costs = probs.flatten()
        best_costs_indices = np.argpartition(flat_next_costs.flatten(),
                                             n_samples)[:n_samples]

        trans_indices = [int(idx)
                         for idx in best_costs_indices / V]  # which beam line
        word_indices = best_costs_indices % V
        costs = flat_next_costs[best_costs_indices]

        for i in range(0, n_samples):
            hypothesis.append((word_indices[i], trans_indices[i], costs[i]))

        return hypothesis

    def beam_search(self, sen, len_inps, ori_key_states, key_initial_state,
                    ori_topic_trace, ori_his_mem, ori_his_mem_mask,
                    ori_global_trace, enc_mask, ori_key_mask, repeatidxvec,
                    phs):
        trg_length = len(phs)
        beam_size = self.beam_size

        enc_state, ori_attn_states = self.encoder_computer(sen, enc_mask)
        enc_mask = np.tile(enc_mask, [beam_size, 1, 1])
        key_mask = np.tile(ori_key_mask, [beam_size, 1, 1])

        attn_states = copy.deepcopy(ori_attn_states)
        key_states = copy.deepcopy(ori_key_states)
        topic_trace = copy.deepcopy(ori_topic_trace)
        global_trace = copy.deepcopy(ori_global_trace)
        his_mem = copy.deepcopy(ori_his_mem)
        his_mem_mask = copy.deepcopy(ori_his_mem_mask)

        fin_trans, fin_costs, fin_align = [], [], []

        trans = [[] for i in range(0, beam_size)]
        costs = [0.0]

        key_align = []
        for i in range(beam_size):
            key_align.append(np.zeros([1, self.key_slots], dtype=np.float32))

        state = enc_state
        if not (key_initial_state is None):
            state = key_initial_state
        inp = np.array([self.B_ID] * beam_size)

        ph_inp = [phs[0]] * beam_size
        specify_len = len_inps[0]

        output, state, alignments = self.decoder_computer(
            inp, specify_len, ph_inp, state, attn_states, key_states, his_mem,
            global_trace, enc_mask, key_mask, his_mem_mask, topic_trace)

        n_samples = beam_size

        for k in range(1, trg_length + 4):
            if n_samples == 0:
                break

            if k == 1:
                output = output[0, :]

            log_probs = np.log(output)
            next_costs = np.array(costs)[:, None] - log_probs

            # Form a beam for the next iteration
            new_trans = [[] for i in range(0, n_samples)]
            new_costs = np.zeros(n_samples, dtype="float32")
            new_states = np.zeros((n_samples, self.hidden_size),
                                  dtype="float32")
            new_align = [[] for i in range(0, n_samples)]

            inputs = np.zeros(n_samples, dtype=np.int32)

            ph_require = phs[k - 1] if k <= len(phs) else 0

            #print (gl_require)
            hypothesis = self.beam_select(next_costs, trans, k, trg_length,
                                          n_samples, repeatidxvec, ph_require)

            for i, (next_word, orig_idx, next_cost) in enumerate(hypothesis):
                #print("%d %d %d %f %s" % (i, next_word, orig_idx, next_cost))
                new_trans[i] = trans[orig_idx] + [next_word]
                new_costs[i] = next_cost
                align_start = self.his_mem_slots
                align_end = self.his_mem_slots + self.key_slots
                current_align = alignments[orig_idx, :, align_start:align_end]
                new_align[i] = np.concatenate(
                    (key_align[orig_idx], current_align), axis=0)
                new_states[i] = state[orig_idx, :]
                inputs[i] = next_word

            # Filter the sequences that end with end-of-sequence character
            trans, costs, indices, key_align = [], [], [], []

            for i in range(n_samples):
                if new_trans[i][-1] != self.E_ID:
                    trans.append(new_trans[i])
                    costs.append(new_costs[i])
                    indices.append(i)
                    key_align.append(new_align[i])
                else:
                    n_samples -= 1
                    fin_trans.append(new_trans[i])
                    fin_costs.append(new_costs[i])
                    fin_align.append(new_align[i])

            if n_samples == 0:
                break

            inputs = inputs[indices]
            new_states = new_states[indices]
            attn_states = attn_states[indices, :, :]

            global_trace = global_trace[indices, :]
            enc_mask = enc_mask[indices, :, :]
            key_states = key_states[indices, :, :]
            his_mem = his_mem[indices, :, :]
            key_mask = key_mask[indices, :, :]
            his_mem_mask = his_mem_mask[indices, :]
            topic_trace = topic_trace[indices, :]

            if k >= np.shape(len_inps)[0]:
                specify_len = len_inps[np.shape(len_inps)[0] - 1, indices]
            else:
                specify_len = len_inps[k, indices]

            if k >= len(phs):
                ph_inp = [0] * n_samples
            else:
                ph_inp = [phs[k]] * n_samples

            output, state, alignments = self.decoder_computer(
                inputs, specify_len, ph_inp, new_states, attn_states,
                key_states, his_mem, global_trace, enc_mask, key_mask,
                his_mem_mask, topic_trace)

        for i in range(len(fin_align)):
            fin_align[i] = fin_align[i][1:, :]

        index = np.argsort(fin_costs)
        fin_align = np.array(fin_align)[index]
        fin_trans = np.array(fin_trans)[index]
        fin_costs = np.array(sorted(fin_costs))

        if len(fin_trans) == 0:
            index = np.argsort(costs)
            fin_align = np.array(key_align)[index]
            fin_trans = np.array(trans)[index]
            fin_costs = np.array(sorted(costs))

        return fin_trans, fin_costs, fin_align, ori_attn_states

    def get_new_global_trace(self, global_trace, ori_enc_states, beam_size):
        enc_states = np.expand_dims(ori_enc_states, axis=0)
        prev_global_trace = np.expand_dims(global_trace[0, :], 0)
        new_global_trace = self.global_trace_computer(prev_global_trace,
                                                      enc_states)
        new_global_trace = np.tile(new_global_trace, [beam_size, 1])
        return new_global_trace

    def get_new_his_mem(self, ori_his_mem, attn_states, ori_global_trace,
                        beam_size, src_len):
        #ori_his_mem [4, mem_size]
        his_mem = np.expand_dims(ori_his_mem, axis=0)
        enc_outs = []
        assert np.shape(attn_states)[0] == self.enc_len
        for i in range(0, self.enc_len):
            enc_outs.append(np.expand_dims(attn_states[i, :], 0))

        mask = [np.ones((1, 1))] * src_len + [np.zeros(
            (1, 1))] * (self.enc_len - src_len)
        global_trace = np.expand_dims(ori_global_trace[0, :], 0)
        new_his_mem = self.his_mem_computer(his_mem, enc_outs, mask,
                                            global_trace)

        new_his_mem = np.tile(new_his_mem, [beam_size, 1, 1])
        return new_his_mem

    def get_new_topic_trace(self, ori_topic_trace, key_align, ori_key_states,
                            beam_size):
        key_states = np.expand_dims(ori_key_states[0, :, :], 0)
        topic_trace = np.expand_dims(ori_topic_trace, axis=0)
        key_align = np.mean(key_align, axis=0)
        key_align = np.expand_dims(key_align, axis=0)
        new_topic_trace = self.topic_trace_computer(key_states, topic_trace,
                                                    key_align)
        new_topic_trace = np.tile(new_topic_trace, [beam_size, 1])
        return new_topic_trace

    def generate_one(self, keystr, pattern):
        beam_size = self.beam_size
        ans, repeatidxes = [], []
        sens_num = len(pattern)

        keys = keystr.strip()
        print("using keywords: %s" % (keystr))
        keys = keystr.split(" ")
        keys_idxes = [
            self.tool.chars2idxes(self.tool.line2chars(key)) for key in keys
        ]

        # Here we set batch size=1 and then tile the results
        key_inps, key_mask = self.tool.gen_batch_key_beam(keys_idxes,
                                                          batch_size=1)

        # Calculate initial_key state and key_states
        key_initial_state, key_states = self.key_memory_computer(
            key_inps, key_mask, beam_size)

        his_mem_mask = np.zeros([beam_size, self.his_mem_slots],
                                dtype=np.float32)
        global_trace = np.zeros([beam_size, self.global_trace_size],
                                dtype=np.float32)
        his_mem = np.zeros([beam_size, self.his_mem_slots, self.mem_size],
                           dtype=np.float32)
        topic_trace = np.zeros(
            [beam_size, self.topic_trace_size + self.key_slots],
            dtype=np.float32)

        # Generate the lines, line0 is an empty list
        sen = []
        for step in range(0, sens_num):
            print("generating %d line..." % (step + 1))
            if step > 0:
                key_initial_state = None

            phs = pattern[step]
            trg_len = len(phs)
            src_len = len(sen)

            batch_sen, enc_mask, len_inps = self.tool.gen_enc_beam(
                sen, trg_len, batch_size=1)
            len_inps = np.tile(len_inps, [1, beam_size])

            trans, costs, align, attn_states = self.beam_search(
                batch_sen, len_inps, key_states, key_initial_state,
                topic_trace, his_mem, his_mem_mask, global_trace, enc_mask,
                key_mask, repeatidxes, phs)

            trans, costs, align, attn_states = self.pFilter(
                trans, costs, align, attn_states, trg_len)

            if len(trans) == 0:
                return [], ("line %d generation failed!" % (step + 1))

            which = 0

            one_his_mem = his_mem[which, :, :]
            his_mem = self.get_new_his_mem(one_his_mem, attn_states[which],
                                           global_trace, beam_size, src_len)

            if step >= 1:
                his_mem_mask = np.sum(np.abs(one_his_mem), axis=1)
                his_mem_mask = his_mem_mask != 0
                his_mem_mask = np.tile(his_mem_mask.astype(np.float32),
                                       [beam_size, 1])

            sentence = self.tool.beam_get_sentence(trans[which])
            sentence = sentence.strip()
            ans.append(sentence)
            attn_aligns = align[which][0:trg_len, :]
            topic_trace = self.get_new_topic_trace(topic_trace[which, :],
                                                   attn_aligns, key_states,
                                                   beam_size)
            global_trace = self.get_new_global_trace(global_trace,
                                                     attn_states[which],
                                                     beam_size)

            sentence = self.tool.line2chars(sentence)
            sen = self.tool.chars2idxes(sentence)
            repeatidxes = list(set(repeatidxes).union(set(sen)))

        return ans, "ok"

    def pFilter(self, trans, costs, align, states, trg_len):
        new_trans, new_costs, new_align, new_states = [], [], [], []

        for i in range(len(trans)):
            if len(trans[i]) < trg_len:
                continue

            tran = trans[i][0:trg_len]
            sen = self.tool.idxes2chars(tran)
            sen = "".join(sen)
            if trg_len > 4 and self.dtool.checkIfInLib(sen):
                continue
            new_trans.append(tran)
            new_align.append(align[i])
            new_states.append(states[i])
            new_costs.append(costs[i])

        return new_trans, new_costs, new_align, new_states

    # -----------------------------------------------------------------
    # Some apis for beam search
    def key_memory_computer(self, key_inps, key_mask, beam_size):
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0
        for step in range(self.key_slots):
            for j in range(2):
                input_feed[self.model.key_inps[step]
                           [j].name] = key_inps[step][j]
        input_feed[self.model.key_mask.name] = key_mask

        output_feed = [self.model.key_initial_state, self.model.key_states]
        [key_initial_state,
         key_states] = self.sess.run(output_feed, input_feed)
        key_initial_state = np.tile(key_initial_state, [beam_size, 1])
        key_states = np.tile(key_states, [beam_size, 1, 1])
        return key_initial_state, key_states

    def encoder_computer(self, enc_inps, enc_mask):
        assert self.enc_len == len(enc_inps)
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0
        for l in range(self.enc_len):
            input_feed[self.model.enc_inps[0][l].name] = enc_inps[l]

        input_feed[self.model.enc_mask[0].name] = enc_mask
        output_feed = [self.model.enc_state, self.model.attn_states]
        [enc_state, attn_states] = self.sess.run(output_feed, input_feed)

        enc_state = np.tile(enc_state, [self.beam_size, 1])
        attn_states = np.tile(attn_states, [self.beam_size, 1, 1])

        return enc_state, attn_states

    def decoder_computer(self, dec_inp, len_inp, ph_inp, prev_state,
                         attn_states, key_states, his_mem, global_trace,
                         enc_mask, key_mask, his_mem_mask, topic_trace):
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0

        input_feed[self.model.dec_inps[0][0].name] = dec_inp
        input_feed[self.model.len_inps[0][0].name] = len_inp
        input_feed[self.model.ph_inps[0][0].name] = ph_inp

        input_feed[self.model.beam_attn_states.name] = attn_states
        input_feed[self.model.enc_mask[0].name] = enc_mask
        input_feed[self.model.beam_initial_state.name] = prev_state
        input_feed[self.model.beam_key_states.name] = key_states
        input_feed[self.model.key_mask.name] = key_mask
        input_feed[self.model.beam_global_trace.name] = global_trace
        input_feed[self.model.beam_topic_trace.name] = topic_trace
        input_feed[self.model.beam_his_mem.name] = his_mem
        input_feed[self.model.beam_his_mem_mask.name] = his_mem_mask

        output_feed = [
            self.model.next_out, self.model.next_state, self.model.next_align
        ]

        [next_output, next_state,
         next_align] = self.sess.run(output_feed, input_feed)

        return next_output, next_state, next_align

    def his_mem_computer(self, his_mem, enc_outs, write_mask, global_trace):
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0
        input_feed[self.model.gama] = [0.05]
        input_feed[self.model.beam_his_mem.name] = his_mem

        for l in range(self.enc_len):
            input_feed[self.model.beam_enc_outs[l]] = enc_outs[l]
            input_feed[self.model.write_masks[0][l]] = write_mask[l]
        input_feed[self.model.beam_global_trace.name] = global_trace

        output_feed = [self.model.new_his_mem]
        [new_his_mem] = self.sess.run(output_feed, input_feed)

        return new_his_mem

    def topic_trace_computer(self, key_states, prev_topic_trace, key_align):
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0
        input_feed[self.model.beam_key_states.name] = key_states
        input_feed[self.model.beam_topic_trace.name] = prev_topic_trace
        input_feed[self.model.beam_key_align.name] = key_align
        output_feed = [self.model.new_topic_trace]

        [new_topic_trace] = self.sess.run(output_feed, input_feed)
        return new_topic_trace

    def global_trace_computer(self, prev_global_trace, prev_attn_states):
        input_feed = {}
        input_feed[self.model.keep_prob] = 1.0
        input_feed[self.model.beam_global_trace] = prev_global_trace
        input_feed[self.model.beam_attn_states] = prev_attn_states
        output_feed = [self.model.new_global_trace]

        [new_global_trace] = self.sess.run(output_feed, input_feed)
        return new_global_trace