Ejemplo n.º 1
0
def create_vqrnn(inp_tm1, inp_t, h1_init, c1_init, h1_q_init, c1_q_init):
    oh_tm1 = OneHot(inp_tm1, n_inputs)
    p_tm1 = Linear([oh_tm1], [n_inputs],
                   n_hid,
                   random_state=random_state,
                   name="proj",
                   init=forward_init)

    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t], [n_hid],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t,
            n_hid,
            n_clusters,
            random_state=random_state,
            name="h1_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t

    r = scan(step, [p_tm1],
             [None, h1_init, c1_init, h1_q_init, c1_q_init, None, None, None])
    out = r[0]
    hiddens = r[1]
    cells = r[2]
    q_hiddens = r[3]
    q_cells = r[4]
    q_nst_hiddens = r[5]
    q_nvq_hiddens = r[6]
    i_hiddens = r[7]

    pred = Linear([out], [n_hid],
                  n_inputs,
                  random_state=random_state,
                  name="out",
                  init=forward_init)
    pred_sm = Softmax(pred)
    return pred_sm, pred, hiddens, cells, q_hiddens, q_cells, q_nst_hiddens, q_nvq_hiddens, i_hiddens, oh_tm1
Ejemplo n.º 2
0
    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t, h1_q_tm1], [n_hid, n_hid],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        output_q_t, output_i_t, output_nst_q_t, output_emb = VqEmbedding(
            output, n_hid, n_emb, random_state=random_state, name="out_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        output_i_t = tf.cast(h1_i_t, tf.float32)

        lf_output = Bilinear(h1_q_t,
                             n_hid,
                             output_emb,
                             n_hid,
                             random_state=random_state,
                             name="out_mix",
                             init=forward_init)
        rf_output = Bilinear(output_q_t,
                             n_hid,
                             h1_emb,
                             n_hid,
                             random_state=random_state,
                             name="h_mix",
                             init=forward_init)
        f_output = Linear([lf_output, rf_output], [n_emb, n_emb],
                          n_hid,
                          random_state=random_state,
                          name="out_f",
                          init=forward_init)

        # r[0]
        rets = [f_output]
        # r[1:3]
        rets += [h1_t, c1_t]
        # r[3:9]
        rets += [h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t, h1_emb]
        # r[9:]
        rets += [output_q_t, output_nst_q_t, output, output_i_t, output_emb]
        return rets
Ejemplo n.º 3
0
def create_model(inp_tm1, inp_t, h1_init, c1_init):
    def step(x_t, h1_tm1, c1_tm1):
        output, s = LSTMCell([x_t], [1],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]
        return output, h1_t, c1_t

    r = scan(step, [inp_tm1], [None, h1_init, c1_init])
    out = r[0]
    hiddens = r[1]
    cells = r[2]
    pred = Linear([out], [n_hid],
                  1,
                  random_state=random_state,
                  name="out",
                  init=forward_init)
    """
    z_e_x = create_encoder(inp, bn)
    z_q_x, z_i_x, emb = VqEmbedding(z_e_x, l_dims[-1][0], embedding_dim, random_state=random_state, name="embed")
    x_tilde = create_decoder(z_q_x, bn)
    """
    return pred, hiddens, cells
Ejemplo n.º 4
0
def create_model(inp_tm1, inp_t, cell_dropout, h1_init, c1_init, h1_q_init,
                 c1_q_init):
    e_tm1, emb_r = Embedding(inp_tm1,
                             n_inputs,
                             in_emb,
                             random_state=random_state,
                             name="in_emb")

    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t], [in_emb],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t

    r = scan(step, [e_tm1],
             [None, h1_init, c1_init, h1_q_init, c1_q_init, None, None, None])
    out = r[0]
    hiddens = r[1]
    cells = r[2]
    q_hiddens = r[3]
    q_cells = r[4]
    q_nst_hiddens = r[5]
    q_nvq_hiddens = r[6]
    i_hiddens = r[7]

    # tied weights?
    pred = Linear([out], [n_hid],
                  n_inputs,
                  random_state=random_state,
                  name="out",
                  init=forward_init)
    pred_sm = Softmax(pred)
    return pred_sm, pred, hiddens, cells, q_hiddens, q_cells, q_nst_hiddens, q_nvq_hiddens, i_hiddens
