def create_pixel_cnn(inp, lbl, cond):
    e_inp, emb = Embedding(inp, 256, n_channels, random_state=random_state, name="inp_emb")
    c_inp, emb = Embedding(cond, 256, n_channels, random_state=random_state, name="cond_emb")
    l1_v, l1_h = GatedMaskedConv2d([e_inp], [n_channels], [e_inp], [n_channels],
                                   n_channels,
                                   residual=False,
                                   conditioning_class_input=lbl,
                                   conditioning_num_classes=n_labels,
                                   conditioning_spatial_map=c_inp,
                                   kernel_size=kernel_size0, name="pcnn0",
                                   mask_type="img_A",
                                   random_state=random_state)
    o_v = l1_v
    o_h = l1_h
    for i in range(n_layers - 1):
        t_v, t_h = GatedMaskedConv2d([o_v], [n_channels], [o_h], [n_channels],
                                     n_channels,
                                     conditioning_class_input=lbl,
                                     conditioning_num_classes=n_labels,
                                     conditioning_spatial_map=c_inp,
                                     kernel_size=kernel_size1, name="pcnn{}".format(i + 1),
                                     mask_type="img_B",
                                     random_state=random_state)
        o_v = t_v
        o_h = t_h

    cleanup = Conv2d([o_h], [n_channels], n_channels, kernel_size=(1, 1),
                     name="conv_c",
                     random_state=random_state)
    r_p = ReLU(cleanup)
    out = Conv2d([r_p], [n_channels], 256, kernel_size=(1, 1),
                 name="conv_o",
                 random_state=random_state)
    #s_out = Softmax(out)
    return out#s_out
Exemplo n.º 2
0
def create_model(inp_tm1, inp_t, cell_dropout, h1_init, c1_init, h1_q_init,
                 c1_q_init):
    e_tm1, emb_r = Embedding(inp_tm1,
                             n_inputs,
                             in_emb,
                             random_state=random_state,
                             name="in_emb")

    def step(x_t, h1_tm1, c1_tm1, h1_q_tm1, c1_q_tm1):
        output, s = LSTMCell([x_t], [in_emb],
                             h1_tm1,
                             c1_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1",
                             init=rnn_init)
        h1_t = s[0]
        c1_t = s[1]

        output, s = LSTMCell([h1_t], [n_hid],
                             h1_q_tm1,
                             c1_q_tm1,
                             n_hid,
                             random_state=random_state,
                             cell_dropout=cell_dropout,
                             name="rnn1_q",
                             init=rnn_init)
        h1_cq_t = s[0]
        c1_q_t = s[1]

        h1_q_t, h1_i_t, h1_nst_q_t, h1_emb = VqEmbedding(
            h1_cq_t, n_hid, n_emb, random_state=random_state, name="h1_vq_emb")

        # not great
        h1_i_t = tf.cast(h1_i_t, tf.float32)
        return output, h1_t, c1_t, h1_q_t, c1_q_t, h1_nst_q_t, h1_cq_t, h1_i_t

    r = scan(step, [e_tm1],
             [None, h1_init, c1_init, h1_q_init, c1_q_init, None, None, None])
    out = r[0]
    hiddens = r[1]
    cells = r[2]
    q_hiddens = r[3]
    q_cells = r[4]
    q_nst_hiddens = r[5]
    q_nvq_hiddens = r[6]
    i_hiddens = r[7]

    # tied weights?
    pred = Linear([out], [n_hid],
                  n_inputs,
                  random_state=random_state,
                  name="out",
                  init=forward_init)
    pred_sm = Softmax(pred)
    return pred_sm, pred, hiddens, cells, q_hiddens, q_cells, q_nst_hiddens, q_nvq_hiddens, i_hiddens
