def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup(scalar=["loss", "perplexity", "kl_loss", "kld", "kl_weight"], prefix="train") scalarlist = ["loss", "perplexity", "kl_loss", "kld", "kl_weight"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="test")
def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup(scalar=["neg_elbo", "recontruction_loss", "KL_weight", "KL_divergence", "bow_loss", "perplexity"], prefix="train") scalarlist = ["neg_elbo", "reconstruction_loss", "KL_weight", "KL_divergence", "bow_loss", "perplexity"] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test")
def create_summary(args): summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) gen_trainSummary = summaryHelper.addGroup(scalar=["loss", "rewards"], prefix="gen_train") dis_trainSummary = summaryHelper.addGroup(scalar=["loss", "accuracy"], prefix="dis_train") gen_scalarlist = ["loss", "rewards"] dis_scalarlist = ["loss", "accuracy"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) gen_devSummary = summaryHelper.addGroup(scalar=gen_scalarlist, tensor=tensorlist, text=textlist, prefix="gen_dev") gen_testSummary = summaryHelper.addGroup(scalar=gen_scalarlist, tensor=tensorlist, text=textlist, prefix="gen_test") dis_devSummary = summaryHelper.addGroup(scalar=dis_scalarlist, tensor=tensorlist, text=textlist, prefix="dis_dev") dis_testSummary = summaryHelper.addGroup(scalar=dis_scalarlist, tensor=tensorlist, text=textlist, prefix="dis_test") return gen_trainSummary, gen_devSummary, gen_testSummary, dis_trainSummary, dis_devSummary, dis_testSummary
class VAEModel(object): def __init__(self, data, args, embed): with tf.variable_scope("input"): with tf.variable_scope("embedding"): # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.sentence = tf.placeholder(tf.int32, (None, None), 'sen_inps') # batch*len self.sentence_length = tf.placeholder(tf.int32, (None, ), 'sen_lens') # batch self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") batch_size, batch_len = tf.shape(self.sentence)[0], tf.shape( self.sentence)[1] self.decoder_max_len = batch_len - 1 self.encoder_input = tf.nn.embedding_lookup( self.embed, self.sentence) # batch*len*unit self.encoder_len = self.sentence_length decoder_input = tf.split(self.sentence, [self.decoder_max_len, 1], 1)[0] # no eos_id self.decoder_input = tf.nn.embedding_lookup( self.embed, decoder_input) # batch*(len-1)*unit self.decoder_target = tf.split(self.sentence, [1, self.decoder_max_len], 1)[1] # no go_id, batch*(len-1) self.decoder_len = self.sentence_length - 1 self.decoder_mask = tf.sequence_mask( self.decoder_len, self.decoder_max_len, dtype=tf.float32) # batch*(len-1) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): encoder_output, encoder_state = dynamic_rnn(cell_enc, self.encoder_input, self.encoder_len, dtype=tf.float32, scope="encoder_rnn") with tf.variable_scope('recognition_net'): recog_input = encoder_state self.recog_mu = tf.layers.dense(inputs=recog_input, units=args.z_dim, activation=None, name='recog_mu') self.recog_logvar = tf.layers.dense(inputs=recog_input, units=args.z_dim, activation=None, name='recog_logvar') epsilon = tf.random_normal(tf.shape(self.recog_logvar), name="epsilon") std = tf.exp(0.5 * self.recog_logvar) self.recog_z = tf.add(self.recog_mu, tf.multiply(std, epsilon), name='recog_z') self.kld = tf.reduce_mean(0.5 * tf.reduce_sum( tf.exp(self.recog_logvar) + self.recog_mu * self.recog_mu - self.recog_logvar - 1, axis=-1)) self.prior_z = tf.random_normal(tf.shape(self.recog_logvar), name="prior_z") latent_sample = tf.cond(self.use_prior, lambda: self.prior_z, lambda: self.recog_z, name='latent_sample') dec_init_state = tf.layers.dense(inputs=latent_sample, units=args.dh_size, activation=None) with tf.variable_scope("output_layer", initializer=tf.orthogonal_initializer()): self.output_layer = Dense( data.vocab_size, kernel_initializer=tf.truncated_normal_initializer(stddev=0.1), use_bias=True) with tf.variable_scope("decode", initializer=tf.orthogonal_initializer()): train_helper = tf.contrib.seq2seq.TrainingHelper( inputs=self.decoder_input, sequence_length=self.decoder_len) train_decoder = tf.contrib.seq2seq.BasicDecoder( cell=cell_dec, helper=train_helper, initial_state=dec_init_state, output_layer=self.output_layer) train_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=train_decoder, maximum_iterations=self.decoder_max_len, impute_finished=True) logits = train_output.rnn_output crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=self.decoder_target, logits=logits) crossent = tf.reduce_sum(crossent * self.decoder_mask) self.sen_loss = crossent / tf.to_float(batch_size) self.ppl_loss = crossent / tf.reduce_sum(self.decoder_mask) self.decoder_distribution_teacher = tf.nn.log_softmax(logits) with tf.variable_scope("decode", reuse=True): infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.go_id), data.eos_id) infer_decoder = tf.contrib.seq2seq.BasicDecoder( cell=cell_dec, helper=infer_helper, initial_state=dec_init_state, output_layer=self.output_layer) infer_output, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder=infer_decoder, maximum_iterations=self.decoder_max_len, impute_finished=True) self.decoder_distribution = infer_output.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK self.kl_weights = tf.minimum( tf.to_float(self.global_step) / args.full_kl_step, 1.0) self.kl_loss = self.kl_weights * tf.maximum(self.kld, args.min_kl) self.loss = self.sen_loss + self.kl_loss # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=args.momentum) gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step) else: self.best_saver.save(sess, path, global_step=self.global_step) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity", "kl_loss", "kld", "kl_weight"], prefix="train") scalarlist = ["loss", "perplexity", "kl_loss", "kld", "kl_weight"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, forward_only=False): input_feed = { self.sentence: data['sent'], self.sentence_length: data['sent_length'], self.use_prior: False } if forward_only: output_feed = [ self.loss, self.decoder_distribution_teacher, self.ppl_loss, self.kl_loss, self.kld, self.kl_weights ] else: output_feed = [ self.loss, self.gradient_norm, self.update, self.ppl_loss, self.kl_loss, self.kld, self.kl_weights ] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.sentence: data['sent'], self.sentence_length: data['sent_length'], self.use_prior: True } output_feed = [self.generation_index] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss_step = np.zeros((1, )) ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0 times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_decoder(sess, batched_data, forward_only=True) loss_step += outputs[0] ppl_loss_step += outputs[-4] kl_loss_step += outputs[-3] kld_step += outputs[-2] kl_weight_step = outputs[-1] times += 1 batched_data = data.get_next_batch(key_name) loss_step /= times ppl_loss_step /= times kl_loss_step /= times kld_step /= times print(' loss: %.2f' % loss_step) print(' kl_loss: %.2f' % kl_loss_step) print(' perplexity: %.2f' % np.exp(ppl_loss_step)) print(' kld: %.2f' % kld_step) return loss_step, ppl_loss_step, kl_loss_step, kld_step, kl_weight_step def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0 previous_losses = [1e18] * 5 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: print( "Epoch %d global step %d learning rate %.4f step-time %.2f" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step)) print(' loss: %.2f' % loss_step) print(' kl_loss: %.2f' % kl_loss_step) print(' perplexity: %.2f' % np.exp(ppl_loss_step)) print(' kld: %.2f' % kld_step) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(ppl_loss_step), 'kl_loss': kl_loss_step, 'kld': kld_step, 'kl_weight': kl_weight_step }) #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\ self.store_checkpoint( sess, '%s/checkpoint_latest/checkpoint' % args.model_dir, "latest") devout = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': devout[0], 'perplexity': np.exp(devout[1]), 'kl_loss': devout[2], 'kld': devout[3], 'kl_weight': devout[4] }) testout = self.evaluate(sess, data, args.batch_size, "test") self.testSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': testout[0], 'perplexity': np.exp(testout[1]), 'kl_loss': testout[2], 'kld': testout[3], 'kl_weight': testout[4] }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if devout[0] < best_valid: best_valid = devout[0] self.store_checkpoint( sess, '%s/checkpoint_best/checkpoint' % args.model_dir, "best") previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 ppl_loss_step, kl_loss_step, kld_step, kl_weight_step = 0, 0, 0, 0 start_time = time.time() outputs = self.step_decoder(sess, batched_data) loss_step += outputs[0] / args.checkpoint_steps ppl_loss_step += outputs[-4] / args.checkpoint_steps kl_loss_step += outputs[-3] / args.checkpoint_steps kld_step += outputs[-2] / args.checkpoint_steps kl_weight_step = outputs[-1] time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") results = [] while batched_data != None: batched_responses_id = self.inference(sess, batched_data)[0] gen_log_prob = self.step_decoder(sess, batched_data, forward_only=True)[1] metric1_data = { 'sent_allvocabs': np.array(batched_data['sent_allvocabs']), 'sent_length': np.array(batched_data['sent_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: result_token = [] response_id_list = response_id.tolist() response_token = data.index_to_sen(response_id_list) if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list for token in response_token: if token != data.ext_vocab[data.eos_id]: result_token.append(token) else: break results.append(result_token) batch_results.append(result_id) metric2_data = {'gen': np.array(batch_results)} metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") for key, value in res.items(): if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) for i in range(len(res['gen'])): f.write("%s\n" % " ".join(res['gen'][i])) print("result output to %s." % test_file)
class HredModel(object): def __init__(self, data, args, embed): self.init_states = tf.placeholder(tf.float32, (None, args.ch_size), 'ctx_inps') # batch*ch_size self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.prev_posts = tf.placeholder(tf.int32, (None, None), 'enc_prev_inps') self.prev_posts_length = tf.placeholder(tf.int32, (None, ), 'enc_prev_lens') self.kgs = tf.placeholder(tf.int32, (None, None, None), 'kg_inps') # batch*len self.kgs_h_length = tf.placeholder(tf.int32, (None, None), 'kg_h_lens') # batch self.kgs_hr_length = tf.placeholder(tf.int32, (None, None), 'kg_hr_lens') # batch self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None), 'kg_hrt_lens') # batch self.kgs_index = tf.placeholder(tf.float32, (None, None), 'kg_indices') # batch self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens') self.is_train = tf.placeholder(tf.bool) num_past_turns = tf.shape(self.posts)[0] // tf.shape( self.origin_responses)[0] # deal with original data to adapt encoder and decoder batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 self.posts_input = self.posts # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) kg_len = tf.shape(self.kgs)[2] kg_h_mask = tf.reshape( tf.cumsum(tf.one_hot(self.kgs_h_length - 1, kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) kg_hr_mask = tf.reshape( tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) kg_hrt_mask = tf.reshape( tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) kg_key_mask = kg_hr_mask kg_value_mask = kg_hrt_mask - kg_hr_mask # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs) #self.knowledge_max = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, -mask_value), axis=1) #self.knowledge_min = tf.reduce_max(tf.where(tf.cast(tf.tile(knowledge_mask, [1, 1, args.embedding_size]), tf.bool), self.knowledge_input, mask_value), axis=1) self.kg_key_avg = tf.reduce_sum( self.kg_input * kg_key_mask, axis=2) / tf.maximum( tf.reduce_sum(kg_key_mask, axis=2), tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1), dtype=tf.float32)) self.kg_value_avg = tf.reduce_sum( self.kg_input * kg_value_mask, axis=2) / tf.maximum( tf.reduce_sum(kg_value_mask, axis=2), tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1), dtype=tf.float32)) #self.encoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) # batch*len*unit #self.decoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.responses_input)) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): encoder_output, encoder_state = dynamic_rnn(cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE): prev_output, _ = dynamic_rnn(cell_enc, tf.nn.embedding_lookup( self.embed, self.prev_posts), self.prev_posts_length, dtype=tf.float32, scope="encoder_rnn") with tf.variable_scope('context'): encoder_state_reshape = tf.reshape( encoder_state, [-1, num_past_turns, args.eh_size]) _, self.context_state = dynamic_rnn(cell_ctx, encoder_state_reshape, self.context_length, dtype=tf.float32, scope='context_rnn') # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) # construct attention ''' encoder_len = tf.shape(encoder_output)[1] attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size]) attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1]) attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:,1:]], axis=1) attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask) ''' attn_mechanism = tf.contrib.seq2seq.BahdanauAttention( args.dh_size, prev_output, memory_sequence_length=tf.maximum(self.prev_posts_length, 1)) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) ctx_state_shaping = tf.layers.dense(self.context_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping) # calculate kg embedding with tf.variable_scope('knowledge'): query = tf.reshape( tf.layers.dense(tf.concat(self.context_state, axis=-1), args.embedding_size, use_bias=False), [batch_size, 1, args.embedding_size]) kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2) kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score, -tf.ones_like(kg_score) * np.inf) kg_alignment = tf.nn.softmax(kg_score) kg_max = tf.argmax(kg_alignment, axis=-1) kg_max_onehot = tf.one_hot(kg_max, tf.shape(kg_alignment)[1], dtype=tf.float32) self.kg_acc = tf.reduce_sum( kg_max_onehot * self.kgs_index) / tf.maximum( tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)), tf.constant(1.0)) self.kg_loss = tf.reduce_sum( -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) * self.kgs_index, axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1), tf.ones([batch_size], dtype=tf.float32)) self.kg_loss = tf.reduce_mean(self.kg_loss) self.knowledge_embed = tf.reduce_sum( tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg * tf.cast(kg_num_mask, tf.float32), axis=1) #self.knowledge_embed = tf.Print(self.knowledge_embed, ['acc', self.kg_acc, 'loss', self.kg_loss]) knowledge_embed_extend = tf.tile( tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1]) self.decoder_input = tf.concat( [self.decoder_input, knowledge_embed_extend], axis=2) # construct helper train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, tf.maximum(self.responses_length, 1)) infer_helper = MyInferenceHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id, self.knowledge_embed) #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") self.decoder_output = train_outputs.rnn_output #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8) self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \ sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_sent_length, scope="decoder_rnn") self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) self.loss = self.decoder_loss + self.kg_loss gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key, name): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step, latest_filename=name) else: self.best_saver.save(sess, path, global_step=self.global_step, latest_filename=name) def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, sess, data, forward_only=False, inference=False): input_feed = { #self.init_states: data['init_states'], self.posts: data['posts'], self.posts_length: data['posts_length'], self.prev_posts: data['prev_posts'], self.prev_posts_length: data['prev_posts_length'], self.origin_responses: data['responses'], self.origin_responses_length: data['responses_length'], self.context_length: data['context_length'], self.kgs: data['kg'], self.kgs_h_length: data['kg_h_length'], self.kgs_hr_length: data['kg_hr_length'], self.kgs_hrt_length: data['kg_hrt_length'], self.kgs_index: data['kg_index'], } if inference: input_feed.update({self.is_train: False}) output_feed = [ self.generation_index, self.decoder_distribution_teacher, self.decoder_all_loss, self.kg_loss, self.kg_acc ] else: input_feed.update({self.is_train: True}) if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher, self.kg_loss, self.kg_acc ] else: output_feed = [ self.decoder_loss, self.gradient_norm, self.update, self.kg_loss, self.kg_acc ] return sess.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss = np.zeros((3, )) total_length = np.zeros((3, )) data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: decoder_loss, _, kg_loss, kg_acc = self.step_decoder( sess, batched_data, forward_only=True) length = np.sum( np.maximum(np.array(batched_data['responses_length']) - 1, 0)) kg_length = np.sum(np.max(batched_data['kg_index'], axis=-1)) total_length += [length, kg_length, kg_length] loss += [ decoder_loss * length, kg_loss * kg_length, kg_acc * kg_length ] batched_data = data.get_next_batch(key_name) loss /= total_length print( ' perplexity on %s set: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f' % (key_name, np.exp(loss[0]), np.exp(loss[1]), loss[1], loss[2])) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((3, )), .0, 0 previous_losses = [1e18] * 3 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: print( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, np.exp(loss_step[0]), np.exp( loss_step[1]), loss_step[1], loss_step[2])) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step[0], 'perplexity': np.exp(loss_step[0]) }) self.store_checkpoint( sess, '%s/checkpoint_latest/%s' % (args.model_dir, args.name), "latest", args.name) dev_loss = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss[0], 'perplexity': np.exp(dev_loss[0]) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss[0] < best_valid: best_valid = dev_loss[0] self.store_checkpoint( sess, '%s/checkpoint_best/%s' % (args.model_dir, args.name), "best", args.name) previous_losses = previous_losses[1:] + [ np.sum(loss_step[0]) ] loss_step, time_step = np.zeros((3, )), .0 start_time = time.time() step_out = self.step_decoder(sess, batched_data) loss_step += np.array([step_out[0], step_out[3], step_out[4] ]) / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process_hits(self, sess, data, args): with open(os.path.join(args.datapath, 'test_distractors.json'), 'r', encoding='utf8') as f: test_distractors = json.load(f) data.restart("test", batch_size=1, shuffle=False) batched_data = data.get_next_batch("test") loss_record = [] cnt = 0 while batched_data != None: for key in batched_data: if isinstance(batched_data[key], np.ndarray): batched_data[key] = batched_data[key].tolist() batched_data['responses_length'] = [ len(batched_data['responses'][0]) ] for each_resp in test_distractors[cnt]: batched_data['responses'].append( [data.go_id] + data.convert_tokens_to_ids(jieba.lcut(each_resp)) + [data.eos_id]) batched_data['responses_length'].append( len(batched_data['responses'][-1])) max_length = max(batched_data['responses_length']) resp = np.zeros((len(batched_data['responses']), max_length), dtype=int) for i, each_resp in enumerate(batched_data['responses']): resp[i, :len(each_resp)] = each_resp batched_data['responses'] = resp posts = [] posts_length = [] prev_posts = [] prev_posts_length = [] context_length = [] kg = [] kg_h_length = [] kg_hr_length = [] kg_hrt_length = [] kg_index = [] for _ in range(len(resp)): posts += batched_data['posts'] posts_length += batched_data['posts_length'] prev_posts += batched_data['prev_posts'] prev_posts_length += batched_data['prev_posts_length'] context_length += batched_data['context_length'] kg += batched_data['kg'] kg_h_length += batched_data['kg_h_length'] kg_hr_length += batched_data['kg_hr_length'] kg_hrt_length += batched_data['kg_hrt_length'] kg_index += batched_data['kg_index'] batched_data['posts'] = posts batched_data['posts_length'] = posts_length batched_data['prev_posts'] = prev_posts batched_data['prev_posts_length'] = prev_posts_length batched_data['context_length'] = context_length batched_data['kg'] = kg batched_data['kg_h_length'] = kg_h_length batched_data['kg_hr_length'] = kg_hr_length batched_data['kg_hrt_length'] = kg_hrt_length batched_data['kg_index'] = kg_index _, _, loss, _, _ = self.step_decoder(sess, batched_data, inference=True) loss_record.append(loss) cnt += 1 batched_data = data.get_next_batch("test") assert cnt == len(test_distractors) loss = np.array(loss_record) loss_rank = np.argsort(loss, axis=1) hits1 = float(np.mean(loss_rank[:, 0] == 0)) hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0)) return {'hits@1': hits1, 'hits@3': hits3} def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") while batched_data != None: batched_responses_id, gen_log_prob, _, _, _ = self.step_decoder( sess, batched_data, False, True) metric1_data = { 'resp_allvocabs': np.array(batched_data['responses_allvocabs']), 'resp_length': np.array(batched_data['responses_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list batch_results.append(result_id) metric2_data = { 'gen': np.array(batch_results), 'resp_allvocabs': np.array(batched_data['responses_allvocabs']) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) res.update(self.test_process_hits(sess, data, args)) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") res_print = list(res.items()) res_print.sort(key=lambda x: x[0]) for key, value in res_print: if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) f.write('\n') for i in range(len(res['resp'])): f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n\n" % " ".join(res['gen'][i])) print("result output to %s" % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }
class CVAEModel(object): def __init__(self, data, args, embed): with tf.name_scope("placeholders"): self.contexts = tf.placeholder(tf.int32, (None, None, None), 'cxt_inps') # [batch, utt_len, len] self.contexts_length = tf.placeholder(tf.int32, (None, ), 'cxt_lens') # [batch] self.posts_length = tf.placeholder(tf.int32, (None, None), 'enc_lens') # [batch, utt_len] self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # [batch, len] self.origin_responses_length = tf.placeholder( tf.int32, (None, ), 'dec_lens') # [batch] # deal with original data to adapt encoder and decoder max_sen_length = tf.shape(self.contexts)[2] max_cxt_size = tf.shape(self.contexts)[1] self.posts_input = tf.reshape( self.contexts, [-1, max_sen_length]) # [batch * cxt_len, utt_len] self.flat_posts_length = tf.reshape(self.posts_length, [-1]) # [batch * cxt_len] decoder_len = tf.shape(self.origin_responses)[1] self.responses_target = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_length = self.origin_responses_length - 1 decoder_len = decoder_len - 1 self.decoder_mask = tf.sequence_mask(self.responses_length, decoder_len, dtype=tf.float32) loss_mask = tf.cast(tf.greater(self.responses_length, 0), dtype=tf.float32) batch_size = tf.reduce_sum(loss_mask) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) with tf.name_scope("embedding"): # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.word_embed = tf.get_variable( 'word_embed', [data.vocab_size, args.word_embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.word_embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) posts_enc_input = tf.nn.embedding_lookup(self.word_embed, self.posts_input) responses_enc_input = tf.nn.embedding_lookup( self.word_embed, self.origin_responses) responses_dec_input = tf.nn.embedding_lookup( self.word_embed, self.responses_input) with tf.name_scope("cell"): # build rnn_cell cell_enc_fw = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_enc_bw = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): posts_enc_output, posts_enc_state = tf.nn.bidirectional_dynamic_rnn( cell_enc_fw, cell_enc_bw, posts_enc_input, self.flat_posts_length, dtype=tf.float32, scope="encoder_bi_rnn") posts_enc_state = tf.reshape(tf.concat(posts_enc_state, 1), [-1, max_cxt_size, 2 * args.eh_size]) with tf.variable_scope('context'): _, context_state = tf.nn.dynamic_rnn(cell_ctx, posts_enc_state, self.contexts_length, dtype=tf.float32, scope='context_rnn') cond_info = context_state with tf.variable_scope("recognition_network"): _, responses_enc_state = tf.nn.bidirectional_dynamic_rnn( cell_enc_fw, cell_enc_bw, responses_enc_input, self.responses_length, dtype=tf.float32, scope='encoder_bid_rnn') responses_enc_state = tf.concat(responses_enc_state, 1) recog_input = tf.concat((cond_info, responses_enc_state), 1) recog_output = tf.layers.dense(recog_input, 2 * args.latent_size) recog_mu, recog_logvar = tf.split(recog_output, 2, 1) recog_z = self.sample_gaussian( (tf.size(self.contexts_length), args.latent_size), recog_mu, recog_logvar) with tf.variable_scope("prior_network"): prior_input = cond_info prior_fc_1 = tf.layers.dense(prior_input, 2 * args.latent_size, activation=tf.tanh) prior_output = tf.layers.dense(prior_fc_1, 2 * args.latent_size) prior_mu, prior_logvar = tf.split(prior_output, 2, 1) prior_z = self.sample_gaussian( (tf.size(self.contexts_length), args.latent_size), prior_mu, prior_logvar) with tf.name_scope("decode"): # get output projection function dec_init_fn = MyDense(args.dh_size, use_bias=True) output_fn = MyDense(data.vocab_size, use_bias=True) with tf.name_scope("training"): decoder_input = responses_dec_input gen_input = tf.concat((cond_info, recog_z), 1) dec_init_fn_input = gen_input train_helper = tf.contrib.seq2seq.TrainingHelper( decoder_input, tf.maximum(self.responses_length, 0)) dec_init_state = dec_init_fn(dec_init_fn_input) decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec, train_helper, dec_init_state, output_fn) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") responses_dec_output = train_outputs.rnn_output self.decoder_distribution_teacher = tf.nn.log_softmax( responses_dec_output, 2) with tf.name_scope("losses"): crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=responses_dec_output, labels=self.responses_target) self.reconstruct_loss = tf.reduce_sum( crossent * self.decoder_mask) / batch_size self.KL_loss = tf.reduce_sum(\ loss_mask * self.KL_divergence(prior_mu, prior_logvar, recog_mu, recog_logvar)) / batch_size self.KL_weight = tf.minimum( 1.0, tf.to_float(self.global_step) / args.full_kl_step) self.anneal_KL_loss = self.KL_weight * self.KL_loss bow_logits = tf.layers.dense(tf.layers.dense(dec_init_fn_input, 400, activation=tf.tanh),\ data.vocab_size) tile_bow_logits = tf.tile(tf.expand_dims(bow_logits, 1), [1, decoder_len, 1]) bow_loss = self.decoder_mask * tf.nn.sparse_softmax_cross_entropy_with_logits(\ logits=tile_bow_logits,\ labels=self.responses_target) self.bow_loss = tf.reduce_sum(bow_loss) / batch_size self.neg_elbo = self.reconstruct_loss + self.KL_loss self.train_loss = self.reconstruct_loss + self.anneal_KL_loss + self.bow_loss with tf.name_scope("inference"): gen_input = tf.concat((cond_info, prior_z), 1) dec_init_fn_input = gen_input dec_init_state = dec_init_fn(dec_init_fn_input) infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.word_embed, tf.fill([tf.size(self.contexts_length)], data.go_id), data.eos_id) decoder_infer = MyBasicDecoder(cell_dec, infer_helper, dec_init_state, output_layer=output_fn, _aug_context_vector=None) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_sen_length, scope="decoder_rnn") self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.train_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def sample_gaussian(self, shape, mu, logvar): normal = tf.random_normal(shape=shape, dtype=tf.float32) z = tf.exp(logvar / 2) * normal + mu return z def KL_divergence(self, prior_mu, prior_logvar, recog_mu, recog_logvar): KL_divergence = 0.5 * ( tf.exp(recog_logvar - prior_logvar) + tf.pow(recog_mu - prior_mu, 2) / tf.exp(prior_logvar) - 1 - (recog_logvar - prior_logvar)) return tf.reduce_sum(KL_divergence, axis=1) def store_checkpoint(self, sess, path, key): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step) else: self.best_saver.save(sess, path, global_step=self.global_step) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup(scalar=[ "neg_elbo", "recontruction_loss", "KL_weight", "KL_divergence", "bow_loss", "perplexity" ], prefix="train") scalarlist = [ "neg_elbo", "reconstruction_loss", "KL_weight", "KL_divergence", "bow_loss", "perplexity" ] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def _pad_batch(self, raw_batch): '''Padding posts_length and trim session, invoked by ^SwitchboardCorpus.split_session^ and ^SwitchboardCorpus.multi_reference_batches^ ''' batch = {'posts_length': [], \ 'contexts_length': [], \ 'responses_length': raw_batch['responses_length']} max_cxt_size = np.shape(raw_batch['contexts'])[1] max_post_len = 0 for i, speaker in enumerate(raw_batch['posts_length']): batch['contexts_length'].append(len(raw_batch['posts_length'][i])) if raw_batch['posts_length'][i]: max_post_len = max(max_post_len, max(raw_batch['posts_length'][i])) batch['posts_length'].append(raw_batch['posts_length'][i] + \ [0] * (max_cxt_size - len(raw_batch['posts_length'][i]))) batch['contexts'] = raw_batch['contexts'][:, :, :max_post_len] batch['responses'] = raw_batch[ 'responses'][:, :np.max(raw_batch['responses_length'])] return batch def _cut_batch_data(self, batch_data, start, end): '''Using session[start: end - 1) as context, session[end - 1] as response, invoked by ^SwitchboardCorpus.split_session^ ''' raw_batch = {'posts_length': [], 'responses_length': []} for i in range(len(batch_data['turn_length'])): raw_batch['posts_length'].append( \ batch_data['sent_length'][i][start: end - 1]) turn_len = len(batch_data['sent_length'][i]) if end - 1 < turn_len: raw_batch['responses_length'].append( \ batch_data['sent_length'][i][end - 1]) else: raw_batch['responses_length'].append(1) raw_batch['contexts'] = batch_data['sent'][:, start:end - 1] raw_batch['responses'] = batch_data['sent'][:, end - 1] return self._pad_batch(raw_batch) def split_session(self, batch_data, session_window, inference=False): '''Splits session with different utterances serving as responses Arguments: batch_data (dict): must be the same format as the return of self.get_batch inference (bool): True: utterances take turn to serve as responses (without shuffle) False: shuffles the order of utterances being responses Returns: (list): each element is a dict that contains at least: * contexts (:class:`numpy.array`): A 3-d PADDED array containing id of words in contexts. Only provide valid words. `unk_id` will be used if a word is not valid. Size: `[batch_size, max_turn_length, max_utterance_length]` * contexts_length (list): A 1-d list, number of turns Size: ^[batch_size]^ * posts_length (list): A 2-d PADDED list, the length of utterances. Size: ^[batch_size, max_turn_length]^ * responses (:class:`numpy.array`): A 3-d PADDED array containing ids of words in responses. Size: ^[batch_size, max_response_length]^ * responses_length (list): A 1-d list, the length of responses. Size: ^[batch_size]^ ''' max_turn = np.max(batch_data['turn_length']) ends = list(range(2, max_turn + 1)) if not inference: np.random.shuffle(ends) for end in ends: start = max(end - session_window, 0) turn_data = self._cut_batch_data(batch_data, start, end) yield turn_data def multi_reference_batches(self, data, batch_size): '''Get batches of with multiple response candidates Arguments: * batch_size (int): batch size Returns: (list): each element contains those specified in self.split_session and what follows: * candidates (list): A 3-d list of response candidates. Size: [batch_size, _num_candidates, _num_words] ''' data.restart('multi_ref', batch_size, shuffle=False) batch_data = data.get_next_batch('multi_ref') while batch_data is not None: batch = self._cut_batch_data(batch_data,\ 0, np.max(batch_data['turn_length'])) batch['candidate'] = batch_data['candidate'] yield batch batch_data = data.get_next_batch('multi_ref') def step_decoder(self, sess, data, forward_only=False, inference=False): input_feed = { self.contexts: data['contexts'], self.contexts_length: data['contexts_length'], self.posts_length: data['posts_length'], self.origin_responses: data['responses'], self.origin_responses_length: data['responses_length'], } if inference: output_feed = [self.generation_index] else: if forward_only: output_feed = [ self.neg_elbo, self.reconstruct_loss, self.KL_weight, self.KL_loss, self.bow_loss, self.decoder_distribution_teacher ] else: output_feed = [ self.neg_elbo, self.reconstruct_loss, self.KL_weight, self.KL_loss, self.bow_loss, self.update ] return sess.run(output_feed, input_feed) def train_step(self, sess, data, args): output = self.step_decoder(sess, data) res = output[:5] log_ppl = output[1] * np.sum(np.array(data['responses_length'], dtype=np.int32) > 1) / np.sum(\ np.maximum(np.array(data['responses_length'], dtype=np.int32) - 1, 0)) res += [log_ppl] return res def evaluate(self, sess, data, key_name, args): loss_list = np.array([.0] * 6, dtype=np.float32) total_length = 0 total_inst = 0 data.restart(key_name, batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: for cut_batch_data in self.split_session(batched_data, args.session_window): output = self.step_decoder(sess, cut_batch_data, forward_only=True) batch_size = np.sum( np.array(cut_batch_data['responses_length'], dtype=np.int32) > 1) loss_list[:-1] += np.array(output[:5], dtype=np.float32) * batch_size loss_list[-1] += output[1] * batch_size total_length += np.sum( np.maximum( np.array(cut_batch_data['responses_length'], dtype=np.int32) - 1, 0)) total_inst += batch_size batched_data = data.get_next_batch(key_name) loss_list[:-1] /= total_inst loss_list[-1] = np.exp(loss_list[-1] / total_length) print(' perplexity on %s set: %.2f' % (key_name, loss_list[-1])) return loss_list def train_process(self, sess, data, args): time_step, epoch_step = .0, 0 loss_names = [ 'neg_elbo', 'reconstuction_loss', 'KL_weight', 'KL_divergence', 'bow_loss', 'perplexity' ] loss_list = np.array([.0] * len(loss_names), dtype=np.float32) previous_losses = [1e18] * 5 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: for cut_batch_data in self.split_session( batched_data, args.session_window): start_time = time.time() output = self.train_step(sess, cut_batch_data, args) loss_list += np.array(output[:len(loss_list)], dtype=np.float32) time_step += time.time() - start_time if (self.global_step.eval() + 1) % args.checkpoint_steps == 0: loss_list /= args.checkpoint_steps loss_list[-1] = np.exp(loss_list[-1]) time_step /= args.checkpoint_steps print( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, loss_list[-1])) self.trainSummary(self.global_step.eval() // args.checkpoint_steps, \ dict(zip(loss_names, loss_list))) self.store_checkpoint( sess, '%s/checkpoint_latest/checkpoint' % args.model_dir, "latest") dev_loss = self.evaluate(sess, data, "dev", args) self.devSummary(self.global_step.eval() // args.checkpoint_steps,\ dict(zip(loss_names, dev_loss))) test_loss = self.evaluate(sess, data, "test", args) self.testSummary(self.global_step.eval() // args.checkpoint_steps,\ dict(zip(loss_names, test_loss))) if loss_list[0] > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss[0] < best_valid: best_valid = dev_loss[0] self.store_checkpoint( sess, '%s/checkpoint_best/checkpoint' % args.model_dir, "best") previous_losses = previous_losses[1:] + [loss_list[0]] loss_list *= .0 time_step = .0 batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process(self, sess, data, args): def get_batch_results(batched_responses_id, data): batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() if data.eos_id in response_id_list: end = response_id_list.index(data.eos_id) + 1 result_id = response_id_list[:end] else: result_id = response_id_list batch_results.append(result_id) return batch_results def padding(matrix, pad_go_id=False): l = max([len(d) for d in matrix]) if not pad_go_id: res = [[d + [data.pad_id] * (l - len(d)) for d in matrix]] else: res = [[[data.go_id] + d + [data.pad_id] * (l - len(d)) for d in matrix]] return res metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") cnt = 0 start_time = time.time() while batched_data != None: conv_data = [{ 'contexts': [], 'responses': [], 'generations': [] } for _ in range(len(batched_data['turn_length']))] for cut_batch_data in self.split_session(batched_data, args.session_window, inference=True): eval_out = self.step_decoder(sess, cut_batch_data, forward_only=True) decoder_loss, gen_prob = eval_out[:6], eval_out[-1] batched_responses_id = self.step_decoder(sess, cut_batch_data, inference=True)[0] batch_results = get_batch_results(batched_responses_id, data) cut_batch_data['gen_prob'] = gen_prob cut_batch_data['generations'] = batch_results responses_length = [] for length in cut_batch_data['responses_length']: if length == 1: length += 1 responses_length.append(length) metric1_data = { 'sent': np.expand_dims(cut_batch_data['responses'], 1), 'sent_length': np.expand_dims(responses_length, 1), 'gen_prob': np.expand_dims(cut_batch_data['gen_prob'], 1) } metric1.forward(metric1_data) valid_index = [ idx for idx, length in enumerate( cut_batch_data['responses_length']) if length > 1 ] for key in ['contexts', 'responses', 'generations']: for idx, d in enumerate(cut_batch_data[key]): if idx in valid_index: if key == 'contexts': d = d[cut_batch_data['contexts_length'][idx] - 1] conv_data[idx][key].append(list(d)) if (cnt + 1) % 10 == 0: print('processing %d batch data, time cost %.2f s/batch' % (cnt, (time.time() - start_time) / 10)) start_time = time.time() cnt += 1 for conv in conv_data: metric2_data = { 'context': np.array(padding(conv['contexts'])), 'turn_length': np.array([len(conv['contexts'])], dtype=np.int32), 'reference': np.array(padding(conv['responses'])), 'gen': np.array(padding(conv['generations'], pad_go_id=True)) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'a') as f: print("Test Result:") for key, value in res.items(): if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) for i, context in enumerate(res['context']): f.write("session: \t%d\n" % i) for j in range(len(context)): f.write("\tpost:\t%s\n" % " ".join(context[j])) f.write("\tresp:\t%s\n" % " ".join(res['reference'][i][j])) f.write("\tgen:\t%s\n\n" % " ".join(res['gen'][i][j])) f.write("\n") print("result output to %s" % test_file) def test_multi_ref(self, sess, data, embed, args): def process_cands(candidates): res = [] for cands in candidates: tmp = [] for sent in cands: tmp.append([ wid if wid < data.vocab_size else data.unk_id for wid in sent ]) res.append(tmp) return res prec_rec_metrics = data.get_precision_recall_metric(embed) for batch_data in self.multi_reference_batches(data, args.batch_size): responses = [] for _ in range(args.repeat_N): batched_responses_id = self.step_decoder(sess, batch_data, inference=True)[0] for rid, resp in enumerate(batched_responses_id): resp = list(resp) if rid == len(responses): responses.append([]) if data.eos_id in resp: resp = resp[:resp.index(data.eos_id)] if len(resp) > 0: responses[rid].append(resp) metric_data = { 'resp': process_cands(batch_data['candidate']), 'gen': responses } prec_rec_metrics.forward(metric_data) res = prec_rec_metrics.close() test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Multi Reference Result:") f.write("Test Multi Reference Result:\n") for key, val in res.items(): print("\t{}\t{}".format(key, val)) f.write("\t{}\t{}".format(key, val) + "\n") f.write("\n")
class LM(object): def __init__(self, data, args, embed): self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch self.is_train = tf.placeholder(tf.bool) # deal with original data to adapt encoder and decoder batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] self.responses_target = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] self.responses_length = self.origin_responses_length - 1 decoder_len = decoder_len - 1 self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) #self.decoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.responses_input)) # build rnn_cell cell = tf.nn.rnn_cell.GRUCell(args.eh_size) # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) # build encoder with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): _, self.encoder_state = dynamic_rnn(cell, self.encoder_input, self.posts_length, dtype=tf.float32, scope="decoder_rnn") # construct helper and attention infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.eos_id), data.eos_id) dec_start = tf.cond( self.is_train, lambda: tf.zeros([batch_size, args.dh_size], dtype=tf.float32), lambda: self.encoder_state) # build decoder (train) with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): self.decoder_output, _ = dynamic_rnn( cell, self.decoder_input, self.responses_length, dtype=tf.float32, initial_state=self.encoder_state, scope='decoder_rnn') #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8) self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \ sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_sent_length, scope="decoder_rnn") self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key, name): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step, latest_filename=name) else: self.best_saver.save(sess, path, global_step=self.global_step, latest_filename=name) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, forward_only=False): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.is_train: True } if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher, self.decoder_output ] else: output_feed = [self.decoder_loss, self.gradient_norm, self.update] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.is_train: False } output_feed = [ self.generation_index, self.decoder_distribution_teacher, self.decoder_all_loss ] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss = np.zeros((1, )) times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_decoder(sess, batched_data, forward_only=True) loss += outputs[0] times += 1 batched_data = data.get_next_batch(key_name) loss /= times print(' perplexity on %s set: %.2f' % (key_name, np.exp(loss))) print(loss) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 previous_losses = [1e18] * 3 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for i in range(2): print( data.convert_ids_to_tokens( batched_data['post_allvocabs'][i].tolist(), trim=False)) print( data.convert_ids_to_tokens( batched_data['resp_allvocabs'][i].tolist(), trim=False)) for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: show = lambda a: '[%s]' % (' '.join( ['%.2f' % x for x in a])) print( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, show(np.exp(loss_step)))) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(loss_step) }) self.store_checkpoint( sess, '%s/checkpoint_latest/%s' % (args.model_dir, args.name), "latest", args.name) dev_loss = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss, 'perplexity': np.exp(dev_loss) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss < best_valid: best_valid = dev_loss self.store_checkpoint( sess, '%s/checkpoint_best/%s' % (args.model_dir, args.name), "best", args.name) previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 start_time = time.time() loss_step += self.step_decoder( sess, batched_data)[0] / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process_hits(self, sess, data, args): with open(os.path.join(args.datapath, 'test_distractors.json'), 'r') as f: test_distractors = json.load(f) data.restart("test", batch_size=1, shuffle=False) batched_data = data.get_next_batch("test") loss_record = [] cnt = 0 while batched_data != None: batched_data['resp_length'] = [len(batched_data['resp'][0])] batched_data['resp'] = batched_data['resp'].tolist() for each_resp in test_distractors[cnt]: batched_data['resp'].append( [data.eos_id] + data.convert_tokens_to_ids(jieba.lcut(each_resp)) + [data.eos_id]) batched_data['resp_length'].append( len(batched_data['resp'][-1])) max_length = max(batched_data['resp_length']) resp = np.zeros((len(batched_data['resp']), max_length), dtype=int) for i, each_resp in enumerate(batched_data['resp']): resp[i, :len(each_resp)] = each_resp batched_data['resp'] = resp post = [] post_length = [] for _ in range(len(resp)): post = post + batched_data['post'].tolist() post_length = post_length + batched_data['post_length'].tolist( ) batched_data['post'] = post batched_data['post_length'] = post_length _, _, loss = self.inference(sess, batched_data) loss_record.append(loss) cnt += 1 batched_data = data.get_next_batch("test") assert cnt == len(test_distractors) loss = np.array(loss_record) loss_rank = np.argsort(loss, axis=1) hits1 = float(np.mean(loss_rank[:, 0] == 0)) hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0)) return {'hits@1': hits1, 'hits@3': hits3} def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size * 4, shuffle=False) batched_data = data.get_next_batch("test") for i in range(3): print( ' post %d ' % i, data.convert_ids_to_tokens( batched_data['post_allvocabs'][i].tolist(), trim=False)) print( ' resp %d ' % i, data.convert_ids_to_tokens( batched_data['resp_allvocabs'][i].tolist(), trim=False)) results = [] while batched_data != None: batched_responses_id, gen_log_prob, _ = self.inference( sess, batched_data) metric1_data = { 'resp_allvocabs': np.array(batched_data['resp_allvocabs']), 'resp_length': np.array(batched_data['resp_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: result_token = [] response_id_list = response_id.tolist() response_token = data.convert_ids_to_tokens(response_id_list) if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list for token in response_token: if token != data.ext_vocab[data.eos_id]: result_token.append(token) else: break results.append(result_token) batch_results.append(result_id) metric2_data = { 'gen': np.array(batch_results), 'post_allvocabs': np.array(batched_data['post_allvocabs']), 'resp_allvocabs': np.array(batched_data['resp_allvocabs']), } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) res.update(self.test_process_hits(sess, data, args)) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") res_print = list(res.items()) res_print.sort(key=lambda x: x[0]) for key, value in res_print: if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) f.write('\n') for i in range(len(res['post'])): f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n\n" % " ".join(res['gen'][i])) print("result output to %s." % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }
class Classification(BaseModel): def __init__(self, param): args = param.args net = Network(param) self.optimizer = optim.Adam(net.get_parameters_by_name(), lr=args.lr) optimizerList = {"optimizer": self.optimizer} checkpoint_manager = CheckpointManager(args.name, args.model_dir, \ args.checkpoint_steps, args.checkpoint_max_to_keep, "min") super().__init__(param, net, optimizerList, checkpoint_manager) self.create_summary() def create_summary(self): args = self.param.args self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \ args) self.trainSummary = self.summaryHelper.addGroup(\ scalar=["loss", "accuracy_on_batch"],\ prefix="train") scalarlist = ["loss", "accuracy"] tensorlist = [] textlist = [] emblist = [] for i in self.args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="dev") self.testSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="test") def _preprocess_batch(self, data): incoming = Storage() incoming.data = data = Storage(data) data.batch_size = data.sent.shape[0] data.sent = cuda(torch.LongTensor(data.sent.transpose( 1, 0))) # length * batch_size data.label = cuda(torch.LongTensor(data.label)) return incoming def get_next_batch(self, dm, key, restart=True): data = dm.get_next_batch(key) if data is None: if restart: dm.restart(key) return self.get_next_batch(dm, key, False) else: return None return self._preprocess_batch(data) def get_batches(self, dm, key): batches = list( dm.get_batches(key, batch_size=self.args.batch_size, shuffle=False)) return len(batches), (self._preprocess_batch(data) for data in batches) def get_select_batch(self, dm, key, i): data = dm.get_batch(key, i) if data is None: return None return self._preprocess_batch(data) def train(self, batch_num): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss accuracy = np.mean( (incoming.result.label == incoming.result.prediction ).float().detach().cpu().numpy()) detail_arr = storage_to_list(incoming.result) detail_arr.update({'accuracy_on_batch': accuracy}) self.trainSummary(self.now_batch, detail_arr) logging.info("batch %d : classification loss=%f, batch accuracy=%f", \ self.now_batch, loss.detach().cpu().numpy(), accuracy) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step() def evaluate(self, key): args = self.param.args dm = self.param.volatile.dm dm.restart(key, args.batch_size, shuffle=False) result_arr = [] while True: incoming = self.get_next_batch(dm, key, restart=False) if incoming is None: break incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) result_arr.append(incoming.result) detail_arr = Storage() for i in args.show_sample: index = [i * args.batch_size + j for j in range(args.batch_size)] incoming = self.get_select_batch(dm, key, index) incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) detail_arr["show_str%d" % i] = incoming.result.show_str detail_arr.update({'loss':get_mean(result_arr, 'loss'), \ 'accuracy':get_accuracy(result_arr, label_key='label', prediction_key='prediction')}) return detail_arr def train_process(self): args = self.param.args dm = self.param.volatile.dm while self.now_epoch < args.epochs: self.now_epoch += 1 self.updateOtherWeights() dm.restart('train', args.batch_size) self.net.train() self.train(args.batch_per_epoch) self.net.eval() devloss_detail = self.evaluate("dev") self.devSummary(self.now_batch, devloss_detail) logging.info("epoch %d, evaluate dev", self.now_epoch) testloss_detail = self.evaluate("test") self.testSummary(self.now_batch, testloss_detail) logging.info("epoch %d, evaluate test", self.now_epoch) self.save_checkpoint(value=devloss_detail.loss.tolist()) def test(self, key): args = self.param.args dm = self.param.volatile.dm metric = dm.get_accuracy_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval accuracy") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) data = incoming.data data.prediction = imcoming.result.prediction data.label = imcoming.data.label metric.forward(data) res = metric.close() if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) filename = args.out_dir + "/%s_%s.txt" % (args.name, key) with open(filename, 'w') as f: logging.info("%s Test Result:", key) for key, value in res.items(): if isinstance(value, float) or isinstance(value, bytes): logging.info("\t{}:\t{}".format(key, value)) f.write("{}:\t{}\n".format(key, value)) f.flush() logging.info("result output to %s.", filename) def test_process(self): logging.info("Test Start.") self.net.eval() self.test("dev") self.test("test") logging.info("Test Finish.")
class Seq2seq(BaseModel): def __init__(self, param): args = param.args net = Network(param) self.optimizer = optim.Adam(net.get_parameters_by_name(), lr=args.lr) optimizerList = {"optimizer": self.optimizer} checkpoint_manager = CheckpointManager(args.name, args.model_dir, \ args.checkpoint_steps, args.checkpoint_max_to_keep, "min") super().__init__(param, net, optimizerList, checkpoint_manager) self.create_summary() def create_summary(self): args = self.param.args self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \ args) self.trainSummary = self.summaryHelper.addGroup(\ scalar=["loss", "word_loss", "perplexity"],\ prefix="train") scalarlist = ["word_loss", "perplexity_avg_on_batch"] tensorlist = [] textlist = [] emblist = [] for i in self.args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="dev") self.testSummary = self.summaryHelper.addGroup(\ scalar=scalarlist,\ tensor=tensorlist,\ text=textlist,\ embedding=emblist,\ prefix="test") def _preprocess_batch(self, data): incoming = Storage() incoming.data = data = Storage(data) data.batch_size = data.post.shape[0] data.post = cuda(torch.LongTensor(data.post.transpose( 1, 0))) # length * batch_size data.resp = cuda(torch.LongTensor(data.resp.transpose( 1, 0))) # length * batch_size data.post_bert = cuda(torch.LongTensor(data.post_bert.transpose( 1, 0))) # length * batch_size data.resp_bert = cuda(torch.LongTensor(data.resp_bert.transpose( 1, 0))) # length * batch_size return incoming def get_next_batch(self, dm, key, restart=True): data = dm.get_next_batch(key) if data is None: if restart: dm.restart(key) return self.get_next_batch(dm, key, False) else: return None return self._preprocess_batch(data) def get_batches(self, dm, key): batches = list( dm.get_batches(key, batch_size=self.args.batch_size, shuffle=False)) return len(batches), (self._preprocess_batch(data) for data in batches) def get_select_batch(self, dm, key, i): data = dm.get_batch(key, i) if data is None: return None return self._preprocess_batch(data) def train(self, batch_num): args = self.param.args dm = self.param.volatile.dm datakey = 'train' for i in range(batch_num): self.now_batch += 1 incoming = self.get_next_batch(dm, datakey) incoming.args = Storage() if (i + 1) % args.batch_num_per_gradient == 0: self.zero_grad() self.net.forward(incoming) loss = incoming.result.loss self.trainSummary(self.now_batch, storage_to_list(incoming.result)) logging.info("batch %d : gen loss=%f", self.now_batch, loss.detach().cpu().numpy()) loss.backward() if (i + 1) % args.batch_num_per_gradient == 0: nn.utils.clip_grad_norm_(self.net.parameters(), args.grad_clip) self.optimizer.step() def evaluate(self, key): args = self.param.args dm = self.param.volatile.dm dm.restart(key, args.batch_size, shuffle=False) result_arr = [] while True: incoming = self.get_next_batch(dm, key, restart=False) if incoming is None: break incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) result_arr.append(incoming.result) detail_arr = Storage() for i in args.show_sample: index = [i * args.batch_size + j for j in range(args.batch_size)] incoming = self.get_select_batch(dm, key, index) incoming.args = Storage() with torch.no_grad(): self.net.detail_forward(incoming) detail_arr["show_str%d" % i] = incoming.result.show_str detail_arr.update( {key: get_mean(result_arr, key) for key in result_arr[0]}) detail_arr.perplexity_avg_on_batch = np.exp(detail_arr.word_loss) return detail_arr def train_process(self): args = self.param.args dm = self.param.volatile.dm while self.now_epoch < args.epochs: self.now_epoch += 1 self.updateOtherWeights() dm.restart('train', args.batch_size) self.net.train() self.train(args.batch_per_epoch) self.net.eval() devloss_detail = self.evaluate("dev") self.devSummary(self.now_batch, devloss_detail) logging.info("epoch %d, evaluate dev", self.now_epoch) testloss_detail = self.evaluate("test") self.testSummary(self.now_batch, testloss_detail) logging.info("epoch %d, evaluate test", self.now_epoch) self.save_checkpoint(value=devloss_detail.loss.tolist()) def test(self, key): args = self.param.args dm = self.param.volatile.dm metric1 = dm.get_teacher_forcing_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval teacher-forcing") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.forward(incoming) gen_log_prob = nn.functional.log_softmax(incoming.gen.w, -1) data = incoming.data data.resp = incoming.data.resp_allvocabs data.resp_length = incoming.data.resp_length data.gen_log_prob = gen_log_prob.transpose( 1, 0).detach().cpu().numpy() metric1.forward(data) res = metric1.close() metric2 = dm.get_inference_metric() batch_num, batches = self.get_batches(dm, key) logging.info("eval free-run") for incoming in tqdm.tqdm(batches, total=batch_num): incoming.args = Storage() with torch.no_grad(): self.net.detail_forward(incoming) data = incoming.data data.resp = incoming.data.resp_allvocabs data.post = incoming.data.post_allvocabs data.gen = incoming.gen.w_o.detach().cpu().numpy().transpose(1, 0) metric2.forward(data) res.update(metric2.close()) if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) filename = args.out_dir + "/%s_%s.txt" % (args.name, key) with open(filename, 'w') as f: logging.info("%s Test Result:", key) for key, value in res.items(): if isinstance(value, float) or isinstance(value, bytes): logging.info("\t{}:\t{}".format(key, value)) f.write("{}:\t{}\n".format(key, value)) for i in range(len(res['post'])): f.write("post:\t%s\n" % " ".join(res['post'][i])) f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n" % " ".join(res['gen'][i])) f.flush() logging.info("result output to %s.", filename) def test_process(self): logging.info("Test Start.") self.net.eval() self.test("dev") self.test("test") logging.info("Test Finish.")
class HredModel(object): def __init__(self, data, args, embed): self.init_states = tf.placeholder(tf.float32, (None, args.ch_size), 'ctx_inps') # batch*ch_size self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch # deal with original data to adapt encoder and decoder batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 self.posts_input = self.posts # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): encoder_output, encoder_state = dynamic_rnn(cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") with tf.variable_scope('context'): _, self.context_state = cell_ctx(encoder_state, self.init_states) # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) # construct helper and attention train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, tf.maximum(self.responses_length, 1)) infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.go_id), data.eos_id) attn_mechanism = tf.contrib.seq2seq.LuongAttention( args.dh_size, encoder_output, memory_sequence_length=tf.maximum(self.posts_length, 1)) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) ctx_state_shaping = tf.layers.dense(self.context_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") self.decoder_output = train_outputs.rnn_output self.decoder_distribution_teacher, self.decoder_loss = sampled_sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_sen_length, scope="decoder_rnn") self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step) else: self.best_saver.save(sess, path, global_step=self.global_step) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, sess, data, forward_only=False, inference=False): input_feed = { self.init_states: data['init_states'], self.posts: data['posts'], self.posts_length: data['posts_length'], self.origin_responses: data['responses'], self.origin_responses_length: data['responses_length'], } if inference: output_feed = [self.generation_index, self.context_state] else: if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher, self.context_state ] else: output_feed = [ self.decoder_loss, self.gradient_norm, self.update, self.context_state ] return sess.run(output_feed, input_feed) def get_step_data(self, step_data, batched_data, turn): current_batch_size = batched_data['sent'].shape[0] max_turn_length = batched_data['sent'].shape[1] max_sent_length = batched_data['sent'].shape[2] if turn == -1: step_data['posts'] = np.zeros((current_batch_size, 1), dtype=int) else: step_data['posts'] = batched_data['sent'][:, turn, :] step_data['responses'] = batched_data['sent'][:, turn + 1, :] step_data['posts_length'] = np.zeros((current_batch_size, ), dtype=int) step_data['responses_length'] = np.zeros((current_batch_size, ), dtype=int) for i in range(current_batch_size): if turn < len(batched_data['sent_length'][i]): if turn == -1: step_data['posts_length'][i] = 1 else: step_data['posts_length'][i] = batched_data['sent_length'][ i][turn] if turn + 1 < len(batched_data['sent_length'][i]): step_data['responses_length'][i] = batched_data['sent_length'][ i][turn + 1] max_posts_length = np.max(step_data['posts_length']) max_responses_length = np.max(step_data['responses_length']) step_data['posts'] = step_data['posts'][:, 0:max_posts_length] step_data['responses'] = step_data['responses'][:, 0:max_responses_length] def train_step(self, sess, data, args): current_batch_size = data['sent'].shape[0] max_turn_length = data['sent'].shape[1] max_sent_length = data['sent'].shape[2] loss = np.zeros((1, )) total_length = np.zeros((1, )) step_data = {} context_states = np.zeros((current_batch_size, args.ch_size)) for turn in range(max_turn_length - 1): self.get_step_data(step_data, data, turn) step_data['init_states'] = context_states decoder_loss, _, _, context_states = self.step_decoder( sess, step_data) length = np.sum(np.maximum(step_data['responses_length'] - 1, 0)) total_length += length loss += decoder_loss * length return loss / total_length def evaluate(self, sess, data, key_name, args): loss = np.zeros((1, )) total_length = np.zeros((1, )) data.restart(key_name, batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: current_batch_size = batched_data['sent'].shape[0] max_turn_length = batched_data['sent'].shape[1] max_sent_length = batched_data['sent'].shape[2] step_data = {} context_states = np.zeros((current_batch_size, args.ch_size)) for turn in range(max_turn_length - 1): self.get_step_data(step_data, batched_data, turn) step_data['init_states'] = context_states decoder_loss, _, context_states = self.step_decoder( sess, step_data, forward_only=True) length = np.sum( np.maximum(step_data['responses_length'] - 1, 0)) total_length += length loss += decoder_loss * length batched_data = data.get_next_batch(key_name) loss /= total_length print(' perplexity on %s set: %.2f' % (key_name, np.exp(loss))) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 previous_losses = [1e18] * 5 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: for step in range(args.checkpoint_steps): if batched_data == None: break start_time = time.time() loss_step += self.train_step(sess, batched_data, args) time_step += time.time() - start_time batched_data = data.get_next_batch("train") loss_step /= args.checkpoint_steps time_step /= args.checkpoint_steps show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) print( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, show(np.exp(loss_step)))) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(loss_step) }) #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\ self.store_checkpoint( sess, '%s/checkpoint_latest/checkpoint' % args.model_dir, "latest") dev_loss = self.evaluate(sess, data, "dev", args) self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss, 'perplexity': np.exp(dev_loss) }) test_loss = self.evaluate(sess, data, "test", args) self.testSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': test_loss, 'perplexity': np.exp(test_loss) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss < best_valid: best_valid = dev_loss self.store_checkpoint( sess, '%s/checkpoint_best/checkpoint' % args.model_dir, "best") previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process(self, sess, data, args): def get_batch_results(batched_responses_id, data): batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() response_token = data.index_to_sen(response_id_list) if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list batch_results.append(result_id) return batch_results def padding(matrix): l = max([len(d) for d in matrix]) res = [d + [data.pad_id] * (l - len(d)) for d in matrix] return res metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") cnt = 0 start_time = time.time() while batched_data != None: current_batch_size = batched_data['sent'].shape[0] max_turn_length = batched_data['sent'].shape[1] max_sent_length = batched_data['sent'].shape[2] if cnt > 0 and cnt % 10 == 0: print('processing %d batch data, time cost %.2f s/batch' % (cnt, (time.time() - start_time) / 10)) start_time = time.time() cnt += 1 step_data = {} context_states = np.zeros((current_batch_size, args.ch_size)) batched_gen_prob = [] batched_gen = [] for turn in range(max_turn_length): self.get_step_data(step_data, batched_data, turn - 1) step_data['init_states'] = context_states decoder_loss, gen_prob, context_states = self.step_decoder( sess, step_data, forward_only=True) batched_responses_id, context_states = self.step_decoder( sess, step_data, inference=True) batch_results = get_batch_results(batched_responses_id, data) step_data['gen_prob'] = gen_prob batched_gen_prob.append(step_data['gen_prob']) step_data['generations'] = batch_results batched_gen.append(step_data['generations']) def transpose(batched_gen_prob): batched_gen_prob_temp = [[0 for i in range(max_turn_length)] for j in range(current_batch_size)] for i in range(max_turn_length): for j in range(current_batch_size): batched_gen_prob_temp[j][i] = batched_gen_prob[i][j] batched_gen_prob[:] = batched_gen_prob_temp[:] for i in range(current_batch_size): for j in range(max_turn_length): batched_gen_prob[i][j] = np.concatenate( (batched_gen_prob[i][j], [batched_gen_prob[i][j][-1]]), axis=0) transpose(batched_gen_prob) transpose(batched_gen) sent_length = [] for i in range(current_batch_size): sent_length.append( np.array(batched_data['sent_length'][i]) + 1) batched_sent = np.zeros( (current_batch_size, max_turn_length, max_sent_length + 2), dtype=int) empty_sent = np.zeros((current_batch_size, 1, max_sent_length + 2), dtype=int) for i in range(current_batch_size): for j, _ in enumerate(sent_length[i]): batched_sent[i][j][0] = data.go_id batched_sent[i][j][1:sent_length[i][j]] = batched_data[ 'sent'][i][j][0:sent_length[i][j] - 1] empty_sent[i][0][0] = data.go_id empty_sent[i][0][1] = data.eos_id metric1_data = { 'sent_allvocabs': batched_data['sent_allvocabs'], 'sent_length': batched_data['sent_length'], 'gen_log_prob': batched_gen_prob, } metric1.forward(metric1_data) metric2_data = { 'context_allvocabs': [], 'reference_allvocabs': batched_data['sent_allvocabs'], 'turn_length': batched_data['turn_length'], 'gen': batched_gen, } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") for key, value in res.items(): if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) for i in range(len(res['context'])): f.write("batch number:\t%d\n" % i) for j in range( min(len(res['context'][i]), len(res['reference'][i]))): if j > 0 and " ".join(res['context'][i][j]) != " ".join( res['reference'][i][j - 1]): f.write("\n") f.write("post:\t%s\n" % " ".join(res['context'][i][j])) f.write("resp:\t%s\n" % " ".join(res['reference'][i][j])) if j < len(res['gen'][i]): f.write("gen:\t%s\n" % " ".join(res['gen'][i][j])) else: f.write("gen:\n") print("result output to %s" % test_file)
class Seq2SeqModel(object): def __init__(self, data, args, embed): self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch # deal with original data to adapt encoder and decoder batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 self.posts_input = self.posts # batch*len self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup( self.embed, self.posts_input) #batch*len*unit self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): encoder_output, encoder_state = dynamic_rnn(cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) # construct helper and attention train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, self.responses_length) infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.go_id), data.eos_id) attn_mechanism = tf.contrib.seq2seq.LuongAttention( args.dh_size, encoder_output, memory_sequence_length=self.posts_length) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) enc_state_shaping = tf.layers.dense(encoder_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") self.decoder_output = train_outputs.rnn_output self.decoder_distribution_teacher, self.decoder_loss = sampled_sequence_loss( self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_sen_length, scope="decoder_rnn") self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step) else: self.best_saver.save(sess, path, global_step=self.global_step) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, forward_only=False): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'] } if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher ] else: output_feed = [self.decoder_loss, self.gradient_norm, self.update] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'] } output_feed = [self.generation_index] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss = np.zeros((1, )) times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_decoder(sess, batched_data, forward_only=True) loss += outputs[0] times += 1 batched_data = data.get_next_batch(key_name) loss /= times print(' perplexity on %s set: %.2f' % (key_name, np.exp(loss))) print(loss) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 previous_losses = [1e18] * 5 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: show = lambda a: '[%s]' % (' '.join( ['%.2f' % x for x in a])) print( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, show(np.exp(loss_step)))) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(loss_step) }) #self.saver.save(sess, '%s/checkpoint_latest' % args.model_dir, global_step=self.global_step)\ self.store_checkpoint( sess, '%s/checkpoint_latest/checkpoint' % args.model_dir, "latest") dev_loss = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss, 'perplexity': np.exp(dev_loss) }) test_loss = self.evaluate(sess, data, args.batch_size, "test") self.testSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': test_loss, 'perplexity': np.exp(test_loss) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss < best_valid: best_valid = dev_loss self.store_checkpoint( sess, '%s/checkpoint_best/checkpoint' % args.model_dir, "best") previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 start_time = time.time() loss_step += self.step_decoder( sess, batched_data)[0] / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") results = [] while batched_data != None: batched_responses_id = self.inference(sess, batched_data)[0] _, gen_prob = self.step_decoder(sess, batched_data, forward_only=True) metric1_data = { 'resp': np.array(batched_data['resp']), 'resp_length': np.array(batched_data['resp_length']), 'gen_prob': np.array(gen_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: result_token = [] response_id_list = response_id.tolist() response_token = data.index_to_sen(response_id_list) if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list for token in response_token: if token != data.ext_vocab[data.eos_id]: result_token.append(token) else: break results.append(result_token) batch_results.append(result_id) metric2_data = { 'post': np.array(batched_data['post']), 'resp': np.array(batched_data['resp']), 'gen': np.array(batch_results) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") for key, value in res.items(): if isinstance(value, float): print("\t%s:\t%f", key, value) f.write("%s:\t%f\n" % (key, value)) for i in range(len(res['post'])): f.write("post:\t%s\n" % " ".join(res['post'][i])) f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n" % " ".join(res['gen'][i])) print("result output to %s.", test_file)
class Seq2SeqModel(object): def __init__(self, data, args, embed): # posts表示编码器,即历史对话输入 [batch, encoder_len] # posts_length表示输入的每一句话的实际长度 [batch] # prev_length除去最后一轮,之前轮次语句的长度(包含<go>和<eos>),[batch] self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.prevs_length = tf.placeholder(tf.int32, (None, ), 'enc_lens_prev') # batch # origin_responses表示回复的内容,[batch, resp_len] # origin_responses_length表示每一个回复的实际长度,[batch, ] self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch self.is_train = tf.placeholder(tf.bool) # deal with original data to adapt encoder and decoder batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] # 这里对回复进行分割,此时祛除了回复中的go_id self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 # 这里得到解码器的输入和输出,输入去除了最后的eos_id,输出去除了最开始的go_id,这样保证对齐 # [batch, decoder_len](这里的decoder_len等于resp_len-1) self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 # 编码器输入 [batch, encoder_len] self.posts_input = self.posts # batch*len # 这里计算decoder的mask矩阵 # 等于[batch, decoder_len] self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) # 将编码器和解码器的输入转化为词向量 # encoder_input: [batch, encoder_len, embed_size] # decoder_input: [batch, decoder_len, embed_size] self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) #self.encoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit #self.decoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.responses_input)) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): # encoder_output: [batch, encoder_len, eh_size] # encoder_state: [batch, eh_size] encoder_output, encoder_state = tf.nn.dynamic_rnn( cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) encoder_len = tf.shape(encoder_output)[1] # 这里计算posts和prevs的mask矩阵 posts_mask = tf.sequence_mask(self.posts_length, encoder_len) prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len) # 不同为1,相同为1 # 这里表示只关注最后一轮,[batch, encoder_len] attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask), [batch_size, encoder_len]) # construct helper and attention train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, self.responses_length) # 这里在推理的时候,起始位置全部使用go_id进行填充 # 这在对输入数据进行封装时即进行了定义 infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.go_id), data.eos_id) # 这里编码器是按照多轮输入进行编码的 # 但是解码器在attention的时候只关注最后一轮输入 # 这里定义输入输出attention attn_mechanism = MyAttention(args.dh_size, encoder_output, attention_mask) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) # 把编码器最后一层的隐状态映射到解码器隐状态的维度 # [batch, dh_size] enc_state_shaping = tf.layers.dense(encoder_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") self.decoder_output = train_outputs.rnn_output #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8) # 计算损失和概率分布 # decoder_distribution_teacher:[batch, decoder_length, vocab_size] (这里都是对数概率) # decoder_loss,基于这个batch中所有词的损失,0维 # decoder_all_loss,每一句话的损失,[batch, ] self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \ sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): # 这里output_fn会重用上面的权重和偏置 decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_decoder_length, scope="decoder_rnn") # [batch, max_decoder_len, vocab_size] self.decoder_distribution = infer_outputs.rnn_output # 这里在计算索引概率最大值的去除前面两个<pad>和<unk>的影响 self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) # 定义优化器 gradients = tf.gradients(self.decoder_loss, self.params) # 计算参数的梯度 clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) # 梯度裁剪 self.update = opt.apply_gradients( zip(clipped_gradients, self.params), global_step=self.global_step) # 对参数进行更新 # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key, name): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step, latest_filename=name) else: self.best_saver.save(sess, path, global_step=self.global_step, latest_filename=name) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") def print_parameters(self): for item in self.params: logger.info('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, forward_only=False): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.prevs_length: data['prev_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.is_train: True } if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher ] else: output_feed = [self.decoder_loss, self.gradient_norm, self.update] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.prevs_length: data['prev_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.is_train: False } output_feed = [ self.generation_index, self.decoder_distribution_teacher, self.decoder_all_loss ] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss = np.zeros((1, )) times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_decoder(sess, batched_data, forward_only=True) loss += outputs[0] times += 1 batched_data = data.get_next_batch(key_name) loss /= times logger.info( f'Evaluate loss: {float(loss):.2f} | perplexity on {key_name} set: {float(np.exp(loss)): .2f}' ) # print(loss) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 previous_losses = [1e18] * 3 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for i in range(2): logger.info( f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}" ) logger.info( f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}" ) for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: show = lambda a: '[%s]' % (' '.join( ['%.2f' % x for x in a])) logger.info( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, show(np.exp(loss_step)))) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(loss_step) }) self.store_checkpoint( sess, '%s/checkpoint_latest/%s' % (args.model_dir, args.name), "latest", args.name) dev_loss = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss, 'perplexity': np.exp(dev_loss) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss < best_valid: best_valid = dev_loss self.store_checkpoint( sess, '%s/checkpoint_best/%s' % (args.model_dir, args.name), "best", args.name) previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 start_time = time.time() loss_step += self.step_decoder( sess, batched_data)[0] / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process_hits(self, sess, data, args): with open(os.path.join(args.datapath, 'test_distractors.json'), 'r', encoding="utf-8") as f: test_distractors = json.load(f) data.restart("test", batch_size=1, shuffle=False) batched_data = data.get_next_batch("test") loss_record = [] cnt = 0 while batched_data != None: for key in batched_data: if isinstance(batched_data[key], np.ndarray): batched_data[key] = batched_data[key].tolist() batched_data['resp_length'] = [len(batched_data['resp'][0])] batched_data['resp'] = batched_data['resp'] for each_resp in test_distractors[cnt]: batched_data['resp'].append( [data.go_id] + data.convert_tokens_to_ids(jieba.lcut(each_resp)) + [data.eos_id]) batched_data['resp_length'].append( len(batched_data['resp'][-1])) max_length = max(batched_data['resp_length']) resp = np.zeros((len(batched_data['resp']), max_length), dtype=int) for i, each_resp in enumerate(batched_data['resp']): resp[i, :len(each_resp)] = each_resp batched_data['resp'] = resp post = [] post_length = [] prev_length = [] for _ in range(len(resp)): post += batched_data['post'] post_length += batched_data['post_length'] prev_length += batched_data['prev_length'] batched_data['post'] = post batched_data['post_length'] = post_length batched_data['prev_length'] = prev_length _, _, loss = self.inference(sess, batched_data) loss_record.append(loss) cnt += 1 batched_data = data.get_next_batch("test") assert cnt == len(test_distractors) loss = np.array(loss_record) loss_rank = np.argsort(loss, axis=1) hits1 = float(np.mean(loss_rank[:, 0] == 0)) hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0)) hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0)) return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5} def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() # 这里主要计算ppl指标 metric2 = data.get_inference_metric() # 这里主要计算bleu和distinct指标 data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") for i in range(3): logger.info( f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}" ) logger.info( f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}" ) while batched_data != None: batched_responses_id, gen_log_prob, _ = self.inference( sess, batched_data) metric1_data = { 'resp_allvocabs': np.array(batched_data['resp_allvocabs']), 'resp_length': np.array(batched_data['resp_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list batch_results.append(result_id) metric2_data = { 'gen': np.array(batch_results), 'resp_allvocabs': np.array(batched_data['resp_allvocabs']) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) res.update(self.test_process_hits(sess, data, args)) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w', encoding="utf-8") as f: print("Test Result:") res_print = list(res.items()) res_print.sort(key=lambda x: x[0]) for key, value in res_print: if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) f.write('\n') for i in range(len(res['resp'])): f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n\n" % " ".join(res['gen'][i])) logger.info("result output to %s." % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }
class HredModel(object): def __init__(self, data, args, embed): #self.init_states = tf.placeholder(tf.float32, (None, args.ch_size), 'ctx_inps') # batch*ch_size # posts: [batch*(num_turns-1), max_post_length] # posts_length: [batch*(num_turns-1),] self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch * num_turns-1 * len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch * num_turns-1 # prev_posts: [batch, max_prev_length],即对应上面post的最后一轮 # prev_posts_length: [batch],即最后一轮每句话的实际长度 self.prev_posts = tf.placeholder(tf.int32, (None, None), 'enc_prev_inps') self.prev_posts_length = tf.placeholder(tf.int32, (None, ), 'enc_prev_lens') # origin_responses: [batch, max_response_length] # origin_responses_length: [batch] self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch # context_length: [batch],表示每一个posts的实际轮次 self.context_length = tf.placeholder(tf.int32, (None, ), 'ctx_lens') self.is_train = tf.placeholder(tf.bool) # 即对应num_turns-1(也有可能比这个小) # 表示当前batch的实际最大轮次 num_past_turns = tf.shape(self.posts)[0] // tf.shape( self.origin_responses)[0] # deal with original data to adapt encoder and decoder # 获取解码器的输入和输出 # 其中输入没有最后的<eos>,输出没有最开始的<go> batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 self.posts_input = self.posts # batch*len # [batch, decoder_length] self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts) self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # self.encoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) # batch*len*unit # self.decoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.responses_input)) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_ctx = tf.nn.rnn_cell.GRUCell(args.ch_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): # encoder_output: [batch*(num_turns-1), max_post_length, eh_size] # encoder_state: [batch*(num_turns-1), eh_size] encoder_output, encoder_state = tf.nn.dynamic_rnn( cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE): # prev_output: [batch, max_prev_length, eh_size] prev_output, _ = tf.nn.dynamic_rnn(cell_enc, tf.nn.embedding_lookup( self.embed, self.prev_posts), self.prev_posts_length, dtype=tf.float32, scope="encoder_rnn") # encoder_hidden_size = tf.shape(encoder_state)[-1] with tf.variable_scope('context'): # encoder_state_reshape: [batch, num_turns-1, eh_size] # context_output: [batch, num_turns-1, ch_size] # context_state: [batch, ch_size] encoder_state_reshape = tf.reshape( encoder_state, [-1, num_past_turns, args.eh_size]) context_output, self.context_state = tf.nn.dynamic_rnn( cell_ctx, encoder_state_reshape, self.context_length, dtype=tf.float32, scope='context_rnn') # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) # construct helper and attention train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, tf.maximum(self.responses_length, 1)) infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( self.embed, tf.fill([batch_size], data.go_id), data.eos_id) #encoder_len = tf.shape(encoder_output)[1] #attention_memory = tf.reshape(encoder_output, [batch_size, -1, args.eh_size]) #attention_mask = tf.reshape(tf.sequence_mask(self.posts_length, encoder_len), [batch_size, -1]) ''' attention_memory = context_output attention_mask = tf.reshape(tf.sequence_mask(self.context_length, self.num_turns - 1), [batch_size, -1]) ''' #attention_mask = tf.concat([tf.ones([batch_size, 1], tf.bool), attention_mask[:, 1:]], axis=1) #attn_mechanism = MyAttention(args.dh_size, attention_memory, attention_mask) # 注意这里的inputs,是最后一句话的编码,即[batch_size, prev_post_length, eh_size] # 在attention中,如果query的维度和inputs不一致,需要先经过线性层将query转化为 attn_mechanism = tf.contrib.seq2seq.BahdanauAttention( args.dh_size, prev_output, memory_sequence_length=tf.maximum(self.prev_posts_length, 1)) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) # 将posts的编码输出转化为解码器的维度 [batch, dh_size] ctx_state_shaping = tf.layers.dense(self.context_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=ctx_state_shaping) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") # 这里的decoder_output: [batch, decoder_length, dh_size] self.decoder_output = train_outputs.rnn_output # self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8) # decoder_distribution_teacher: [batch, decoder_length, vocab_size] # decoder_loss: 标量 # decoder_all_loss: [batch, ],表示每一句话的对数损失 self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \ sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_decoder_length, scope="decoder_rnn") # [batch, max_decoder_length, vocab_size] self.decoder_distribution = infer_outputs.rnn_output # 得到每一步解码的单词索引[batch, max_decoder_length] self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) gradients = tf.gradients(self.decoder_loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key, name): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step, latest_filename=name) else: self.best_saver.save(sess, path, global_step=self.global_step, latest_filename=name) def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] emblist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, embedding=emblist, prefix="test") def print_parameters(self): for item in self.params: logger.info('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, sess, data, forward_only=False, inference=False): input_feed = { self.posts: data['posts'], self.posts_length: data['posts_length'], self.origin_responses: data['responses'], self.origin_responses_length: data['responses_length'], self.context_length: data['context_length'], self.prev_posts: data['prev_posts'], self.prev_posts_length: data['prev_posts_length'] } if inference: input_feed.update({self.is_train: False}) output_feed = [ self.generation_index, self.decoder_distribution_teacher, self.decoder_all_loss ] else: input_feed.update({self.is_train: True}) if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher ] else: output_feed = [ self.decoder_loss, self.gradient_norm, self.update ] return sess.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): loss = np.zeros((1, )) times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_decoder(sess, batched_data, forward_only=True) loss += outputs[0] times += 1 batched_data = data.get_next_batch(key_name) loss /= times logger.info(' perplexity on %s set: %.2f' % (key_name, float(np.exp(loss)))) logger.info(loss) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 previous_losses = [1e18] * 3 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: show = lambda a: '[%s]' % (' '.join( ['%.2f' % x for x in a])) logger.info( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity %s" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, show(np.exp(loss_step)))) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(loss_step) }) self.store_checkpoint( sess, '%s/checkpoint_latest/%s' % (args.model_dir, args.name), "latest", args.name) dev_loss = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss, 'perplexity': np.exp(dev_loss) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) # 如果验证集的损失小于之前最小的损失,则将当前模型保存到最好的模型中 if dev_loss < best_valid: best_valid = dev_loss self.store_checkpoint( sess, '%s/checkpoint_best/%s' % (args.model_dir, args.name), "best", args.name) previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 start_time = time.time() loss_step += self.step_decoder( sess, batched_data)[0] / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process_hits(self, sess, data, args): with open(os.path.join(args.datapath, 'test_distractors.json'), 'r', encoding="utf-8") as f: test_distractors = json.load(f) data.restart("test", batch_size=1, shuffle=False) batched_data = data.get_next_batch("test") loss_record = [] cnt = 0 while batched_data != None: for key in batched_data: if isinstance(batched_data[key], np.ndarray): batched_data[key] = batched_data[key].tolist() batched_data['responses_length'] = [ len(batched_data['responses'][0]) ] batched_data['responses'] = batched_data['responses'] for each_resp in test_distractors[cnt]: batched_data['responses'].append( [data.go_id] + data.convert_tokens_to_ids(jieba.lcut(each_resp)) + [data.eos_id]) batched_data['responses_length'].append( len(batched_data['responses'][-1])) max_length = max(batched_data['responses_length']) resp = np.zeros((len(batched_data['responses']), max_length), dtype=int) for i, each_resp in enumerate(batched_data['responses']): resp[i, :len(each_resp)] = each_resp batched_data['responses'] = resp posts = [] posts_length = [] prev_posts = [] prev_posts_length = [] context_length = [] for _ in range(len(resp)): posts += batched_data['posts'] posts_length += batched_data['posts_length'] prev_posts += batched_data['prev_posts'] prev_posts_length += batched_data['prev_posts_length'] context_length += batched_data['context_length'] batched_data['posts'] = posts batched_data['posts_length'] = posts_length batched_data['prev_posts'] = prev_posts batched_data['prev_posts_length'] = prev_posts_length batched_data['context_length'] = context_length _, _, loss = self.step_decoder(sess, batched_data, inference=True) loss_record.append(loss) cnt += 1 batched_data = data.get_next_batch("test") assert cnt == len(test_distractors) loss = np.array(loss_record) loss_rank = np.argsort(loss, axis=1) hits1 = float(np.mean(loss_rank[:, 0] == 0)) hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0)) hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0)) return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5} def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") while batched_data != None: batched_responses_id, gen_log_prob, _ = self.step_decoder( sess, batched_data, inference=True) metric1_data = { 'resp_allvocabs': np.array(batched_data['responses_allvocabs']), 'resp_length': np.array(batched_data['responses_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list batch_results.append(result_id) metric2_data = { 'gen': np.array(batch_results), 'resp_allvocabs': np.array(batched_data['responses_allvocabs']) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) res.update(self.test_process_hits(sess, data, args)) test_file = args.output_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w', encoding="utf-8") as f: print("Test Result:") res_print = list(res.items()) res_print.sort(key=lambda x: x[0]) for key, value in res_print: if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) f.write('\n') for i in range(len(res['resp'])): f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n\n" % " ".join(res['gen'][i])) logger.info("result output to %s." % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }
class LMModel(object): def __init__(self, data, args, embed): with tf.variable_scope("input"): with tf.variable_scope("embedding"): # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.frequent_vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) # input self.sentence = tf.placeholder(tf.int32, (None, None), 'sen_inps') # batch*len self.sentence_length = tf.placeholder(tf.int32, (None, ), 'sen_lens') # batch self.use_prior = tf.placeholder(dtype=tf.bool, name="use_prior") batch_size, batch_len = tf.shape(self.sentence)[0], tf.shape( self.sentence)[1] self.scentence_max_len = batch_len - 1 # data processing LM_input = tf.split(self.sentence, [self.scentence_max_len, 1], 1)[0] # no eos_id self.LM_input = tf.nn.embedding_lookup( self.embed, LM_input) # batch*(len-1)*unit self.LM_target = tf.split(self.sentence, [1, self.scentence_max_len], 1)[1] # no go_id, batch*(len-1) self.input_len = self.sentence_length - 1 self.input_mask = tf.sequence_mask( self.input_len, self.scentence_max_len, dtype=tf.float32) # 0 for <pad>, batch*(len-1) # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build LSTM NN basic_cell = tf.nn.rnn_cell.LSTMCell(args.dh_size) with tf.variable_scope('rnnlm'): LM_output, _ = dynamic_rnn(basic_cell, self.LM_input, self.input_len, dtype=tf.float32, scope="rnnlm") # fullly connected layer LM_output = tf.layers.dense( inputs=LM_output, units=data.frequent_vocab_size ) # shape of LM_output: (batch_size, batch_len-1, vocab_size) # loss with tf.variable_scope("loss", initializer=tf.orthogonal_initializer()): crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=LM_output, labels=self.LM_target) crossent = tf.reduce_sum(crossent * self.input_mask) # to ignore <pad>s self.sen_loss = crossent / tf.to_float(batch_size) self.ppl_loss = crossent / tf.reduce_sum( self.input_mask) # crossent per word. # self.ppl_loss = tf.Print(self.ppl_loss, [self.ppl_loss] ) self.decoder_distribution_teacher = tf.nn.log_softmax(LM_output) with tf.variable_scope("decode", reuse=True): self.decoder_distribution = LM_output # (batch_size, batch_len-1, vocab_size) # for inference self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.frequent_vocab_size - 2], 2)[1], 2) + 2 # for removing UNK. 0 for <pad> and 1 for <unk> self.loss = self.sen_loss # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) opt = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=args.momentum) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step) else: self.best_saver.save(sess, path, global_step=self.global_step) def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup(scalar=[ "loss", "perplexity", ], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="test") def print_parameters(self): for item in self.params: print('%s: %s' % (item.name, item.get_shape())) def step_LM(self, session, data, forward_only=False): ''' run the LM for one step (batch) ''' input_feed = { self.sentence: data['sent'], self.sentence_length: data['sent_length'], self.use_prior: False } if forward_only: # test mode output_feed = [ self.loss, self.decoder_distribution_teacher, self.ppl_loss ] else: # train mode output_feed = [ self.loss, self.gradient_norm, self.update, self.ppl_loss ] return session.run(output_feed, input_feed) def inference(self, session, data): input_feed = { self.sentence: data['sent'], self.sentence_length: data['sent_length'] } output_feed = [self.generation_index] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name): ''' to get the loss and ppl_loss per step on dev and test ''' loss_step = np.zeros((1, )) ppl_loss_step = 0 times = 0 data.restart(key_name, batch_size=batch_size, shuffle=False) # initialize mini-batches batched_data = data.get_next_batch(key_name) while batched_data != None: outputs = self.step_LM(sess, batched_data, forward_only=True) loss_step += outputs[0] ppl_loss_step += outputs[-1] times += 1 batched_data = data.get_next_batch(key_name) loss_step /= times ppl_loss_step /= times print(' loss: %.2f' % loss_step) return loss_step, ppl_loss_step def train_process(self, sess, data, args): # 'X_step' <=> X per step loss_step, time_step, epoch_step = np.zeros((1, )), .0, 0 ppl_loss_step = 0 previous_losses = [1e18] * 5 # previous 5 losses best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: print( "Epoch %d global step %d learning rate %.4f step-time %.2f" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step)) print(' loss: %.2f' % loss_step) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step, 'perplexity': np.exp(ppl_loss_step), }) self.store_checkpoint( sess, '%s/checkpoint_latest/checkpoint' % args.model_dir, "latest") devout = self.evaluate(sess, data, args.batch_size, "dev") self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': devout[0], 'perplexity': np.exp(devout[1]), }) testout = self.evaluate(sess, data, args.batch_size, "test") self.testSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': testout[0], 'perplexity': np.exp(testout[1]), }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if devout[0] < best_valid: best_valid = devout[0] self.store_checkpoint( sess, '%s/checkpoint_best/checkpoint' % args.model_dir, "best") previous_losses = previous_losses[1:] + [np.sum(loss_step)] loss_step, time_step = np.zeros((1, )), .0 ppl_loss_step = 0 start_time = time.time() outputs = self.step_LM(sess, batched_data) # outputs: loss, decoder_distribution_teacher, ppl_loss loss_step += outputs[0] / args.checkpoint_steps ppl_loss_step += outputs[-1] / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") results = [] while batched_data != None: batched_responses_id = self.inference(sess, batched_data)[0] gen_prob = self.step_LM(sess, batched_data, forward_only=True)[1] metric1_data = { 'sent_allvocabs': np.array(batched_data['sent_allvocabs']), 'sent_length': np.array(batched_data['sent_length']), 'gen_log_prob': np.array(gen_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: result_token = [] response_id_list = response_id.tolist() response_token = data.convert_ids_to_tokens(response_id_list) if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list for token in response_token: ext_vocab = data.get_special_tokens_mapping() if token != ext_vocab['eos']: result_token.append(token) else: break results.append(result_token) batch_results.append(result_id) metric2_data = {'gen': np.array(batch_results)} metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) test_file = args.out_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w') as f: print("Test Result:") for key, value in res.items(): if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) for i in range(len(res['gen'])): f.write("%s\n" % " ".join(res['gen'][i])) print("result output to %s." % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }
class Seq2SeqModel(object): def __init__(self, data, args, embed): # 这里的输入和前面的seq2seq一致 self.posts = tf.placeholder(tf.int32, (None, None), 'enc_inps') # batch*len self.posts_length = tf.placeholder(tf.int32, (None, ), 'enc_lens') # batch self.prevs_length = tf.placeholder(tf.int32, (None, ), 'enc_lens_prev') # batch self.origin_responses = tf.placeholder(tf.int32, (None, None), 'dec_inps') # batch*len self.origin_responses_length = tf.placeholder(tf.int32, (None, ), 'dec_lens') # batch # kgs表示该样例所在这段对话中所有的知识:[batch, max_kg_nums, max_kg_length] # kgs_h_length表示每一个知识中head entity的长度:[batch, max_kg_nums] # kgs_hr_length表示每一个知识中head entity和relation的长度:[batch, max_kg_nums] # kgs_hrt_length表示每一个知识中h,r,t的长度:[batch, max_kg_nums] # kgs_index表示当前这句话实际使用的kg的索引指示矩阵:[batch, max_kg_nums](其中使用的知识对应为1,没有使用的知识对应为0) self.kgs = tf.placeholder(tf.int32, (None, None, None), 'kg_inps') self.kgs_h_length = tf.placeholder(tf.int32, (None, None), 'kg_h_lens') self.kgs_hr_length = tf.placeholder(tf.int32, (None, None), 'kg_hr_lens') self.kgs_hrt_length = tf.placeholder(tf.int32, (None, None), 'kg_hrt_lens') self.kgs_index = tf.placeholder(tf.float32, (None, None), 'kg_indices') # 用来平衡解码损失和kg损失的超参数 self.lamb = tf.placeholder(tf.float32, name='lamb') self.is_train = tf.placeholder(tf.bool) # deal with original data to adapt encoder and decoder # 获取解码器的输入和输出 batch_size, decoder_len = tf.shape(self.origin_responses)[0], tf.shape( self.origin_responses)[1] self.responses = tf.split(self.origin_responses, [1, decoder_len - 1], 1)[1] # no go_id self.responses_length = self.origin_responses_length - 1 self.responses_input = tf.split(self.origin_responses, [decoder_len - 1, 1], 1)[0] # no eos_id self.responses_target = self.responses decoder_len = decoder_len - 1 # 获取编码器的输入 self.posts_input = self.posts # batch*len # 对解码器的mask矩阵,对于pad的mask self.decoder_mask = tf.reshape( tf.cumsum(tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1), [-1, decoder_len]) kg_len = tf.shape(self.kgs)[2] #kg_len = tf.Print(kg_len, [batch_size, kg_len, decoder_len, self.kgs_length]) # kg_h_mask = tf.reshape(tf.cumsum(tf.one_hot(self.kgs_h_length-1, # kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) # 这里分别得到对于key(也就是hr)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1] # 以及对于value(也就是t)的mask矩阵:[batch_size, max_kg_nums, max_kg_length, 1] kg_hr_mask = tf.reshape( tf.cumsum(tf.one_hot(self.kgs_hr_length - 1, kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) kg_hrt_mask = tf.reshape( tf.cumsum(tf.one_hot(self.kgs_hrt_length - 1, kg_len), reverse=True, axis=2), [batch_size, -1, kg_len, 1]) kg_key_mask = kg_hr_mask kg_value_mask = kg_hrt_mask - kg_hr_mask # initialize the training process self.learning_rate = tf.Variable(float(args.lr), trainable=False, dtype=tf.float32) self.learning_rate_decay_op = self.learning_rate.assign( self.learning_rate * args.lr_decay) self.global_step = tf.Variable(0, trainable=False) # build the embedding table and embedding input if embed is None: # initialize the embedding randomly self.embed = tf.get_variable( 'embed', [data.vocab_size, args.embedding_size], tf.float32) else: # initialize the embedding by pre-trained word vectors self.embed = tf.get_variable('embed', dtype=tf.float32, initializer=embed) # encoder_input: [batch, encoder_len, embed_size] self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts) # decoder_input: [batch, decoder_len, embed_size] self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_input) # kg_input: [batch, max_kg_nums, max_kg_length, embed_size] self.kg_input = tf.nn.embedding_lookup(self.embed, self.kgs) #self.encoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.posts_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.posts_input)) #batch*len*unit #self.decoder_input = tf.cond(self.is_train, # lambda: tf.nn.dropout(tf.nn.embedding_lookup(self.embed, self.responses_input), 0.8), # lambda: tf.nn.embedding_lookup(self.embed, self.responses_input)) # build rnn_cell cell_enc = tf.nn.rnn_cell.GRUCell(args.eh_size) cell_dec = tf.nn.rnn_cell.GRUCell(args.dh_size) # build encoder with tf.variable_scope('encoder'): encoder_output, encoder_state = tf.nn.dynamic_rnn( cell_enc, self.encoder_input, self.posts_length, dtype=tf.float32, scope="encoder_rnn") # key对应一个知识h,r的词向量的均值 [batch, max_kg_nums, embed_size] # value对应一个知识t的词向量的均值 [batch, max_kg_nums, embed_size] self.kg_key_avg = tf.reduce_sum( self.kg_input * kg_key_mask, axis=2) / tf.maximum( tf.reduce_sum(kg_key_mask, axis=2), tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1), dtype=tf.float32)) self.kg_value_avg = tf.reduce_sum( self.kg_input * kg_value_mask, axis=2) / tf.maximum( tf.reduce_sum(kg_value_mask, axis=2), tf.ones_like(tf.expand_dims(self.kgs_hrt_length, -1), dtype=tf.float32)) # 将编码器的输出状态映射到embed_size的维度 # query: [batch, 1, embed_size] with tf.variable_scope('knowledge'): query = tf.reshape( tf.layers.dense(tf.concat(encoder_state, axis=-1), args.embedding_size, use_bias=False), [batch_size, 1, args.embedding_size]) # [batch, max_kg_nums] kg_score = tf.reduce_sum(query * self.kg_key_avg, axis=2) # 对于hrt大于0的位置(即该位置存在知识),取对应的kg_score,否则对应位置为-inf kg_score = tf.where(tf.greater(self.kgs_hrt_length, 0), kg_score, -tf.ones_like(kg_score) * np.inf) # 计算每个知识对应的分数 [batch, max_kg_nums] kg_alignment = tf.nn.softmax(kg_score) # 根据计算的kg注意力分数的位置,计算关注的kg准确率和损失 kg_max = tf.argmax(kg_alignment, axis=-1) kg_max_onehot = tf.one_hot(kg_max, tf.shape(kg_alignment)[1], dtype=tf.float32) self.kg_acc = tf.reduce_sum( kg_max_onehot * self.kgs_index) / tf.maximum( tf.reduce_sum(tf.reduce_max(self.kgs_index, axis=-1)), tf.constant(1.0)) self.kg_loss = tf.reduce_sum( -tf.log(tf.clip_by_value(kg_alignment, 1e-12, 1.0)) * self.kgs_index, axis=1) / tf.maximum(tf.reduce_sum(self.kgs_index, axis=1), tf.ones([batch_size], dtype=tf.float32)) self.kg_loss = tf.reduce_mean(self.kg_loss) # 得到注意力之后的知识的嵌入:[batch, embed_size] self.knowledge_embed = tf.reduce_sum( tf.expand_dims(kg_alignment, axis=-1) * self.kg_value_avg, axis=1) # 对维度进行扩充[batch, decoder_len, embed_size] knowledge_embed_extend = tf.tile( tf.expand_dims(self.knowledge_embed, axis=1), [1, decoder_len, 1]) # 将知识和原始的解码输入拼接,作为新的解码输入 [batch, decoder_len, 2*embed_size] self.decoder_input = tf.concat( [self.decoder_input, knowledge_embed_extend], axis=2) # get output projection function output_fn = MyDense(data.vocab_size, use_bias=True) sampled_sequence_loss = output_projection_layer( args.dh_size, data.vocab_size, args.softmax_samples) encoder_len = tf.shape(encoder_output)[1] posts_mask = tf.sequence_mask(self.posts_length, encoder_len) prevs_mask = tf.sequence_mask(self.prevs_length, encoder_len) attention_mask = tf.reshape(tf.logical_xor(posts_mask, prevs_mask), [batch_size, encoder_len]) # construct helper and attention train_helper = tf.contrib.seq2seq.TrainingHelper( self.decoder_input, self.responses_length) #infer_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id) # 为了在推理的时候,每一次的输入都是上一次输出和知识的拼接 infer_helper = MyInferenceHelper(self.embed, tf.fill([batch_size], data.go_id), data.eos_id, self.knowledge_embed) #attn_mechanism = tf.contrib.seq2seq.BahdanauAttention(args.dh_size, encoder_output, # memory_sequence_length=self.posts_length) # 这里的MyAttention主要解决BahdanauAttention只能输入编码序列长度的问题 attn_mechanism = MyAttention(args.dh_size, encoder_output, attention_mask) cell_dec_attn = tf.contrib.seq2seq.AttentionWrapper( cell_dec, attn_mechanism, attention_layer_size=args.dh_size) enc_state_shaping = tf.layers.dense(encoder_state, args.dh_size, activation=None) dec_start = cell_dec_attn.zero_state( batch_size, dtype=tf.float32).clone(cell_state=enc_state_shaping) # build decoder (train) with tf.variable_scope('decoder'): decoder_train = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, train_helper, dec_start) train_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_train, impute_finished=True, scope="decoder_rnn") self.decoder_output = train_outputs.rnn_output #self.decoder_output = tf.nn.dropout(self.decoder_output, 0.8) # 输出概率分布和解码损失 self.decoder_distribution_teacher, self.decoder_loss, self.decoder_all_loss = \ sampled_sequence_loss(self.decoder_output, self.responses_target, self.decoder_mask) # build decoder (test) with tf.variable_scope('decoder', reuse=True): decoder_infer = tf.contrib.seq2seq.BasicDecoder( cell_dec_attn, infer_helper, dec_start, output_layer=output_fn) infer_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode( decoder_infer, impute_finished=True, maximum_iterations=args.max_decoder_length, scope="decoder_rnn") # 输出解码概率分布 self.decoder_distribution = infer_outputs.rnn_output self.generation_index = tf.argmax( tf.split(self.decoder_distribution, [2, data.vocab_size - 2], 2)[1], 2) + 2 # for removing UNK # calculate the gradient of parameters and update self.params = [ k for k in tf.trainable_variables() if args.name in k.name ] opt = tf.train.AdamOptimizer(self.learning_rate) # 将解码损失和kg损失相加 self.loss = self.decoder_loss + self.lamb * self.kg_loss gradients = tf.gradients(self.loss, self.params) clipped_gradients, self.gradient_norm = tf.clip_by_global_norm( gradients, args.grad_clip) self.update = opt.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step) # save checkpoint self.latest_saver = tf.train.Saver( write_version=tf.train.SaverDef.V2, max_to_keep=args.checkpoint_max_to_keep, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) self.best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) # create summary for tensorboard self.create_summary(args) def store_checkpoint(self, sess, path, key, name): if key == "latest": self.latest_saver.save(sess, path, global_step=self.global_step, latest_filename=name) else: self.best_saver.save(sess, path, global_step=self.global_step, latest_filename=name) #self.best_global_step = self.global_step def create_summary(self, args): self.summaryHelper = SummaryHelper("%s/%s_%s" % \ (args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), args) self.trainSummary = self.summaryHelper.addGroup( scalar=["loss", "perplexity"], prefix="train") scalarlist = ["loss", "perplexity"] tensorlist = [] textlist = [] for i in args.show_sample: textlist.append("show_str%d" % i) self.devSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="dev") self.testSummary = self.summaryHelper.addGroup(scalar=scalarlist, tensor=tensorlist, text=textlist, prefix="test") def print_parameters(self): for item in self.params: logger.info('%s: %s' % (item.name, item.get_shape())) def step_decoder(self, session, data, lamb=1.0, forward_only=False): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.prevs_length: data['prev_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.kgs: data['kg'], self.kgs_h_length: data['kg_h_length'], self.kgs_hr_length: data['kg_hr_length'], self.kgs_hrt_length: data['kg_hrt_length'], self.kgs_index: data['kg_index'], self.lamb: lamb, self.is_train: True } if forward_only: output_feed = [ self.decoder_loss, self.decoder_distribution_teacher, self.kg_loss, self.kg_acc ] else: output_feed = [ self.decoder_loss, self.gradient_norm, self.update, self.kg_loss, self.kg_acc ] return session.run(output_feed, input_feed) def inference(self, session, data, lamb=1.0): input_feed = { self.posts: data['post'], self.posts_length: data['post_length'], self.prevs_length: data['prev_length'], self.origin_responses: data['resp'], self.origin_responses_length: data['resp_length'], self.kgs: data['kg'], self.kgs_h_length: data['kg_h_length'], self.kgs_hr_length: data['kg_hr_length'], self.kgs_hrt_length: data['kg_hrt_length'], self.kgs_index: data['kg_index'], self.lamb: lamb, self.is_train: False } output_feed = [ self.generation_index, self.decoder_distribution_teacher, self.decoder_all_loss, self.kg_loss, self.kg_acc ] return session.run(output_feed, input_feed) def evaluate(self, sess, data, batch_size, key_name, lamb=1.0): loss = np.zeros((3, )) total_length = np.zeros((3, )) data.restart(key_name, batch_size=batch_size, shuffle=False) batched_data = data.get_next_batch(key_name) while batched_data != None: decoder_loss, _, kg_loss, kg_acc = self.step_decoder( sess, batched_data, lamb=lamb, forward_only=True) # 这里计算response中最长的长度 length = np.sum( np.maximum(np.array(batched_data['resp_length']) - 1, 0)) # 这里计算当前使用知识的总数 kg_length = np.sum(np.max(batched_data['kg_index'], axis=-1)) total_length += [length, kg_length, kg_length] loss += [ decoder_loss * length, kg_loss * kg_length, kg_acc * kg_length ] batched_data = data.get_next_batch(key_name) loss /= total_length logger.info( ' perplexity on %s set: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f' % (key_name, np.exp(loss[0]), np.exp(loss[1]), loss[1], loss[2])) return loss def train_process(self, sess, data, args): loss_step, time_step, epoch_step = np.zeros((3, )), .0, 0 previous_losses = [1e18] * 3 best_valid = 1e18 data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") for i in range(2): logger.info( f"post@ {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}" ) logger.info( f"length@ {batched_data['prev_length'][i], batched_data['post_length'][i]}" ) logger.info( f"last@ {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist()[batched_data['prev_length'][i]: batched_data['post_length'][i]], trim=False)}" ) logger.info( f"resp@ {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}" ) for epoch_step in range(args.epochs): while batched_data != None: if self.global_step.eval( ) % args.checkpoint_steps == 0 and self.global_step.eval( ) != 0: logger.info( "Epoch %d global step %d learning rate %.4f step-time %.2f perplexity: %.2f, kg_ppx: %.2f, kg_loss: %.4f, kg_acc: %.4f" % (epoch_step, self.global_step.eval(), self.learning_rate.eval(), time_step, np.exp(loss_step[0]), np.exp( loss_step[1]), loss_step[1], loss_step[2])) self.trainSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': loss_step[0], 'perplexity': np.exp(loss_step[0]) }) self.store_checkpoint( sess, '%s/checkpoint_latest/%s' % (args.model_dir, args.name), "latest", args.name) dev_loss = self.evaluate(sess, data, args.batch_size, "dev", lamb=args.lamb) self.devSummary( self.global_step.eval() // args.checkpoint_steps, { 'loss': dev_loss[0], 'perplexity': np.exp(dev_loss[0]) }) if np.sum(loss_step) > max(previous_losses): sess.run(self.learning_rate_decay_op) if dev_loss[0] < best_valid: best_valid = dev_loss[0] self.store_checkpoint( sess, '%s/checkpoint_best/%s' % (args.model_dir, args.name), "best", args.name) previous_losses = previous_losses[1:] + [ np.sum(loss_step[0]) ] loss_step, time_step = np.zeros((3, )), .0 start_time = time.time() step_out = self.step_decoder(sess, batched_data, lamb=args.lamb) loss_step += np.array([step_out[0], step_out[3], step_out[4] ]) / args.checkpoint_steps time_step += (time.time() - start_time) / args.checkpoint_steps batched_data = data.get_next_batch("train") data.restart("train", batch_size=args.batch_size, shuffle=True) batched_data = data.get_next_batch("train") def test_process_hits(self, sess, data, args): with open(os.path.join(args.datapath, 'test_distractors.json'), 'r', encoding='utf8') as f: test_distractors = json.load(f) data.restart("test", batch_size=1, shuffle=False) batched_data = data.get_next_batch("test") loss_record = [] cnt = 0 while batched_data != None: for key in batched_data: if isinstance(batched_data[key], np.ndarray): batched_data[key] = batched_data[key].tolist() batched_data['resp_length'] = [len(batched_data['resp'][0])] for each_resp in test_distractors[cnt]: batched_data['resp'].append( [data.go_id] + data.convert_tokens_to_ids(jieba.lcut(each_resp)) + [data.eos_id]) batched_data['resp_length'].append( len(batched_data['resp'][-1])) max_length = max(batched_data['resp_length']) resp = np.zeros((len(batched_data['resp']), max_length), dtype=int) for i, each_resp in enumerate(batched_data['resp']): resp[i, :len(each_resp)] = each_resp batched_data['resp'] = resp post = [] post_length = [] prev_length = [] kg = [] kg_h_length = [] kg_hr_length = [] kg_hrt_length = [] kg_index = [] for _ in range(len(resp)): post += batched_data['post'] post_length += batched_data['post_length'] prev_length += batched_data['prev_length'] kg += batched_data['kg'] kg_h_length += batched_data['kg_h_length'] kg_hr_length += batched_data['kg_hr_length'] kg_hrt_length += batched_data['kg_hrt_length'] kg_index += batched_data['kg_index'] batched_data['post'] = post batched_data['post_length'] = post_length batched_data['prev_length'] = prev_length batched_data['kg'] = kg batched_data['kg_h_length'] = kg_h_length batched_data['kg_hr_length'] = kg_hr_length batched_data['kg_hrt_length'] = kg_hrt_length batched_data['kg_index'] = kg_index _, _, loss, _, _ = self.inference(sess, batched_data, lamb=args.lamb) loss_record.append(loss) cnt += 1 batched_data = data.get_next_batch("test") assert cnt == len(test_distractors) loss = np.array(loss_record) loss_rank = np.argsort(loss, axis=1) hits1 = float(np.mean(loss_rank[:, 0] == 0)) hits3 = float(np.mean(np.min(loss_rank[:, :3], axis=1) == 0)) hits5 = float(np.mean(np.min(loss_rank[:, :5], axis=1) == 0)) return {'hits@1': hits1, 'hits@3': hits3, 'hits@5': hits5} def test_process(self, sess, data, args): metric1 = data.get_teacher_forcing_metric() metric2 = data.get_inference_metric() data.restart("test", batch_size=args.batch_size, shuffle=False) batched_data = data.get_next_batch("test") for i in range(2): logger.info( f"post@{i}: {data.convert_ids_to_tokens(batched_data['post_allvocabs'][i].tolist(), trim=False)}" ) logger.info( f"resp@{i}: {data.convert_ids_to_tokens(batched_data['resp_allvocabs'][i].tolist(), trim=False)}" ) while batched_data != None: batched_responses_id, gen_log_prob, _, _, _ = self.inference( sess, batched_data, lamb=args.lamb) metric1_data = { 'resp_allvocabs': np.array(batched_data['resp_allvocabs']), 'resp_length': np.array(batched_data['resp_length']), 'gen_log_prob': np.array(gen_log_prob) } metric1.forward(metric1_data) batch_results = [] for response_id in batched_responses_id: response_id_list = response_id.tolist() if data.eos_id in response_id_list: result_id = response_id_list[:response_id_list. index(data.eos_id) + 1] else: result_id = response_id_list batch_results.append(result_id) metric2_data = { 'gen': np.array(batch_results), 'resp_allvocabs': np.array(batched_data['resp_allvocabs']) } metric2.forward(metric2_data) batched_data = data.get_next_batch("test") res = metric1.close() res.update(metric2.close()) res.update(self.test_process_hits(sess, data, args)) test_file = args.output_dir + "/%s_%s.txt" % (args.name, "test") with open(test_file, 'w', encoding="utf-8") as f: print("Test Result:") res_print = list(res.items()) res_print.sort(key=lambda x: x[0]) for key, value in res_print: if isinstance(value, float): print("\t%s:\t%f" % (key, value)) f.write("%s:\t%f\n" % (key, value)) f.write('\n') for i in range(len(res['resp'])): f.write("resp:\t%s\n" % " ".join(res['resp'][i])) f.write("gen:\t%s\n\n" % " ".join(res['gen'][i])) logger.info("result output to %s." % test_file) return { key: val for key, val in res.items() if type(val) in [bytes, int, float] }