Пример #1
0
	def create_summary(self, args):
		self.summaryHelper = SummaryHelper("%s/%s_%s" % \
				(args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

		self.trainSummary = self.summaryHelper.addGroup(scalar=["loss",
																"perplexity",
																"kl_loss",
																"kld",
																"kl_weight"],
														prefix="train")

		scalarlist = ["loss", "perplexity", "kl_loss", "kld", "kl_weight"]
		tensorlist = []
		textlist = []
		for i in args.show_sample:
			textlist.append("show_str%d" % i)
		self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist,
													   prefix="dev")
		self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist,
													   prefix="test")
Пример #2
0
	def create_summary(self, args):
		self.summaryHelper = SummaryHelper("%s/%s_%s" % \
				(args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

		self.trainSummary = self.summaryHelper.addGroup(scalar=["neg_elbo",
																"recontruction_loss",
																"KL_weight",
																"KL_divergence",
																"bow_loss",
																"perplexity"], prefix="train")

		scalarlist = ["neg_elbo", "reconstruction_loss", "KL_weight", "KL_divergence", "bow_loss",
					  "perplexity"]
		tensorlist = []
		textlist = []
		emblist = []
		for i in args.show_sample:
			textlist.append("show_str%d" % i)
		self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist,
				embedding=emblist, prefix="dev")
		self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist,
				embedding=emblist, prefix="test")