Exemplo n.º 3
0
def create_encoder(inp, bn_flag):
    e_inps = []
    for ci in range(4):
        e_inp, emb = Embedding(inp[..., ci][..., None],
                               n_out,
                               inp_emb_dim,
                               random_state=random_state,
                               name="inp_emb_{}".format(ci))
        e_inps.append(e_inp)
    e_inp = tf.concat(e_inps, axis=-1)
    l1 = Conv2d([e_inp], [4 * inp_emb_dim],
                l_dims[0][0],
                kernel_size=l_dims[0][1:3],
                name="enc1",
                strides=l_dims[0][-1],
                border_mode=ebpad,
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_enc1")
    r_l1 = ReLU(bn_l1)

    l2 = Conv2d([r_l1], [l_dims[0][0]],
                l_dims[1][0],
                kernel_size=l_dims[1][1:3],
                name="enc2",
                strides=l_dims[1][-1],
                border_mode=ebpad,
                random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_enc2")
    r_l2 = ReLU(bn_l2)

    l3 = Conv2d([r_l2], [l_dims[1][0]],
                l_dims[2][0],
                kernel_size=l_dims[2][1:3],
                name="enc3",
                random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_enc3")
    return bn_l3
def create_graph():
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(2899)

        text = tf.placeholder(tf.float32, shape=[None, batch_size, 1])
        text_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        mask = tf.placeholder(tf.float32, shape=[None, batch_size, 1])
        mask_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        mels = tf.placeholder(tf.float32,
                              shape=[None, batch_size, output_size])
        mel_mask = tf.placeholder(tf.float32, shape=[None, batch_size])

        bias = tf.placeholder_with_default(tf.zeros(shape=[]), shape=[])
        cell_dropout = tf.placeholder_with_default(cell_dropout_scale *
                                                   tf.ones(shape=[]),
                                                   shape=[])
        prenet_dropout = tf.placeholder_with_default(0.5 * tf.ones(shape=[]),
                                                     shape=[])
        bn_flag = tf.placeholder_with_default(tf.zeros(shape=[]), shape=[])

        att_w_init = tf.placeholder(tf.float32,
                                    shape=[batch_size, 2 * enc_units])
        att_k_init = tf.placeholder(tf.float32,
                                    shape=[batch_size, window_mixtures])
        att_h_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        att_c_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        h1_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        c1_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        h2_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])
        c2_init = tf.placeholder(tf.float32, shape=[batch_size, dec_units])

        in_mels = mels[:-1, :, :]
        in_mel_mask = mel_mask[:-1]
        out_mels = mels[1:, :, :]
        out_mel_mask = mel_mask[1:]

        projmel1 = Linear([in_mels], [output_size],
                          prenet_units,
                          dropout_flag_prob_keep=prenet_dropout,
                          name="prenet1",
                          random_state=random_state)
        projmel2 = Linear([projmel1], [prenet_units],
                          prenet_units,
                          dropout_flag_prob_keep=prenet_dropout,
                          name="prenet2",
                          random_state=random_state)

        text_char_e, t_c_emb = Embedding(text,
                                         vocabulary_size,
                                         emb_dim,
                                         random_state=random_state,
                                         name="text_char_emb")
        text_phone_e, t_p_emb = Embedding(text,
                                          vocabulary_size,
                                          emb_dim,
                                          random_state=random_state,
                                          name="text_phone_emb")

        text_e = (1. - mask) * text_char_e + mask * text_phone_e

        # masks are either 0 or 1... use embed + voc size of two so that text and mask embs have same size / same impact on the repr
        mask_e, m_emb = Embedding(mask,
                                  2,
                                  emb_dim,
                                  random_state=random_state,
                                  name="mask_emb")
        conv_text = SequenceConv1dStack([text_e + mask_e], [emb_dim],
                                        n_filts,
                                        bn_flag,
                                        n_stacks=n_stacks,
                                        kernel_sizes=[(1, 1), (3, 3), (5, 5)],
                                        name="enc_conv1",
                                        random_state=random_state)

        # text_mask and mask_mask should be the same, doesn't matter which one we use
        bitext = BiLSTMLayer([conv_text], [n_filts],
                             enc_units,
                             input_mask=text_mask,
                             name="encode_bidir",
                             init=rnn_init,
                             random_state=random_state)

        def step(inp_t, inp_mask_t, corr_inp_t, att_w_tm1, att_k_tm1,
                 att_h_tm1, att_c_tm1, h1_tm1, c1_tm1, h2_tm1, c2_tm1):

            o = GaussianAttentionCell(
                [corr_inp_t],
                [prenet_units],
                (att_h_tm1, att_c_tm1),
                att_k_tm1,
                bitext,
                2 * enc_units,
                dec_units,
                att_w_tm1,
                input_mask=inp_mask_t,
                conditioning_mask=text_mask,
                #attention_scale=1. / 10.,
                attention_scale=1.,
                step_op="softplus",
                name="att",
                random_state=random_state,
                cell_dropout=1.,  #cell_dropout,
                init=rnn_init)
            att_w_t, att_k_t, att_phi_t, s = o
            att_h_t = s[0]
            att_c_t = s[1]

            output, s = LSTMCell([corr_inp_t, att_w_t, att_h_t],
                                 [prenet_units, 2 * enc_units, dec_units],
                                 h1_tm1,
                                 c1_tm1,
                                 dec_units,
                                 input_mask=inp_mask_t,
                                 random_state=random_state,
                                 cell_dropout=cell_dropout,
                                 name="rnn1",
                                 init=rnn_init)
            h1_t = s[0]
            c1_t = s[1]

            output, s = LSTMCell([corr_inp_t, att_w_t, h1_t],
                                 [prenet_units, 2 * enc_units, dec_units],
                                 h2_tm1,
                                 c2_tm1,
                                 dec_units,
                                 input_mask=inp_mask_t,
                                 random_state=random_state,
                                 cell_dropout=cell_dropout,
                                 name="rnn2",
                                 init=rnn_init)
            h2_t = s[0]
            c2_t = s[1]
            return output, att_w_t, att_k_t, att_phi_t, att_h_t, att_c_t, h1_t, c1_t, h2_t, c2_t

        r = scan(step, [in_mels, in_mel_mask, projmel2], [
            None, att_w_init, att_k_init, None, att_h_init, att_c_init,
            h1_init, c1_init, h2_init, c2_init
        ])
        output = r[0]
        att_w = r[1]
        att_k = r[2]
        att_phi = r[3]
        att_h = r[4]
        att_c = r[5]
        h1 = r[6]
        c1 = r[7]
        h2 = r[8]
        c2 = r[9]

        pred = Linear([output], [dec_units],
                      output_size,
                      name="out_proj",
                      random_state=random_state)
        """
        mix, means, lins = DiscreteMixtureOfLogistics([proj], [output_size], n_output_channels=1,
                                                      name="dml", random_state=random_state)
        cc = DiscreteMixtureOfLogisticsCost(mix, means, lins, out_mels, 256)
        """

        # correct masking
        cc = (pred - out_mels)**2
        #cc = out_mel_mask[..., None] * cc
        #loss = tf.reduce_sum(tf.reduce_sum(cc, axis=-1)) / tf.reduce_sum(out_mel_mask)
        loss = tf.reduce_mean(tf.reduce_sum(cc, axis=-1))

        learning_rate = 0.0001
        #steps = tf.Variable(0.)
        #learning_rate = tf.train.exponential_decay(0.001, steps, staircase=True,
        #                                           decay_steps=50000, decay_rate=0.5)

        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           use_locking=True)
        grad, var = zip(*optimizer.compute_gradients(loss))
        grad, _ = tf.clip_by_global_norm(grad, 10.)
        #train_step = optimizer.apply_gradients(zip(grad, var), global_step=steps)
        train_step = optimizer.apply_gradients(zip(grad, var))

    things_names = [
        "mels",
        "mel_mask",
        "in_mels",
        "in_mel_mask",
        "out_mels",
        "out_mel_mask",
        "text",
        "text_mask",
        "mask",
        "mask_mask",
        "bias",
        "cell_dropout",
        "prenet_dropout",
        "bn_flag",
        "pred",
        #"mix", "means", "lins",
        "att_w_init",
        "att_k_init",
        "att_h_init",
        "att_c_init",
        "h1_init",
        "c1_init",
        "h2_init",
        "c2_init",
        "att_w",
        "att_k",
        "att_phi",
        "att_h",
        "att_c",
        "h1",
        "c1",
        "h2",
        "c2",
        "loss",
        "train_step",
        "learning_rate"
    ]
    things_tf = [eval(name) for name in things_names]
    for tn, tt in zip(things_names, things_tf):
        graph.add_to_collection(tn, tt)
    train_model = namedtuple('Model', things_names)(*things_tf)
    return graph, train_model