Exemple #1
0
    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t, h1_q_tm1], [n_hid, n_hid],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        output_q_t, output_i_t, output_nst_q_t, output_emb = VqEmbedding(
            output, n_hid, n_emb, random_state=random_state, name="out_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        output_i_t = tf.cast(h1_i_t, tf.float32)

        lf_output = Bilinear(h1_q_t,
                             n_hid,
                             output_emb,
                             n_hid,
                             random_state=random_state,
                             name="out_mix",
                             init=forward_init)
        rf_output = Bilinear(output_q_t,
                             n_hid,
                             h1_emb,
                             n_hid,
                             random_state=random_state,
                             name="h_mix",
                             init=forward_init)
        f_output = Linear([lf_output, rf_output], [n_emb, n_emb],
                          n_hid,
                          random_state=random_state,
                          name="out_f",
                          init=forward_init)

        # r[0]
        rets = [f_output]
        # r[1:3]
        rets += [h1_t, c1_t]
        # r[3:9]
        rets += [h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t, h1_emb]
        # r[9:]
        rets += [output_q_t, output_nst_q_t, output, output_i_t, output_emb]
        return rets
Exemple #2
0
    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t], [in_emb],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t
def create_vqvae(inp, bn):
    z_e_x = create_encoder(inp, bn)
    z_q_x, z_i_x, z_nst_q_x, emb = VqEmbedding(z_e_x,
                                               l_dims[-1][0],
                                               embedding_dim,
                                               random_state=random_state,
                                               name="embed")
    x_tilde = create_decoder(z_q_x, bn)
    return x_tilde, z_e_x, z_q_x, z_i_x, z_nst_q_x, emb
    def step(x_t, h1_tm1):
        output, s = GRUCell([x_t], [1],
                            h1_tm1,
                            n_hid,
                            random_state=random_state,
                            name="rnn1",
                            init=rnn_init)
        h1_cq_t = s[0]
        """
        output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid,
                             random_state=random_state,
                             name="rnn1_q", init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]
        """
        qhs = []
        ihs = []
        nst_qhs = []
        embs = []
        for i in list(range(n_split)):
            e_div = int(n_hid / n_split)
            h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
                h1_cq_t[:, i * e_div:(i + 1) * e_div],
                e_div,
                n_emb,
                random_state=random_state,
                # shared space?
                name="h1_vq_emb")
            #name="h1_{}_vq_emb".format(i))
            qhs.append(h1_q_t)
            ihs.append(h1_i_t[:, None])
            nst_qhs.append(h1_nst_q_t)
            embs.append(h1_emb)
        h1_q_t = tf.concat(qhs, axis=-1)
        h1_nst_q_t = tf.concat(nst_qhs, axis=-1)
        h1_i_t = tf.concat(ihs, axis=-1)

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t