Пример #3
0
def create_summary(args):
    summaryHelper = SummaryHelper("%s/%s_%s" % \
            (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

    gen_trainSummary = summaryHelper.addGroup(scalar=["loss", "rewards"],
                                                    prefix="gen_train")
    dis_trainSummary = summaryHelper.addGroup(scalar=["loss", "accuracy"],
                                                    prefix="dis_train")
    gen_scalarlist = ["loss", "rewards"]
    dis_scalarlist = ["loss", "accuracy"]
    tensorlist = []
    textlist = []
    for i in args.show_sample:
        textlist.append("show_str%d" % i)
    gen_devSummary = summaryHelper.addGroup(scalar=gen_scalarlist, tensor=tensorlist, text=textlist,
                                                    prefix="gen_dev")
    gen_testSummary = summaryHelper.addGroup(scalar=gen_scalarlist, tensor=tensorlist, text=textlist,
                                                    prefix="gen_test")
    dis_devSummary = summaryHelper.addGroup(scalar=dis_scalarlist, tensor=tensorlist, text=textlist,
                                                    prefix="dis_dev")
    dis_testSummary = summaryHelper.addGroup(scalar=dis_scalarlist, tensor=tensorlist, text=textlist,
                                                    prefix="dis_test")

    return gen_trainSummary, gen_devSummary, gen_testSummary, dis_trainSummary, dis_devSummary, dis_testSummary
Пример #4
0
class VAEModel(object):
    def __init__(self, data, args, embed):

        with tf.variable_scope("input"):
            with tf.variable_scope("embedding"):
                # build the embedding table and embedding input
                if embed is None:
                    # initialize the embedding randomly
                    self.embed = tf.get_variable(
                        'embed', [data.vocab_size, args.embedding_size],
                        tf.float32)
                else:
                    # initialize the embedding by pre-trained word vectors
                    self.embed = tf.get_variable('embed',
                                                 dtype=tf.float32,
                                                 initializer=embed)

            self.sentence = tf.placeholder(tf.int32, (None, None),
                                           'sen_inps')  # batch*len
            self.sentence_length = tf.placeholder(tf.int32, (None, ),
                                                  'sen_lens')  # batch
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

            batch_size, batch_len = tf.shape(self.sentence)[0], tf.shape(
                self.sentence)[1]
            self.decoder_max_len = batch_len - 1

            self.encoder_input = tf.nn.embedding_lookup(
                self.embed, self.sentence)  # batch*len*unit
            self.encoder_len = self.sentence_length

            decoder_input = tf.split(self.sentence, [self.decoder_max_len, 1],
                                     1)[0]  # no eos_id
            self.decoder_input = tf.nn.embedding_lookup(
                self.embed, decoder_input)  # batch*(len-1)*unit
            self.decoder_target = tf.split(self.sentence,
                                           [1, self.decoder_max_len],
                                           1)[1]  # no go_id, batch*(len-1)
            self.decoder_len = self.sentence_length - 1
            self.decoder_mask = tf.sequence_mask(
                self.decoder_len, self.decoder_max_len,
                dtype=tf.float32)  # batch*(len-1)

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = dynamic_rnn(cell_enc,
                                                        self.encoder_input,
                                                        self.encoder_len,
                                                        dtype=tf.float32,
                                                        scope="encoder_rnn")

        with tf.variable_scope('recognition_net'):
            recog_input = encoder_state
            self.recog_mu = tf.layers.dense(inputs=recog_input,
                                            units=args.z_dim,
                                            activation=None,
                                            name='recog_mu')
            self.recog_logvar = tf.layers.dense(inputs=recog_input,
                                                units=args.z_dim,
                                                activation=None,
                                                name='recog_logvar')

            epsilon = tf.random_normal(tf.shape(self.recog_logvar),
                                       name="epsilon")
            std = tf.exp(0.5 * self.recog_logvar)
            self.recog_z = tf.add(self.recog_mu,
                                  tf.multiply(std, epsilon),
                                  name='recog_z')

            self.kld = tf.reduce_mean(0.5 * tf.reduce_sum(
                tf.exp(self.recog_logvar) + self.recog_mu * self.recog_mu -
                self.recog_logvar - 1,
                axis=-1))
            self.prior_z = tf.random_normal(tf.shape(self.recog_logvar),
                                            name="prior_z")
            latent_sample = tf.cond(self.use_prior,
                                    lambda: self.prior_z,
                                    lambda: self.recog_z,
                                    name='latent_sample')
            dec_init_state = tf.layers.dense(inputs=latent_sample,
                                             units=args.dh_size,
                                             activation=None)

        with tf.variable_scope("output_layer",
                               initializer=tf.orthogonal_initializer()):
            self.output_layer = Dense(
                data.vocab_size,
                kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),
                use_bias=True)

        with tf.variable_scope("decode",
                               initializer=tf.orthogonal_initializer()):
            train_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=self.decoder_input, sequence_length=self.decoder_len)
            train_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=cell_dec,
                helper=train_helper,
                initial_state=dec_init_state,
                output_layer=self.output_layer)
            train_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=train_decoder,
                maximum_iterations=self.decoder_max_len,
                impute_finished=True)
            logits = train_output.rnn_output

            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self.decoder_target, logits=logits)
            crossent = tf.reduce_sum(crossent * self.decoder_mask)
            self.sen_loss = crossent / tf.to_float(batch_size)
            self.ppl_loss = crossent / tf.reduce_sum(self.decoder_mask)

            self.decoder_distribution_teacher = tf.nn.log_softmax(logits)

        with tf.variable_scope("decode", reuse=True):
            infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                self.embed, tf.fill([batch_size], data.go_id), data.eos_id)
            infer_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell=cell_dec,
                helper=infer_helper,
                initial_state=dec_init_state,
                output_layer=self.output_layer)
            infer_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder=infer_decoder,
                maximum_iterations=self.decoder_max_len,
                impute_finished=True)
            self.decoder_distribution = infer_output.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        self.kl_weights = tf.minimum(
            tf.to_float(self.global_step) / args.full_kl_step, 1.0)
        self.kl_loss = self.kl_weights * tf.maximum(self.kld, args.min_kl)
        self.loss = self.sen_loss + self.kl_loss

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.MomentumOptimizer(learning_rate=self.learning_rate,
                                         momentum=args.momentum)
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key):
        if key == "latest":
            self.latest_saver.save(sess, path, global_step=self.global_step)
        else:
            self.best_saver.save(sess, path, global_step=self.global_step)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity", "kl_loss", "kld", "kl_weight"],
            prefix="train")

        scalarlist = ["loss", "perplexity", "kl_loss", "kld", "kl_weight"]
        tensorlist = []
        textlist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, forward_only=False):
        input_feed = {
            self.sentence: data['sent'],
            self.sentence_length: data['sent_length'],
            self.use_prior: False
        }
        if forward_only:
            output_feed = [
                self.loss, self.decoder_distribution_teacher, self.ppl_loss,
                self.kl_loss, self.kld, self.kl_weights
            ]
        else:
            output_feed = [
                self.loss, self.gradient_norm, self.update, self.ppl_loss,
                self.kl_loss, self.kld, self.kl_weights
            ]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.sentence: data['sent'],
            self.sentence_length: data['sent_length'],
            self.use_prior: True
        }
        output_feed = [self.generation_index]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss_step = np.zeros((1, ))
        ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0
        times = 0
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_decoder(sess, batched_data, forward_only=True)
            loss_step += outputs[0]
            ppl_loss_step += outputs[-4]
            kl_loss_step += outputs[-3]
            kld_step += outputs[-2]
            kl_weight_step = outputs[-1]
            times += 1
            batched_data = data.get_next_batch(key_name)

        loss_step /= times
        ppl_loss_step /= times
        kl_loss_step /= times
        kld_step /= times

        print('    loss: %.2f' % loss_step)
        print('    kl_loss: %.2f' % kl_loss_step)
        print('    perplexity: %.2f' % np.exp(ppl_loss_step))
        print('    kld: %.2f' % kld_step)
        return loss_step, ppl_loss_step, kl_loss_step, kld_step, kl_weight_step

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0
        previous_losses = [1e18] * 5
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    print(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step))
                    print('    loss: %.2f' % loss_step)
                    print('    kl_loss: %.2f' % kl_loss_step)
                    print('    perplexity: %.2f' % np.exp(ppl_loss_step))
                    print('    kld: %.2f' % kld_step)
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(ppl_loss_step),
                            'kl_loss': kl_loss_step,
                            'kld': kld_step,
                            'kl_weight': kl_weight_step
                        })
                    #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\
                    self.store_checkpoint(
                        sess,
                        '%s/checkpoint_latest/checkpoint' % args.model_dir,
                        "latest")

                    devout = self.evaluate(sess, data, args.batch_size, "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': devout[0],
                            'perplexity': np.exp(devout[1]),
                            'kl_loss': devout[2],
                            'kld': devout[3],
                            'kl_weight': devout[4]
                        })

                    testout = self.evaluate(sess, data, args.batch_size,
                                            "test")
                    self.testSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': testout[0],
                            'perplexity': np.exp(testout[1]),
                            'kl_loss': testout[2],
                            'kld': testout[3],
                            'kl_weight': testout[4]
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if devout[0] < best_valid:
                        best_valid = devout[0]
                        self.store_checkpoint(
                            sess,
                            '%s/checkpoint_best/checkpoint' % args.model_dir,
                            "best")

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0
                    ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0

                start_time = time.time()
                outputs = self.step_decoder(sess, batched_data)
                loss_step += outputs[0] / args.checkpoint_steps
                ppl_loss_step += outputs[-4] / args.checkpoint_steps
                kl_loss_step += outputs[-3] / args.checkpoint_steps
                kld_step += outputs[-2] / args.checkpoint_steps
                kl_weight_step = outputs[-1]

                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process(self, sess, data, args):
        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")
        results = []
        while batched_data != None:
            batched_responses_id = self.inference(sess, batched_data)[0]
            gen_log_prob = self.step_decoder(sess,
                                             batched_data,
                                             forward_only=True)[1]
            metric1_data = {
                'sent_allvocabs': np.array(batched_data['sent_allvocabs']),
                'sent_length': np.array(batched_data['sent_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                result_token = []
                response_id_list = response_id.tolist()
                response_token = data.index_to_sen(response_id_list)
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                for token in response_token:
                    if token != data.ext_vocab[data.eos_id]:
                        result_token.append(token)
                    else:
                        break
                results.append(result_token)
                batch_results.append(result_id)

            metric2_data = {'gen': np.array(batch_results)}
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            for key, value in res.items():
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            for i in range(len(res['gen'])):
                f.write("%s\n" % " ".join(res['gen'][i]))

        print("result output to %s." % test_file)
Пример #5
0
class HredModel(object):
    def __init__(self, data, args, embed):
        self.init_states = tf.placeholder(tf.float32, (None, args.ch_size),
                                          'ctx_inps')  # batch*ch_size
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prev_posts = tf.placeholder(tf.int32, (None, None),
                                         'enc_prev_inps')
        self.prev_posts_length = tf.placeholder(tf.int32, (None, ),
                                                'enc_prev_lens')

        self.kgs = tf.placeholder(tf.int32, (None, None, None),
                                  'kg_inps')  # batch*len
        self.kgs_h_length = tf.placeholder(tf.int32, (None, None),
                                           'kg_h_lens')  # batch
        self.kgs_hr_length = tf.placeholder(tf.int32, (None, None),
                                            'kg_hr_lens')  # batch
        self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None),
                                             'kg_hrt_lens')  # batch
        self.kgs_index = tf.placeholder(tf.float32, (None, None),
                                        'kg_indices')  # batch

        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens')
        self.is_train = tf.placeholder(tf.bool)

        num_past_turns = tf.shape(self.posts)[0] // tf.shape(
            self.origin_responses)[0]

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])
        kg_len = tf.shape(self.kgs)[2]
        kg_h_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_h_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hr_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hrt_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_key_mask = kg_hr_mask
        kg_value_mask = kg_hrt_mask - kg_hr_mask

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs)
        #self.knowledge_max = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, -mask_value), axis=1)
        #self.knowledge_min = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, mask_value), axis=1)
        self.kg_key_avg = tf.reduce_sum(
            self.kg_input * kg_key_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_key_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        self.kg_value_avg = tf.reduce_sum(
            self.kg_input * kg_value_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_value_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))

        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input))  # batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = dynamic_rnn(cell_enc,
                                                        self.encoder_input,
                                                        self.posts_length,
                                                        dtype=tf.float32,
                                                        scope="encoder_rnn")

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            prev_output, _ = dynamic_rnn(cell_enc,
                                         tf.nn.embedding_lookup(
                                             self.embed, self.prev_posts),
                                         self.prev_posts_length,
                                         dtype=tf.float32,
                                         scope="encoder_rnn")

        with tf.variable_scope('context'):
            encoder_state_reshape = tf.reshape(
                encoder_state, [-1, num_past_turns, args.eh_size])
            _, self.context_state = dynamic_rnn(cell_ctx,
                                                encoder_state_reshape,
                                                self.context_length,
                                                dtype=tf.float32,
                                                scope='context_rnn')

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct attention
        '''
		encoder_len = tf.shape(encoder_output)[1]
		attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size])
		attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1])
		attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:,1:]], axis=1)
		attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask)
		'''
        attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            args.dh_size,
            prev_output,
            memory_sequence_length=tf.maximum(self.prev_posts_length, 1))
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        ctx_state_shaping = tf.layers.dense(self.context_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping)

        # calculate kg embedding
        with tf.variable_scope('knowledge'):
            query = tf.reshape(
                tf.layers.dense(tf.concat(self.context_state, axis=-1),
                                args.embedding_size,
                                use_bias=False),
                [batch_size, 1, args.embedding_size])
        kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2)
        kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score,
                            -tf.ones_like(kg_score) * np.inf)
        kg_alignment = tf.nn.softmax(kg_score)
        kg_max = tf.argmax(kg_alignment, axis=-1)
        kg_max_onehot = tf.one_hot(kg_max,
                                   tf.shape(kg_alignment)[1],
                                   dtype=tf.float32)
        self.kg_acc = tf.reduce_sum(
            kg_max_onehot * self.kgs_index) / tf.maximum(
                tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)),
                tf.constant(1.0))
        self.kg_loss = tf.reduce_sum(
            -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) *
            self.kgs_index,
            axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1),
                                 tf.ones([batch_size], dtype=tf.float32))
        self.kg_loss = tf.reduce_mean(self.kg_loss)

        self.knowledge_embed = tf.reduce_sum(
            tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg *
            tf.cast(kg_num_mask, tf.float32),
            axis=1)
        #self.knowledge_embed = tf.Print(self.knowledge_embed, ['acc', self.kg_acc, 'loss', self.kg_loss])
        knowledge_embed_extend = tf.tile(
            tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1])
        self.decoder_input = tf.concat(
            [self.decoder_input, knowledge_embed_extend], axis=2)
        # construct helper
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, tf.maximum(self.responses_length, 1))
        infer_helper = MyInferenceHelper(self.embed,
                                         tf.fill([batch_size], data.go_id),
                                         data.eos_id, self.knowledge_embed)
        #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sent_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        self.loss = self.decoder_loss + self.kg_loss
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key, name):
        if key == "latest":
            self.latest_saver.save(sess,
                                   path,
                                   global_step=self.global_step,
                                   latest_filename=name)
        else:
            self.best_saver.save(sess,
                                 path,
                                 global_step=self.global_step,
                                 latest_filename=name)

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      embedding=emblist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       embedding=emblist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, sess, data, forward_only=False, inference=False):
        input_feed = {
            #self.init_states: data['init_states'],
            self.posts: data['posts'],
            self.posts_length: data['posts_length'],
            self.prev_posts: data['prev_posts'],
            self.prev_posts_length: data['prev_posts_length'],
            self.origin_responses: data['responses'],
            self.origin_responses_length: data['responses_length'],
            self.context_length: data['context_length'],
            self.kgs: data['kg'],
            self.kgs_h_length: data['kg_h_length'],
            self.kgs_hr_length: data['kg_hr_length'],
            self.kgs_hrt_length: data['kg_hrt_length'],
            self.kgs_index: data['kg_index'],
        }

        if inference:
            input_feed.update({self.is_train: False})
            output_feed = [
                self.generation_index, self.decoder_distribution_teacher,
                self.decoder_all_loss, self.kg_loss, self.kg_acc
            ]
        else:
            input_feed.update({self.is_train: True})
            if forward_only:
                output_feed = [
                    self.decoder_loss, self.decoder_distribution_teacher,
                    self.kg_loss, self.kg_acc
                ]
            else:
                output_feed = [
                    self.decoder_loss, self.gradient_norm, self.update,
                    self.kg_loss, self.kg_acc
                ]

        return sess.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss = np.zeros((3, ))
        total_length = np.zeros((3, ))
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            decoder_loss, _, kg_loss, kg_acc = self.step_decoder(
                sess, batched_data, forward_only=True)
            length = np.sum(
                np.maximum(np.array(batched_data['responses_length']) - 1, 0))
            kg_length = np.sum(np.max(batched_data['kg_index'], axis=-1))
            total_length += [length, kg_length, kg_length]
            loss += [
                decoder_loss * length, kg_loss * kg_length, kg_acc * kg_length
            ]
            batched_data = data.get_next_batch(key_name)
        loss /= total_length
        print(
            '	perplexity on %s set: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f'
            % (key_name, np.exp(loss[0]), np.exp(loss[1]), loss[1], loss[2]))
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((3, )), .0, 0
        previous_losses = [1e18] * 3
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")

        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    print(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           np.exp(loss_step[0]), np.exp(
                               loss_step[1]), loss_step[1], loss_step[2]))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step[0],
                            'perplexity': np.exp(loss_step[0])
                        })
                    self.store_checkpoint(
                        sess, '%s/checkpoint_latest/%s' %
                        (args.model_dir, args.name), "latest", args.name)

                    dev_loss = self.evaluate(sess, data, args.batch_size,
                                             "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss[0],
                            'perplexity': np.exp(dev_loss[0])
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if dev_loss[0] < best_valid:
                        best_valid = dev_loss[0]
                        self.store_checkpoint(
                            sess, '%s/checkpoint_best/%s' %
                            (args.model_dir, args.name), "best", args.name)

                    previous_losses = previous_losses[1:] + [
                        np.sum(loss_step[0])
                    ]
                    loss_step, time_step = np.zeros((3, )), .0

                start_time = time.time()
                step_out = self.step_decoder(sess, batched_data)
                loss_step += np.array([step_out[0], step_out[3], step_out[4]
                                       ]) / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process_hits(self, sess, data, args):

        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'r',
                  encoding='utf8') as f:
            test_distractors = json.load(f)

        data.restart("test", batch_size=1, shuffle=False)
        batched_data = data.get_next_batch("test")

        loss_record = []
        cnt = 0
        while batched_data != None:

            for key in batched_data:
                if isinstance(batched_data[key], np.ndarray):
                    batched_data[key] = batched_data[key].tolist()

            batched_data['responses_length'] = [
                len(batched_data['responses'][0])
            ]
            for each_resp in test_distractors[cnt]:
                batched_data['responses'].append(
                    [data.go_id] +
                    data.convert_tokens_to_ids(jieba.lcut(each_resp)) +
                    [data.eos_id])
                batched_data['responses_length'].append(
                    len(batched_data['responses'][-1]))
            max_length = max(batched_data['responses_length'])
            resp = np.zeros((len(batched_data['responses']), max_length),
                            dtype=int)
            for i, each_resp in enumerate(batched_data['responses']):
                resp[i, :len(each_resp)] = each_resp
            batched_data['responses'] = resp

            posts = []
            posts_length = []
            prev_posts = []
            prev_posts_length = []
            context_length = []

            kg = []
            kg_h_length = []
            kg_hr_length = []
            kg_hrt_length = []
            kg_index = []

            for _ in range(len(resp)):
                posts += batched_data['posts']
                posts_length += batched_data['posts_length']
                prev_posts += batched_data['prev_posts']
                prev_posts_length += batched_data['prev_posts_length']
                context_length += batched_data['context_length']

                kg += batched_data['kg']
                kg_h_length += batched_data['kg_h_length']
                kg_hr_length += batched_data['kg_hr_length']
                kg_hrt_length += batched_data['kg_hrt_length']
                kg_index += batched_data['kg_index']

            batched_data['posts'] = posts
            batched_data['posts_length'] = posts_length
            batched_data['prev_posts'] = prev_posts
            batched_data['prev_posts_length'] = prev_posts_length
            batched_data['context_length'] = context_length

            batched_data['kg'] = kg
            batched_data['kg_h_length'] = kg_h_length
            batched_data['kg_hr_length'] = kg_hr_length
            batched_data['kg_hrt_length'] = kg_hrt_length
            batched_data['kg_index'] = kg_index

            _, _, loss, _, _ = self.step_decoder(sess,
                                                 batched_data,
                                                 inference=True)
            loss_record.append(loss)
            cnt += 1

            batched_data = data.get_next_batch("test")

        assert cnt == len(test_distractors)

        loss = np.array(loss_record)
        loss_rank = np.argsort(loss, axis=1)
        hits1 = float(np.mean(loss_rank[:, 0] == 0))
        hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0))
        return {'hits@1': hits1, 'hits@3': hits3}

    def test_process(self, sess, data, args):

        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")

        while batched_data != None:
            batched_responses_id, gen_log_prob, _, _, _ = self.step_decoder(
                sess, batched_data, False, True)
            metric1_data = {
                'resp_allvocabs':
                np.array(batched_data['responses_allvocabs']),
                'resp_length': np.array(batched_data['responses_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)

            metric2_data = {
                'gen': np.array(batch_results),
                'resp_allvocabs': np.array(batched_data['responses_allvocabs'])
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())
        res.update(self.test_process_hits(sess, data, args))

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            res_print = list(res.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')
            for i in range(len(res['resp'])):
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n\n" % " ".join(res['gen'][i]))

        print("result output to %s" % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }
Пример #6
0
class CVAEModel(object):
    def __init__(self, data, args, embed):
        with tf.name_scope("placeholders"):
            self.contexts = tf.placeholder(tf.int32, (None, None, None),
                                           'cxt_inps')  # [batch, utt_len, len]
            self.contexts_length = tf.placeholder(tf.int32, (None, ),
                                                  'cxt_lens')  # [batch]
            self.posts_length = tf.placeholder(tf.int32, (None, None),
                                               'enc_lens')  # [batch, utt_len]
            self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                                   'dec_inps')  # [batch, len]
            self.origin_responses_length = tf.placeholder(
                tf.int32, (None, ), 'dec_lens')  # [batch]

        # deal with original data to adapt encoder and decoder
        max_sen_length = tf.shape(self.contexts)[2]
        max_cxt_size = tf.shape(self.contexts)[1]
        self.posts_input = tf.reshape(
            self.contexts, [-1, max_sen_length])  # [batch * cxt_len, utt_len]
        self.flat_posts_length = tf.reshape(self.posts_length,
                                            [-1])  # [batch * cxt_len]

        decoder_len = tf.shape(self.origin_responses)[1]
        self.responses_target = tf.split(self.origin_responses,
                                         [1, decoder_len - 1],
                                         1)[1]  # no go_id
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_length = self.origin_responses_length - 1
        decoder_len = decoder_len - 1
        self.decoder_mask = tf.sequence_mask(self.responses_length,
                                             decoder_len,
                                             dtype=tf.float32)
        loss_mask = tf.cast(tf.greater(self.responses_length, 0),
                            dtype=tf.float32)
        batch_size = tf.reduce_sum(loss_mask)

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        with tf.name_scope("embedding"):
            # build the embedding table and embedding input
            if embed is None:
                # initialize the embedding randomly
                self.word_embed = tf.get_variable(
                    'word_embed', [data.vocab_size, args.word_embedding_size],
                    tf.float32)
            else:
                # initialize the embedding by pre-trained word vectors
                self.word_embed = tf.get_variable('word_embed',
                                                  dtype=tf.float32,
                                                  initializer=embed)
            posts_enc_input = tf.nn.embedding_lookup(self.word_embed,
                                                     self.posts_input)
            responses_enc_input = tf.nn.embedding_lookup(
                self.word_embed, self.origin_responses)
            responses_dec_input = tf.nn.embedding_lookup(
                self.word_embed, self.responses_input)

        with tf.name_scope("cell"):
            # build rnn_cell
            cell_enc_fw = tf.nn.rnn_cell.GRUCell(args.eh_size)
            cell_enc_bw = tf.nn.rnn_cell.GRUCell(args.eh_size)
            cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
            cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            posts_enc_output, posts_enc_state = tf.nn.bidirectional_dynamic_rnn(
                cell_enc_fw,
                cell_enc_bw,
                posts_enc_input,
                self.flat_posts_length,
                dtype=tf.float32,
                scope="encoder_bi_rnn")
            posts_enc_state = tf.reshape(tf.concat(posts_enc_state, 1),
                                         [-1, max_cxt_size, 2 * args.eh_size])

        with tf.variable_scope('context'):
            _, context_state = tf.nn.dynamic_rnn(cell_ctx,
                                                 posts_enc_state,
                                                 self.contexts_length,
                                                 dtype=tf.float32,
                                                 scope='context_rnn')
            cond_info = context_state

        with tf.variable_scope("recognition_network"):
            _, responses_enc_state = tf.nn.bidirectional_dynamic_rnn(
                cell_enc_fw,
                cell_enc_bw,
                responses_enc_input,
                self.responses_length,
                dtype=tf.float32,
                scope='encoder_bid_rnn')
            responses_enc_state = tf.concat(responses_enc_state, 1)
            recog_input = tf.concat((cond_info, responses_enc_state), 1)
            recog_output = tf.layers.dense(recog_input, 2 * args.latent_size)
            recog_mu, recog_logvar = tf.split(recog_output, 2, 1)
            recog_z = self.sample_gaussian(
                (tf.size(self.contexts_length), args.latent_size), recog_mu,
                recog_logvar)

        with tf.variable_scope("prior_network"):
            prior_input = cond_info
            prior_fc_1 = tf.layers.dense(prior_input,
                                         2 * args.latent_size,
                                         activation=tf.tanh)
            prior_output = tf.layers.dense(prior_fc_1, 2 * args.latent_size)
            prior_mu, prior_logvar = tf.split(prior_output, 2, 1)
            prior_z = self.sample_gaussian(
                (tf.size(self.contexts_length), args.latent_size), prior_mu,
                prior_logvar)

        with tf.name_scope("decode"):
            # get output projection function
            dec_init_fn = MyDense(args.dh_size, use_bias=True)
            output_fn = MyDense(data.vocab_size, use_bias=True)

            with tf.name_scope("training"):
                decoder_input = responses_dec_input
                gen_input = tf.concat((cond_info, recog_z), 1)
                dec_init_fn_input = gen_input
                train_helper = tf.contrib.seq2seq.TrainingHelper(
                    decoder_input, tf.maximum(self.responses_length, 0))
                dec_init_state = dec_init_fn(dec_init_fn_input)
                decoder_train = tf.contrib.seq2seq.BasicDecoder(
                    cell_dec, train_helper, dec_init_state, output_fn)
                train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_train, impute_finished=True, scope="decoder_rnn")
                responses_dec_output = train_outputs.rnn_output
                self.decoder_distribution_teacher = tf.nn.log_softmax(
                    responses_dec_output, 2)

                with tf.name_scope("losses"):
                    crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=responses_dec_output,
                        labels=self.responses_target)
                    self.reconstruct_loss = tf.reduce_sum(
                        crossent * self.decoder_mask) / batch_size

                    self.KL_loss = tf.reduce_sum(\
                     loss_mask * self.KL_divergence(prior_mu, prior_logvar, recog_mu, recog_logvar)) / batch_size
                    self.KL_weight = tf.minimum(
                        1.0,
                        tf.to_float(self.global_step) / args.full_kl_step)
                    self.anneal_KL_loss = self.KL_weight * self.KL_loss

                    bow_logits = tf.layers.dense(tf.layers.dense(dec_init_fn_input, 400, activation=tf.tanh),\
                            data.vocab_size)
                    tile_bow_logits = tf.tile(tf.expand_dims(bow_logits, 1),
                                              [1, decoder_len, 1])
                    bow_loss = self.decoder_mask * tf.nn.sparse_softmax_cross_entropy_with_logits(\
                     logits=tile_bow_logits,\
                     labels=self.responses_target)
                    self.bow_loss = tf.reduce_sum(bow_loss) / batch_size

                    self.neg_elbo = self.reconstruct_loss + self.KL_loss
                    self.train_loss = self.reconstruct_loss + self.anneal_KL_loss + self.bow_loss

            with tf.name_scope("inference"):
                gen_input = tf.concat((cond_info, prior_z), 1)
                dec_init_fn_input = gen_input
                dec_init_state = dec_init_fn(dec_init_fn_input)

                infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                    self.word_embed,
                    tf.fill([tf.size(self.contexts_length)], data.go_id),
                    data.eos_id)
                decoder_infer = MyBasicDecoder(cell_dec,
                                               infer_helper,
                                               dec_init_state,
                                               output_layer=output_fn,
                                               _aug_context_vector=None)
                infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder_infer,
                    impute_finished=True,
                    maximum_iterations=args.max_sen_length,
                    scope="decoder_rnn")
                self.decoder_distribution = infer_outputs.rnn_output
                self.generation_index = tf.argmax(
                    tf.split(self.decoder_distribution,
                             [2, data.vocab_size - 2], 2)[1],
                    2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.train_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def sample_gaussian(self, shape, mu, logvar):
        normal = tf.random_normal(shape=shape, dtype=tf.float32)
        z = tf.exp(logvar / 2) * normal + mu
        return z

    def KL_divergence(self, prior_mu, prior_logvar, recog_mu, recog_logvar):
        KL_divergence = 0.5 * (
            tf.exp(recog_logvar - prior_logvar) +
            tf.pow(recog_mu - prior_mu, 2) / tf.exp(prior_logvar) - 1 -
            (recog_logvar - prior_logvar))
        return tf.reduce_sum(KL_divergence, axis=1)

    def store_checkpoint(self, sess, path, key):
        if key == "latest":
            self.latest_saver.save(sess, path, global_step=self.global_step)
        else:
            self.best_saver.save(sess, path, global_step=self.global_step)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(scalar=[
            "neg_elbo", "recontruction_loss", "KL_weight", "KL_divergence",
            "bow_loss", "perplexity"
        ],
                                                        prefix="train")

        scalarlist = [
            "neg_elbo", "reconstruction_loss", "KL_weight", "KL_divergence",
            "bow_loss", "perplexity"
        ]
        tensorlist = []
        textlist = []
        emblist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      embedding=emblist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       embedding=emblist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def _pad_batch(self, raw_batch):
        '''Padding posts_length and trim session, invoked by ^SwitchboardCorpus.split_session^
		and ^SwitchboardCorpus.multi_reference_batches^
		'''
        batch = {'posts_length': [], \
           'contexts_length': [], \
           'responses_length': raw_batch['responses_length']}
        max_cxt_size = np.shape(raw_batch['contexts'])[1]
        max_post_len = 0

        for i, speaker in enumerate(raw_batch['posts_length']):
            batch['contexts_length'].append(len(raw_batch['posts_length'][i]))

            if raw_batch['posts_length'][i]:
                max_post_len = max(max_post_len,
                                   max(raw_batch['posts_length'][i]))
            batch['posts_length'].append(raw_batch['posts_length'][i] + \
                    [0] * (max_cxt_size - len(raw_batch['posts_length'][i])))

        batch['contexts'] = raw_batch['contexts'][:, :, :max_post_len]
        batch['responses'] = raw_batch[
            'responses'][:, :np.max(raw_batch['responses_length'])]
        return batch

    def _cut_batch_data(self, batch_data, start, end):
        '''Using session[start: end - 1) as context, session[end - 1] as response,
			invoked by ^SwitchboardCorpus.split_session^
		'''
        raw_batch = {'posts_length': [], 'responses_length': []}
        for i in range(len(batch_data['turn_length'])):
            raw_batch['posts_length'].append( \
             batch_data['sent_length'][i][start: end - 1])
            turn_len = len(batch_data['sent_length'][i])
            if end - 1 < turn_len:
                raw_batch['responses_length'].append( \
                 batch_data['sent_length'][i][end - 1])
            else:
                raw_batch['responses_length'].append(1)

        raw_batch['contexts'] = batch_data['sent'][:, start:end - 1]
        raw_batch['responses'] = batch_data['sent'][:, end - 1]
        return self._pad_batch(raw_batch)

    def split_session(self, batch_data, session_window, inference=False):
        '''Splits session with different utterances serving as responses

		Arguments:
		    batch_data (dict): must be the same format as the return of self.get_batch
		    inference (bool): True: utterances take turn to serve as responses (without shuffle)
		    				  False: shuffles the order of utterances being responses

		Returns:
		    (list): each element is a dict that contains at least:

				* contexts (:class:`numpy.array`): A 3-d PADDED array containing id of words in contexts.
								Only provide valid words. `unk_id` will be used if a word is not valid.
					Size: `[batch_size, max_turn_length, max_utterance_length]`
				* contexts_length (list): A 1-d list, number of turns
					Size: ^[batch_size]^
				* posts_length (list): A 2-d PADDED list, the length of utterances.
					Size: ^[batch_size, max_turn_length]^
				* responses (:class:`numpy.array`): A 3-d PADDED array containing ids of words in responses.
					Size: ^[batch_size, max_response_length]^
				* responses_length (list): A 1-d list, the length of responses.
					Size: ^[batch_size]^

		'''
        max_turn = np.max(batch_data['turn_length'])
        ends = list(range(2, max_turn + 1))
        if not inference:
            np.random.shuffle(ends)
        for end in ends:
            start = max(end - session_window, 0)
            turn_data = self._cut_batch_data(batch_data, start, end)
            yield turn_data

    def multi_reference_batches(self, data, batch_size):
        '''Get batches of with multiple response candidates

		Arguments:
		     * batch_size (int): batch size

		Returns:
		    (list): each element contains those specified in self.split_session and what follows:

		    	* candidates (list): A 3-d list of response candidates.
		    		Size: [batch_size, _num_candidates, _num_words]

		'''
        data.restart('multi_ref', batch_size, shuffle=False)
        batch_data = data.get_next_batch('multi_ref')
        while batch_data is not None:
            batch = self._cut_batch_data(batch_data,\
                0, np.max(batch_data['turn_length']))
            batch['candidate'] = batch_data['candidate']
            yield batch
            batch_data = data.get_next_batch('multi_ref')

    def step_decoder(self, sess, data, forward_only=False, inference=False):
        input_feed = {
            self.contexts: data['contexts'],
            self.contexts_length: data['contexts_length'],
            self.posts_length: data['posts_length'],
            self.origin_responses: data['responses'],
            self.origin_responses_length: data['responses_length'],
        }
        if inference:
            output_feed = [self.generation_index]
        else:
            if forward_only:
                output_feed = [
                    self.neg_elbo, self.reconstruct_loss, self.KL_weight,
                    self.KL_loss, self.bow_loss,
                    self.decoder_distribution_teacher
                ]
            else:
                output_feed = [
                    self.neg_elbo, self.reconstruct_loss, self.KL_weight,
                    self.KL_loss, self.bow_loss, self.update
                ]

        return sess.run(output_feed, input_feed)

    def train_step(self, sess, data, args):
        output = self.step_decoder(sess, data)
        res = output[:5]
        log_ppl = output[1] * np.sum(np.array(data['responses_length'], dtype=np.int32) > 1) / np.sum(\
         np.maximum(np.array(data['responses_length'], dtype=np.int32) - 1, 0))
        res += [log_ppl]
        return res

    def evaluate(self, sess, data, key_name, args):
        loss_list = np.array([.0] * 6, dtype=np.float32)
        total_length = 0
        total_inst = 0
        data.restart(key_name, batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            for cut_batch_data in self.split_session(batched_data,
                                                     args.session_window):
                output = self.step_decoder(sess,
                                           cut_batch_data,
                                           forward_only=True)
                batch_size = np.sum(
                    np.array(cut_batch_data['responses_length'],
                             dtype=np.int32) > 1)
                loss_list[:-1] += np.array(output[:5],
                                           dtype=np.float32) * batch_size
                loss_list[-1] += output[1] * batch_size
                total_length += np.sum(
                    np.maximum(
                        np.array(cut_batch_data['responses_length'],
                                 dtype=np.int32) - 1, 0))
                total_inst += batch_size
            batched_data = data.get_next_batch(key_name)
        loss_list[:-1] /= total_inst
        loss_list[-1] = np.exp(loss_list[-1] / total_length)

        print('	perplexity on %s set: %.2f' % (key_name, loss_list[-1]))
        return loss_list

    def train_process(self, sess, data, args):
        time_step, epoch_step = .0, 0
        loss_names = [
            'neg_elbo', 'reconstuction_loss', 'KL_weight', 'KL_divergence',
            'bow_loss', 'perplexity'
        ]
        loss_list = np.array([.0] * len(loss_names), dtype=np.float32)
        previous_losses = [1e18] * 5
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                for cut_batch_data in self.split_session(
                        batched_data, args.session_window):
                    start_time = time.time()
                    output = self.train_step(sess, cut_batch_data, args)
                    loss_list += np.array(output[:len(loss_list)],
                                          dtype=np.float32)

                    time_step += time.time() - start_time

                    if (self.global_step.eval() +
                            1) % args.checkpoint_steps == 0:
                        loss_list /= args.checkpoint_steps
                        loss_list[-1] = np.exp(loss_list[-1])
                        time_step /= args.checkpoint_steps

                        print(
                            "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                            % (epoch_step, self.global_step.eval(),
                               self.learning_rate.eval(), time_step,
                               loss_list[-1]))
                        self.trainSummary(self.global_step.eval() // args.checkpoint_steps, \
                              dict(zip(loss_names, loss_list)))
                        self.store_checkpoint(
                            sess,
                            '%s/checkpoint_latest/checkpoint' % args.model_dir,
                            "latest")

                        dev_loss = self.evaluate(sess, data, "dev", args)
                        self.devSummary(self.global_step.eval() // args.checkpoint_steps,\
                            dict(zip(loss_names, dev_loss)))

                        test_loss = self.evaluate(sess, data, "test", args)
                        self.testSummary(self.global_step.eval() // args.checkpoint_steps,\
                             dict(zip(loss_names, test_loss)))

                        if loss_list[0] > max(previous_losses):
                            sess.run(self.learning_rate_decay_op)
                        if dev_loss[0] < best_valid:
                            best_valid = dev_loss[0]
                            self.store_checkpoint(
                                sess, '%s/checkpoint_best/checkpoint' %
                                args.model_dir, "best")

                        previous_losses = previous_losses[1:] + [loss_list[0]]
                        loss_list *= .0
                        time_step = .0

                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process(self, sess, data, args):
        def get_batch_results(batched_responses_id, data):
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                if data.eos_id in response_id_list:
                    end = response_id_list.index(data.eos_id) + 1
                    result_id = response_id_list[:end]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)
            return batch_results

        def padding(matrix, pad_go_id=False):
            l = max([len(d) for d in matrix])
            if not pad_go_id:
                res = [[d + [data.pad_id] * (l - len(d)) for d in matrix]]
            else:
                res = [[[data.go_id] + d + [data.pad_id] * (l - len(d))
                        for d in matrix]]
            return res

        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")
        cnt = 0
        start_time = time.time()
        while batched_data != None:
            conv_data = [{
                'contexts': [],
                'responses': [],
                'generations': []
            } for _ in range(len(batched_data['turn_length']))]
            for cut_batch_data in self.split_session(batched_data,
                                                     args.session_window,
                                                     inference=True):
                eval_out = self.step_decoder(sess,
                                             cut_batch_data,
                                             forward_only=True)
                decoder_loss, gen_prob = eval_out[:6], eval_out[-1]
                batched_responses_id = self.step_decoder(sess,
                                                         cut_batch_data,
                                                         inference=True)[0]
                batch_results = get_batch_results(batched_responses_id, data)

                cut_batch_data['gen_prob'] = gen_prob
                cut_batch_data['generations'] = batch_results
                responses_length = []
                for length in cut_batch_data['responses_length']:
                    if length == 1:
                        length += 1
                    responses_length.append(length)
                metric1_data = {
                    'sent': np.expand_dims(cut_batch_data['responses'], 1),
                    'sent_length': np.expand_dims(responses_length, 1),
                    'gen_prob': np.expand_dims(cut_batch_data['gen_prob'], 1)
                }
                metric1.forward(metric1_data)
                valid_index = [
                    idx for idx, length in enumerate(
                        cut_batch_data['responses_length']) if length > 1
                ]
                for key in ['contexts', 'responses', 'generations']:
                    for idx, d in enumerate(cut_batch_data[key]):
                        if idx in valid_index:
                            if key == 'contexts':
                                d = d[cut_batch_data['contexts_length'][idx] -
                                      1]
                            conv_data[idx][key].append(list(d))

                if (cnt + 1) % 10 == 0:
                    print('processing %d batch data, time cost %.2f s/batch' %
                          (cnt, (time.time() - start_time) / 10))
                    start_time = time.time()
                cnt += 1

            for conv in conv_data:
                metric2_data = {
                    'context':
                    np.array(padding(conv['contexts'])),
                    'turn_length':
                    np.array([len(conv['contexts'])], dtype=np.int32),
                    'reference':
                    np.array(padding(conv['responses'])),
                    'gen':
                    np.array(padding(conv['generations'], pad_go_id=True))
                }
                metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'a') as f:
            print("Test Result:")
            for key, value in res.items():
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            for i, context in enumerate(res['context']):
                f.write("session: \t%d\n" % i)
                for j in range(len(context)):
                    f.write("\tpost:\t%s\n" % " ".join(context[j]))
                    f.write("\tresp:\t%s\n" % " ".join(res['reference'][i][j]))
                    f.write("\tgen:\t%s\n\n" % " ".join(res['gen'][i][j]))
                f.write("\n")

        print("result output to %s" % test_file)

    def test_multi_ref(self, sess, data, embed, args):
        def process_cands(candidates):
            res = []
            for cands in candidates:
                tmp = []
                for sent in cands:
                    tmp.append([
                        wid if wid < data.vocab_size else data.unk_id
                        for wid in sent
                    ])
                res.append(tmp)
            return res

        prec_rec_metrics = data.get_precision_recall_metric(embed)
        for batch_data in self.multi_reference_batches(data, args.batch_size):
            responses = []
            for _ in range(args.repeat_N):
                batched_responses_id = self.step_decoder(sess,
                                                         batch_data,
                                                         inference=True)[0]
                for rid, resp in enumerate(batched_responses_id):
                    resp = list(resp)
                    if rid == len(responses):
                        responses.append([])
                    if data.eos_id in resp:
                        resp = resp[:resp.index(data.eos_id)]
                    if len(resp) > 0:
                        responses[rid].append(resp)
            metric_data = {
                'resp': process_cands(batch_data['candidate']),
                'gen': responses
            }
            prec_rec_metrics.forward(metric_data)

        res = prec_rec_metrics.close()

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Multi Reference Result:")
            f.write("Test Multi Reference Result:\n")
            for key, val in res.items():
                print("\t{}\t{}".format(key, val))
                f.write("\t{}\t{}".format(key, val) + "\n")
            f.write("\n")
