Пример #1
0
def svhn_encoder(x, numHidden, labels, num_labels, mb_size, image_width):

    in_width = image_width
    layerLst = []

    c = [3, 64, 128, 256, 256]

    layerLst += [ConvPoolLayer(in_channels = c[0], out_channels = c[1], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[1], out_channels = c[1], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[1], out_channels = c[1], kernel_len = 5, stride=2, batch_norm = False)]

    layerLst += [ConvPoolLayer(in_channels = c[1], out_channels = c[2], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[2], out_channels = c[2], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[2], out_channels = c[2], kernel_len = 5, stride=2, batch_norm = False)]

    layerLst += [ConvPoolLayer(in_channels = c[2], out_channels = c[3], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[3], out_channels = c[3], kernel_len = 5, batch_norm = False)]
    layerLst += [ConvPoolLayer(in_channels = c[3], out_channels = c[4], kernel_len = 5, stride=2, batch_norm = False)]

    layerLst += [HiddenLayer(num_in = 4 * 4 * c[4], num_out = numHidden, flatten_input = True, batch_norm = False)]

    layerLst += [HiddenLayer(num_in = numHidden, num_out = numHidden, batch_norm = True)]

    outputs = [normalize(x.transpose(0,3,1,2))]

    for i in range(0, len(layerLst)):
        outputs += [layerLst[i].output(outputs[-1])]

    h1 = HiddenLayer(num_in = numHidden + num_labels, num_out = numHidden, batch_norm = True)
    h2 = HiddenLayer(num_in = numHidden, num_out = numHidden, batch_norm = True)

    h1_out = h1.output(T.concatenate([outputs[-1], labels], axis = 1))
    h2_out = h2.output(h1_out)

    return {'layers' : layerLst + [h1,h2], 'extra_params' : [], 'output' : h2_out}
Пример #2
0
            params += [layerParams[paramKey]]
            l2_loss += T.mean(T.sqr(params[-1]))

    l2_loss = 0.0001 * T.sqrt(l2_loss)

    params += encoder_extra_params + decoder_extra_params

    for param in params:
        print param.get_value().shape

    variational_loss = config['vae_weight'] * 0.5 * T.sum(z_mean**2 + z_var - T.log(z_var) - 1.0)


    #smoothness_penalty = 0.001 * (total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,0:1,:,:]) + total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,1:2,:,:]) + total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,2:3,:,:]))

    square_loss = config['square_loss_weight'] * 1.0 * T.sum(T.sqr(normalize(x) - normalize(x_reconstructed)))

    loss = 0.0

    loss += l2_loss

    loss += square_loss

    netDist = NetDist(x, x_reconstructed, config)

    if config['style_weight'] > 0.0:
        style_loss, style_out_1, style_out_2 = netDist.get_dist_style()
        style_loss *= config['style_weight']
    else:
        style_loss = style_out_1 = style_out_2 = theano.shared(np.asarray(0.0).astype('float32'))
Пример #3
0
            params += [layerParams[paramKey]]
            l2_loss += T.mean(T.sqr(params[-1]))

    l2_loss = 0.0001 * T.sqrt(l2_loss)

    params += encoder_extra_params + decoder_extra_params

    for param in params:
        print param.get_value().shape

    variational_loss = config['vae_weight'] * 0.5 * T.sum(z_mean**2 + z_var - T.log(z_var) - 1.0)

    smoothness_penalty = 0.0 * (total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,0:1,:,:]) + total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,1:2,:,:]) + total_denoising_variation_penalty(x_reconstructed.transpose(0,3,1,2)[:,2:3,:,:]))


    raw_square_loss = T.sum(T.sqr(normalize(x) - normalize(x_reconstructed)))

    square_loss = raw_square_loss * config['square_loss_weight']

    loss = 0.0

    loss += l2_loss

    loss += square_loss

    netDist = NetDist(x, x_reconstructed, config)

    if config['style_weight'] > 0.0:
        style_loss, style_out_1, style_out_2 = netDist.get_dist_style()
        style_loss *= config['style_weight']
    else: