예제 #1
0
            recon_samples_list = self.vhv(vis_samples)
            vis_p = self._one_rbm_compute_down(self.rbm_list[0], recon_samples_list[1])
            vis_samples = recon_samples_list[0]
            return x+1, vis_p, vis_samples

        _, prob_imgs, sampled_imgs = tf.while_loop(cond, body, [0, vis, vis], back_prop=False)
        return prob_imgs, sampled_imgs

if __name__ == '__main__':
    if len(sys.argv) < 3 or len(sys.argv) > 4:
        print 'usage: python drbn.py pcd/cd cd-k [output_dir]'
        sys.exit()
    else:
        use_pcd = (sys.argv[1] == 'pcd')
        cd_k = int(sys.argv[2])
        output_dir = None if len(sys.argv) == 3 else sys.argv[3]

    (train_xs, _), _, _ = cPickle.load(file('mnist.pkl', 'rb'))
    train_xs = train_xs.reshape((-1, 28, 28, 1))
    batch_size = 20
    lr = 0.001 if use_pcd else 0.1

    # drbn = DRBN([784])
    drbn = DRBN([28, 28, 1])
    drbn.add_conv_layer((5, 5, 1, 64), (2, 2), 'SAME', 'conv1')
    drbn.add_conv_layer((5, 5, 64, 64), (2, 2), 'SAME', 'conv2')
    drbn.add_fc_layer(500, 'fc1')
    drbn.print_network()

    train_rbm.train(drbn, train_xs, lr, 40, batch_size, use_pcd, cd_k, output_dir)
예제 #2
0
    if keras.backend.image_dim_ordering() != 'tf':
        keras.backend.set_image_dim_ordering('tf')
        print "INFO: temporarily set 'image_dim_ordering' to 'tf'"

    (train_xs, _), (_, _) = cifar10.load_data()
    train_xs, mean, std = utils.preprocess_cifar10(train_xs)
    batch_size = 20
    pcd_chain_size = 100
    lr = 0.0001 if use_pcd else 1e-5

    train_xs = train_xs[:, :, :, :1]  #.reshape(-1, 32*32)
    rbm = GaussianCRBM((32, 32, 1), (12, 12, 1, 256), (2, 2), 'VALID',
                       output_dir, {})
    # rbm = GaussianRBM(32*32, 1000, output_dir)
    # drbn = DRBN([32*32], output_dir)
    # drbn.add_fc_layer(500, 'fc1', use_gaussian=True)
    # drbn.add_fc_layer(500, 'fc2')
    # drbn.add_fc_layer(1000, 'fc3')
    # drbn.print_network()

    train_rbm.train(rbm,
                    train_xs,
                    lr,
                    40,
                    batch_size,
                    use_pcd,
                    cd_k,
                    output_dir,
                    pcd_chain_size=pcd_chain_size)  # , mean, std)