def encoder(self, batch, sequence_lengths): """Define the bi-directional encoder module of sketch-rnn.""" unused_outputs, last_states = tf.nn.bidirectional_dynamic_rnn( self.enc_cell_fw, self.enc_cell_bw, batch, sequence_length=sequence_lengths, time_major=False, swap_memory=True, dtype=tf.float32, scope='ENC_RNN') last_state_fw, last_state_bw = last_states last_h_fw = self.enc_cell_fw.get_output(last_state_fw) last_h_bw = self.enc_cell_bw.get_output(last_state_bw) last_h = tf.concat([last_h_fw, last_h_bw], 1) mu = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear( last_h, self.hps.z_size, input_size=self.hps.enc_rnn_size * 2, # bi-dir, so x2 scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def get_mu_sig(self, image_embedding): enc_size = int(image_embedding.shape[-1]) mu = rnn.super_linear(image_embedding, self.hps.z_size, input_size=enc_size, scope='ENC_RNN_mu', init_w='gaussian', weight_start=0.001) presig = rnn.super_linear(image_embedding, self.hps.z_size, input_size=enc_size, scope='ENC_RNN_sigma', init_w='gaussian', weight_start=0.001) return mu, presig
def get_init_state(self, image_embedding): self.mean, self.presig = self.get_mu_sig(image_embedding) self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal((self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) # batch_z = self.mean + tf.multiply(self.sigma, eps) if self.hps.is_train: batch_z = self.mean + tf.multiply(self.sigma, eps) else: batch_z = self.mean if self.hps.inter_z: batch_z = self.mean + tf.multiply(self.sigma, self.sample_gussian) # KL cost kl_cost = -0.5 * tf.reduce_mean( (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) kl_cost = tf.maximum(kl_cost, self.hps.kl_tolerance) # get initial state based on batch_z initial_state = tf.nn.tanh( rnn.super_linear(batch_z, self.cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) pre_tile_y = tf.reshape(batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) return initial_state, actual_input_x, batch_z, kl_cost
def build_model(self, hps): """Define model architecture.""" if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = self.hps.use_recurrent_dropout use_input_dropout = self.hps.use_input_dropout use_output_dropout = self.hps.use_output_dropout cell = cell_fn( hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: # vae mode: if hps.enc_model == 'hyper': self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) tf.logging.info('Output dropout mode = %s.', use_output_dropout) tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) self.cell = cell self.sequence_lengths = tf.placeholder( dtype=tf.int32, shape=[self.hps.batch_size]) self.input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) # The target/expected vectors of strokes self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :] # vectors of strokes to be fed to decoder (same as above, but lagged behind # one step to include initial dummy value of (0, 0, 1, 0, 0)) self.input_x = self.input_data[:, :self.hps.max_seq_len, :] # either do vae-bit and get z, or do unconditional, decoder-only if hps.conditional: # vae mode: self.mean, self.presig = self.encoder(self.output_x, self.sequence_lengths) self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal( (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps) # KL cost self.kl_cost = -0.5 * tf.reduce_mean( (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) self.initial_state = tf.nn.tanh( rnn.super_linear( self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) else: # unconditional, decoder-only generation self.batch_z = tf.zeros( (self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.kl_cost = tf.zeros([], dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state( batch_size=hps.batch_size, dtype=tf.float32) self.num_mixture = hps.num_mixture # TODO(deck): Better understand this comment. # Number of outputs is 3 (one logit per pen state) plus 6 per mixture # component: mean_x, stdev_x, mean_y, stdev_y, correlation_xy, and the # mixture weight/probability (Pi_k) n_out = (3 + self.num_mixture * 6) with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state # NB: the below are inner functions, not methods of Model def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): """Returns result of eq # 24 of http://arxiv.org/abs/1308.0850.""" norm1 = tf.subtract(x1, mu1) norm2 = tf.subtract(x2, mu2) s1s2 = tf.multiply(s1, s2) # eq 25 z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)) neg_rho = 1 - tf.square(rho) result = tf.exp(tf.div(-z, 2 * neg_rho)) denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho)) result = tf.div(result, denom) return result def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, pen_data): """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850.""" # This represents the L_R only (i.e. does not include the KL loss term). result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) epsilon = 1e-6 # result1 is the loss wrt pen offset (L_s in equation 9 of # https://arxiv.org/pdf/1704.03477.pdf) result1 = tf.multiply(result0, z_pi) result1 = tf.reduce_sum(result1, 1, keep_dims=True) result1 = -tf.log(result1 + epsilon) # avoid log(0) fs = 1.0 - pen_data[:, 2] # use training data for this fs = tf.reshape(fs, [-1, 1]) # Zero out loss terms beyond N_s, the last actual stroke result1 = tf.multiply(result1, fs) # result2: loss wrt pen state, (L_p in equation 9) result2 = tf.nn.softmax_cross_entropy_with_logits( labels=pen_data, logits=z_pen_logits) result2 = tf.reshape(result2, [-1, 1]) if not self.hps.is_training: # eval mode, mask eos columns result2 = tf.multiply(result2, fs) result = result1 + result2 return result # below is where we need to do MDN (Mixture Density Network) splitting of # distribution params def get_mixture_coef(output): """Returns the tf slices containing mdn dist params.""" # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850. z = output z_pen_logits = z[:, 0:3] # pen states z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1) # process output z's into MDN paramters # softmax all the pi's and pen states: z_pi = tf.nn.softmax(z_pi) z_pen = tf.nn.softmax(z_pen_logits) # exponentiate the sigmas and also make corr between -1 and 1. z_sigma1 = tf.exp(z_sigma1) z_sigma2 = tf.exp(z_sigma2) z_corr = tf.tanh(z_corr) r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits] return r out = get_mixture_coef(output) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen_logits = o_pen_logits # pen state probabilities (result of applying softmax to self.pen_logits) self.pen = o_pen # reshape target data so that it is compatible with prediction shape target = tf.reshape(self.output_x, [-1, 5]) [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1) pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, x1_data, x2_data, pen_data) self.r_cost = tf.reduce_mean(lossfunc) if self.hps.is_training: self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.cost = self.r_cost + self.kl_cost * self.kl_weight gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')
def build_model(self, hps): """Define model architecture.""" if hps.is_training: self.global_step = tf.Variable(0, name='global_step', trainable=False) if hps.dec_model == 'lstm': cell_fn = rnn.LSTMCell elif hps.dec_model == 'layer_norm': cell_fn = rnn.LayerNormLSTMCell elif hps.dec_model == 'hyper': cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' if hps.enc_model == 'lstm': enc_cell_fn = rnn.LSTMCell elif hps.enc_model == 'layer_norm': enc_cell_fn = rnn.LayerNormLSTMCell elif hps.enc_model == 'hyper': enc_cell_fn = rnn.HyperLSTMCell else: assert False, 'please choose a respectable cell' use_recurrent_dropout = False if self.hps.use_recurrent_dropout == 1: use_recurrent_dropout = True use_input_dropout = False if self.hps.use_input_dropout == 0 else True use_output_dropout = False if self.hps.use_output_dropout == 0 else True if hps.dec_model == 'hyper': cell = cell_fn( hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: cell = cell_fn( hps.dec_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) if hps.conditional: # vae mode: if hps.enc_model == 'hyper': self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) else: self.enc_cell_fw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) self.enc_cell_bw = enc_cell_fn( hps.enc_rnn_size, use_recurrent_dropout=use_recurrent_dropout, dropout_keep_prob=self.hps.recurrent_dropout_prob) # dropout: tf.logging.info('Input dropout mode = %s.', use_input_dropout) tf.logging.info('Output dropout mode = %s.', use_output_dropout) tf.logging.info('Recurrent dropout mode = %s.', use_recurrent_dropout) if use_input_dropout: tf.logging.info('Dropout to input w/ keep_prob = %4.4f.', self.hps.input_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, input_keep_prob=self.hps.input_dropout_prob) if use_output_dropout: tf.logging.info('Dropout to output w/ keep_prob = %4.4f.', self.hps.output_dropout_prob) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=self.hps.output_dropout_prob) self.cell = cell self.sequence_lengths = tf.placeholder( dtype=tf.int32, shape=[self.hps.batch_size]) self.input_data = tf.placeholder( dtype=tf.float32, shape=[self.hps.batch_size, self.hps.max_seq_len + 1, 5]) self.input_x = self.input_data[:, :self.hps.max_seq_len, :] self.output_x = self.input_data[:, 1:self.hps.max_seq_len + 1, :] # either do vae-bit and get z, or do unconditional, decoder-only if hps.conditional: # vae mode: self.mean, self.presig = self.encoder(self.output_x, self.sequence_lengths) self.sigma = tf.exp(self.presig / 2.0) # sigma > 0. div 2.0 -> sqrt. eps = tf.random_normal( (self.hps.batch_size, self.hps.z_size), 0.0, 1.0, dtype=tf.float32) self.batch_z = self.mean + tf.multiply(self.sigma, eps) # KL cost self.kl_cost = -0.5 * tf.reduce_mean( (1 + self.presig - tf.square(self.mean) - tf.exp(self.presig))) self.kl_cost = tf.maximum(self.kl_cost, self.hps.kl_tolerance) pre_tile_y = tf.reshape(self.batch_z, [self.hps.batch_size, 1, self.hps.z_size]) overlay_x = tf.tile(pre_tile_y, [1, self.hps.max_seq_len, 1]) actual_input_x = tf.concat([self.input_x, overlay_x], 2) self.initial_state = tf.nn.tanh( rnn.super_linear( self.batch_z, cell.state_size, init_w='gaussian', weight_start=0.001, input_size=self.hps.z_size)) else: # unconditional, decoder-only generation self.batch_z = tf.zeros( (self.hps.batch_size, self.hps.z_size), dtype=tf.float32) self.kl_cost = tf.zeros([], dtype=tf.float32) actual_input_x = self.input_x self.initial_state = cell.zero_state( batch_size=hps.batch_size, dtype=tf.float32) self.num_mixture = hps.num_mixture # TODO(deck): Better understand this comment. # Number of outputs is end_of_stroke + prob + 2*(mu + sig) + corr n_out = (3 + self.num_mixture * 6) with tf.variable_scope('RNN'): output_w = tf.get_variable('output_w', [self.hps.dec_rnn_size, n_out]) output_b = tf.get_variable('output_b', [n_out]) # decoder module of sketch-rnn is below output, last_state = tf.nn.dynamic_rnn( cell, actual_input_x, initial_state=self.initial_state, time_major=False, swap_memory=True, dtype=tf.float32, scope='RNN') output = tf.reshape(output, [-1, hps.dec_rnn_size]) output = tf.nn.xw_plus_b(output, output_w, output_b) self.final_state = last_state def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho): """Returns result of eq # 24 and 25 of http://arxiv.org/abs/1308.0850.""" norm1 = tf.subtract(x1, mu1) norm2 = tf.subtract(x2, mu2) s1s2 = tf.multiply(s1, s2) z = (tf.square(tf.div(norm1, s1)) + tf.square(tf.div(norm2, s2)) - 2 * tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)) neg_rho = 1 - tf.square(rho) result = tf.exp(tf.div(-z, 2 * neg_rho)) denom = 2 * np.pi * tf.multiply(s1s2, tf.sqrt(neg_rho)) result = tf.div(result, denom) return result def get_lossfunc(z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen_logits, x1_data, x2_data, pen_data): """Returns a loss fn based on eq #26 of http://arxiv.org/abs/1308.0850.""" result0 = tf_2d_normal(x1_data, x2_data, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr) epsilon = 1e-6 result1 = tf.multiply(result0, z_pi) result1 = tf.reduce_sum(result1, 1, keep_dims=True) result1 = -tf.log(result1 + epsilon) # avoid log(0) fs = 1.0 - pen_data[:, 2] # use training data for this fs = tf.reshape(fs, [-1, 1]) result1 = tf.multiply(result1, fs) result2 = tf.nn.softmax_cross_entropy_with_logits( labels=pen_data, logits=z_pen_logits) result2 = tf.reshape(result2, [-1, 1]) if not self.hps.is_training: # eval mode, mask eos columns result2 = tf.multiply(result2, fs) result = result1 + result2 return result # below is where we need to do MDN splitting of distribution params def get_mixture_coef(output): """Returns the tf slices containing mdn dist params.""" # This uses eqns 18 -> 23 of http://arxiv.org/abs/1308.0850. z = output z_pen_logits = z[:, 0:3] # pen states z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr = tf.split(z[:, 3:], 6, 1) # process output z's into MDN paramters # softmax all the pi's and pen states: z_pi = tf.nn.softmax(z_pi) z_pen = tf.nn.softmax(z_pen_logits) # exponentiate the sigmas and also make corr between -1 and 1. z_sigma1 = tf.exp(z_sigma1) z_sigma2 = tf.exp(z_sigma2) z_corr = tf.tanh(z_corr) r = [z_pi, z_mu1, z_mu2, z_sigma1, z_sigma2, z_corr, z_pen, z_pen_logits] return r out = get_mixture_coef(output) [o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen, o_pen_logits] = out self.pi = o_pi self.mu1 = o_mu1 self.mu2 = o_mu2 self.sigma1 = o_sigma1 self.sigma2 = o_sigma2 self.corr = o_corr self.pen = o_pen # state of the pen self.pen_logits = o_pen_logits # state of the pen # reshape target data so that it is compatible with prediction shape target = tf.reshape(self.output_x, [-1, 5]) [x1_data, x2_data, eos_data, eoc_data, cont_data] = tf.split(target, 5, 1) pen_data = tf.concat([eos_data, eoc_data, cont_data], 1) lossfunc = get_lossfunc(o_pi, o_mu1, o_mu2, o_sigma1, o_sigma2, o_corr, o_pen_logits, x1_data, x2_data, pen_data) self.r_cost = tf.reduce_mean(lossfunc) if self.hps.is_training: self.cost = self.r_cost + self.kl_cost * self.hps.kl_weight if self.hps.is_training: self.lr = tf.Variable(self.hps.learning_rate, trainable=False) optimizer = tf.train.AdamOptimizer(self.lr) self.kl_weight = tf.Variable(self.hps.kl_weight_start, trainable=False) self.cost = self.r_cost + self.kl_cost * self.kl_weight gvs = optimizer.compute_gradients(self.cost) g = self.hps.grad_clip capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs] self.train_op = optimizer.apply_gradients( capped_gvs, global_step=self.global_step, name='train_step')