Ejemplo n.º 5
0
def create_model(inp, out):
    p_inp, emb = PositionalEncoding(inp,
                                    n_syms,
                                    dim,
                                    random_state=random_state,
                                    name="inp_pos_emb")
    prev = p_inp
    enc_atts = []
    for i in range(n_layers):
        li, atti = TransformerBlock(prev,
                                    prev,
                                    prev,
                                    dim,
                                    mask=False,
                                    random_state=random_state,
                                    name="l{}".format(i))
        prev = li
        enc_atts.append(atti)

    p_out, out_emb = PositionalEncoding(out,
                                        n_syms,
                                        dim,
                                        random_state=random_state,
                                        name="out_pos_emb")
    dec_atts = []
    prev_in = prev
    prev = p_out
    for i in range(n_layers):
        li, atti = TransformerBlock(prev_in,
                                    prev_in,
                                    prev,
                                    dim,
                                    mask=True,
                                    random_state=random_state,
                                    name="dl{}".format(i))
        prev = li
        dec_atts.append(atti)
    prob_logits = Linear([prev], [dim],
                         n_syms,
                         random_state=random_state,
                         name="prob_logits")
    return prob_logits, enc_atts, dec_atts
Ejemplo n.º 6
0
def create_model(inp_tm1, inp_t, h1_init, c1_init, h1_q_init, c1_q_init):
    p_tm1 = Linear([inp_tm1], [1],
                   n_hid,
                   random_state=random_state,
                   name="proj_in",
                   init=forward_init)

    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t, h1_q_tm1], [n_hid, n_hid],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        output_q_t, output_i_t, output_nst_q_t, output_emb = VqEmbedding(
            output, n_hid, n_emb, random_state=random_state, name="out_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        output_i_t = tf.cast(h1_i_t, tf.float32)

        lf_output = Bilinear(h1_q_t,
                             n_hid,
                             output_emb,
                             n_hid,
                             random_state=random_state,
                             name="out_mix",
                             init=forward_init)
        rf_output = Bilinear(output_q_t,
                             n_hid,
                             h1_emb,
                             n_hid,
                             random_state=random_state,
                             name="h_mix",
                             init=forward_init)
        f_output = Linear([lf_output, rf_output], [n_emb, n_emb],
                          n_hid,
                          random_state=random_state,
                          name="out_f",
                          init=forward_init)

        # r[0]
        rets = [f_output]
        # r[1:3]
        rets += [h1_t, c1_t]
        # r[3:9]
        rets += [h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t, h1_emb]
        # r[9:]
        rets += [output_q_t, output_nst_q_t, output, output_i_t, output_emb]
        return rets

    outputs_info = [
        None, h1_init, c1_init, h1_q_init, c1_q_init, None, None, None, None,
        None, None, None, None, None
    ]
    r = scan(step, [p_tm1], outputs_info)
    out = r[0]
    hiddens = r[1]
    cells = r[2]
    q_hiddens = r[3]
    q_cells = r[4]
    q_nst_hiddens = r[5]
    q_nvq_hiddens = r[6]
    i_hiddens = r[7]
    emb_hiddens = r[8]
    q_out = r[9]
    q_nst_out = r[10]
    q_nvq_out = r[11]
    i_out = r[12]
    emb_out = r[13]

    l1 = Linear([out, q_hiddens], [n_hid, n_hid],
                n_hid,
                random_state=random_state,
                name="l1",
                init=forward_init)
    r_l1 = ReLU(l1)
    pred = Linear([r_l1], [n_hid],
                  1,
                  random_state=random_state,
                  name="out",
                  init=forward_init)
    outs_names = [
        "pred", "hiddens", "cells", "q_hiddens", "q_cells", "q_nst_hiddens",
        "q_nvq_hiddens", "i_hiddens", "emb_hiddens", "q_out", "q_nst_out",
        "q_nvq_out", "i_out", "emb_out"
    ]
    outs_tf = [eval(name) for name in outs_names]
    c = namedtuple("Core", outs_names)(*outs_tf)
    return c
Ejemplo n.º 7
0
def create_model(inp_tm1, h1_q_init):
    def step(x_t, h1_tm1):
        output, s = GRUCell([x_t], [1],
                            h1_tm1,
                            n_hid,
                            random_state=random_state,
                            name="rnn1",
                            init=rnn_init)
        h1_cq_t = s[0]
        """
        output, s = LSTMCell([h1_t], [n_hid], h1_q_tm1, c1_q_tm1, n_hid,
                             random_state=random_state,
                             name="rnn1_q", init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]
        """
        qhs = []
        ihs = []
        nst_qhs = []
        embs = []
        for i in list(range(n_split)):
            e_div = int(n_hid / n_split)
            h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
                h1_cq_t[:, i * e_div:(i + 1) * e_div],
                e_div,
                n_emb,
                random_state=random_state,
                # shared space?
                name="h1_vq_emb")
            #name="h1_{}_vq_emb".format(i))
            qhs.append(h1_q_t)
            ihs.append(h1_i_t[:, None])
            nst_qhs.append(h1_nst_q_t)
            embs.append(h1_emb)
        h1_q_t = tf.concat(qhs, axis=-1)
        h1_nst_q_t = tf.concat(nst_qhs, axis=-1)
        h1_i_t = tf.concat(ihs, axis=-1)

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t

    r = scan(step, [inp_tm1], [None, h1_q_init, None, None, None])
    out = r[0]
    q_hiddens = r[1]
    q_nst_hiddens = r[2]
    q_nvq_hiddens = r[3]
    i_hiddens = r[4]

    l1 = Linear([out], [n_hid],
                n_hid,
                random_state=random_state,
                name="l1",
                init=forward_init)
    r_l1 = ReLU(l1)
    pred = Linear([r_l1], [n_hid],
                  1,
                  random_state=random_state,
                  name="out",
                  init=forward_init)

    return pred, q_hiddens, q_nst_hiddens, q_nvq_hiddens, i_hiddens
