def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1): output, s = LSTMCell([x_t, h1_q_tm1], [n_hid, n_hid], h1_tm1, c1_tm1, n_hid, random_state=random_state, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid, random_state=random_state, name="rnn1_q", init=rnn_init) h1_cq_t = s[0] c1_q_t = s[1] h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding( h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb") output_q_t, output_i_t, output_nst_q_t, output_emb = VqEmbedding( output, n_hid, n_emb, random_state=random_state, name="out_vq_emb") # not great h1_i_t = tf.cast(h1_i_t, tf.float32) output_i_t = tf.cast(h1_i_t, tf.float32) lf_output = Bilinear(h1_q_t, n_hid, output_emb, n_hid, random_state=random_state, name="out_mix", init=forward_init) rf_output = Bilinear(output_q_t, n_hid, h1_emb, n_hid, random_state=random_state, name="h_mix", init=forward_init) f_output = Linear([lf_output, rf_output], [n_emb, n_emb], n_hid, random_state=random_state, name="out_f", init=forward_init) # r[0] rets = [f_output] # r[1:3] rets += [h1_t, c1_t] # r[3:9] rets += [h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t, h1_emb] # r[9:] rets += [output_q_t, output_nst_q_t, output, output_i_t, output_emb] return rets
def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1): output, s = LSTMCell([x_t], [in_emb], h1_tm1, c1_tm1, n_hid, random_state=random_state, cell_dropout=cell_dropout, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid, random_state=random_state, cell_dropout=cell_dropout, name="rnn1_q", init=rnn_init) h1_cq_t = s[0] c1_q_t = s[1] h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding( h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb") # not great h1_i_t = tf.cast(h1_i_t, tf.float32) return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t
def step(inp_t, inp_mask_t, corr_inp_t, att_w_tm1, att_k_tm1, att_h_tm1, att_c_tm1, h1_tm1, c1_tm1, h2_tm1, c2_tm1): o = GaussianAttentionCell( [corr_inp_t], [prenet_units], (att_h_tm1, att_c_tm1), att_k_tm1, bitext, 2 * enc_units, dec_units, att_w_tm1, input_mask=inp_mask_t, conditioning_mask=text_mask, #attention_scale=1. / 10., attention_scale=1., step_op="softplus", name="att", random_state=random_state, cell_dropout=1., #cell_dropout, init=rnn_init) att_w_t, att_k_t, att_phi_t, s = o att_h_t = s[0] att_c_t = s[1] output, s = LSTMCell([corr_inp_t, att_w_t, att_h_t], [prenet_units, 2 * enc_units, dec_units], h1_tm1, c1_tm1, dec_units, input_mask=inp_mask_t, random_state=random_state, cell_dropout=cell_dropout, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([corr_inp_t, att_w_t, h1_t], [prenet_units, 2 * enc_units, dec_units], h2_tm1, c2_tm1, dec_units, input_mask=inp_mask_t, random_state=random_state, cell_dropout=cell_dropout, name="rnn2", init=rnn_init) h2_t = s[0] c2_t = s[1] return output, att_w_t, att_k_t, att_phi_t, att_h_t, att_c_t, h1_t, c1_t, h2_t, c2_t
def step(inp_t, inp_mask_t, att_w_tm1, att_k_tm1, att_h_tm1, att_c_tm1, h1_tm1, c1_tm1, h2_tm1, c2_tm1): o = GaussianAttentionCell([inp_t], [speech_size], (att_h_tm1, att_c_tm1), att_k_tm1, sequence, num_letters, num_units, att_w_tm1, input_mask=inp_mask_t, conditioning_mask=sequence_mask, attention_scale=1. / 10., name="att", random_state=random_state, cell_dropout=cell_dropout, init=rnn_init) att_w_t, att_k_t, att_phi_t, s = o att_h_t = s[0] att_c_t = s[1] output, s = LSTMCell([inp_t, att_w_t, att_h_t], [speech_size, num_letters, num_units], h1_tm1, c1_tm1, num_units, input_mask=inp_mask_t, random_state=random_state, cell_dropout=cell_dropout, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] output, s = LSTMCell([inp_t, att_w_t, h1_t], [speech_size, num_letters, num_units], h2_tm1, c2_tm1, num_units, input_mask=inp_mask_t, random_state=random_state, cell_dropout=cell_dropout, name="rnn2", init=rnn_init) h2_t = s[0] c2_t = s[1] return output, att_w_t, att_k_t, att_phi_t, att_h_t, att_c_t, h1_t, c1_t, h2_t, c2_t
def step(x_t, h1_tm1, c1_tm1): output, s = LSTMCell([x_t], [1], h1_tm1, c1_tm1, n_hid, random_state=random_state, name="rnn1", init=rnn_init) h1_t = s[0] c1_t = s[1] return output, h1_t, c1_t