Пример #7
0
class LM(object):
    def __init__(self, data, args, embed):

        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1], 1)[0]
        self.responses_target = tf.split(self.origin_responses,
                                         [1, decoder_len - 1], 1)[1]
        self.responses_length = self.origin_responses_length - 1
        decoder_len = decoder_len - 1

        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell = tf.nn.rnn_cell.GRUCell(args.eh_size)

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # build encoder
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            _, self.encoder_state = dynamic_rnn(cell,
                                                self.encoder_input,
                                                self.posts_length,
                                                dtype=tf.float32,
                                                scope="decoder_rnn")

        # construct helper and attention
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.eos_id), data.eos_id)

        dec_start = tf.cond(
            self.is_train,
            lambda: tf.zeros([batch_size, args.dh_size], dtype=tf.float32),
            lambda: self.encoder_state)

        # build decoder (train)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            self.decoder_output, _ = dynamic_rnn(
                cell,
                self.decoder_input,
                self.responses_length,
                dtype=tf.float32,
                initial_state=self.encoder_state,
                scope='decoder_rnn')
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sent_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key, name):
        if key == "latest":
            self.latest_saver.save(sess,
                                   path,
                                   global_step=self.global_step,
                                   latest_filename=name)
        else:
            self.best_saver.save(sess,
                                 path,
                                 global_step=self.global_step,
                                 latest_filename=name)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      prefix="dev")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, forward_only=False):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.is_train: True
        }
        if forward_only:
            output_feed = [
                self.decoder_loss, self.decoder_distribution_teacher,
                self.decoder_output
            ]
        else:
            output_feed = [self.decoder_loss, self.gradient_norm, self.update]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.is_train: False
        }
        output_feed = [
            self.generation_index, self.decoder_distribution_teacher,
            self.decoder_all_loss
        ]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss = np.zeros((1, ))
        times = 0
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_decoder(sess, batched_data, forward_only=True)
            loss += outputs[0]
            times += 1
            batched_data = data.get_next_batch(key_name)
        loss /= times

        print('    perplexity on %s set: %.2f' % (key_name, np.exp(loss)))
        print(loss)
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        previous_losses = [1e18] * 3
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")

        for i in range(2):
            print(
                data.convert_ids_to_tokens(
                    batched_data['post_allvocabs'][i].tolist(), trim=False))
            print(
                data.convert_ids_to_tokens(
                    batched_data['resp_allvocabs'][i].tolist(), trim=False))

        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    show = lambda a: '[%s]' % (' '.join(
                        ['%.2f' % x for x in a]))
                    print(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           show(np.exp(loss_step))))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(loss_step)
                        })
                    self.store_checkpoint(
                        sess, '%s/checkpoint_latest/%s' %
                        (args.model_dir, args.name), "latest", args.name)

                    dev_loss = self.evaluate(sess, data, args.batch_size,
                                             "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss,
                            'perplexity': np.exp(dev_loss)
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if dev_loss < best_valid:
                        best_valid = dev_loss
                        self.store_checkpoint(
                            sess, '%s/checkpoint_best/%s' %
                            (args.model_dir, args.name), "best", args.name)

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0

                start_time = time.time()
                loss_step += self.step_decoder(
                    sess, batched_data)[0] / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process_hits(self, sess, data, args):

        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'r') as f:
            test_distractors = json.load(f)

        data.restart("test", batch_size=1, shuffle=False)
        batched_data = data.get_next_batch("test")

        loss_record = []
        cnt = 0
        while batched_data != None:
            batched_data['resp_length'] = [len(batched_data['resp'][0])]
            batched_data['resp'] = batched_data['resp'].tolist()
            for each_resp in test_distractors[cnt]:
                batched_data['resp'].append(
                    [data.eos_id] +
                    data.convert_tokens_to_ids(jieba.lcut(each_resp)) +
                    [data.eos_id])
                batched_data['resp_length'].append(
                    len(batched_data['resp'][-1]))
            max_length = max(batched_data['resp_length'])
            resp = np.zeros((len(batched_data['resp']), max_length), dtype=int)
            for i, each_resp in enumerate(batched_data['resp']):
                resp[i, :len(each_resp)] = each_resp
            batched_data['resp'] = resp

            post = []
            post_length = []
            for _ in range(len(resp)):
                post = post + batched_data['post'].tolist()
                post_length = post_length + batched_data['post_length'].tolist(
                )
            batched_data['post'] = post
            batched_data['post_length'] = post_length

            _, _, loss = self.inference(sess, batched_data)
            loss_record.append(loss)
            cnt += 1

            batched_data = data.get_next_batch("test")

        assert cnt == len(test_distractors)

        loss = np.array(loss_record)
        loss_rank = np.argsort(loss, axis=1)
        hits1 = float(np.mean(loss_rank[:, 0] == 0))
        hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0))
        return {'hits@1': hits1, 'hits@3': hits3}

    def test_process(self, sess, data, args):
        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size * 4, shuffle=False)
        batched_data = data.get_next_batch("test")

        for i in range(3):
            print(
                '  post %d ' % i,
                data.convert_ids_to_tokens(
                    batched_data['post_allvocabs'][i].tolist(), trim=False))
            print(
                '  resp %d ' % i,
                data.convert_ids_to_tokens(
                    batched_data['resp_allvocabs'][i].tolist(), trim=False))

        results = []
        while batched_data != None:
            batched_responses_id, gen_log_prob, _ = self.inference(
                sess, batched_data)
            metric1_data = {
                'resp_allvocabs': np.array(batched_data['resp_allvocabs']),
                'resp_length': np.array(batched_data['resp_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                result_token = []
                response_id_list = response_id.tolist()
                response_token = data.convert_ids_to_tokens(response_id_list)
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                for token in response_token:
                    if token != data.ext_vocab[data.eos_id]:
                        result_token.append(token)
                    else:
                        break
                results.append(result_token)
                batch_results.append(result_id)

            metric2_data = {
                'gen': np.array(batch_results),
                'post_allvocabs': np.array(batched_data['post_allvocabs']),
                'resp_allvocabs': np.array(batched_data['resp_allvocabs']),
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())
        res.update(self.test_process_hits(sess, data, args))

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            res_print = list(res.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')
            for i in range(len(res['post'])):
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n\n" % " ".join(res['gen'][i]))

        print("result output to %s." % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }
Пример #8
0
class Classification(BaseModel):
    def __init__(self, param):
        args = param.args
        net = Network(param)
        self.optimizer = optim.Adam(net.get_parameters_by_name(), lr=args.lr)
        optimizerList = {"optimizer": self.optimizer}
        checkpoint_manager = CheckpointManager(args.name, args.model_dir, \
            args.checkpoint_steps, args.checkpoint_max_to_keep, "min")
        super().__init__(param, net, optimizerList, checkpoint_manager)

        self.create_summary()

    def create_summary(self):
        args = self.param.args
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \
          args)
        self.trainSummary = self.summaryHelper.addGroup(\
         scalar=["loss", "accuracy_on_batch"],\
         prefix="train")

        scalarlist = ["loss", "accuracy"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in self.args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="test")

    def _preprocess_batch(self, data):
        incoming = Storage()
        incoming.data = data = Storage(data)
        data.batch_size = data.sent.shape[0]
        data.sent = cuda(torch.LongTensor(data.sent.transpose(
            1, 0)))  # length * batch_size
        data.label = cuda(torch.LongTensor(data.label))
        return incoming

    def get_next_batch(self, dm, key, restart=True):
        data = dm.get_next_batch(key)
        if data is None:
            if restart:
                dm.restart(key)
                return self.get_next_batch(dm, key, False)
            else:
                return None
        return self._preprocess_batch(data)

    def get_batches(self, dm, key):
        batches = list(
            dm.get_batches(key, batch_size=self.args.batch_size,
                           shuffle=False))
        return len(batches), (self._preprocess_batch(data) for data in batches)

    def get_select_batch(self, dm, key, i):
        data = dm.get_batch(key, i)
        if data is None:
            return None
        return self._preprocess_batch(data)

    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            accuracy = np.mean(
                (incoming.result.label == incoming.result.prediction
                 ).float().detach().cpu().numpy())
            detail_arr = storage_to_list(incoming.result)
            detail_arr.update({'accuracy_on_batch': accuracy})
            self.trainSummary(self.now_batch, detail_arr)
            logging.info("batch %d : classification loss=%f, batch accuracy=%f", \
             self.now_batch, loss.detach().cpu().numpy(), accuracy)

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()

    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update({'loss':get_mean(result_arr, 'loss'), \
         'accuracy':get_accuracy(result_arr, label_key='label', prediction_key='prediction')})
        return detail_arr

    def train_process(self):
        args = self.param.args
        dm = self.param.volatile.dm

        while self.now_epoch < args.epochs:
            self.now_epoch += 1
            self.updateOtherWeights()

            dm.restart('train', args.batch_size)
            self.net.train()
            self.train(args.batch_per_epoch)

            self.net.eval()
            devloss_detail = self.evaluate("dev")
            self.devSummary(self.now_batch, devloss_detail)
            logging.info("epoch %d, evaluate dev", self.now_epoch)

            testloss_detail = self.evaluate("test")
            self.testSummary(self.now_batch, testloss_detail)
            logging.info("epoch %d, evaluate test", self.now_epoch)

            self.save_checkpoint(value=devloss_detail.loss.tolist())

    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric = dm.get_accuracy_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval accuracy")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
            data = incoming.data
            data.prediction = imcoming.result.prediction
            data.label = imcoming.data.label
            metric.forward(data)
        res = metric.close()

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, bytes):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            f.flush()
        logging.info("result output to %s.", filename)

    def test_process(self):
        logging.info("Test Start.")
        self.net.eval()
        self.test("dev")
        self.test("test")
        logging.info("Test Finish.")
Пример #9
0
class Seq2seq(BaseModel):
    def __init__(self, param):
        args = param.args
        net = Network(param)
        self.optimizer = optim.Adam(net.get_parameters_by_name(), lr=args.lr)
        optimizerList = {"optimizer": self.optimizer}
        checkpoint_manager = CheckpointManager(args.name, args.model_dir, \
            args.checkpoint_steps, args.checkpoint_max_to_keep, "min")
        super().__init__(param, net, optimizerList, checkpoint_manager)

        self.create_summary()

    def create_summary(self):
        args = self.param.args
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \
          args)
        self.trainSummary = self.summaryHelper.addGroup(\
         scalar=["loss", "word_loss", "perplexity"],\
         prefix="train")

        scalarlist = ["word_loss", "perplexity_avg_on_batch"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in self.args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(\
         scalar=scalarlist,\
         tensor=tensorlist,\
         text=textlist,\
         embedding=emblist,\
         prefix="test")

    def _preprocess_batch(self, data):
        incoming = Storage()
        incoming.data = data = Storage(data)
        data.batch_size = data.post.shape[0]
        data.post = cuda(torch.LongTensor(data.post.transpose(
            1, 0)))  # length * batch_size
        data.resp = cuda(torch.LongTensor(data.resp.transpose(
            1, 0)))  # length * batch_size
        data.post_bert = cuda(torch.LongTensor(data.post_bert.transpose(
            1, 0)))  # length * batch_size
        data.resp_bert = cuda(torch.LongTensor(data.resp_bert.transpose(
            1, 0)))  # length * batch_size
        return incoming

    def get_next_batch(self, dm, key, restart=True):
        data = dm.get_next_batch(key)
        if data is None:
            if restart:
                dm.restart(key)
                return self.get_next_batch(dm, key, False)
            else:
                return None
        return self._preprocess_batch(data)

    def get_batches(self, dm, key):
        batches = list(
            dm.get_batches(key, batch_size=self.args.batch_size,
                           shuffle=False))
        return len(batches), (self._preprocess_batch(data) for data in batches)

    def get_select_batch(self, dm, key, i):
        data = dm.get_batch(key, i)
        if data is None:
            return None
        return self._preprocess_batch(data)

    def train(self, batch_num):
        args = self.param.args
        dm = self.param.volatile.dm
        datakey = 'train'

        for i in range(batch_num):
            self.now_batch += 1
            incoming = self.get_next_batch(dm, datakey)
            incoming.args = Storage()

            if (i + 1) % args.batch_num_per_gradient == 0:
                self.zero_grad()
            self.net.forward(incoming)

            loss = incoming.result.loss
            self.trainSummary(self.now_batch, storage_to_list(incoming.result))
            logging.info("batch %d : gen loss=%f", self.now_batch,
                         loss.detach().cpu().numpy())

            loss.backward()

            if (i + 1) % args.batch_num_per_gradient == 0:
                nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip)
                self.optimizer.step()

    def evaluate(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        dm.restart(key, args.batch_size, shuffle=False)

        result_arr = []
        while True:
            incoming = self.get_next_batch(dm, key, restart=False)
            if incoming is None:
                break
            incoming.args = Storage()

            with torch.no_grad():
                self.net.forward(incoming)
            result_arr.append(incoming.result)

        detail_arr = Storage()
        for i in args.show_sample:
            index = [i * args.batch_size + j for j in range(args.batch_size)]
            incoming = self.get_select_batch(dm, key, index)
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            detail_arr["show_str%d" % i] = incoming.result.show_str

        detail_arr.update(
            {key: get_mean(result_arr, key)
             for key in result_arr[0]})
        detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss)
        return detail_arr

    def train_process(self):
        args = self.param.args
        dm = self.param.volatile.dm

        while self.now_epoch < args.epochs:
            self.now_epoch += 1
            self.updateOtherWeights()

            dm.restart('train', args.batch_size)
            self.net.train()
            self.train(args.batch_per_epoch)

            self.net.eval()
            devloss_detail = self.evaluate("dev")
            self.devSummary(self.now_batch, devloss_detail)
            logging.info("epoch %d, evaluate dev", self.now_epoch)

            testloss_detail = self.evaluate("test")
            self.testSummary(self.now_batch, testloss_detail)
            logging.info("epoch %d, evaluate test", self.now_epoch)

            self.save_checkpoint(value=devloss_detail.loss.tolist())

    def test(self, key):
        args = self.param.args
        dm = self.param.volatile.dm

        metric1 = dm.get_teacher_forcing_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval teacher-forcing")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.forward(incoming)
                gen_log_prob = nn.functional.log_softmax(incoming.gen.w, -1)
            data = incoming.data
            data.resp = incoming.data.resp_allvocabs
            data.resp_length = incoming.data.resp_length
            data.gen_log_prob = gen_log_prob.transpose(
                1, 0).detach().cpu().numpy()
            metric1.forward(data)
        res = metric1.close()

        metric2 = dm.get_inference_metric()
        batch_num, batches = self.get_batches(dm, key)
        logging.info("eval free-run")
        for incoming in tqdm.tqdm(batches, total=batch_num):
            incoming.args = Storage()
            with torch.no_grad():
                self.net.detail_forward(incoming)
            data = incoming.data
            data.resp = incoming.data.resp_allvocabs
            data.post = incoming.data.post_allvocabs
            data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0)
            metric2.forward(data)
        res.update(metric2.close())

        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        filename = args.out_dir + "/%s_%s.txt" % (args.name, key)

        with open(filename, 'w') as f:
            logging.info("%s Test Result:", key)
            for key, value in res.items():
                if isinstance(value, float) or isinstance(value, bytes):
                    logging.info("\t{}:\t{}".format(key, value))
                    f.write("{}:\t{}\n".format(key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))
            f.flush()
        logging.info("result output to %s.", filename)

    def test_process(self):
        logging.info("Test Start.")
        self.net.eval()
        self.test("dev")
        self.test("test")
        logging.info("Test Finish.")
