def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1): output, s = LSTMCell([x_t, h1_q_tm1], [n_hid, n_hid], h1_tm1, c1_tm1, n_hid, random_state=random_state, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid, random_state=random_state, name="rnn1_q", init=rnn_init) h1_cq_t = s[0] c1_q_t = s[1] h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding( h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb") output_q_t, output_i_t, output_nst_q_t, output_emb = VqEmbedding( output, n_hid, n_emb, random_state=random_state, name="out_vq_emb") # not great h1_i_t = tf.cast(h1_i_t, tf.float32) output_i_t = tf.cast(h1_i_t, tf.float32) lf_output = Bilinear(h1_q_t, n_hid, output_emb, n_hid, random_state=random_state, name="out_mix", init=forward_init) rf_output = Bilinear(output_q_t, n_hid, h1_emb, n_hid, random_state=random_state, name="h_mix", init=forward_init) f_output = Linear([lf_output, rf_output], [n_emb, n_emb], n_hid, random_state=random_state, name="out_f", init=forward_init) # r[0] rets = [f_output] # r[1:3] rets += [h1_t, c1_t] # r[3:9] rets += [h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t, h1_emb] # r[9:] rets += [output_q_t, output_nst_q_t, output, output_i_t, output_emb] return rets
def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1): output, s = LSTMCell([x_t], [in_emb], h1_tm1, c1_tm1, n_hid, random_state=random_state, cell_dropout=cell_dropout, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid, random_state=random_state, cell_dropout=cell_dropout, name="rnn1_q", init=rnn_init) h1_cq_t = s[0] c1_q_t = s[1] h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding( h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb") # not great h1_i_t = tf.cast(h1_i_t, tf.float32) return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t
def create_vqvae(inp, bn): z_e_x = create_encoder(inp, bn) z_q_x, z_i_x, z_nst_q_x, emb = VqEmbedding(z_e_x, l_dims[-1][0], embedding_dim, random_state=random_state, name="embed") x_tilde = create_decoder(z_q_x, bn) return x_tilde, z_e_x, z_q_x, z_i_x, z_nst_q_x, emb
def step(x_t, h1_tm1): output, s = GRUCell([x_t], [1], h1_tm1, n_hid, random_state=random_state, name="rnn1", init=rnn_init) h1_cq_t = s[0] """ output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid, random_state=random_state, name="rnn1_q", init=rnn_init) h1_cq_t = s[0] c1_q_t = s[1] """ qhs = [] ihs = [] nst_qhs = [] embs = [] for i in list(range(n_split)): e_div = int(n_hid / n_split) h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding( h1_cq_t[:, i * e_div:(i + 1) * e_div], e_div, n_emb, random_state=random_state, # shared space? name="h1_vq_emb") #name="h1_{}_vq_emb".format(i)) qhs.append(h1_q_t) ihs.append(h1_i_t[:, None]) nst_qhs.append(h1_nst_q_t) embs.append(h1_emb) h1_q_t = tf.concat(qhs, axis=-1) h1_nst_q_t = tf.concat(nst_qhs, axis=-1) h1_i_t = tf.concat(ihs, axis=-1) # not great h1_i_t = tf.cast(h1_i_t, tf.float32) return output, h1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t