Esempio n. 1
0
def z1_pre_encoder(x, z2, rhus=[256, 256]):
    """
    Pre-stochastic layer encoder for z1 (latent segment variable)
    Args:
        x(tf.Tensor): tensor of shape (bs, T, F)
        z2(tf.Tensor): tensor of shape (bs, D1)
        rhus(list): list of numbers of LSTM layer hidden units
    Return:
        out(tf.Tensor): concatenation of hidden states of all LSTM layers
    """
    bs, T = tf.shape(x)[0], tf.shape(x)[1]
    z2 = tf.tile(tf.expand_dims(z2, 1), (1, T, 1))
    x_z2 = tf.concat([x, z2], axis=-1)

    cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus])
    init_state = cell.zero_state(bs, x.dtype)
    name = "z1_enc_lstm_%s" % ("_".join(map(str, rhus)), )
    _, final_state = dynamic_rnn(cell,
                                 x_z2,
                                 dtype=x.dtype,
                                 initial_state=init_state,
                                 time_major=False,
                                 scope=name)

    out = [l_final_state.h for l_final_state in final_state]
    out = tf.concat(out, axis=-1)
    return out
Esempio n. 2
0
def decoder(z1, z2, x, rhus=[256, 256], x_mu_nl=None, x_logvar_nl=None):
    """
    decoder
    Args:
        z1(tf.Tensor)
        z2(tf.Tensor)
        x(tf.Tensor): tensor of shape (bs, T, F). only shape is used
        rhus(list)
    """
    bs = tf.shape(x)[0]
    z1_z2 = tf.concat([z1, z2], axis=-1)

    cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus])
    state_t = cell.zero_state(bs, x.dtype)
    name = "dec_lstm_%s_step" % ("_".join(map(str, rhus)), )

    def cell_step(inp, prev_state):
        return cell(inp, prev_state, scope=name)

    gdim = x.get_shape().as_list()[2]
    gname = "dec_gauss_step"

    def glayer_step(inp):
        return gauss_layer(inp, gdim, x_mu_nl, x_logvar_nl, gname)

    out, x_mu, x_logvar, x_sample = [], [], [], []
    for t in xrange(x.get_shape().as_list()[1]):
        if t > 0:
            tf.get_variable_scope().reuse_variables()

        out_t, state_t, x_mu_t, x_logvar_t, x_sample_t = decoder_step(
            z1_z2, state_t, cell_step, glayer_step)
        out.append(out_t)
        x_mu.append(x_mu_t)
        x_logvar.append(x_logvar_t)
        x_sample.append(x_sample_t)

    out = tf.stack(out, axis=1, name="dec_pre_out")
    x_mu = tf.stack(x_mu, axis=1, name="dec_x_mu")
    x_logvar = tf.stack(x_logvar, axis=1, name="dec_x_logvar")
    x_sample = tf.stack(x_sample, axis=1, name="dec_x_sample")
    px_z = [x_mu, x_logvar]
    return out, px_z, x_sample
