def create_pixel_cnn(inp, lbl, cond):
    e_inp, emb = Embedding(inp, 256, n_channels, random_state=random_state, name="inp_emb")
    c_inp, emb = Embedding(cond, 256, n_channels, random_state=random_state, name="cond_emb")
    l1_v, l1_h = GatedMaskedConv2d([e_inp], [n_channels], [e_inp], [n_channels],
                                   n_channels,
                                   residual=False,
                                   conditioning_class_input=lbl,
                                   conditioning_num_classes=n_labels,
                                   conditioning_spatial_map=c_inp,
                                   kernel_size=kernel_size0, name="pcnn0",
                                   mask_type="img_A",
                                   random_state=random_state)
    o_v = l1_v
    o_h = l1_h
    for i in range(n_layers - 1):
        t_v, t_h = GatedMaskedConv2d([o_v], [n_channels], [o_h], [n_channels],
                                     n_channels,
                                     conditioning_class_input=lbl,
                                     conditioning_num_classes=n_labels,
                                     conditioning_spatial_map=c_inp,
                                     kernel_size=kernel_size1, name="pcnn{}".format(i + 1),
                                     mask_type="img_B",
                                     random_state=random_state)
        o_v = t_v
        o_h = t_h

    cleanup = Conv2d([o_h], [n_channels], n_channels, kernel_size=(1, 1),
                     name="conv_c",
                     random_state=random_state)
    r_p = ReLU(cleanup)
    out = Conv2d([r_p], [n_channels], 256, kernel_size=(1, 1),
                 name="conv_o",
                 random_state=random_state)
    #s_out = Softmax(out)
    return out#s_out