Пример #10
0
class HredModel(object):
    def __init__(self, data, args, embed):
        self.init_states = tf.placeholder(tf.float32, (None, args.ch_size),
                                          'ctx_inps')  # batch*ch_size
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = dynamic_rnn(cell_enc,
                                                        self.encoder_input,
                                                        self.posts_length,
                                                        dtype=tf.float32,
                                                        scope="encoder_rnn")

        with tf.variable_scope('context'):
            _, self.context_state = cell_ctx(encoder_state, self.init_states)

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, tf.maximum(self.responses_length, 1))
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)
        attn_mechanism = tf.contrib.seq2seq.LuongAttention(
            args.dh_size,
            encoder_output,
            memory_sequence_length=tf.maximum(self.posts_length, 1))
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        ctx_state_shaping = tf.layers.dense(self.context_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            self.decoder_distribution_teacher, self.decoder_loss = sampled_sequence_loss(
                self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sen_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key):
        if key == "latest":
            self.latest_saver.save(sess, path, global_step=self.global_step)
        else:
            self.best_saver.save(sess, path, global_step=self.global_step)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      embedding=emblist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       embedding=emblist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, sess, data, forward_only=False, inference=False):
        input_feed = {
            self.init_states: data['init_states'],
            self.posts: data['posts'],
            self.posts_length: data['posts_length'],
            self.origin_responses: data['responses'],
            self.origin_responses_length: data['responses_length'],
        }

        if inference:
            output_feed = [self.generation_index, self.context_state]
        else:
            if forward_only:
                output_feed = [
                    self.decoder_loss, self.decoder_distribution_teacher,
                    self.context_state
                ]
            else:
                output_feed = [
                    self.decoder_loss, self.gradient_norm, self.update,
                    self.context_state
                ]

        return sess.run(output_feed, input_feed)

    def get_step_data(self, step_data, batched_data, turn):
        current_batch_size = batched_data['sent'].shape[0]
        max_turn_length = batched_data['sent'].shape[1]
        max_sent_length = batched_data['sent'].shape[2]
        if turn == -1:
            step_data['posts'] = np.zeros((current_batch_size, 1), dtype=int)
        else:
            step_data['posts'] = batched_data['sent'][:, turn, :]
        step_data['responses'] = batched_data['sent'][:, turn + 1, :]
        step_data['posts_length'] = np.zeros((current_batch_size, ), dtype=int)
        step_data['responses_length'] = np.zeros((current_batch_size, ),
                                                 dtype=int)
        for i in range(current_batch_size):
            if turn < len(batched_data['sent_length'][i]):
                if turn == -1:
                    step_data['posts_length'][i] = 1
                else:
                    step_data['posts_length'][i] = batched_data['sent_length'][
                        i][turn]
            if turn + 1 < len(batched_data['sent_length'][i]):
                step_data['responses_length'][i] = batched_data['sent_length'][
                    i][turn + 1]
        max_posts_length = np.max(step_data['posts_length'])
        max_responses_length = np.max(step_data['responses_length'])
        step_data['posts'] = step_data['posts'][:, 0:max_posts_length]
        step_data['responses'] = step_data['responses'][:,
                                                        0:max_responses_length]

    def train_step(self, sess, data, args):
        current_batch_size = data['sent'].shape[0]
        max_turn_length = data['sent'].shape[1]
        max_sent_length = data['sent'].shape[2]
        loss = np.zeros((1, ))
        total_length = np.zeros((1, ))
        step_data = {}
        context_states = np.zeros((current_batch_size, args.ch_size))

        for turn in range(max_turn_length - 1):
            self.get_step_data(step_data, data, turn)
            step_data['init_states'] = context_states
            decoder_loss, _, _, context_states = self.step_decoder(
                sess, step_data)
            length = np.sum(np.maximum(step_data['responses_length'] - 1, 0))
            total_length += length
            loss += decoder_loss * length
        return loss / total_length

    def evaluate(self, sess, data, key_name, args):
        loss = np.zeros((1, ))
        total_length = np.zeros((1, ))
        data.restart(key_name, batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            current_batch_size = batched_data['sent'].shape[0]
            max_turn_length = batched_data['sent'].shape[1]
            max_sent_length = batched_data['sent'].shape[2]
            step_data = {}
            context_states = np.zeros((current_batch_size, args.ch_size))
            for turn in range(max_turn_length - 1):
                self.get_step_data(step_data, batched_data, turn)
                step_data['init_states'] = context_states
                decoder_loss, _, context_states = self.step_decoder(
                    sess, step_data, forward_only=True)
                length = np.sum(
                    np.maximum(step_data['responses_length'] - 1, 0))
                total_length += length
                loss += decoder_loss * length
            batched_data = data.get_next_batch(key_name)
        loss /= total_length

        print('	perplexity on %s set: %.2f' % (key_name, np.exp(loss)))
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        previous_losses = [1e18] * 5
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                for step in range(args.checkpoint_steps):
                    if batched_data == None:
                        break
                    start_time = time.time()
                    loss_step += self.train_step(sess, batched_data, args)
                    time_step += time.time() - start_time
                    batched_data = data.get_next_batch("train")

                loss_step /= args.checkpoint_steps
                time_step /= args.checkpoint_steps
                show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a]))
                print(
                    "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                    % (epoch_step, self.global_step.eval(),
                       self.learning_rate.eval(), time_step,
                       show(np.exp(loss_step))))
                self.trainSummary(
                    self.global_step.eval() // args.checkpoint_steps, {
                        'loss': loss_step,
                        'perplexity': np.exp(loss_step)
                    })
                #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\
                self.store_checkpoint(
                    sess, '%s/checkpoint_latest/checkpoint' % args.model_dir,
                    "latest")

                dev_loss = self.evaluate(sess, data, "dev", args)
                self.devSummary(
                    self.global_step.eval() // args.checkpoint_steps, {
                        'loss': dev_loss,
                        'perplexity': np.exp(dev_loss)
                    })

                test_loss = self.evaluate(sess, data, "test", args)
                self.testSummary(
                    self.global_step.eval() // args.checkpoint_steps, {
                        'loss': test_loss,
                        'perplexity': np.exp(test_loss)
                    })

                if np.sum(loss_step) > max(previous_losses):
                    sess.run(self.learning_rate_decay_op)
                if dev_loss < best_valid:
                    best_valid = dev_loss
                    self.store_checkpoint(
                        sess, '%s/checkpoint_best/checkpoint' % args.model_dir,
                        "best")

                previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                loss_step, time_step = np.zeros((1, )), .0

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process(self, sess, data, args):
        def get_batch_results(batched_responses_id, data):
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                response_token = data.index_to_sen(response_id_list)
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)
            return batch_results

        def padding(matrix):
            l = max([len(d) for d in matrix])
            res = [d + [data.pad_id] * (l - len(d)) for d in matrix]
            return res

        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")
        cnt = 0
        start_time = time.time()
        while batched_data != None:
            current_batch_size = batched_data['sent'].shape[0]
            max_turn_length = batched_data['sent'].shape[1]
            max_sent_length = batched_data['sent'].shape[2]
            if cnt > 0 and cnt % 10 == 0:
                print('processing %d batch data, time cost %.2f s/batch' %
                      (cnt, (time.time() - start_time) / 10))
                start_time = time.time()
            cnt += 1
            step_data = {}
            context_states = np.zeros((current_batch_size, args.ch_size))
            batched_gen_prob = []
            batched_gen = []
            for turn in range(max_turn_length):
                self.get_step_data(step_data, batched_data, turn - 1)
                step_data['init_states'] = context_states
                decoder_loss, gen_prob, context_states = self.step_decoder(
                    sess, step_data, forward_only=True)
                batched_responses_id, context_states = self.step_decoder(
                    sess, step_data, inference=True)
                batch_results = get_batch_results(batched_responses_id, data)
                step_data['gen_prob'] = gen_prob
                batched_gen_prob.append(step_data['gen_prob'])
                step_data['generations'] = batch_results
                batched_gen.append(step_data['generations'])

            def transpose(batched_gen_prob):
                batched_gen_prob_temp = [[0 for i in range(max_turn_length)]
                                         for j in range(current_batch_size)]
                for i in range(max_turn_length):
                    for j in range(current_batch_size):
                        batched_gen_prob_temp[j][i] = batched_gen_prob[i][j]
                batched_gen_prob[:] = batched_gen_prob_temp[:]
                for i in range(current_batch_size):
                    for j in range(max_turn_length):
                        batched_gen_prob[i][j] = np.concatenate(
                            (batched_gen_prob[i][j],
                             [batched_gen_prob[i][j][-1]]),
                            axis=0)

            transpose(batched_gen_prob)
            transpose(batched_gen)

            sent_length = []
            for i in range(current_batch_size):
                sent_length.append(
                    np.array(batched_data['sent_length'][i]) + 1)
            batched_sent = np.zeros(
                (current_batch_size, max_turn_length, max_sent_length + 2),
                dtype=int)
            empty_sent = np.zeros((current_batch_size, 1, max_sent_length + 2),
                                  dtype=int)
            for i in range(current_batch_size):
                for j, _ in enumerate(sent_length[i]):
                    batched_sent[i][j][0] = data.go_id
                    batched_sent[i][j][1:sent_length[i][j]] = batched_data[
                        'sent'][i][j][0:sent_length[i][j] - 1]
                empty_sent[i][0][0] = data.go_id
                empty_sent[i][0][1] = data.eos_id

            metric1_data = {
                'sent_allvocabs': batched_data['sent_allvocabs'],
                'sent_length': batched_data['sent_length'],
                'gen_log_prob': batched_gen_prob,
            }
            metric1.forward(metric1_data)
            metric2_data = {
                'context_allvocabs': [],
                'reference_allvocabs': batched_data['sent_allvocabs'],
                'turn_length': batched_data['turn_length'],
                'gen': batched_gen,
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            for key, value in res.items():
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            for i in range(len(res['context'])):
                f.write("batch number:\t%d\n" % i)
                for j in range(
                        min(len(res['context'][i]), len(res['reference'][i]))):
                    if j > 0 and " ".join(res['context'][i][j]) != " ".join(
                            res['reference'][i][j - 1]):
                        f.write("\n")
                    f.write("post:\t%s\n" % " ".join(res['context'][i][j]))
                    f.write("resp:\t%s\n" % " ".join(res['reference'][i][j]))
                    if j < len(res['gen'][i]):
                        f.write("gen:\t%s\n" % " ".join(res['gen'][i][j]))
                    else:
                        f.write("gen:\n")

        print("result output to %s" % test_file)
Пример #11
0
class Seq2SeqModel(object):
    def __init__(self, data, args, embed):

        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(
            self.embed, self.posts_input)  #batch*len*unit
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = dynamic_rnn(cell_enc,
                                                        self.encoder_input,
                                                        self.posts_length,
                                                        dtype=tf.float32,
                                                        scope="encoder_rnn")

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, self.responses_length)
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)
        attn_mechanism = tf.contrib.seq2seq.LuongAttention(
            args.dh_size,
            encoder_output,
            memory_sequence_length=self.posts_length)
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        enc_state_shaping = tf.layers.dense(encoder_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            self.decoder_distribution_teacher, self.decoder_loss = sampled_sequence_loss(
                self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_sen_length,
                scope="decoder_rnn")
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key):
        if key == "latest":
            self.latest_saver.save(sess, path, global_step=self.global_step)
        else:
            self.best_saver.save(sess, path, global_step=self.global_step)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      embedding=emblist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       embedding=emblist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, forward_only=False):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length']
        }
        if forward_only:
            output_feed = [
                self.decoder_loss, self.decoder_distribution_teacher
            ]
        else:
            output_feed = [self.decoder_loss, self.gradient_norm, self.update]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length']
        }
        output_feed = [self.generation_index]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss = np.zeros((1, ))
        times = 0
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_decoder(sess, batched_data, forward_only=True)
            loss += outputs[0]
            times += 1
            batched_data = data.get_next_batch(key_name)
        loss /= times

        print('    perplexity on %s set: %.2f' % (key_name, np.exp(loss)))
        print(loss)
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        previous_losses = [1e18] * 5
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    show = lambda a: '[%s]' % (' '.join(
                        ['%.2f' % x for x in a]))
                    print(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           show(np.exp(loss_step))))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(loss_step)
                        })
                    #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\
                    self.store_checkpoint(
                        sess,
                        '%s/checkpoint_latest/checkpoint' % args.model_dir,
                        "latest")

                    dev_loss = self.evaluate(sess, data, args.batch_size,
                                             "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss,
                            'perplexity': np.exp(dev_loss)
                        })

                    test_loss = self.evaluate(sess, data, args.batch_size,
                                              "test")
                    self.testSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': test_loss,
                            'perplexity': np.exp(test_loss)
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if dev_loss < best_valid:
                        best_valid = dev_loss
                        self.store_checkpoint(
                            sess,
                            '%s/checkpoint_best/checkpoint' % args.model_dir,
                            "best")

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0

                start_time = time.time()
                loss_step += self.step_decoder(
                    sess, batched_data)[0] / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process(self, sess, data, args):
        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")
        results = []
        while batched_data != None:
            batched_responses_id = self.inference(sess, batched_data)[0]
            _, gen_prob = self.step_decoder(sess,
                                            batched_data,
                                            forward_only=True)
            metric1_data = {
                'resp': np.array(batched_data['resp']),
                'resp_length': np.array(batched_data['resp_length']),
                'gen_prob': np.array(gen_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                result_token = []
                response_id_list = response_id.tolist()
                response_token = data.index_to_sen(response_id_list)
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                for token in response_token:
                    if token != data.ext_vocab[data.eos_id]:
                        result_token.append(token)
                    else:
                        break
                results.append(result_token)
                batch_results.append(result_id)

            metric2_data = {
                'post': np.array(batched_data['post']),
                'resp': np.array(batched_data['resp']),
                'gen': np.array(batch_results)
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            for key, value in res.items():
                if isinstance(value, float):
                    print("\t%s:\t%f", key, value)
                    f.write("%s:\t%f\n" % (key, value))
            for i in range(len(res['post'])):
                f.write("post:\t%s\n" % " ".join(res['post'][i]))
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n" % " ".join(res['gen'][i]))

        print("result output to %s.", test_file)
Пример #12
0
class Seq2SeqModel(object):
    def __init__(self, data, args, embed):
        # posts表示编码器,即历史对话输入 [batch, encoder_len]
        # posts_length表示输入的每一句话的实际长度 [batch]
        # prev_length除去最后一轮,之前轮次语句的长度(包含<go>和<eos>),[batch]
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prevs_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens_prev')  # batch

        # origin_responses表示回复的内容,[batch, resp_len]
        # origin_responses_length表示每一个回复的实际长度,[batch, ]
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        # 这里对回复进行分割,此时祛除了回复中的go_id
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        # 这里得到解码器的输入和输出,输入去除了最后的eos_id,输出去除了最开始的go_id,这样保证对齐
        # [batch, decoder_len](这里的decoder_len等于resp_len-1)
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        # 编码器输入 [batch, encoder_len]
        self.posts_input = self.posts  # batch*len
        # 这里计算decoder的mask矩阵
        # 等于[batch, decoder_len]
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        # 将编码器和解码器的输入转化为词向量
        # encoder_input: [batch, encoder_len, embed_size]
        # decoder_input: [batch, decoder_len, embed_size]
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            # encoder_output: [batch, encoder_len, eh_size]
            # encoder_state: [batch, eh_size]
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        encoder_len = tf.shape(encoder_output)[1]
        # 这里计算posts和prevs的mask矩阵
        posts_mask = tf.sequence_mask(self.posts_length, encoder_len)
        prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len)
        # 不同为1,相同为1
        # 这里表示只关注最后一轮,[batch, encoder_len]
        attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask),
                                    [batch_size, encoder_len])

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, self.responses_length)
        # 这里在推理的时候,起始位置全部使用go_id进行填充
        # 这在对输入数据进行封装时即进行了定义
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        # 这里编码器是按照多轮输入进行编码的
        # 但是解码器在attention的时候只关注最后一轮输入
        # 这里定义输入输出attention
        attn_mechanism = MyAttention(args.dh_size, encoder_output,
                                     attention_mask)
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        # 把编码器最后一层的隐状态映射到解码器隐状态的维度
        # [batch, dh_size]
        enc_state_shaping = tf.layers.dense(encoder_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # 计算损失和概率分布
            # decoder_distribution_teacher:[batch, decoder_length, vocab_size] (这里都是对数概率)
            # decoder_loss,基于这个batch中所有词的损失,0维
            # decoder_all_loss,每一句话的损失,[batch, ]
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            # 这里output_fn会重用上面的权重和偏置
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # [batch, max_decoder_len, vocab_size]
            self.decoder_distribution = infer_outputs.rnn_output
            # 这里在计算索引概率最大值的去除前面两个<pad>和<unk>的影响
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)  # 定义优化器
        gradients = tf.gradients(self.decoder_loss, self.params)  # 计算参数的梯度
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)  # 梯度裁剪
        self.update = opt.apply_gradients(
            zip(clipped_gradients,
                self.params), global_step=self.global_step)  # 对参数进行更新

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key, name):
        if key == "latest":
            self.latest_saver.save(sess,
                                   path,
                                   global_step=self.global_step,
                                   latest_filename=name)
        else:
            self.best_saver.save(sess,
                                 path,
                                 global_step=self.global_step,
                                 latest_filename=name)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      prefix="dev")

    def print_parameters(self):
        for item in self.params:
            logger.info('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, forward_only=False):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.prevs_length: data['prev_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.is_train: True
        }
        if forward_only:
            output_feed = [
                self.decoder_loss, self.decoder_distribution_teacher
            ]
        else:
            output_feed = [self.decoder_loss, self.gradient_norm, self.update]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.prevs_length: data['prev_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.is_train: False
        }
        output_feed = [
            self.generation_index, self.decoder_distribution_teacher,
            self.decoder_all_loss
        ]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss = np.zeros((1, ))
        times = 0
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_decoder(sess, batched_data, forward_only=True)
            loss += outputs[0]
            times += 1
            batched_data = data.get_next_batch(key_name)
        loss /= times

        logger.info(
            f'Evaluate loss: {float(loss):.2f} | perplexity on {key_name} set: {float(np.exp(loss)): .2f}'
        )
        # print(loss)
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        previous_losses = [1e18] * 3
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")

        for i in range(2):
            logger.info(
                f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}"
            )
            logger.info(
                f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}"
            )

        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    show = lambda a: '[%s]' % (' '.join(
                        ['%.2f' % x for x in a]))
                    logger.info(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           show(np.exp(loss_step))))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(loss_step)
                        })
                    self.store_checkpoint(
                        sess, '%s/checkpoint_latest/%s' %
                        (args.model_dir, args.name), "latest", args.name)

                    dev_loss = self.evaluate(sess, data, args.batch_size,
                                             "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss,
                            'perplexity': np.exp(dev_loss)
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if dev_loss < best_valid:
                        best_valid = dev_loss
                        self.store_checkpoint(
                            sess, '%s/checkpoint_best/%s' %
                            (args.model_dir, args.name), "best", args.name)

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0

                start_time = time.time()
                loss_step += self.step_decoder(
                    sess, batched_data)[0] / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process_hits(self, sess, data, args):

        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'r',
                  encoding="utf-8") as f:
            test_distractors = json.load(f)

        data.restart("test", batch_size=1, shuffle=False)
        batched_data = data.get_next_batch("test")

        loss_record = []
        cnt = 0
        while batched_data != None:

            for key in batched_data:
                if isinstance(batched_data[key], np.ndarray):
                    batched_data[key] = batched_data[key].tolist()

            batched_data['resp_length'] = [len(batched_data['resp'][0])]
            batched_data['resp'] = batched_data['resp']
            for each_resp in test_distractors[cnt]:
                batched_data['resp'].append(
                    [data.go_id] +
                    data.convert_tokens_to_ids(jieba.lcut(each_resp)) +
                    [data.eos_id])
                batched_data['resp_length'].append(
                    len(batched_data['resp'][-1]))
            max_length = max(batched_data['resp_length'])
            resp = np.zeros((len(batched_data['resp']), max_length), dtype=int)
            for i, each_resp in enumerate(batched_data['resp']):
                resp[i, :len(each_resp)] = each_resp
            batched_data['resp'] = resp

            post = []
            post_length = []
            prev_length = []

            for _ in range(len(resp)):
                post += batched_data['post']
                post_length += batched_data['post_length']
                prev_length += batched_data['prev_length']

            batched_data['post'] = post
            batched_data['post_length'] = post_length
            batched_data['prev_length'] = prev_length

            _, _, loss = self.inference(sess, batched_data)
            loss_record.append(loss)
            cnt += 1
            batched_data = data.get_next_batch("test")

        assert cnt == len(test_distractors)

        loss = np.array(loss_record)
        loss_rank = np.argsort(loss, axis=1)
        hits1 = float(np.mean(loss_rank[:, 0] == 0))
        hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0))
        hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0))
        return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5}

    def test_process(self, sess, data, args):
        metric1 = data.get_teacher_forcing_metric()  # 这里主要计算ppl指标
        metric2 = data.get_inference_metric()  # 这里主要计算bleu和distinct指标
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")

        for i in range(3):
            logger.info(
                f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}"
            )
            logger.info(
                f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}"
            )

        while batched_data != None:
            batched_responses_id, gen_log_prob, _ = self.inference(
                sess, batched_data)
            metric1_data = {
                'resp_allvocabs': np.array(batched_data['resp_allvocabs']),
                'resp_length': np.array(batched_data['resp_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)

            metric2_data = {
                'gen': np.array(batch_results),
                'resp_allvocabs': np.array(batched_data['resp_allvocabs'])
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())
        res.update(self.test_process_hits(sess, data, args))

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w', encoding="utf-8") as f:
            print("Test Result:")
            res_print = list(res.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')
            for i in range(len(res['resp'])):
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n\n" % " ".join(res['gen'][i]))

        logger.info("result output to %s." % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }
Пример #13
0
class HredModel(object):
    def __init__(self, data, args, embed):
        #self.init_states = tf.placeholder(tf.float32, (None, args.ch_size), 'ctx_inps')  # batch*ch_size
        # posts: [batch*(num_turns-1), max_post_length]
        # posts_length: [batch*(num_turns-1),]
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch * num_turns-1 * len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch * num_turns-1

        # prev_posts: [batch, max_prev_length],即对应上面post的最后一轮
        # prev_posts_length: [batch],即最后一轮每句话的实际长度
        self.prev_posts = tf.placeholder(tf.int32, (None, None),
                                         'enc_prev_inps')
        self.prev_posts_length = tf.placeholder(tf.int32, (None, ),
                                                'enc_prev_lens')

        # origin_responses: [batch, max_response_length]
        # origin_responses_length: [batch]
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch
        # context_length: [batch],表示每一个posts的实际轮次
        self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens')
        self.is_train = tf.placeholder(tf.bool)

        # 即对应num_turns-1(也有可能比这个小)
        # 表示当前batch的实际最大轮次
        num_past_turns = tf.shape(self.posts)[0] // tf.shape(
            self.origin_responses)[0]

        # deal with original data to adapt encoder and decoder
        # 获取解码器的输入和输出
        # 其中输入没有最后的<eos>,输出没有最开始的<go>
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        self.posts_input = self.posts  # batch*len
        # [batch, decoder_length]
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)

        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        # self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input))  # batch*len*unit
        # self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            # encoder_output: [batch*(num_turns-1), max_post_length, eh_size]
            # encoder_state: [batch*(num_turns-1), eh_size]
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            # prev_output: [batch, max_prev_length, eh_size]
            prev_output, _ = tf.nn.dynamic_rnn(cell_enc,
                                               tf.nn.embedding_lookup(
                                                   self.embed,
                                                   self.prev_posts),
                                               self.prev_posts_length,
                                               dtype=tf.float32,
                                               scope="encoder_rnn")

        # encoder_hidden_size = tf.shape(encoder_state)[-1]

        with tf.variable_scope('context'):
            # encoder_state_reshape: [batch, num_turns-1, eh_size]
            # context_output: [batch, num_turns-1, ch_size]
            # context_state: [batch, ch_size]
            encoder_state_reshape = tf.reshape(
                encoder_state, [-1, num_past_turns, args.eh_size])
            context_output, self.context_state = tf.nn.dynamic_rnn(
                cell_ctx,
                encoder_state_reshape,
                self.context_length,
                dtype=tf.float32,
                scope='context_rnn')

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, tf.maximum(self.responses_length, 1))
        infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            self.embed, tf.fill([batch_size], data.go_id), data.eos_id)

        #encoder_len = tf.shape(encoder_output)[1]
        #attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size])
        #attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1])
        '''
        attention_memory = context_output
        attention_mask = tf.reshape(tf.sequence_mask(self.context_length, self.num_turns - 1), [batch_size, -1])
        '''
        #attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:, 1:]], axis=1)
        #attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask)
        # 注意这里的inputs,是最后一句话的编码,即[batch_size, prev_post_length, eh_size]
        # 在attention中,如果query的维度和inputs不一致,需要先经过线性层将query转化为
        attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            args.dh_size,
            prev_output,
            memory_sequence_length=tf.maximum(self.prev_posts_length, 1))
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        # 将posts的编码输出转化为解码器的维度 [batch, dh_size]
        ctx_state_shaping = tf.layers.dense(self.context_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            # 这里的decoder_output: [batch, decoder_length, dh_size]
            self.decoder_output = train_outputs.rnn_output
            # self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # decoder_distribution_teacher: [batch, decoder_length, vocab_size]
            # decoder_loss: 标量
            # decoder_all_loss: [batch, ],表示每一句话的对数损失
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # [batch, max_decoder_length, vocab_size]
            self.decoder_distribution = infer_outputs.rnn_output
            # 得到每一步解码的单词索引[batch, max_decoder_length]
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        gradients = tf.gradients(self.decoder_loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key, name):
        if key == "latest":
            self.latest_saver.save(sess,
                                   path,
                                   global_step=self.global_step,
                                   latest_filename=name)
        else:
            self.best_saver.save(sess,
                                 path,
                                 global_step=self.global_step,
                                 latest_filename=name)

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        emblist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      embedding=emblist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       embedding=emblist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            logger.info('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, sess, data, forward_only=False, inference=False):
        input_feed = {
            self.posts: data['posts'],
            self.posts_length: data['posts_length'],
            self.origin_responses: data['responses'],
            self.origin_responses_length: data['responses_length'],
            self.context_length: data['context_length'],
            self.prev_posts: data['prev_posts'],
            self.prev_posts_length: data['prev_posts_length']
        }

        if inference:
            input_feed.update({self.is_train: False})
            output_feed = [
                self.generation_index, self.decoder_distribution_teacher,
                self.decoder_all_loss
            ]
        else:
            input_feed.update({self.is_train: True})
            if forward_only:
                output_feed = [
                    self.decoder_loss, self.decoder_distribution_teacher
                ]
            else:
                output_feed = [
                    self.decoder_loss, self.gradient_norm, self.update
                ]

        return sess.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        loss = np.zeros((1, ))
        times = 0
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_decoder(sess, batched_data, forward_only=True)
            loss += outputs[0]
            times += 1
            batched_data = data.get_next_batch(key_name)
        loss /= times

        logger.info('    perplexity on %s set: %.2f' %
                    (key_name, float(np.exp(loss))))
        logger.info(loss)
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        previous_losses = [1e18] * 3
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    show = lambda a: '[%s]' % (' '.join(
                        ['%.2f' % x for x in a]))
                    logger.info(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           show(np.exp(loss_step))))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(loss_step)
                        })
                    self.store_checkpoint(
                        sess, '%s/checkpoint_latest/%s' %
                        (args.model_dir, args.name), "latest", args.name)

                    dev_loss = self.evaluate(sess, data, args.batch_size,
                                             "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss,
                            'perplexity': np.exp(dev_loss)
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    # 如果验证集的损失小于之前最小的损失,则将当前模型保存到最好的模型中
                    if dev_loss < best_valid:
                        best_valid = dev_loss
                        self.store_checkpoint(
                            sess, '%s/checkpoint_best/%s' %
                            (args.model_dir, args.name), "best", args.name)

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0

                start_time = time.time()
                loss_step += self.step_decoder(
                    sess, batched_data)[0] / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process_hits(self, sess, data, args):

        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'r',
                  encoding="utf-8") as f:
            test_distractors = json.load(f)

        data.restart("test", batch_size=1, shuffle=False)
        batched_data = data.get_next_batch("test")

        loss_record = []
        cnt = 0
        while batched_data != None:

            for key in batched_data:
                if isinstance(batched_data[key], np.ndarray):
                    batched_data[key] = batched_data[key].tolist()

            batched_data['responses_length'] = [
                len(batched_data['responses'][0])
            ]
            batched_data['responses'] = batched_data['responses']
            for each_resp in test_distractors[cnt]:
                batched_data['responses'].append(
                    [data.go_id] +
                    data.convert_tokens_to_ids(jieba.lcut(each_resp)) +
                    [data.eos_id])
                batched_data['responses_length'].append(
                    len(batched_data['responses'][-1]))
            max_length = max(batched_data['responses_length'])
            resp = np.zeros((len(batched_data['responses']), max_length),
                            dtype=int)
            for i, each_resp in enumerate(batched_data['responses']):
                resp[i, :len(each_resp)] = each_resp
            batched_data['responses'] = resp

            posts = []
            posts_length = []
            prev_posts = []
            prev_posts_length = []
            context_length = []
            for _ in range(len(resp)):
                posts += batched_data['posts']
                posts_length += batched_data['posts_length']
                prev_posts += batched_data['prev_posts']
                prev_posts_length += batched_data['prev_posts_length']
                context_length += batched_data['context_length']
            batched_data['posts'] = posts
            batched_data['posts_length'] = posts_length
            batched_data['prev_posts'] = prev_posts
            batched_data['prev_posts_length'] = prev_posts_length
            batched_data['context_length'] = context_length

            _, _, loss = self.step_decoder(sess, batched_data, inference=True)
            loss_record.append(loss)
            cnt += 1
            batched_data = data.get_next_batch("test")

        assert cnt == len(test_distractors)

        loss = np.array(loss_record)
        loss_rank = np.argsort(loss, axis=1)
        hits1 = float(np.mean(loss_rank[:, 0] == 0))
        hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0))
        hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0))
        return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5}

    def test_process(self, sess, data, args):

        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")

        while batched_data != None:
            batched_responses_id, gen_log_prob, _ = self.step_decoder(
                sess, batched_data, inference=True)
            metric1_data = {
                'resp_allvocabs':
                np.array(batched_data['responses_allvocabs']),
                'resp_length': np.array(batched_data['responses_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)

            metric2_data = {
                'gen': np.array(batch_results),
                'resp_allvocabs': np.array(batched_data['responses_allvocabs'])
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())
        res.update(self.test_process_hits(sess, data, args))

        test_file = args.output_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w', encoding="utf-8") as f:
            print("Test Result:")
            res_print = list(res.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')
            for i in range(len(res['resp'])):
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n\n" % " ".join(res['gen'][i]))

        logger.info("result output to %s." % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }
Пример #14
0
class LMModel(object):
    def __init__(self, data, args, embed):

        with tf.variable_scope("input"):
            with tf.variable_scope("embedding"):
                # build the embedding table and embedding input
                if embed is None:
                    # initialize the embedding randomly
                    self.embed = tf.get_variable(
                        'embed',
                        [data.frequent_vocab_size, args.embedding_size],
                        tf.float32)
                else:
                    # initialize the embedding by pre-trained word vectors
                    self.embed = tf.get_variable('embed',
                                                 dtype=tf.float32,
                                                 initializer=embed)

            # input
            self.sentence = tf.placeholder(tf.int32, (None, None),
                                           'sen_inps')  # batch*len
            self.sentence_length = tf.placeholder(tf.int32, (None, ),
                                                  'sen_lens')  # batch
            self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior")

            batch_size, batch_len = tf.shape(self.sentence)[0], tf.shape(
                self.sentence)[1]
            self.scentence_max_len = batch_len - 1

            # data processing
            LM_input = tf.split(self.sentence, [self.scentence_max_len, 1],
                                1)[0]  # no eos_id
            self.LM_input = tf.nn.embedding_lookup(
                self.embed, LM_input)  # batch*(len-1)*unit
            self.LM_target = tf.split(self.sentence,
                                      [1, self.scentence_max_len],
                                      1)[1]  # no go_id, batch*(len-1)
            self.input_len = self.sentence_length - 1
            self.input_mask = tf.sequence_mask(
                self.input_len, self.scentence_max_len,
                dtype=tf.float32)  # 0 for <pad>, batch*(len-1)

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build LSTM NN
        basic_cell = tf.nn.rnn_cell.LSTMCell(args.dh_size)
        with tf.variable_scope('rnnlm'):
            LM_output, _ = dynamic_rnn(basic_cell,
                                       self.LM_input,
                                       self.input_len,
                                       dtype=tf.float32,
                                       scope="rnnlm")
        # fullly connected layer
        LM_output = tf.layers.dense(
            inputs=LM_output, units=data.frequent_vocab_size
        )  # shape of LM_output: (batch_size, batch_len-1, vocab_size)

        # loss
        with tf.variable_scope("loss",
                               initializer=tf.orthogonal_initializer()):
            crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=LM_output, labels=self.LM_target)
            crossent = tf.reduce_sum(crossent *
                                     self.input_mask)  # to ignore <pad>s

            self.sen_loss = crossent / tf.to_float(batch_size)
            self.ppl_loss = crossent / tf.reduce_sum(
                self.input_mask)  # crossent per word.
            # self.ppl_loss = tf.Print(self.ppl_loss, [self.ppl_loss] )

            self.decoder_distribution_teacher = tf.nn.log_softmax(LM_output)
        with tf.variable_scope("decode", reuse=True):
            self.decoder_distribution = LM_output  # (batch_size, batch_len-1, vocab_size)
            # for inference
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution,
                         [2, data.frequent_vocab_size - 2], 2)[1],
                2) + 2  # for removing UNK. 0 for <pad> and 1 for <unk>

        self.loss = self.sen_loss

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        opt = tf.train.MomentumOptimizer(learning_rate=self.learning_rate,
                                         momentum=args.momentum)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key):
        if key == "latest":
            self.latest_saver.save(sess, path, global_step=self.global_step)
        else:
            self.best_saver.save(sess, path, global_step=self.global_step)

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(scalar=[
            "loss",
            "perplexity",
        ],
                                                        prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            print('%s: %s' % (item.name, item.get_shape()))

    def step_LM(self, session, data, forward_only=False):
        '''
		run the LM for one step (batch)
		'''
        input_feed = {
            self.sentence: data['sent'],
            self.sentence_length: data['sent_length'],
            self.use_prior: False
        }
        if forward_only:
            # test mode
            output_feed = [
                self.loss, self.decoder_distribution_teacher, self.ppl_loss
            ]
        else:
            # train mode
            output_feed = [
                self.loss, self.gradient_norm, self.update, self.ppl_loss
            ]
        return session.run(output_feed, input_feed)

    def inference(self, session, data):
        input_feed = {
            self.sentence: data['sent'],
            self.sentence_length: data['sent_length']
        }
        output_feed = [self.generation_index]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name):
        '''
		to get the loss and ppl_loss per step on dev and test
		'''
        loss_step = np.zeros((1, ))
        ppl_loss_step = 0
        times = 0

        data.restart(key_name, batch_size=batch_size,
                     shuffle=False)  # initialize mini-batches
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            outputs = self.step_LM(sess, batched_data, forward_only=True)
            loss_step += outputs[0]
            ppl_loss_step += outputs[-1]
            times += 1
            batched_data = data.get_next_batch(key_name)

        loss_step /= times
        ppl_loss_step /= times

        print('    loss: %.2f' % loss_step)
        return loss_step, ppl_loss_step

    def train_process(self, sess, data, args):
        # 'X_step' <=> X per step
        loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0
        ppl_loss_step = 0
        previous_losses = [1e18] * 5  # previous 5 losses
        best_valid = 1e18

        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")
        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    print(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step))
                    print('    loss: %.2f' % loss_step)
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step,
                            'perplexity': np.exp(ppl_loss_step),
                        })
                    self.store_checkpoint(
                        sess,
                        '%s/checkpoint_latest/checkpoint' % args.model_dir,
                        "latest")

                    devout = self.evaluate(sess, data, args.batch_size, "dev")
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': devout[0],
                            'perplexity': np.exp(devout[1]),
                        })

                    testout = self.evaluate(sess, data, args.batch_size,
                                            "test")
                    self.testSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': testout[0],
                            'perplexity': np.exp(testout[1]),
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if devout[0] < best_valid:
                        best_valid = devout[0]
                        self.store_checkpoint(
                            sess,
                            '%s/checkpoint_best/checkpoint' % args.model_dir,
                            "best")

                    previous_losses = previous_losses[1:] + [np.sum(loss_step)]
                    loss_step, time_step = np.zeros((1, )), .0
                    ppl_loss_step = 0

                start_time = time.time()
                outputs = self.step_LM(sess, batched_data)

                # outputs: loss, decoder_distribution_teacher, ppl_loss
                loss_step += outputs[0] / args.checkpoint_steps
                ppl_loss_step += outputs[-1] / args.checkpoint_steps

                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process(self, sess, data, args):
        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")
        results = []
        while batched_data != None:
            batched_responses_id = self.inference(sess, batched_data)[0]
            gen_prob = self.step_LM(sess, batched_data, forward_only=True)[1]
            metric1_data = {
                'sent_allvocabs': np.array(batched_data['sent_allvocabs']),
                'sent_length': np.array(batched_data['sent_length']),
                'gen_log_prob': np.array(gen_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                result_token = []
                response_id_list = response_id.tolist()
                response_token = data.convert_ids_to_tokens(response_id_list)
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                for token in response_token:
                    ext_vocab = data.get_special_tokens_mapping()
                    if token != ext_vocab['eos']:
                        result_token.append(token)
                    else:
                        break
                results.append(result_token)
                batch_results.append(result_id)

            metric2_data = {'gen': np.array(batch_results)}
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())

        test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w') as f:
            print("Test Result:")
            for key, value in res.items():
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            for i in range(len(res['gen'])):
                f.write("%s\n" % " ".join(res['gen'][i]))

        print("result output to %s." % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }
Пример #15
0
class Seq2SeqModel(object):
    def __init__(self, data, args, embed):
        # 这里的输入和前面的seq2seq一致
        self.posts = tf.placeholder(tf.int32, (None, None),
                                    'enc_inps')  # batch*len
        self.posts_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens')  # batch
        self.prevs_length = tf.placeholder(tf.int32, (None, ),
                                           'enc_lens_prev')  # batch
        self.origin_responses = tf.placeholder(tf.int32, (None, None),
                                               'dec_inps')  # batch*len
        self.origin_responses_length = tf.placeholder(tf.int32, (None, ),
                                                      'dec_lens')  # batch

        # kgs表示该样例所在这段对话中所有的知识:[batch, max_kg_nums, max_kg_length]
        # kgs_h_length表示每一个知识中head entity的长度:[batch, max_kg_nums]
        # kgs_hr_length表示每一个知识中head entity和relation的长度:[batch, max_kg_nums]
        # kgs_hrt_length表示每一个知识中h,r,t的长度:[batch, max_kg_nums]
        # kgs_index表示当前这句话实际使用的kg的索引指示矩阵:[batch, max_kg_nums](其中使用的知识对应为1,没有使用的知识对应为0)
        self.kgs = tf.placeholder(tf.int32, (None, None, None), 'kg_inps')
        self.kgs_h_length = tf.placeholder(tf.int32, (None, None), 'kg_h_lens')
        self.kgs_hr_length = tf.placeholder(tf.int32, (None, None),
                                            'kg_hr_lens')
        self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None),
                                             'kg_hrt_lens')
        self.kgs_index = tf.placeholder(tf.float32, (None, None), 'kg_indices')

        # 用来平衡解码损失和kg损失的超参数
        self.lamb = tf.placeholder(tf.float32, name='lamb')
        self.is_train = tf.placeholder(tf.bool)

        # deal with original data to adapt encoder and decoder
        # 获取解码器的输入和输出
        batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape(
            self.origin_responses)[1]
        self.responses = tf.split(self.origin_responses, [1, decoder_len - 1],
                                  1)[1]  # no go_id
        self.responses_length = self.origin_responses_length - 1
        self.responses_input = tf.split(self.origin_responses,
                                        [decoder_len - 1, 1],
                                        1)[0]  # no eos_id
        self.responses_target = self.responses
        decoder_len = decoder_len - 1
        # 获取编码器的输入
        self.posts_input = self.posts  # batch*len
        # 对解码器的mask矩阵,对于pad的mask
        self.decoder_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len),
                      reverse=True,
                      axis=1), [-1, decoder_len])
        kg_len = tf.shape(self.kgs)[2]
        #kg_len = tf.Print(kg_len, [batch_size, kg_len, decoder_len, self.kgs_length])
        # kg_h_mask = tf.reshape(tf.cumsum(tf.one_hot(self.kgs_h_length-1,
        # 	kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1])
        # 这里分别得到对于key(也就是hr)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1]
        # 以及对于value(也就是t)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1]
        kg_hr_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_hrt_mask = tf.reshape(
            tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len),
                      reverse=True,
                      axis=2), [batch_size, -1, kg_len, 1])
        kg_key_mask = kg_hr_mask
        kg_value_mask = kg_hrt_mask - kg_hr_mask

        # initialize the training process
        self.learning_rate = tf.Variable(float(args.lr),
                                         trainable=False,
                                         dtype=tf.float32)
        self.learning_rate_decay_op = self.learning_rate.assign(
            self.learning_rate * args.lr_decay)
        self.global_step = tf.Variable(0, trainable=False)

        # build the embedding table and embedding input
        if embed is None:
            # initialize the embedding randomly
            self.embed = tf.get_variable(
                'embed', [data.vocab_size, args.embedding_size], tf.float32)
        else:
            # initialize the embedding by pre-trained word vectors
            self.embed = tf.get_variable('embed',
                                         dtype=tf.float32,
                                         initializer=embed)
        # encoder_input: [batch, encoder_len, embed_size]
        self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts)
        # decoder_input: [batch, decoder_len, embed_size]
        self.decoder_input = tf.nn.embedding_lookup(self.embed,
                                                    self.responses_input)
        # kg_input: [batch, max_kg_nums, max_kg_length, embed_size]
        self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs)
        #self.encoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit
        #self.decoder_input = tf.cond(self.is_train,
        #							 lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8),
        #							 lambda: tf.nn.embedding_lookup(self.embed, self.responses_input))

        # build rnn_cell
        cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size)
        cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size)

        # build encoder
        with tf.variable_scope('encoder'):
            encoder_output, encoder_state = tf.nn.dynamic_rnn(
                cell_enc,
                self.encoder_input,
                self.posts_length,
                dtype=tf.float32,
                scope="encoder_rnn")
        # key对应一个知识h,r的词向量的均值 [batch, max_kg_nums, embed_size]
        # value对应一个知识t的词向量的均值 [batch, max_kg_nums, embed_size]
        self.kg_key_avg = tf.reduce_sum(
            self.kg_input * kg_key_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_key_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        self.kg_value_avg = tf.reduce_sum(
            self.kg_input * kg_value_mask, axis=2) / tf.maximum(
                tf.reduce_sum(kg_value_mask, axis=2),
                tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1),
                             dtype=tf.float32))
        # 将编码器的输出状态映射到embed_size的维度
        # query: [batch, 1, embed_size]
        with tf.variable_scope('knowledge'):
            query = tf.reshape(
                tf.layers.dense(tf.concat(encoder_state, axis=-1),
                                args.embedding_size,
                                use_bias=False),
                [batch_size, 1, args.embedding_size])
        # [batch, max_kg_nums]
        kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2)
        # 对于hrt大于0的位置(即该位置存在知识),取对应的kg_score,否则对应位置为-inf
        kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score,
                            -tf.ones_like(kg_score) * np.inf)
        # 计算每个知识对应的分数 [batch, max_kg_nums]
        kg_alignment = tf.nn.softmax(kg_score)

        # 根据计算的kg注意力分数的位置,计算关注的kg准确率和损失
        kg_max = tf.argmax(kg_alignment, axis=-1)
        kg_max_onehot = tf.one_hot(kg_max,
                                   tf.shape(kg_alignment)[1],
                                   dtype=tf.float32)
        self.kg_acc = tf.reduce_sum(
            kg_max_onehot * self.kgs_index) / tf.maximum(
                tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)),
                tf.constant(1.0))
        self.kg_loss = tf.reduce_sum(
            -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) *
            self.kgs_index,
            axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1),
                                 tf.ones([batch_size], dtype=tf.float32))
        self.kg_loss = tf.reduce_mean(self.kg_loss)
        # 得到注意力之后的知识的嵌入:[batch, embed_size]
        self.knowledge_embed = tf.reduce_sum(
            tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg, axis=1)
        # 对维度进行扩充[batch, decoder_len, embed_size]
        knowledge_embed_extend = tf.tile(
            tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1])
        # 将知识和原始的解码输入拼接,作为新的解码输入 [batch, decoder_len, 2*embed_size]
        self.decoder_input = tf.concat(
            [self.decoder_input, knowledge_embed_extend], axis=2)

        # get output projection function
        output_fn = MyDense(data.vocab_size, use_bias=True)
        sampled_sequence_loss = output_projection_layer(
            args.dh_size, data.vocab_size, args.softmax_samples)

        encoder_len = tf.shape(encoder_output)[1]
        posts_mask = tf.sequence_mask(self.posts_length, encoder_len)
        prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len)
        attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask),
                                    [batch_size, encoder_len])

        # construct helper and attention
        train_helper = tf.contrib.seq2seq.TrainingHelper(
            self.decoder_input, self.responses_length)
        #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id)
        # 为了在推理的时候,每一次的输入都是上一次输出和知识的拼接
        infer_helper = MyInferenceHelper(self.embed,
                                         tf.fill([batch_size], data.go_id),
                                         data.eos_id, self.knowledge_embed)
        #attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(args.dh_size, encoder_output,
        #  memory_sequence_length=self.posts_length)
        # 这里的MyAttention主要解决BahdanauAttention只能输入编码序列长度的问题
        attn_mechanism = MyAttention(args.dh_size, encoder_output,
                                     attention_mask)
        cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper(
            cell_dec, attn_mechanism, attention_layer_size=args.dh_size)
        enc_state_shaping = tf.layers.dense(encoder_state,
                                            args.dh_size,
                                            activation=None)
        dec_start = cell_dec_attn.zero_state(
            batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping)

        # build decoder (train)
        with tf.variable_scope('decoder'):
            decoder_train = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, train_helper, dec_start)
            train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_train, impute_finished=True, scope="decoder_rnn")
            self.decoder_output = train_outputs.rnn_output
            #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8)
            # 输出概率分布和解码损失
            self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \
             sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask)

        # build decoder (test)
        with tf.variable_scope('decoder', reuse=True):
            decoder_infer = tf.contrib.seq2seq.BasicDecoder(
                cell_dec_attn, infer_helper, dec_start, output_layer=output_fn)
            infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder_infer,
                impute_finished=True,
                maximum_iterations=args.max_decoder_length,
                scope="decoder_rnn")
            # 输出解码概率分布
            self.decoder_distribution = infer_outputs.rnn_output
            self.generation_index = tf.argmax(
                tf.split(self.decoder_distribution, [2, data.vocab_size - 2],
                         2)[1], 2) + 2  # for removing UNK

        # calculate the gradient of parameters and update
        self.params = [
            k for k in tf.trainable_variables() if args.name in k.name
        ]
        opt = tf.train.AdamOptimizer(self.learning_rate)
        # 将解码损失和kg损失相加
        self.loss = self.decoder_loss + self.lamb * self.kg_loss
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(
            gradients, args.grad_clip)
        self.update = opt.apply_gradients(zip(clipped_gradients, self.params),
                                          global_step=self.global_step)

        # save checkpoint
        self.latest_saver = tf.train.Saver(
            write_version=tf.train.SaverDef.V2,
            max_to_keep=args.checkpoint_max_to_keep,
            pad_step_number=True,
            keep_checkpoint_every_n_hours=1.0)
        self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                         max_to_keep=1,
                                         pad_step_number=True,
                                         keep_checkpoint_every_n_hours=1.0)

        # create summary for tensorboard
        self.create_summary(args)

    def store_checkpoint(self, sess, path, key, name):
        if key == "latest":
            self.latest_saver.save(sess,
                                   path,
                                   global_step=self.global_step,
                                   latest_filename=name)
        else:
            self.best_saver.save(sess,
                                 path,
                                 global_step=self.global_step,
                                 latest_filename=name)
            #self.best_global_step = self.global_step

    def create_summary(self, args):
        self.summaryHelper = SummaryHelper("%s/%s_%s" % \
          (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args)

        self.trainSummary = self.summaryHelper.addGroup(
            scalar=["loss", "perplexity"], prefix="train")

        scalarlist = ["loss", "perplexity"]
        tensorlist = []
        textlist = []
        for i in args.show_sample:
            textlist.append("show_str%d" % i)
        self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                      tensor=tensorlist,
                                                      text=textlist,
                                                      prefix="dev")
        self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist,
                                                       tensor=tensorlist,
                                                       text=textlist,
                                                       prefix="test")

    def print_parameters(self):
        for item in self.params:
            logger.info('%s: %s' % (item.name, item.get_shape()))

    def step_decoder(self, session, data, lamb=1.0, forward_only=False):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.prevs_length: data['prev_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.kgs: data['kg'],
            self.kgs_h_length: data['kg_h_length'],
            self.kgs_hr_length: data['kg_hr_length'],
            self.kgs_hrt_length: data['kg_hrt_length'],
            self.kgs_index: data['kg_index'],
            self.lamb: lamb,
            self.is_train: True
        }
        if forward_only:
            output_feed = [
                self.decoder_loss, self.decoder_distribution_teacher,
                self.kg_loss, self.kg_acc
            ]
        else:
            output_feed = [
                self.decoder_loss, self.gradient_norm, self.update,
                self.kg_loss, self.kg_acc
            ]
        return session.run(output_feed, input_feed)

    def inference(self, session, data, lamb=1.0):
        input_feed = {
            self.posts: data['post'],
            self.posts_length: data['post_length'],
            self.prevs_length: data['prev_length'],
            self.origin_responses: data['resp'],
            self.origin_responses_length: data['resp_length'],
            self.kgs: data['kg'],
            self.kgs_h_length: data['kg_h_length'],
            self.kgs_hr_length: data['kg_hr_length'],
            self.kgs_hrt_length: data['kg_hrt_length'],
            self.kgs_index: data['kg_index'],
            self.lamb: lamb,
            self.is_train: False
        }
        output_feed = [
            self.generation_index, self.decoder_distribution_teacher,
            self.decoder_all_loss, self.kg_loss, self.kg_acc
        ]
        return session.run(output_feed, input_feed)

    def evaluate(self, sess, data, batch_size, key_name, lamb=1.0):
        loss = np.zeros((3, ))
        total_length = np.zeros((3, ))
        data.restart(key_name, batch_size=batch_size, shuffle=False)
        batched_data = data.get_next_batch(key_name)
        while batched_data != None:
            decoder_loss, _, kg_loss, kg_acc = self.step_decoder(
                sess, batched_data, lamb=lamb, forward_only=True)
            # 这里计算response中最长的长度
            length = np.sum(
                np.maximum(np.array(batched_data['resp_length']) - 1, 0))
            # 这里计算当前使用知识的总数
            kg_length = np.sum(np.max(batched_data['kg_index'], axis=-1))
            total_length += [length, kg_length, kg_length]
            loss += [
                decoder_loss * length, kg_loss * kg_length, kg_acc * kg_length
            ]
            batched_data = data.get_next_batch(key_name)

        loss /= total_length
        logger.info(
            '	perplexity on %s set: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f'
            % (key_name, np.exp(loss[0]), np.exp(loss[1]), loss[1], loss[2]))
        return loss

    def train_process(self, sess, data, args):
        loss_step, time_step, epoch_step = np.zeros((3, )), .0, 0
        previous_losses = [1e18] * 3
        best_valid = 1e18
        data.restart("train", batch_size=args.batch_size, shuffle=True)
        batched_data = data.get_next_batch("train")

        for i in range(2):
            logger.info(
                f"post@ {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}"
            )
            logger.info(
                f"length@ {batched_data['prev_length'][i], batched_data['post_length'][i]}"
            )
            logger.info(
                f"last@ {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist()[batched_data['prev_length'][i]: batched_data['post_length'][i]], trim=False)}"
            )
            logger.info(
                f"resp@ {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}"
            )

        for epoch_step in range(args.epochs):
            while batched_data != None:
                if self.global_step.eval(
                ) % args.checkpoint_steps == 0 and self.global_step.eval(
                ) != 0:
                    logger.info(
                        "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f"
                        % (epoch_step, self.global_step.eval(),
                           self.learning_rate.eval(), time_step,
                           np.exp(loss_step[0]), np.exp(
                               loss_step[1]), loss_step[1], loss_step[2]))
                    self.trainSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': loss_step[0],
                            'perplexity': np.exp(loss_step[0])
                        })
                    self.store_checkpoint(
                        sess, '%s/checkpoint_latest/%s' %
                        (args.model_dir, args.name), "latest", args.name)

                    dev_loss = self.evaluate(sess,
                                             data,
                                             args.batch_size,
                                             "dev",
                                             lamb=args.lamb)
                    self.devSummary(
                        self.global_step.eval() // args.checkpoint_steps, {
                            'loss': dev_loss[0],
                            'perplexity': np.exp(dev_loss[0])
                        })

                    if np.sum(loss_step) > max(previous_losses):
                        sess.run(self.learning_rate_decay_op)
                    if dev_loss[0] < best_valid:
                        best_valid = dev_loss[0]
                        self.store_checkpoint(
                            sess, '%s/checkpoint_best/%s' %
                            (args.model_dir, args.name), "best", args.name)

                    previous_losses = previous_losses[1:] + [
                        np.sum(loss_step[0])
                    ]
                    loss_step, time_step = np.zeros((3, )), .0

                start_time = time.time()
                step_out = self.step_decoder(sess,
                                             batched_data,
                                             lamb=args.lamb)
                loss_step += np.array([step_out[0], step_out[3], step_out[4]
                                       ]) / args.checkpoint_steps
                time_step += (time.time() - start_time) / args.checkpoint_steps
                batched_data = data.get_next_batch("train")

            data.restart("train", batch_size=args.batch_size, shuffle=True)
            batched_data = data.get_next_batch("train")

    def test_process_hits(self, sess, data, args):

        with open(os.path.join(args.datapath, 'test_distractors.json'),
                  'r',
                  encoding='utf8') as f:
            test_distractors = json.load(f)

        data.restart("test", batch_size=1, shuffle=False)
        batched_data = data.get_next_batch("test")

        loss_record = []
        cnt = 0
        while batched_data != None:

            for key in batched_data:
                if isinstance(batched_data[key], np.ndarray):
                    batched_data[key] = batched_data[key].tolist()

            batched_data['resp_length'] = [len(batched_data['resp'][0])]
            for each_resp in test_distractors[cnt]:
                batched_data['resp'].append(
                    [data.go_id] +
                    data.convert_tokens_to_ids(jieba.lcut(each_resp)) +
                    [data.eos_id])
                batched_data['resp_length'].append(
                    len(batched_data['resp'][-1]))
            max_length = max(batched_data['resp_length'])
            resp = np.zeros((len(batched_data['resp']), max_length), dtype=int)
            for i, each_resp in enumerate(batched_data['resp']):
                resp[i, :len(each_resp)] = each_resp
            batched_data['resp'] = resp

            post = []
            post_length = []
            prev_length = []

            kg = []
            kg_h_length = []
            kg_hr_length = []
            kg_hrt_length = []
            kg_index = []

            for _ in range(len(resp)):
                post += batched_data['post']
                post_length += batched_data['post_length']
                prev_length += batched_data['prev_length']

                kg += batched_data['kg']
                kg_h_length += batched_data['kg_h_length']
                kg_hr_length += batched_data['kg_hr_length']
                kg_hrt_length += batched_data['kg_hrt_length']
                kg_index += batched_data['kg_index']

            batched_data['post'] = post
            batched_data['post_length'] = post_length
            batched_data['prev_length'] = prev_length

            batched_data['kg'] = kg
            batched_data['kg_h_length'] = kg_h_length
            batched_data['kg_hr_length'] = kg_hr_length
            batched_data['kg_hrt_length'] = kg_hrt_length
            batched_data['kg_index'] = kg_index

            _, _, loss, _, _ = self.inference(sess,
                                              batched_data,
                                              lamb=args.lamb)
            loss_record.append(loss)
            cnt += 1
            batched_data = data.get_next_batch("test")

        assert cnt == len(test_distractors)

        loss = np.array(loss_record)
        loss_rank = np.argsort(loss, axis=1)
        hits1 = float(np.mean(loss_rank[:, 0] == 0))
        hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0))
        hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0))
        return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5}

    def test_process(self, sess, data, args):

        metric1 = data.get_teacher_forcing_metric()
        metric2 = data.get_inference_metric()
        data.restart("test", batch_size=args.batch_size, shuffle=False)
        batched_data = data.get_next_batch("test")

        for i in range(2):
            logger.info(
                f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}"
            )
            logger.info(
                f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}"
            )

        while batched_data != None:
            batched_responses_id, gen_log_prob, _, _, _ = self.inference(
                sess, batched_data, lamb=args.lamb)
            metric1_data = {
                'resp_allvocabs': np.array(batched_data['resp_allvocabs']),
                'resp_length': np.array(batched_data['resp_length']),
                'gen_log_prob': np.array(gen_log_prob)
            }
            metric1.forward(metric1_data)
            batch_results = []
            for response_id in batched_responses_id:
                response_id_list = response_id.tolist()
                if data.eos_id in response_id_list:
                    result_id = response_id_list[:response_id_list.
                                                 index(data.eos_id) + 1]
                else:
                    result_id = response_id_list
                batch_results.append(result_id)

            metric2_data = {
                'gen': np.array(batch_results),
                'resp_allvocabs': np.array(batched_data['resp_allvocabs'])
            }
            metric2.forward(metric2_data)
            batched_data = data.get_next_batch("test")

        res = metric1.close()
        res.update(metric2.close())
        res.update(self.test_process_hits(sess, data, args))

        test_file = args.output_dir + "/%s_%s.txt" % (args.name, "test")
        with open(test_file, 'w', encoding="utf-8") as f:
            print("Test Result:")
            res_print = list(res.items())
            res_print.sort(key=lambda x: x[0])
            for key, value in res_print:
                if isinstance(value, float):
                    print("\t%s:\t%f" % (key, value))
                    f.write("%s:\t%f\n" % (key, value))
            f.write('\n')
            for i in range(len(res['resp'])):
                f.write("resp:\t%s\n" % " ".join(res['resp'][i]))
                f.write("gen:\t%s\n\n" % " ".join(res['gen'][i]))

        logger.info("result output to %s." % test_file)
        return {
            key: val
            for key, val in res.items() if type(val) in [bytes, int, float]
        }