def z1_pre_encoder(x, z2, rhus=[256, 256]): """ Pre-stochastic layer encoder for z1 (latent segment variable) Args: x(tf.Tensor): tensor of shape (bs, T, F) z2(tf.Tensor): tensor of shape (bs, D1) rhus(list): list of numbers of LSTM layer hidden units Return: out(tf.Tensor): concatenation of hidden states of all LSTM layers """ bs, T = tf.shape(x)[0], tf.shape(x)[1] z2 = tf.tile(tf.expand_dims(z2, 1), (1, T, 1)) x_z2 = tf.concat([x, z2], axis=-1) cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus]) init_state = cell.zero_state(bs, x.dtype) name = "z1_enc_lstm_%s" % ("_".join(map(str, rhus)), ) _, final_state = dynamic_rnn(cell, x_z2, dtype=x.dtype, initial_state=init_state, time_major=False, scope=name) out = [l_final_state.h for l_final_state in final_state] out = tf.concat(out, axis=-1) return out
def decoder(z1, z2, x, rhus=[256, 256], x_mu_nl=None, x_logvar_nl=None): """ decoder Args: z1(tf.Tensor) z2(tf.Tensor) x(tf.Tensor): tensor of shape (bs, T, F). only shape is used rhus(list) """ bs = tf.shape(x)[0] z1_z2 = tf.concat([z1, z2], axis=-1) cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus]) state_t = cell.zero_state(bs, x.dtype) name = "dec_lstm_%s_step" % ("_".join(map(str, rhus)), ) def cell_step(inp, prev_state): return cell(inp, prev_state, scope=name) gdim = x.get_shape().as_list()[2] gname = "dec_gauss_step" def glayer_step(inp): return gauss_layer(inp, gdim, x_mu_nl, x_logvar_nl, gname) out, x_mu, x_logvar, x_sample = [], [], [], [] for t in xrange(x.get_shape().as_list()[1]): if t > 0: tf.get_variable_scope().reuse_variables() out_t, state_t, x_mu_t, x_logvar_t, x_sample_t = decoder_step( z1_z2, state_t, cell_step, glayer_step) out.append(out_t) x_mu.append(x_mu_t) x_logvar.append(x_logvar_t) x_sample.append(x_sample_t) out = tf.stack(out, axis=1, name="dec_pre_out") x_mu = tf.stack(x_mu, axis=1, name="dec_x_mu") x_logvar = tf.stack(x_logvar, axis=1, name="dec_x_logvar") x_sample = tf.stack(x_sample, axis=1, name="dec_x_sample") px_z = [x_mu, x_logvar] return out, px_z, x_sample
def z2_pre_encoder(x, rhus=[256, 256]): """ Pre-stochastic layer encoder for z2 (latent sequence variable) Args: x(tf.Tensor): tensor of shape (bs, T, F) rhus(list): list of numbers of LSTM layer hidden units Return: out(tf.Tensor): concatenation of hidden states of all LSTM layers """ bs = tf.shape(x)[0] cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus]) init_state = cell.zero_state(bs, x.dtype) name = "z2_enc_lstm_%s" % ("_".join(map(str, rhus)),) _, final_state = dynamic_rnn(cell, x, dtype=x.dtype, initial_state=init_state, time_major=False, scope=name) out = [l_final_state.h for l_final_state in final_state] out = tf.concat(out, axis=-1) return out
def _build_rnn_decoder_and_recon_x(self, inputs, targets, training, reuse=False): with tf.variable_scope("dec_rec_and_recon_x", reuse=reuse): C, T, F = self._model_conf["target_shape"] Cell = _cell_dict[self._model_conf["rec_cell_type"]] cell = MultiRNNCell([Cell(hu) \ for hu in self._model_conf["rec_dec"]]) if self._model_conf["rec_learn_init"]: raise NotImplementedError else: input_shape = tuple(array_ops.shape(input_) \ for input_ in nest.flatten(inputs)) batch_size = input_shape[0][0] init_state = cell.zero_state(batch_size, self._model_conf["input_dtype"]) rec_dec_inp = self._model_conf["rec_dec_inp_test"] if training: rec_dec_inp = self._model_conf["rec_dec_inp_train"] if rec_dec_inp is not None: n_concur = self._model_conf["rec_dec_concur"] if T % n_concur != 0: raise ValueError("total time steps must be " + \ "multiples of rec_dec_concur") n_frame = T // n_concur else: n_frame = T n_hist = self._model_conf["rec_dec_inp_hist"] info("decoder: n_frame=%s, n_concur=%s, n_hist=%s" % (n_frame, n_concur, n_hist)) def make_hist(hist, new_hist): with tf.name_scope("make_hist"): if not self._model_conf["x_conti"]: # TODO add target embedding? new_hist = tf.cast(new_hist, tf.float32) if n_hist > n_concur: diff = n_hist - n_concur return tf.concat([hist[:, :, -diff:, :], new_hist], axis=-2) else: return new_hist[:, :, -n_hist:, :] outputs = [] if self._model_conf["x_conti"]: x_mu, x_logvar, x = [], [], [] else: x_logits, x = [], [] state_f = init_state hist = tf.zeros((array_ops.shape(inputs)[0], C, n_hist, F), dtype=self._model_conf["input_dtype"], name="init_hist") for f in xrange(n_frame): input_f = inputs if rec_dec_inp: input_f = tf.concat( [inputs, tf.reshape(hist, (-1, C * n_hist * F))], axis=-1, name="input_f_%s" % f) if f > 0: tf.get_variable_scope().reuse_variables() output_f, state_f = cell(input_f, state_f) outputs.append(output_f) # TODO: input hist as well (like sampleRNN)? if self._model_conf["x_conti"]: x_mu_f, x_logvar_f, x_f = dense_latent( inputs=output_f, num_outputs=C * n_concur * F, mu_nl=self._model_conf["x_mu_nl"], logvar_nl=self._model_conf["x_logvar_nl"], scope="recon_x_f") x_mu.append( tf.reshape(x_mu_f, (-1, C, n_concur, F), name="recon_x_mu_f_4d")) x_logvar.append( tf.reshape(x_logvar_f, (-1, C, n_concur, F), name="recon_x_logvar_f_4d")) x.append( tf.reshape(x_f, (-1, C, n_concur, F), name="recon_x_f_4d")) if rec_dec_inp == "targets": t_slice = slice(f * n_concur, (f + 1) * n_concur) hist = make_hist(hist, targets[:, :, t_slice, :]) elif rec_dec_inp == "x_mu": hist = make_hist(hist, x_mu[-1]) elif rec_dec_inp == "x": hist = make_hist(hist, x[-1]) elif rec_dec_inp: raise ValueError("unsupported rec_dec_inp (%s)" % (rec_dec_inp)) else: raise ValueError # n_bins = self._model_conf["n_bins"] # x_logits_f, x_f = cat_dense_latent( # inputs=output_f, # num_outputs=C * n_concur * F, # n_bins=n_bins, # scope="recon_x_f") # x_logits.append(tf.reshape( # x_logits_f, # (-1, C, n_concur, F, n_bins), # name="recon_x_logits_f_5d")) # x.append(tf.reshape( # x_f, # (-1, C, n_concur, F), # name="recon_x_f_4d")) # if rec_dec_inp == "targets": # t_slice = slice(f * n_concur, (f + 1) * n_concur) # hist = make_hist(hist, targets[:, :, t_slice, :]) # elif rec_dec_inp == "x_max": # hist = make_hist(hist, tf.argmax(x_logits[-1], -1)) # elif rec_dec_inp == "x": # hist = make_hist(hist, x[-1]) # elif rec_dec_inp: # raise ValueError("unsupported rec_dec_inp (%s)" % ( # rec_dec_inp)) # (bs, n_frame, top_rnn_hu) outputs = tf.stack(outputs, axis=1, name="rec_outputs") x = tf.concat(x, axis=2, name="recon_x_t_4d") if self._model_conf["x_conti"]: x_mu = tf.concat(x_mu, axis=2, name="recon_x_mu_t_4d") x_logvar = tf.concat(x_logvar, axis=2, name="recon_x_logvar_t_4d") px = [x_mu, x_logvar] else: x_logits = tf.concat(x_logits, axis=2, name="recon_x_logits_t_5d") px = x_logits return outputs, px, x
def _build_z2_encoder(self, inputs, z1, reuse=False): weights_regularizer = l2_regularizer(self._train_conf["l2_weight"]) normalizer_fn = batch_norm if self._model_conf["if_bn"] else None normalizer_params = None if self._model_conf["if_bn"]: normalizer_params = { "scope": "BatchNorm", "is_training": self._feed_dict["is_train"], "reuse": reuse } # TODO: need to upgrade to latest, # which commit support param_regularizers args if not hasattr(self, "_debug_outputs"): self._debug_outputs = {} C, T, F = self._model_conf["target_shape"] n_concur = self._model_conf["rec_z2_enc_concur"] if T % n_concur != 0: raise ValueError("total time steps must be multiples of %s" % (n_concur)) n_frame = T // n_concur info("z2_encoder: n_frame=%s, n_concur=%s" % (n_frame, n_concur)) # input_dim = np.prod(inputs.get_shape().as_list()[1:]) # outputs = tf.concat([tf.reshape(inputs, [-1, input_dim]), z1], axis=1) with tf.variable_scope("z2_enc", reuse=reuse): # recurrent layers if self._model_conf["rec_z2_enc"]: # reshape to (N, n_frame, n_concur*C*F) inputs = array_ops.transpose(inputs, (0, 2, 1, 3)) inputs_shape = inputs.get_shape().as_list() inputs_depth = np.prod(inputs_shape[2:]) new_shape = (-1, n_frame, n_concur * inputs_depth) inputs = tf.reshape(inputs, new_shape) # append z1 to each frame tiled_z1 = tf.tile(tf.expand_dims(z1, 1), (1, n_frame, 1)) inputs = tf.concat([inputs, tiled_z1], axis=-1) self._debug_outputs["inp_reshape"] = inputs if self._model_conf["rec_z2_enc_bi"]: raise NotImplementedError else: Cell = _cell_dict[self._model_conf["rec_cell_type"]] cell = MultiRNNCell([Cell(hu) \ for hu in self._model_conf["rec_z2_enc"]]) if self._model_conf["rec_learn_init"]: raise NotImplementedError else: input_shape = tuple(array_ops.shape(input_) \ for input_ in nest.flatten(inputs)) batch_size = input_shape[0][0] init_state = cell.zero_state( batch_size, self._model_conf["input_dtype"]) _, final_states = dynamic_rnn( cell, inputs, dtype=self._model_conf["input_dtype"], initial_state=init_state, time_major=False, scope="z2_enc_%sL_rec" % len(self._model_conf["rec_z2_enc"])) self._debug_outputs["raw_rnn_out"] = _ self._debug_outputs["raw_rnn_final"] = final_states if self._model_conf["rec_z2_enc_out"].startswith("last"): final_states = final_states[-1:] if self._model_conf["rec_cell_type"] == "lstm": outputs = [] for state in final_states: if "h" in self._model_conf["rec_z2_enc_out"].split( "_")[1]: outputs.append(state.h) if "c" in self._model_conf["rec_z2_enc_out"].split( "_")[1]: outputs.append(state.c) else: outputs = final_states outputs = tf.concat(outputs, axis=-1) self._debug_outputs["concat_rnn_out"] = outputs else: input_dim = np.prod(inputs.get_shape().as_list()[1:]) outputs = tf.concat([tf.reshape(inputs, [-1, input_dim]), z1], axis=1) # fully connected layers output_dim = np.prod(outputs.get_shape().as_list()[1:]) outputs = tf.reshape(outputs, [-1, output_dim]) for i, hu in enumerate(self._model_conf["hu_z2_enc"]): outputs = fully_connected( inputs=outputs, num_outputs=hu, activation_fn=nn.relu, normalizer_fn=normalizer_fn, normalizer_params=normalizer_params, weights_regularizer=weights_regularizer, reuse=reuse, scope="z2_enc_fc%s" % (i + 1)) z2_mu, z2_logvar, z2 = dense_latent( outputs, self._model_conf["n_latent2"], logvar_nl=self._model_conf["z2_logvar_nl"], reuse=reuse, scope="z2_enc_lat") return [z2_mu, z2_logvar], z2