Example #2
0
def create_encoder(inp, bn_flag):
    l1 = Conv2d([inp], [1],
                l_dims[0][0],
                kernel_size=l_dims[0][1:3],
                name="enc1",
                strides=l_dims[0][-1],
                border_mode=bpad,
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_enc1")
    r_l1 = ReLU(bn_l1)

    l2 = Conv2d([r_l1], [l_dims[0][0]],
                l_dims[1][0],
                kernel_size=l_dims[1][1:3],
                name="enc2",
                strides=l_dims[1][-1],
                border_mode=bpad,
                random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_enc2")
    r_l2 = ReLU(bn_l2)

    l3 = Conv2d([r_l2], [l_dims[1][0]],
                l_dims[2][0],
                kernel_size=l_dims[2][1:3],
                name="enc3",
                random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_enc3")
    return bn_l3
Example #3
0
def create_decoder(latent, bn_flag):
    l1 = Conv2d([latent], [l_dims[2][0]],
                l_dims[1][0],
                kernel_size=l_dims[2][1:3],
                name="dec1",
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_dec3")
    r_l1 = ReLU(bn_l1)

    l2 = ConvTranspose2d([r_l1], [l_dims[1][0]],
                         l_dims[0][0],
                         kernel_size=l_dims[1][1:3],
                         name="dec2",
                         strides=l_dims[1][-1],
                         border_mode=bpad,
                         random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_dec2")
    r_l2 = ReLU(bn_l2)

    l3 = ConvTranspose2d([r_l2], [l_dims[0][0]],
                         1,
                         kernel_size=l_dims[0][1:3],
                         name="dec3",
                         strides=l_dims[0][-1],
                         border_mode=bpad,
                         random_state=random_state)
    s_l3 = Sigmoid(l3)
    return s_l3
Example #4
0
def create_decoder(latent, bn_flag):
    l1 = Conv2d([latent], [l_dims[3][0]], l_dims[2][0], kernel_size=l_dims[3][1:3], name="dec1",
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_dec1")
    r_l1 = ReLU(bn_l1)

    l2 = ConvTranspose2d([r_l1], [l_dims[2][0]], l_dims[1][0], kernel_size=l_dims[2][1:3], name="dec2",
                         strides=l_dims[2][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_dec2")
    r_l2 = ReLU(bn_l2)

    l3 = ConvTranspose2d([r_l2], [l_dims[1][0]], l_dims[0][0], kernel_size=l_dims[1][1:3], name="dec3",
                         strides=l_dims[1][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_dec3")
    r_l3 = ReLU(bn_l3)

    # hack it and do depth to space
    in_chan = l_dims[0][0]
    out_chan = 257
    kernel_sz = [l_dims[0][1], l_dims[0][2]]
    kernel_sz[1] = 1
    l4 = ConvTranspose2d([r_l3], [in_chan], out_chan, kernel_size=kernel_sz, name="dec4",
                         strides=l_dims[0][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    s_l4 = Sigmoid(l4)
    s_l4 = tf.transpose(s_l4, (0, 1, 3, 2))
    return s_l4
Example #5
0
def create_decoder(latent, bn_flag):
    l1 = Conv2d([latent], [l_dims[-1][0]],
                l_dims[-2][0],
                kernel_size=l_dims[-1][1:3],
                name="dec1",
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_dec1")
    r_l1 = ReLU(bn_l1)

    l2 = ConvTranspose2d([r_l1], [l_dims[-2][0]],
                         l_dims[-3][0],
                         kernel_size=l_dims[-2][1:3],
                         name="dec2",
                         strides=l_dims[-2][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_dec2")
    r_l2 = ReLU(bn_l2)

    l3 = ConvTranspose2d([r_l2], [l_dims[-3][0]],
                         4 * n_out,
                         kernel_size=l_dims[-3][1:3],
                         name="dec3",
                         strides=l_dims[-3][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    return tf.reshape(l3, (-1, 1, 48, 4, n_out))
Example #6
0
def create_decoder(latent, bn_flag):
    l1 = Conv2d([latent], [l_dims[-1][0]],
                l_dims[-2][0],
                kernel_size=l_dims[-1][1:3],
                name="dec1",
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_dec1")
    r_l1 = ReLU(bn_l1)

    l2 = ConvTranspose2d([r_l1], [l_dims[-2][0]],
                         l_dims[-3][0],
                         kernel_size=l_dims[-2][1:3],
                         name="dec2",
                         strides=l_dims[-2][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_dec2")
    r_l2 = ReLU(bn_l2)

    l3 = ConvTranspose2d([r_l2], [l_dims[-3][0]],
                         l_dims[-4][0],
                         kernel_size=l_dims[-3][1:3],
                         name="dec3",
                         strides=l_dims[-3][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_dec3")
    r_l3 = ReLU(bn_l3)

    l4 = ConvTranspose2d([r_l3], [l_dims[-4][0]],
                         l_dims[-5][0],
                         kernel_size=l_dims[-4][1:3],
                         name="dec4",
                         strides=l_dims[-4][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l4 = BatchNorm2d(l4, bn_flag, name="bn_dec4")
    r_l4 = ReLU(bn_l4)

    l5 = ConvTranspose2d([r_l4], [l_dims[-5][0]],
                         dmol_proj,
                         kernel_size=l_dims[-5][1:3],
                         name="dec5",
                         strides=l_dims[-5][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    l5_mix, l5_means, l5_lin_scales = DiscreteMixtureOfLogistics(
        [l5], [dmol_proj],
        n_components=n_components,
        name="d_out",
        random_state=random_state)
    return l5_mix, l5_means, l5_lin_scales
Example #7
0
def create_encoder(inp, bn_flag):
    e_inps = []
    for ci in range(4):
        e_inp, emb = Embedding(inp[..., ci][..., None],
                               n_out,
                               inp_emb_dim,
                               random_state=random_state,
                               name="inp_emb_{}".format(ci))
        e_inps.append(e_inp)
    e_inp = tf.concat(e_inps, axis=-1)
    l1 = Conv2d([e_inp], [4 * inp_emb_dim],
                l_dims[0][0],
                kernel_size=l_dims[0][1:3],
                name="enc1",
                strides=l_dims[0][-1],
                border_mode=ebpad,
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_enc1")
    r_l1 = ReLU(bn_l1)

    l2 = Conv2d([r_l1], [l_dims[0][0]],
                l_dims[1][0],
                kernel_size=l_dims[1][1:3],
                name="enc2",
                strides=l_dims[1][-1],
                border_mode=ebpad,
                random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_enc2")
    r_l2 = ReLU(bn_l2)

    l3 = Conv2d([r_l2], [l_dims[1][0]],
                l_dims[2][0],
                kernel_size=l_dims[2][1:3],
                name="enc3",
                random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_enc3")
    return bn_l3
Example #8
0
def create_decoder(latent, bn_flag):
    l1 = Conv2d([latent], [l_dims[-1][0]],
                l_dims[-2][0],
                kernel_size=l_dims[-1][1:3],
                name="dec1",
                random_state=random_state)
    bn_l1 = BatchNorm2d(l1, bn_flag, name="bn_dec1")
    r_l1 = ReLU(bn_l1)

    l2 = ConvTranspose2d([r_l1], [l_dims[-2][0]],
                         l_dims[-3][0],
                         kernel_size=l_dims[-2][1:3],
                         name="dec2",
                         strides=l_dims[-2][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l2 = BatchNorm2d(l2, bn_flag, name="bn_dec2")
    r_l2 = ReLU(bn_l2)

    l3 = ConvTranspose2d([r_l2], [l_dims[-3][0]],
                         l_dims[-4][0],
                         kernel_size=l_dims[-3][1:3],
                         name="dec3",
                         strides=l_dims[-3][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l3 = BatchNorm2d(l3, bn_flag, name="bn_dec3")
    r_l3 = ReLU(bn_l3)

    l4 = ConvTranspose2d([r_l3], [l_dims[-4][0]],
                         l_dims[-5][0],
                         kernel_size=l_dims[-4][1:3],
                         name="dec4",
                         strides=l_dims[-4][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    bn_l4 = BatchNorm2d(l4, bn_flag, name="bn_dec4")
    r_l4 = ReLU(bn_l4)

    l5 = ConvTranspose2d([r_l4], [l_dims[-5][0]],
                         1,
                         kernel_size=l_dims[-5][1:3],
                         name="dec5",
                         strides=l_dims[-5][-1],
                         border_mode=dbpad,
                         random_state=random_state)
    #s_l5 = Sigmoid(l5)
    t_l5 = Tanh(l5)
    return t_l5