def create_graph():
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(2899)

        text = tf.placeholder(tf.float32, shape=[None, batch_size, 1])
        text_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        mask = tf.placeholder(tf.float32, shape=[None, batch_size, 1])
        mask_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        mels = tf.placeholder(tf.float32,
                              shape=[None, batch_size, output_size])
        mel_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        bias = tf.placeholder_with_default(tf.zeros(shape=[]), shape=[])
        cell_dropout = tf.placeholder_with_default(cell_dropout_scale *
                                                   tf.ones(shape=[]),
                                                   shape=[])
        prenet_dropout = tf.placeholder_with_default(0.5 * tf.ones(shape=[]),
                                                     shape=[])
        bn_flag = tf.placeholder_with_default(tf.zeros(shape=[]), shape=[])

        att_w_init = tf.placeholder(tf.float32,
                                    shape=[batch_size, 2 * enc_units])
        att_k_init = tf.placeholder(tf.float32,
                                    shape=[batch_size, window_mixtures])
        att_h_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        att_c_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        h1_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        c1_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        h2_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        c2_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])

        in_mels = mels[:-1, :, :]
        in_mel_mask = mel_mask[:-1]
        out_mels = mels[1:, :, :]
        out_mel_mask = mel_mask[1:]

        projmel1 = Linear([in_mels], [output_size],
                          prenet_units,
                          dropout_flag_prob_keep=prenet_dropout,
                          name="prenet1",
                          random_state=random_state)
        projmel2 = Linear([projmel1], [prenet_units],
                          prenet_units,
                          dropout_flag_prob_keep=prenet_dropout,
                          name="prenet2",
                          random_state=random_state)

        text_char_e, t_c_emb = Embedding(text,
                                         vocabulary_size,
                                         emb_dim,
                                         random_state=random_state,
                                         name="text_char_emb")
        text_phone_e, t_p_emb = Embedding(text,
                                          vocabulary_size,
                                          emb_dim,
                                          random_state=random_state,
                                          name="text_phone_emb")

        text_e = (1. - mask) * text_char_e + mask * text_phone_e

        # masks are either 0 or 1... use embed + voc size of two so that text and mask embs have same size / same impact on the repr
        mask_e, m_emb = Embedding(mask,
                                  2,
                                  emb_dim,
                                  random_state=random_state,
                                  name="mask_emb")
        conv_text = SequenceConv1dStack([text_e + mask_e], [emb_dim],
                                        n_filts,
                                        bn_flag,
                                        n_stacks=n_stacks,
                                        kernel_sizes=[(1, 1), (3, 3), (5, 5)],
                                        name="enc_conv1",
                                        random_state=random_state)

        # text_mask and mask_mask should be the same, doesn't matter which one we use
        bitext = BiLSTMLayer([conv_text], [n_filts],
                             enc_units,
                             input_mask=text_mask,
                             name="encode_bidir",
                             init=rnn_init,
                             random_state=random_state)

        def step(inp_t, inp_mask_t, corr_inp_t, att_w_tm1, att_k_tm1,
                 att_h_tm1, att_c_tm1, h1_tm1, c1_tm1, h2_tm1, c2_tm1):

            o = GaussianAttentionCell(
                [corr_inp_t],
                [prenet_units],
                (att_h_tm1, att_c_tm1),
                att_k_tm1,
                bitext,
                2 * enc_units,
                dec_units,
                att_w_tm1,
                input_mask=inp_mask_t,
                conditioning_mask=text_mask,
                #attention_scale=1. / 10.,
                attention_scale=1.,
                step_op="softplus",
                name="att",
                random_state=random_state,
                cell_dropout=1.,  #cell_dropout,
                init=rnn_init)
            att_w_t, att_k_t, att_phi_t, s = o
            att_h_t = s[0]
            att_c_t = s[1]

            output, s = LSTMCell([corr_inp_t, att_w_t, att_h_t],
                                 [prenet_units, 2 * enc_units, dec_units],
                                 h1_tm1,
                                 c1_tm1,
                                 dec_units,
                                 input_mask=inp_mask_t,
                                 random_state=random_state,
                                 cell_dropout=cell_dropout,
                                 name="rnn1",
                                 init=rnn_init)
            h1_t = s[0]
            c1_t = s[1]

            output, s = LSTMCell([corr_inp_t, att_w_t, h1_t],
                                 [prenet_units, 2 * enc_units, dec_units],
                                 h2_tm1,
                                 c2_tm1,
                                 dec_units,
                                 input_mask=inp_mask_t,
                                 random_state=random_state,
                                 cell_dropout=cell_dropout,
                                 name="rnn2",
                                 init=rnn_init)
            h2_t = s[0]
            c2_t = s[1]
            return output, att_w_t, att_k_t, att_phi_t, att_h_t, att_c_t, h1_t, c1_t, h2_t, c2_t

        r = scan(step, [in_mels, in_mel_mask, projmel2], [
            None, att_w_init, att_k_init, None, att_h_init, att_c_init,
            h1_init, c1_init, h2_init, c2_init
        ])
        output = r[0]
        att_w = r[1]
        att_k = r[2]
        att_phi = r[3]
        att_h = r[4]
        att_c = r[5]
        h1 = r[6]
        c1 = r[7]
        h2 = r[8]
        c2 = r[9]

        pred = Linear([output], [dec_units],
                      output_size,
                      name="out_proj",
                      random_state=random_state)
        """
        mix, means, lins = DiscreteMixtureOfLogistics([proj], [output_size], n_output_channels=1,
                                                      name="dml", random_state=random_state)
        cc = DiscreteMixtureOfLogisticsCost(mix, means, lins, out_mels, 256)
        """

        # correct masking
        cc = (pred - out_mels)**2
        #cc = out_mel_mask[..., None] * cc
        #loss = tf.reduce_sum(tf.reduce_sum(cc, axis=-1)) / tf.reduce_sum(out_mel_mask)
        loss = tf.reduce_mean(tf.reduce_sum(cc, axis=-1))

        learning_rate = 0.0001
        #steps = tf.Variable(0.)
        #learning_rate = tf.train.exponential_decay(0.001, steps, staircase=True,
        #                                           decay_steps=50000, decay_rate=0.5)

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           use_locking=True)
        grad, var = zip(*optimizer.compute_gradients(loss))
        grad, _ = tf.clip_by_global_norm(grad, 10.)
        #train_step = optimizer.apply_gradients(zip(grad, var), global_step=steps)
        train_step = optimizer.apply_gradients(zip(grad, var))

    things_names = [
        "mels",
        "mel_mask",
        "in_mels",
        "in_mel_mask",
        "out_mels",
        "out_mel_mask",
        "text",
        "text_mask",
        "mask",
        "mask_mask",
        "bias",
        "cell_dropout",
        "prenet_dropout",
        "bn_flag",
        "pred",
        #"mix", "means", "lins",
        "att_w_init",
        "att_k_init",
        "att_h_init",
        "att_c_init",
        "h1_init",
        "c1_init",
        "h2_init",
        "c2_init",
        "att_w",
        "att_k",
        "att_phi",
        "att_h",
        "att_c",
        "h1",
        "c1",
        "h2",
        "c2",
        "loss",
        "train_step",
        "learning_rate"
    ]
    things_tf = [eval(name) for name in things_names]
    for tn, tt in zip(things_names, things_tf):
        graph.add_to_collection(tn, tt)
    train_model = namedtuple('Model', things_names)(*things_tf)
    return graph, train_model
