def bow_predict_seq2seq(enc_seq2seq_inputs, enc_seq2seq_targets, enc_seq2seq_lens, embedding_matrix, enc_outputs, enc_state, enc_layers, num_paraphrase, max_len, enc_lens, batch_size, vocab_size, state_size, drop_out, dec_start_id): """bow prediction as sequence to sequence""" enc_seq2seq_inputs = tf.nn.embedding_lookup( embedding_matrix, enc_seq2seq_inputs) # [B, P, T, S] -> [P, B, T, S] enc_seq2seq_inputs = tf.transpose(enc_seq2seq_inputs, [1, 0, 2, 3]) # [B, P, T] -> [P, B, T] enc_seq2seq_targets = tf.transpose(enc_seq2seq_targets, [1, 0, 2]) # [B, P] -> [P, B] enc_seq2seq_lens = tf.transpose(enc_seq2seq_lens, [1, 0]) init_state = enc_state enc_pred_loss = 0.0 bow_topk_prob = tf.zeros([batch_size, vocab_size]) enc_infer_pred = [] for i in range(num_paraphrase): # encoder prediction cell enc_pred_cell = [create_cell("enc_pred_p_%d_l_%d" % (i, j), state_size, drop_out) for j in range(enc_layers)] enc_pred_cell = tf.nn.rnn_cell.MultiRNNCell(enc_pred_cell) # projection enc_pred_proj = tf.layers.Dense(vocab_size, name="enc_pred_proj", kernel_initializer=tf.random_normal_initializer(stddev=0.05), bias_initializer=tf.constant_initializer(0.)) # greedy decoding and training _, enc_seq_predict = decoding_infer(dec_start_id, enc_pred_cell, enc_pred_proj, embedding_matrix, init_state, enc_outputs, batch_size, max_len, enc_lens, max_len, is_attn=True) enc_infer_pred.append(enc_seq_predict) enc_pred_inputs = enc_seq2seq_inputs[i] enc_seq_train = decoding_train( enc_pred_inputs, enc_pred_cell, init_state, enc_outputs, max_len, enc_lens, max_len, is_attn=True) enc_seq_train_logits = enc_pred_proj(enc_seq_train) # sequence to sequence loss enc_seq_mask = tf.sequence_mask( enc_seq2seq_lens[i], max_len, dtype=tf.float32) enc_seq_loss = tf.contrib.seq2seq.sequence_loss( enc_seq_train_logits, enc_seq2seq_targets[i], enc_seq_mask) enc_pred_loss += enc_seq_loss # prediction probability enc_pred_prob = tf.nn.softmax(enc_seq_train_logits) # [B, T, V] enc_pred_prob *= tf.expand_dims(enc_seq_mask, [2]) # [B, T, 1] enc_pred_prob = tf.reduce_sum(enc_pred_prob, axis=1) # [B, V] # NOTE: prob of certain words will be repeatedly calculated bow_topk_prob += enc_pred_prob enc_pred_loss /= num_paraphrase enc_infer_pred = tf.stack(enc_infer_pred) # [P, B, T] enc_infer_pred = tf.transpose(enc_infer_pred, [1, 0, 2]) # [B, P, T] return bow_topk_prob, enc_pred_loss, enc_infer_pred
def build(self): """Build the model""" print("Building the sequence to sequence model ... ") vocab_size = self.vocab_size state_size = self.state_size enc_layers = self.enc_layers # Placeholders with tf.name_scope("placeholders"): enc_inputs = tf.placeholder(tf.int32, [None, None], "enc_inputs") inp_lens = tf.placeholder(tf.int32, [None], "inp_lens") self.drop_out = tf.placeholder(tf.float32, (), "drop_out") self.enc_inputs = enc_inputs self.inp_lens = inp_lens if (self.mode == "train"): dec_inputs = tf.placeholder(tf.int32, [None, None], "dec_inputs") targets = tf.placeholder(tf.int32, [None, None], "targets") out_lens = tf.placeholder(tf.int32, [None], "out_lens") self.learning_rate = tf.placeholder(tf.float32, (), "learning_rate") self.lambda_kl = tf.placeholder(tf.float32, (), "lambda_kl") self.dec_inputs = dec_inputs self.targets = targets self.out_lens = out_lens batch_size = tf.shape(enc_inputs)[0] max_len = tf.shape(enc_inputs)[1] # Embedding with tf.variable_scope("embeddings"): embedding_matrix = tf.get_variable( name="embedding_matrix", shape=[vocab_size, state_size], dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.05)) enc_inputs = tf.nn.embedding_lookup(embedding_matrix, enc_inputs) if (self.mode == "train"): dec_inputs = tf.nn.embedding_lookup(embedding_matrix, dec_inputs) # Encoder with tf.variable_scope("encoder"): # TODO: residual LSTM, layer normalization # if(self.bidirectional) # enc_cell_fw = [create_cell( # "enc-fw-%d" % i, state_size, self.drop_out, self.no_residual) # for i in range(enc_layers)] # enc_cell_bw = [create_cell( # "enc-bw-%d" % i, state_size, self.drop_out, self.no_residual) # for i in range(enc_layers)] # else: enc_cell = [ create_cell("enc-%d" % i, state_size, self.drop_out, self.no_residual) for i in range(enc_layers) ] enc_cell = tf.nn.rnn_cell.MultiRNNCell(enc_cell) enc_outputs, enc_state = tf.nn.dynamic_rnn( enc_cell, enc_inputs, sequence_length=inp_lens, dtype=tf.float32) # Decoder with tf.variable_scope("decoder"): dec_cell = [ create_cell("dec-%d" % i, state_size, self.drop_out, self.no_residual) for i in range(enc_layers) ] dec_cell = tf.nn.rnn_cell.MultiRNNCell(dec_cell) dec_proj = tf.layers.Dense( vocab_size, name="dec_proj", kernel_initializer=tf.random_normal_initializer(stddev=0.05), bias_initializer=tf.constant_initializer(0.)) # latent code if (self.vae): print("Using vae model") with tf.variable_scope("latent_code"): enc_mean = tf.reduce_sum(enc_outputs, 1) enc_mean /= tf.expand_dims(tf.cast(inp_lens, tf.float32), [1]) z_code = enc_mean if (self.prior == "gaussian"): print("Gaussian prior") latent_proj = tf.layers.Dense( 2 * state_size, name="latent_proj", kernel_initializer=tf.random_normal_initializer( stddev=0.05), bias_initializer=tf.constant_initializer(0.)) z_loc, z_scale = tf.split(latent_proj(z_code), [state_size, state_size], 1) z_mvn = tfd.MultivariateNormalDiag(z_loc, z_scale) z_sample = z_mvn.sample() elif (self.prior == "vmf"): # print("vmf prior") # latent_proj = tf.layers.Dense(state_size + 1, name="latent_proj", # kernel_initializer=tf.random_normal_initializer(stddev=0.05), # bias_initializer=tf.constant_initializer(0.)) # z_mu, z_conc = tf.split( # latent_proj(z_code), [state_size, 1], 1) # z_mu /= tf.expand_dims(tf.norm(z_mu, axis=1), axis=1) # z_conc = tf.reshape(z_conc, [batch_size]) # z_vmf = tfd.VonMisesFisher(z_mu, z_conc) # z_sample = z_vmf.sample() pass dec_init_state = (LSTMStateTuple(c=z_sample, h=z_sample), LSTMStateTuple(c=z_sample, h=z_sample)) else: print("Using normal seq2seq, no latent variable") dec_init_state = enc_state with tf.variable_scope("decoding"): # greedy decoding _, dec_outputs_predict = decoding_infer(self.dec_start_id, dec_cell, dec_proj, embedding_matrix, dec_init_state, enc_outputs, batch_size, max_len, inp_lens, max_len, self.is_attn, self.sampling_method, self.topk_sampling_size, state_size=self.state_size) # decoding with forward sampling # dec_outputs_sampling = decodeing_infer() # TBC if (self.mode == "train"): # training decoding dec_logits_train, _, _, _, _ = decoding_train( dec_inputs, dec_cell, dec_proj, dec_init_state, enc_outputs, max_len, inp_lens, max_len, self.is_attn, self.state_size) all_variables = slim.get_variables_to_restore() model_variables = [ var for var in all_variables if var.name.split("/")[0] == self.model_name ] print("%s model, variable list:" % self.model_name) for v in model_variables: print(" %s" % v.name) self.model_saver = tf.train.Saver(all_variables, max_to_keep=3) # loss and optimizer dec_mask = tf.sequence_mask(out_lens, max_len, dtype=tf.float32) dec_loss = tf.contrib.seq2seq.sequence_loss( dec_logits_train, targets, dec_mask) if (self.vae): if (self.prior == "gaussian"): standard_normal = tfd.MultivariateNormalDiag( tf.zeros(state_size), tf.ones(state_size)) prior_prob = standard_normal.log_prob(z_sample) # [B] posterior_prob = z_mvn.log_prob(z_sample) # [B] kl_loss = tf.reduce_mean(posterior_prob - prior_prob) loss = dec_loss + self.lambda_kl * kl_loss elif (self.prior == "vmf"): # vmf_mu_0 = tf.ones(state_size) / tf.cast(state_size, tf.float32) # standard_vmf = tfd.VonMisesFisher(vmf_mu_0, 0) # prior_prob = standard_vmf.log_prob(z_sample) # [B] # posterior_prob = z_vmf.log_prob(z_sample) # [B] # kl_loss = tf.reduce_mean(posterior_prob - prior_prob) # loss = dec_loss + self.lambda_kl * kl_loss pass else: loss = dec_loss optimizer = tf.train.AdamOptimizer(self.learning_rate) train_op = optimizer.minimize(loss) self.train_output = {"train_op": train_op, "loss": loss} self.train_output.update(self.inspect) if (self.vae): self.train_output["dec_loss"] = dec_loss self.train_output["kl_loss"] = kl_loss self.valid_output = {"nll": tf.exp(loss)} self.infer_output = {"dec_predict": dec_outputs_predict} else: self.infer_output = {"dec_predict": dec_outputs_predict} return