def next_inputs_fn( time, outputs, state, sample_ids): # define the next input by the current output [pi, mu1, mu2, sigma1, sigma2, corr, pen, pen_logits] = gmm.get_mixture_coef(outputs) idx_eos = tf.argmax(pen, axis=1) eos = tf.one_hot(idx_eos, depth=3) next_x1 = tf.reduce_sum(tf.multiply(mu1, pi), axis=1, keepdims=True) next_x2 = tf.reduce_sum(tf.multiply(mu2, pi), axis=1, keepdims=True) next_x = tf.concat([next_x1, next_x2], axis=1) next_inputs = tf.concat([next_x, eos], axis=1) # shape: (batch_size, 5) tmp = tf.ones([next_x.shape[0]]) elements_finished_1 = tf.equal( tmp, eos[:, -1] ) # this operation produces boolean tensor of [batch_size] elements_finished_2 = (time >= max_seq_len) elements_finished = tf.logical_or(elements_finished_1, elements_finished_2) next_state = state return elements_finished, next_inputs, next_state
def build_model(self, reuse=tf.AUTO_REUSE): """Define model architecture.""" self.enc_seq_lens = tf.placeholder( dtype=tf.int32, shape=[self.hps.batch_size]) # encoder actual input data length self.dec_seq_lens = tf.placeholder( dtype=tf.int32, shape=[self.hps.batch_size]) # decoder actual input data length # input of encoder, reference data. We insert (0, 0, 1, 0, 0) at timestep_0, so "max_seq_len + 1" self.enc_input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) # input of decoder, target data self.dec_input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) # encoding enc_input_x = self.enc_input_data[:, 1:self.hps.max_seq_len + 1, :] # R_1 ~ R_{max_seq_len} enc_all_h, enc_last_h = self.encoder(enc_input_x, self.enc_seq_lens, reuse=reuse) # decoding dec_input = self.dec_input_data[:, :self.hps. max_seq_len, :] # T0 ~ T_{max_seq_len-1} dec_out, timemajor_attn_hist = self.decoder(enc_last_h, enc_all_h, self.enc_seq_lens, dec_input) batch_major_attn_hist = tf.transpose(timemajor_attn_hist, perm=[1, 0, 2]) n_out = (3 + self.hps.num_mixture * 6) # decoder output dimension dec_out = tf.reshape( dec_out, [-1, n_out]) # shape = (batch_size * max_seq_len, n_out) # shape of first 6 tensors: (batch_size * max_seq_len, num_mixture), # shape of last 2 tensors: (batch_size * max_seq_len, 3) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = gmm.get_mixture_coef(dec_out) self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen_logits = o_pen_logits self.pen = o_pen # reshape target data so that it is compatible with prediction shape target = tf.reshape(self.dec_input_data[:, 1:self.hps.max_seq_len + 1, :], [-1, 5]) # (batch_size * max_seq_le, 5) [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1) pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) # Seq(16) and Seq(17) in paper Ld, Lc = gmm.get_loss(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits, x1_data, x2_data, pen_data) self.Ld = tf.reduce_sum(Ld) / tf.to_float( tf.reduce_sum(self.dec_seq_lens)) self.Lc = tf.reduce_mean(Lc) self.Loss = self.Ld + self.hps.lc_weight * self.Lc # Seq(19) in paper if self.hps.is_training: with tf.variable_scope("optimizer", reuse=reuse): self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) gvs = optimizer.compute_gradients(self.Loss) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs if grad is not None] self.train_op = optimizer.apply_gradients(capped_gvs) with tf.name_scope("summary"): Loss_summ = tf.summary.scalar("Loss", self.Loss) Ld_summ = tf.summary.scalar("Ld", self.Ld) Lc_summ = tf.summary.scalar("Lc", self.Lc) lr_summ = tf.summary.scalar("lr", self.lr) self.summ = tf.summary.merge( [Loss_summ, Ld_summ, Lc_summ, lr_summ]) else: assert self.hps.rnn_dropout_keep_prob == 1.0