Esempio n. 3
0
def z2_pre_encoder(x, rhus=[256, 256]):
    """
    Pre-stochastic layer encoder for z2 (latent sequence variable)
    Args:
        x(tf.Tensor): tensor of shape (bs, T, F)
        rhus(list): list of numbers of LSTM layer hidden units
    Return:
        out(tf.Tensor): concatenation of hidden states of all LSTM layers
    """
    bs = tf.shape(x)[0]
    
    cell = MultiRNNCell([BasicLSTMCell(rhu) for rhu in rhus])
    init_state = cell.zero_state(bs, x.dtype)
    name = "z2_enc_lstm_%s" % ("_".join(map(str, rhus)),)
    _, final_state = dynamic_rnn(cell, x, dtype=x.dtype,
            initial_state=init_state, time_major=False, scope=name)
    
    out = [l_final_state.h for l_final_state in final_state]
    out = tf.concat(out, axis=-1)
    return out
    def _build_rnn_decoder_and_recon_x(self,
                                       inputs,
                                       targets,
                                       training,
                                       reuse=False):
        with tf.variable_scope("dec_rec_and_recon_x", reuse=reuse):
            C, T, F = self._model_conf["target_shape"]

            Cell = _cell_dict[self._model_conf["rec_cell_type"]]
            cell = MultiRNNCell([Cell(hu) \
                    for hu in self._model_conf["rec_dec"]])

            if self._model_conf["rec_learn_init"]:
                raise NotImplementedError
            else:
                input_shape = tuple(array_ops.shape(input_) \
                        for input_ in nest.flatten(inputs))
                batch_size = input_shape[0][0]
                init_state = cell.zero_state(batch_size,
                                             self._model_conf["input_dtype"])

            rec_dec_inp = self._model_conf["rec_dec_inp_test"]
            if training:
                rec_dec_inp = self._model_conf["rec_dec_inp_train"]

            if rec_dec_inp is not None:
                n_concur = self._model_conf["rec_dec_concur"]
                if T % n_concur != 0:
                    raise ValueError("total time steps must be " + \
                            "multiples of rec_dec_concur")
                n_frame = T // n_concur
            else:
                n_frame = T
            n_hist = self._model_conf["rec_dec_inp_hist"]
            info("decoder: n_frame=%s, n_concur=%s, n_hist=%s" %
                 (n_frame, n_concur, n_hist))

            def make_hist(hist, new_hist):
                with tf.name_scope("make_hist"):
                    if not self._model_conf["x_conti"]:
                        # TODO add target embedding?
                        new_hist = tf.cast(new_hist, tf.float32)

                    if n_hist > n_concur:
                        diff = n_hist - n_concur
                        return tf.concat([hist[:, :, -diff:, :], new_hist],
                                         axis=-2)
                    else:
                        return new_hist[:, :, -n_hist:, :]

            outputs = []
            if self._model_conf["x_conti"]:
                x_mu, x_logvar, x = [], [], []
            else:
                x_logits, x = [], []
            state_f = init_state
            hist = tf.zeros((array_ops.shape(inputs)[0], C, n_hist, F),
                            dtype=self._model_conf["input_dtype"],
                            name="init_hist")

            for f in xrange(n_frame):
                input_f = inputs
                if rec_dec_inp:
                    input_f = tf.concat(
                        [inputs,
                         tf.reshape(hist, (-1, C * n_hist * F))],
                        axis=-1,
                        name="input_f_%s" % f)
                if f > 0:
                    tf.get_variable_scope().reuse_variables()

                output_f, state_f = cell(input_f, state_f)
                outputs.append(output_f)

                # TODO: input hist as well (like sampleRNN)?
                if self._model_conf["x_conti"]:
                    x_mu_f, x_logvar_f, x_f = dense_latent(
                        inputs=output_f,
                        num_outputs=C * n_concur * F,
                        mu_nl=self._model_conf["x_mu_nl"],
                        logvar_nl=self._model_conf["x_logvar_nl"],
                        scope="recon_x_f")
                    x_mu.append(
                        tf.reshape(x_mu_f, (-1, C, n_concur, F),
                                   name="recon_x_mu_f_4d"))
                    x_logvar.append(
                        tf.reshape(x_logvar_f, (-1, C, n_concur, F),
                                   name="recon_x_logvar_f_4d"))
                    x.append(
                        tf.reshape(x_f, (-1, C, n_concur, F),
                                   name="recon_x_f_4d"))

                    if rec_dec_inp == "targets":
                        t_slice = slice(f * n_concur, (f + 1) * n_concur)
                        hist = make_hist(hist, targets[:, :, t_slice, :])
                    elif rec_dec_inp == "x_mu":
                        hist = make_hist(hist, x_mu[-1])
                    elif rec_dec_inp == "x":
                        hist = make_hist(hist, x[-1])
                    elif rec_dec_inp:
                        raise ValueError("unsupported rec_dec_inp (%s)" %
                                         (rec_dec_inp))
                else:
                    raise ValueError
                    # n_bins = self._model_conf["n_bins"]
                    # x_logits_f, x_f = cat_dense_latent(
                    #         inputs=output_f,
                    #         num_outputs=C * n_concur * F,
                    #         n_bins=n_bins,
                    #         scope="recon_x_f")
                    # x_logits.append(tf.reshape(
                    #         x_logits_f,
                    #         (-1, C, n_concur, F, n_bins),
                    #         name="recon_x_logits_f_5d"))
                    # x.append(tf.reshape(
                    #         x_f,
                    #         (-1, C, n_concur, F),
                    #         name="recon_x_f_4d"))

                    # if rec_dec_inp == "targets":
                    #     t_slice = slice(f * n_concur, (f + 1) * n_concur)
                    #     hist = make_hist(hist, targets[:, :, t_slice, :])
                    # elif rec_dec_inp == "x_max":
                    #     hist = make_hist(hist, tf.argmax(x_logits[-1], -1))
                    # elif rec_dec_inp == "x":
                    #     hist = make_hist(hist, x[-1])
                    # elif rec_dec_inp:
                    #     raise ValueError("unsupported rec_dec_inp (%s)" % (
                    #             rec_dec_inp))

            # (bs, n_frame, top_rnn_hu)
            outputs = tf.stack(outputs, axis=1, name="rec_outputs")
            x = tf.concat(x, axis=2, name="recon_x_t_4d")

            if self._model_conf["x_conti"]:
                x_mu = tf.concat(x_mu, axis=2, name="recon_x_mu_t_4d")
                x_logvar = tf.concat(x_logvar,
                                     axis=2,
                                     name="recon_x_logvar_t_4d")
                px = [x_mu, x_logvar]
            else:
                x_logits = tf.concat(x_logits,
                                     axis=2,
                                     name="recon_x_logits_t_5d")
                px = x_logits

        return outputs, px, x
    def _build_z2_encoder(self, inputs, z1, reuse=False):
        weights_regularizer = l2_regularizer(self._train_conf["l2_weight"])
        normalizer_fn = batch_norm if self._model_conf["if_bn"] else None
        normalizer_params = None
        if self._model_conf["if_bn"]:
            normalizer_params = {
                "scope": "BatchNorm",
                "is_training": self._feed_dict["is_train"],
                "reuse": reuse
            }
            # TODO: need to upgrade to latest,
            #       which commit support param_regularizers args

        if not hasattr(self, "_debug_outputs"):
            self._debug_outputs = {}

        C, T, F = self._model_conf["target_shape"]
        n_concur = self._model_conf["rec_z2_enc_concur"]
        if T % n_concur != 0:
            raise ValueError("total time steps must be multiples of %s" %
                             (n_concur))
        n_frame = T // n_concur
        info("z2_encoder: n_frame=%s, n_concur=%s" % (n_frame, n_concur))

        # input_dim = np.prod(inputs.get_shape().as_list()[1:])
        # outputs = tf.concat([tf.reshape(inputs, [-1, input_dim]), z1], axis=1)

        with tf.variable_scope("z2_enc", reuse=reuse):
            # recurrent layers
            if self._model_conf["rec_z2_enc"]:
                # reshape to (N, n_frame, n_concur*C*F)
                inputs = array_ops.transpose(inputs, (0, 2, 1, 3))
                inputs_shape = inputs.get_shape().as_list()
                inputs_depth = np.prod(inputs_shape[2:])
                new_shape = (-1, n_frame, n_concur * inputs_depth)
                inputs = tf.reshape(inputs, new_shape)

                # append z1 to each frame
                tiled_z1 = tf.tile(tf.expand_dims(z1, 1), (1, n_frame, 1))
                inputs = tf.concat([inputs, tiled_z1], axis=-1)

                self._debug_outputs["inp_reshape"] = inputs
                if self._model_conf["rec_z2_enc_bi"]:
                    raise NotImplementedError
                else:
                    Cell = _cell_dict[self._model_conf["rec_cell_type"]]
                    cell = MultiRNNCell([Cell(hu) \
                            for hu in self._model_conf["rec_z2_enc"]])

                    if self._model_conf["rec_learn_init"]:
                        raise NotImplementedError
                    else:
                        input_shape = tuple(array_ops.shape(input_) \
                                for input_ in nest.flatten(inputs))
                        batch_size = input_shape[0][0]
                        init_state = cell.zero_state(
                            batch_size, self._model_conf["input_dtype"])

                    _, final_states = dynamic_rnn(
                        cell,
                        inputs,
                        dtype=self._model_conf["input_dtype"],
                        initial_state=init_state,
                        time_major=False,
                        scope="z2_enc_%sL_rec" %
                        len(self._model_conf["rec_z2_enc"]))
                    self._debug_outputs["raw_rnn_out"] = _
                    self._debug_outputs["raw_rnn_final"] = final_states

                    if self._model_conf["rec_z2_enc_out"].startswith("last"):
                        final_states = final_states[-1:]

                    if self._model_conf["rec_cell_type"] == "lstm":
                        outputs = []
                        for state in final_states:
                            if "h" in self._model_conf["rec_z2_enc_out"].split(
                                    "_")[1]:
                                outputs.append(state.h)
                            if "c" in self._model_conf["rec_z2_enc_out"].split(
                                    "_")[1]:
                                outputs.append(state.c)
                    else:
                        outputs = final_states

                    outputs = tf.concat(outputs, axis=-1)
                    self._debug_outputs["concat_rnn_out"] = outputs
            else:
                input_dim = np.prod(inputs.get_shape().as_list()[1:])
                outputs = tf.concat([tf.reshape(inputs, [-1, input_dim]), z1],
                                    axis=1)

            # fully connected layers
            output_dim = np.prod(outputs.get_shape().as_list()[1:])
            outputs = tf.reshape(outputs, [-1, output_dim])

            for i, hu in enumerate(self._model_conf["hu_z2_enc"]):
                outputs = fully_connected(
                    inputs=outputs,
                    num_outputs=hu,
                    activation_fn=nn.relu,
                    normalizer_fn=normalizer_fn,
                    normalizer_params=normalizer_params,
                    weights_regularizer=weights_regularizer,
                    reuse=reuse,
                    scope="z2_enc_fc%s" % (i + 1))

            z2_mu, z2_logvar, z2 = dense_latent(
                outputs,
                self._model_conf["n_latent2"],
                logvar_nl=self._model_conf["z2_logvar_nl"],
                reuse=reuse,
                scope="z2_enc_lat")

        return [z2_mu, z2_logvar], z2