Ejemplo n.º 9
0
        def create_model():
            in_speech = speech[:-1, :, :]
            in_speech_mask = speech_mask[:-1]
            out_speech = speech[1:, :, :]
            out_speech_mask = speech_mask[1:]

            def step(inp_t, inp_mask_t, att_w_tm1, att_k_tm1, att_h_tm1,
                     att_c_tm1, h1_tm1, c1_tm1, h2_tm1, c2_tm1):

                o = GaussianAttentionCell([inp_t], [speech_size],
                                          (att_h_tm1, att_c_tm1),
                                          att_k_tm1,
                                          sequence,
                                          num_letters,
                                          num_units,
                                          att_w_tm1,
                                          input_mask=inp_mask_t,
                                          conditioning_mask=sequence_mask,
                                          attention_scale=1. / 10.,
                                          name="att",
                                          random_state=random_state,
                                          cell_dropout=cell_dropout,
                                          init=rnn_init)
                att_w_t, att_k_t, att_phi_t, s = o
                att_h_t = s[0]
                att_c_t = s[1]

                output, s = LSTMCell([inp_t, att_w_t, att_h_t],
                                     [speech_size, num_letters, num_units],
                                     h1_tm1,
                                     c1_tm1,
                                     num_units,
                                     input_mask=inp_mask_t,
                                     random_state=random_state,
                                     cell_dropout=cell_dropout,
                                     name="rnn1",
                                     init=rnn_init)
                h1_t = s[0]
                c1_t = s[1]

                output, s = LSTMCell([inp_t, att_w_t, h1_t],
                                     [speech_size, num_letters, num_units],
                                     h2_tm1,
                                     c2_tm1,
                                     num_units,
                                     input_mask=inp_mask_t,
                                     random_state=random_state,
                                     cell_dropout=cell_dropout,
                                     name="rnn2",
                                     init=rnn_init)
                h2_t = s[0]
                c2_t = s[1]
                return output, att_w_t, att_k_t, att_phi_t, att_h_t, att_c_t, h1_t, c1_t, h2_t, c2_t

            r = scan(step, [in_speech, in_speech_mask], [
                None, att_w_init, att_k_init, None, att_h_init, att_c_init,
                h1_init, c1_init, h2_init, c2_init
            ])
            output = r[0]
            att_w = r[1]
            att_k = r[2]
            att_phi = r[3]
            att_h = r[4]
            att_c = r[5]
            h1 = r[6]
            c1 = r[7]
            h2 = r[8]
            c2 = r[9]

            mean_pred = Linear([output], [num_units],
                               speech_size,
                               random_state=random_state,
                               init=forward_init,
                               name="mean_proj")
            loss = tf.reduce_mean(tf.square(mean_pred - out_speech))

            # save params for easier model loading and prediction
            for param in [('speech', speech), ('in_speech', in_speech),
                          ('out_speech', out_speech),
                          ('speech_mask', speech_mask),
                          ('in_speech_mask', in_speech_mask),
                          ('out_speech_mask', out_speech_mask),
                          ('sequence', sequence),
                          ('sequence_mask', sequence_mask), ('bias', bias),
                          ('cell_dropout', cell_dropout),
                          ('att_w_init', att_w_init), ('att_k_init',
                                                       att_k_init),
                          ('att_h_init', att_h_init),
                          ('att_c_init', att_c_init), ('h1_init', h1_init),
                          ('c1_init', c1_init), ('h2_init', h2_init),
                          ('c2_init', c2_init), ('att_w', att_w),
                          ('att_k', att_k), ('att_phi', att_phi),
                          ('att_h', att_h), ('att_c', att_c), ('h1', h1),
                          ('c1', c1), ('h2', h2), ('c2', c2),
                          ('mean_pred', mean_pred)]:
                tf.add_to_collection(*param)

            with tf.name_scope('training'):
                learning_rate = 0.0001
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                                   use_locking=True)
                grad, var = zip(*optimizer.compute_gradients(loss))
                grad, _ = tf.clip_by_global_norm(grad, 3.)
                train_step = optimizer.apply_gradients(zip(grad, var))

            with tf.name_scope('summary'):
                # TODO: add more summaries
                summary = tf.summary.merge([tf.summary.scalar('loss', loss)])

            things_names = [
                "speech", "speech_mask", "in_speech", "in_speech_mask",
                "out_speech", "out_speech_mask", "sequence", "sequence_mask",
                "att_w_init", "att_k_init", "att_h_init", "att_c_init",
                "h1_init", "c1_init", "h2_init", "c2_init", "att_w", "att_k",
                "att_phi", "att_h", "att_c", "h1", "c1", "h2", "c2",
                "mean_pred", "loss", "train_step", "learning_rate", "summary"
            ]
            things_tf = [
                speech, speech_mask, in_speech, in_speech_mask, out_speech,
                out_speech_mask, sequence, sequence_mask, att_w_init,
                att_k_init, att_h_init, att_c_init, h1_init, c1_init, h2_init,
                c2_init, att_w, att_k, att_phi, att_h, att_c, h1, c1, h2, c2,
                mean_pred, loss, train_step, learning_rate, summary
            ]
            return namedtuple('Model', things_names)(*things_tf)