def encoder(self, batch, sequence_lengths): tf.logging.info('model-encoder部分') unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states last_h_fw = self.enc_cell_fw.get_output( last_state_fw) #得到正向RNN的最后的隐层输出 last_h_bw = self.enc_cell_bw.get_output( last_state_bw) #得到反向RNN的最后的隐层输出 last_h = tf.concat([last_h_fw, last_h_bw], 1) #bi-direction rnn 的最后隐层结果 h #=============================================================================== mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig #128维的encoder最终输出
def cnn_encoder(self, img_x): with tf.variable_scope('ENC_CNN'): # high-pass filter x = self.high_pass_filtering(img_x) # [N, 48, 48, 1] # 6 conv layers x = self.conv_2d('conv1', x, filter_size=2, out_filters=4, strides=self.stride_arr(2)) # [N, 24, 24, 4] x = tf.nn.relu(x) x = self.conv_2d('conv2', x, filter_size=2, out_filters=4, strides=self.stride_arr(1)) # [N, 24, 24, 4] x = tf.nn.relu(x) x = self.conv_2d('conv3', x, filter_size=2, out_filters=8, strides=self.stride_arr(2)) # [N, 12, 12, 8] x = tf.nn.relu(x) x = self.conv_2d('conv4', x, filter_size=2, out_filters=8, strides=self.stride_arr(1)) # [N, 12, 12, 8] x = tf.nn.relu(x) x = self.conv_2d('conv5', x, filter_size=2, out_filters=8, strides=self.stride_arr(2)) # [N, 6, 6, 8] x = tf.nn.relu(x) x = self.conv_2d('conv6', x, filter_size=2, out_filters=8, strides=self.stride_arr(1)) # [N, 6, 6, 8] x = tf.tanh(x) x = tf.reshape(x, shape=[x.shape[0], -1]) # [N, 6 * 6 * 8] mu = rnn.super_linear( x, self.hps.z_size, scope='ENC_CNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( x, self.hps.z_size, scope='ENC_CNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def encoder(self, batch, sequence_lengths): """Define the bi-directional encoder module of sketch-rnn.""" unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states last_h_fw = self.enc_cell_fw.get_output(last_state_fw) last_h_bw = self.enc_cell_bw.get_output(last_state_bw) last_h = tf.concat([last_h_fw, last_h_bw], 1) mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def latent_space(self, reuse): with tf.variable_scope("latent_space", reuse=reuse): mu = rnn.super_linear( self.last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2 + self.expr_num, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( self.last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2 + self.expr_num, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) self.mean, self.presig = mu, presig self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal((self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps)
def encoder(self, batch, sequence_lengths): """Define the encoder module of sketch-rnn.""" mu = 0 presig = 0 # adding CNN option if self.hps.enc_CNN: last_h = cnn.cnn_model(batch) mu = rnn.super_linear(last_h, self.hps.z_size, scope='ENC_CNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear(last_h, self.hps.z_size, scope='ENC_CNN_sigma', init_w='gaussian', weight_start=0.001) # bi-directional RNN else: unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( #returns forward and backward tensors self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states #split the tuple #gets outputs of the final states last_h_fw = self.enc_cell_fw.get_output(last_state_fw) last_h_bw = self.enc_cell_bw.get_output(last_state_bw) last_h = tf.concat([last_h_fw, last_h_bw], 1) mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def get_mu_sig(self, encoded_h): input_size = int(encoded_h.shape[-1]) mu = rnn.super_linear(encoded_h, self.hps.z_size, input_size=input_size, scope='ENC_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear(encoded_h, self.hps.z_size, input_size=input_size, scope='ENC_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def cnn_encoder(self, goal_batch): image_input = goal_batch # tf.stack([current_batch, goal_batch], axis=3) flattened_conv = tf.layers.Flatten()( conv_model.make_model(image_input)) mu = rnn.super_linear(flattened_conv, self.hps.z_size, scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear(flattened_conv, self.hps.z_size, scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def encoder(self, batch): z = rnn.super_linear(batch, self.hps.z_size, input_size=4096, scope='latent_z', init_w='gaussian', weight_start=0.001) return z
def encoder(self, batch, label, sequence_lengths, conditioned=True): """Define the bi-directional encoder module of sketch-rnn.""" unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states last_h_fw = self.enc_cell_fw.get_output(last_state_fw) last_h_bw = self.enc_cell_bw.get_output(last_state_bw) if not conditioned: last_h_fw = tf.zeros_like(last_h_fw) last_h_bw = tf.zeros_like(last_h_bw) label_embedding = tf.nn.embedding_lookup( self.input_label_embedding_matrix, label) last_h = tf.concat([last_h_fw, last_h_bw, label_embedding], 1) mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2 + self.hps.label_embedding_dim, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2 + self.hps.label_embedding_dim, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def get_decoder_inputs(self, encoded_h, is_seq=True, name_scope=None): mean, presig = self.get_mu_sig(encoded_h) sigma = tf.exp(presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal((self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) batch_z = mean + tf.multiply(sigma, eps) # [N, z_size] if not is_seq: return batch_z, mean, presig pre_tile_y = tf.reshape(batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile( pre_tile_y, [1, self.hps.max_seq_len, 1]) # [N, max_seq_len, z_size] actual_input_x = tf.concat([self.input_x, overlay_x], 2) initial_state = tf.nn.tanh( rnn.super_linear(batch_z, self.dec_cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) if name_scope == 'p2s': # print('p2s seq decoder') self.initial_state_p2s = initial_state return batch_z, self.initial_state_p2s, actual_input_x, mean, presig elif name_scope == 's2s': # print('s2s seq decoder') self.initial_state_s2s = initial_state return batch_z, self.initial_state_s2s, actual_input_x, mean, presig else: raise Exception('Unknown name_scope', name_scope)
def build_model(self, hps): if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell #设置编码器为LSTM elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' #dropout技巧 use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout cell = cell_fn( hps.dec_rnn_size, #512 use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: if hps.enc_model == 'hyper': self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: #enc_model = lstm self.enc_cell_fw = enc_cell_fn( #双向RNN的正向 hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( #双向RNN的反向 hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) #false tf.logging.info('Output dropout mode = %s.', use_output_dropout) #false tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) #true if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) #==============================decode !!!============================ tf.logging.info('model- decoder 部分') self.cell = cell self.sequence_lengths = tf.placeholder(dtype=tf.int32, shape=[self.hps.batch_size ]) #batch 大小 #包含了起始符的decoder输入 self.source_input = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) self.target_input = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) tf.logging.info('model- encoder的输入') self.encoder_input_x = self.source_input[:, 1:self.hps.max_seq_len + 1, :] tf.logging.info('model- decoder的输入和标签') self.output_x = self.target_input[:, 1:self.hps.max_seq_len + 1, :] self.decoder_input_source_x = self.source_input[:, :self.hps. max_seq_len, :] self.decoder_input_target_x = self.target_input[:, :self.hps. max_seq_len, :] # 如果condition=true,输入加入隐含变量z if hps.conditional: self.mean, self.presig = self.encoder(self.encoder_input_x, self.sequence_lengths) self.sigma = tf.exp(self.presig / 2.0) eps = tf.random_normal((self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps) #self.kl_cost = -0.5 * tf.reduce_mean( # (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) #self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) #输入数据,得到隐含变量h pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) part_input_x = tf.concat([self.decoder_input_target_x, overlay_x], 2) actual_input_x = tf.concat( [self.decoder_input_source_x, part_input_x], 2) # decoder每一时刻的输入为soure 和 target输入的组合 按咧拼接’ self.initial_state = tf.nn.tanh( rnn.super_linear(self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) # unconditional, decoder-only generation else: self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size), dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state(batch_size=hps.batch_size, dtype=tf.float32) tf.logging.info( 'model- 开始高斯混合模型采样了======================================') self.num_mixture = hps.num_mixture #混合高斯模型的高斯个数 20 n_out = (3 + self.num_mixture * 6) #解码器输出y的维度为 5M + M + 3,分别表示高斯函数参数、权重、(p1,p2,p3) with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) #全连接层,接n_out个神经元 self.final_state = last_state # x1 x2分别是坐标x轴和y轴的偏移量,result为二维正态分布的概率密度 def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): tf.logging.info('model- 根据那篇文章的公式计算二维正态分布的概率密度') norm1 = tf.subtract(x1, mu1) norm2 = tf.subtract(x2, mu2) s1s2 = tf.multiply(s1, s2) z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)) neg_rho = 1 - tf.square(rho) result = tf.exp(tf.div(-z, 2 * neg_rho)) denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho)) result = tf.div(result, denom) return result #计算重构误差,求上面概率密度的对数 def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, pen_data): tf.logging.info('model- 计算重构误差') #采样(x,y)数据的误差 result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) epsilon = 1e-6 result1 = tf.multiply(result0, z_pi) #每一个高斯模型乘上相应的权重 result1 = tf.reduce_sum(result1, 1, keep_dims=True) #对所有求和 result1 = -tf.log(result1 + epsilon) # avoid log(0) fs = 1.0 - pen_data[:, 2] #如果最后一个点为1,则为终点概率为0,反之,为1. fs = tf.reshape(fs, [-1, 1]) #将batch * 1 的数据转为 1* batch的数据 result1 = tf.multiply(result1, fs) # result2: loss wrt pen state, (L_p in equation 9) result2 = tf.nn.softmax_cross_entropy_with_logits( #Lp就是求交叉熵 labels=pen_data, logits=z_pen_logits) result2 = tf.reshape(result2, [-1, 1]) if not self.hps.is_training: # eval mode, mask eos columns result2 = tf.multiply(result2, fs) result = result1 + result2 return result # below is where we need to do MDN (Mixture Density Network) splitting of # distribution params def get_mixture_coef(output): tf.logging.info('model- 将decoder的网络输出切分成高斯混合模型的参数') """Returns the tf slices containing mdn dist params.""" #根据decoder输出的6M+8的值,构建混合高斯模型 z = output z_pen_logits = z[:, 0:3] # pen states z的前三个值为(p1,p2,p3) #剩下的6个分别为权重pi,miu(x),miu(y), z_sigma1, z_sigma2, z_corr z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split( z[:, 3:], 6, 1) # process output z's into MDN paramters # softmax all the pi's and pen states: z_pi = tf.nn.softmax(z_pi) z_pen = tf.nn.softmax(z_pen_logits) #对(p1,p2,p3)和pi的值做做logit处理,是的其都为正,且加起来为1 # exponentiate the sigmas and also make corr between -1 and 1. z_sigma1 = tf.exp(z_sigma1) z_sigma2 = tf.exp(z_sigma2) z_corr = tf.tanh(z_corr) r = [ z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits ] return r tf.logging.info('model- 调用切分输出函数') out = get_mixture_coef(output) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen_logits = o_pen_logits self.pen = o_pen # 重构误差 target = tf.reshape(self.output_x, [-1, 5]) #目标输出(x,y ,p1,p2,p3) [x1_data, x2_data, p1, p2, p3] = tf.split(target, 5, 1) pen_data = tf.concat([p1, p2, p3], 1) lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, x1_data, x2_data, pen_data) self.r_cost = tf.reduce_mean(lossfunc) # identify loss sourcess = tf.reshape(self.encoder_input_x, [-1, 5]) #网络输入(x,y ,p1,p2,p3) [s1_data, s2_data, sp1, sp2, sp3] = tf.split(sourcess, 5, 1) spen_data = tf.concat([sp1, sp2, sp3], 1) identyfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, s1_data, s2_data, spen_data) self.il_cost = tf.reduce_mean(identyfunc) if self.hps.is_training: tf.logging.info('model- 选择学习率和误差函数') self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) #使用ADAM优化方式 #self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.il_weight = tf.Variable(self.hps.il_weight_start, trainable=False) self.cost = self.r_cost + self.il_cost * self.il_weight gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip # 1.0 capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')
def encoder(self, batch, sequence_lengths): """Define the bi-directional encoder module of sketch-rnn.""" unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states last_h_fw = self.enc_cell_fw.get_output(last_state_fw) last_h_bw = self.enc_cell_bw.get_output(last_state_bw) last_h = tf.concat([last_h_fw, last_h_bw], 1) mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) # Will most likely have to change the size. Maybe create an array of the # super linear terms? But either way will have to separate into u, b, w. u_array = [] w_array = [] b_array = [] for i in range(self.hps.num_flows): u = rnn.super_linear( last_h, # self.hps.num_flows * (self.hps.z_size * 2 + 1), (self.hps.z_size), input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_u_' + str(i), init_w='gaussian', weight_start=0.001) u_array.append(u) w = rnn.super_linear( last_h, # self.hps.num_flows * (self.hps.z_size * 2 + 1), (self.hps.z_size), input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_w_' + str(i), init_w='gaussian', weight_start=0.001) w_array.append(w) b = rnn.super_linear( last_h, # self.hps.num_flows * (self.hps.z_size * 2 + 1), 1, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_b_' + str(i), init_w='gaussian', weight_start=0.001) b_array.append(b) return mu, presig, u_array, w_array, b_array
def build_model(self, hps): """Define model architecture.""" if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout cell = cell_fn( hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: # vae mode: if hps.enc_model == 'hyper': self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) tf.logging.info('Output dropout mode = %s.', use_output_dropout) tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) self.cell = cell self.sequence_lengths = tf.placeholder( dtype=tf.int32, shape=[self.hps.batch_size]) self.input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) # The target/expected vectors of strokes self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :] # vectors of strokes to be fed to decoder (same as above, but lagged behind # one step to include initial dummy value of (0, 0, 1, 0, 0)) self.input_x = self.input_data[:, :self.hps.max_seq_len, :] # either do vae-bit and get z, or do unconditional, decoder-only if hps.conditional: # vae mode: self.mean, self.presig, self.us, self.ws, self.bs = self.encoder(self.output_x, self.sequence_lengths) self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal( (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps) # Apply normalizing flow self.sum_log_det_jacobian = 0. h_prime = lambda x: 1 - tf.tanh(x) ** 2 # Derivative of nonlinearity h above for i in range(self.hps.num_flows): psi = h_prime(tf.expand_dims(tf.reduce_sum( self.batch_z * self.ws[i], -1), -1) + self.bs[i]) * self.ws[i] self.sum_log_det_jacobian += tf.log(tf.abs(1 + tf.reduce_sum(psi * self.us[i], -1))) self.batch_z = self.planar_flow(self.batch_z, self.ws[i], self.us[i], self.bs[i]) # KL cost self.kl_cost = -0.5 * tf.reduce_mean( (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) self.kl_cost -= tf.reduce_mean(self.sum_log_det_jacobian) self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) self.initial_state = tf.nn.tanh( rnn.super_linear( self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) else: # unconditional, decoder-only generation self.batch_z = tf.zeros( (self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.kl_cost = tf.zeros([], dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state( batch_size=hps.batch_size, dtype=tf.float32) self.num_mixture = hps.num_mixture # TODO(deck): Better understand this comment. # Number of outputs is 3 (one logit per pen state) plus 6 per mixture # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the # mixture weight/probability (Pi_k) n_out = (3 + self.num_mixture * 6) with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state # NB: the below are inner functions, not methods of Model def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850.""" norm1 = tf.subtract(x1, mu1) norm2 = tf.subtract(x2, mu2) s1s2 = tf.multiply(s1, s2) # eq 25 z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)) neg_rho = 1 - tf.square(rho) result = tf.exp(tf.div(-z, 2 * neg_rho)) denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho)) result = tf.div(result, denom) return result def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, pen_data): """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850.""" # This represents the L_R only (i.e. does not include the KL loss term). result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) epsilon = 1e-6 # result1 is the loss wrt pen offset (L_s in equation 9 of # https://arxiv.org/pdf/1704.03477.pdf) result1 = tf.multiply(result0, z_pi) result1 = tf.reduce_sum(result1, 1, keep_dims=True) result1 = -tf.log(result1 + epsilon) # avoid log(0) fs = 1.0 - pen_data[:, 2] # use training data for this fs = tf.reshape(fs, [-1, 1]) # Zero out loss terms beyond N_s, the last actual stroke result1 = tf.multiply(result1, fs) # result2: loss wrt pen state, (L_p in equation 9) result2 = tf.nn.softmax_cross_entropy_with_logits( labels=pen_data, logits=z_pen_logits) result2 = tf.reshape(result2, [-1, 1]) if not self.hps.is_training: # eval mode, mask eos columns result2 = tf.multiply(result2, fs) result = result1 + result2 return result # below is where we need to do MDN (Mixture Density Network) splitting of # distribution params def get_mixture_coef(output): """Returns the tf slices containing mdn dist params.""" # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850. z = output z_pen_logits = z[:, 0:3] # pen states z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1) # process output z's into MDN parameters # softmax all the pi's and pen states: z_pi = tf.nn.softmax(z_pi) z_pen = tf.nn.softmax(z_pen_logits) # exponentiate the sigmas and also make corr between -1 and 1. z_sigma1 = tf.exp(z_sigma1) z_sigma2 = tf.exp(z_sigma2) z_corr = tf.tanh(z_corr) r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits] return r out = get_mixture_coef(output) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen_logits = o_pen_logits # pen state probabilities (result of applying softmax to self.pen_logits) self.pen = o_pen # reshape target data so that it is compatible with prediction shape target = tf.reshape(self.output_x, [-1, 5]) [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1) pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, x1_data, x2_data, pen_data) self.r_cost = tf.reduce_mean(lossfunc) if self.hps.is_training: self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.cost = self.r_cost + self.kl_cost * self.kl_weight gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')
def build_model(self, hps): """Define model architecture.""" if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout cell = cell_fn(hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: # vae mode: self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) tf.logging.info('Output dropout mode = %s.', use_output_dropout) tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) self.cell = cell self.input_data = tf.placeholder(dtype=tf.float32, shape=[ self.hps.batch_size, self.hps.max_seq_len, self.hps.input_dimension_get ]) self.input_handle = self.input_data[:, :, :self.hps. input_dimension_real] # The target/expected vectors of strokes self.output_x = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len, 1]) # always 0 or 1 indicates one's action # vectors of strokes to be fed to decoder (same as above, but lagged behind # one step to include initial dummy value of (0, 0, 1, 0, 0)) self.input_x = self.input_handle # either do vae-bit and get z, or do unconditional, decoder-only if hps.conditional: # vae mode: _ = tf.concat([ self.input_handle[:, :self.hps.input_seq_len, :], self.output_x[:, :self.hps.input_seq_len, :] ], axis=2) self.mean = self.encoder(_) # self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. # eps = tf.random_normal( # (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean # KL cost self.kl_cost = -0.5 * tf.reduce_mean((1 + -tf.square(self.mean))) self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) self.initial_state = tf.nn.tanh( rnn.super_linear(self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) else: # unconditional, decoder-only generation self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.kl_cost = tf.zeros([], dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state(batch_size=hps.batch_size, dtype=tf.float32) # TODO(deck): Better understand this comment. # Number of outputs is 3 (one logit per pen state) plus 6 per mixture # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the # mixture weight/probability (Pi_k) # n_out = (3 + self.num_mixture * 6) n_out = 1 with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state label = tf.reshape(self.output_x, [self.hps.batch_size, self.hps.max_seq_len]) out = tf.reshape(output, [self.hps.batch_size, self.hps.max_seq_len]) self.sigmoid_out = tf.sigmoid(out) # self.r_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=out[:,self.hps.input_seq_len:], labels=label[:,self.hps.input_seq_len:])) self.r_cost = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=out, labels=label)) if self.hps.is_training: self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.cost = self.r_cost + self.kl_cost * self.kl_weight gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')
def build_model(self, hps): """Define model architecture.""" if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = False if self.hps.use_recurrent_dropout == 1: use_recurrent_dropout = True use_input_dropout = False if self.hps.use_input_dropout == 0 else True use_output_dropout = False if self.hps.use_output_dropout == 0 else True if hps.dec_model == 'hyper': cell = cell_fn(hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: cell = cell_fn(hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: # vae mode: if hps.enc_model == 'hyper': self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) tf.logging.info('Output dropout mode = %s.', use_output_dropout) tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) self.cell = cell self.sequence_lengths = tf.placeholder(dtype=tf.int32, shape=[self.hps.batch_size]) self.input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) self.input_x = self.input_data[:, :self.hps.max_seq_len, :] self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :] self.img_data = tf.placeholder( tf.float32, [self.hps.batch_size, self.hps.img_size, self.hps.img_size, 1]) # either do vae-bit and get z, or do unconditional, decoder-only if hps.conditional: # vae mode: self.mean, self.presig = self.cnn_encoder(self.img_data) self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal((self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps) # KL cost self.kl1 = -0.5 * tf.reduce_mean(self.presig) self.kl2 = -0.5 * tf.reduce_mean(-tf.square(self.mean)) self.kl3 = -0.5 * tf.reduce_mean(-tf.exp(self.presig)) self.kl_cost = -0.5 * tf.reduce_mean( (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) self.initial_state = tf.nn.tanh( rnn.super_linear(self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) else: # unconditional, decoder-only generation self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.kl_cost = tf.zeros([], dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state(batch_size=hps.batch_size, dtype=tf.float32) self.num_mixture = hps.num_mixture # TODO(deck): Better understand this comment. # Number of outputs is end_of_stroke + prob + 2*(mu + sig) + corr n_out = (3 + self.num_mixture * 6) with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): """Returns result of eq # 24 and 25 of http://arxiv.org/abs/1308.0850.""" norm1 = tf.subtract(x1, mu1) norm2 = tf.subtract(x2, mu2) s1s2 = tf.multiply(s1, s2) z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)) neg_rho = 1 - tf.square(rho) result = tf.exp(tf.div(-z, 2 * neg_rho)) denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho)) result = tf.div(result, denom) return result def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, pen_data): """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850.""" result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) epsilon = 1e-6 result1 = tf.multiply(result0, z_pi) result1 = tf.reduce_sum(result1, 1, keep_dims=True) result1 = -tf.log(result1 + epsilon) # avoid log(0) fs = 1.0 - pen_data[:, 2] # use training data for this fs = tf.reshape(fs, [-1, 1]) result1 = tf.multiply(result1, fs) result2 = tf.nn.softmax_cross_entropy_with_logits( labels=pen_data, logits=z_pen_logits) result2 = tf.reshape(result2, [-1, 1]) if not self.hps.is_training: # eval mode, mask eos columns result2 = tf.multiply(result2, fs) result = result1 + result2 return result # below is where we need to do MDN splitting of distribution params def get_mixture_coef(output): """Returns the tf slices containing mdn dist params.""" # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850. z = output z_pen_logits = z[:, 0:3] # pen states z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split( z[:, 3:], 6, 1) # process output z's into MDN paramters # softmax all the pi's and pen states: z_pi = tf.nn.softmax(z_pi) z_pen = tf.nn.softmax(z_pen_logits) # exponentiate the sigmas and also make corr between -1 and 1. z_sigma1 = tf.exp(z_sigma1) z_sigma2 = tf.exp(z_sigma2) z_corr = tf.tanh(z_corr) r = [ z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits ] return r out = get_mixture_coef(output) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen = o_pen # state of the pen self.pen_logits = o_pen_logits # state of the pen # reshape target data so that it is compatible with prediction shape target = tf.reshape(self.output_x, [-1, 5]) [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1) pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, x1_data, x2_data, pen_data) self.r_cost = tf.reduce_mean(lossfunc) if self.hps.is_training: self.cost = self.r_cost + self.kl_cost * self.hps.kl_weight self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) if self.hps.kl_weight != 0: self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.cost = self.r_cost + self.kl_cost * self.kl_weight else: self.kl_weight = tf.Variable(self.hps.kl_weight, trainable=False) self.cost = self.r_cost gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')
def decoder(self, reuse): with tf.variable_scope("decode", reuse=reuse): if self.hps.conditional: self.batch_z = tf.concat([self.batch_z, self.c_expr], 1) pre_tile_y = tf.reshape( self.batch_z, [self.hps.batch_size, 1, self.hps.z_size + self.expr_num]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1 ]) #replicating input multiples times actual_input_x = tf.concat([self.input_x, overlay_x], 2) else: self.batch_z = tf.zeros((self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.batch_z = tf.concat([self.batch_z, self.c_expr], 1) actual_input_x = self.input_x self.num_mixture = self.hps.num_mixture # Number of mixtures in Gaussian mixture model. n_out = (3 + self.num_mixture * 6) output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) if self.hps.conditional: self.initial_state = tf.nn.tanh( rnn.super_linear(self.batch_z, self.cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size + self.expr_num)) else: self.initial_state = self.cell.zero_state( batch_size=self.hps.batch_size, dtype=tf.float32) output, last_state = tf.nn.dynamic_rnn( self.cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') self.c_kl_batch_train = tf.zeros([], dtype=tf.float32) self.training_logits = output output = tf.reshape(self.training_logits, [-1, self.hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state out = self.get_mixture_coef(output) [ o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits ] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen_logits = o_pen_logits self.pen = o_pen