recon_samples_list = self.vhv(vis_samples) vis_p = self._one_rbm_compute_down(self.rbm_list[0], recon_samples_list[1]) vis_samples = recon_samples_list[0] return x+1, vis_p, vis_samples _, prob_imgs, sampled_imgs = tf.while_loop(cond, body, [0, vis, vis], back_prop=False) return prob_imgs, sampled_imgs if __name__ == '__main__': if len(sys.argv) < 3 or len(sys.argv) > 4: print 'usage: python drbn.py pcd/cd cd-k [output_dir]' sys.exit() else: use_pcd = (sys.argv[1] == 'pcd') cd_k = int(sys.argv[2]) output_dir = None if len(sys.argv) == 3 else sys.argv[3] (train_xs, _), _, _ = cPickle.load(file('mnist.pkl', 'rb')) train_xs = train_xs.reshape((-1, 28, 28, 1)) batch_size = 20 lr = 0.001 if use_pcd else 0.1 # drbn = DRBN([784]) drbn = DRBN([28, 28, 1]) drbn.add_conv_layer((5, 5, 1, 64), (2, 2), 'SAME', 'conv1') drbn.add_conv_layer((5, 5, 64, 64), (2, 2), 'SAME', 'conv2') drbn.add_fc_layer(500, 'fc1') drbn.print_network() train_rbm.train(drbn, train_xs, lr, 40, batch_size, use_pcd, cd_k, output_dir)
if keras.backend.image_dim_ordering() != 'tf': keras.backend.set_image_dim_ordering('tf') print "INFO: temporarily set 'image_dim_ordering' to 'tf'" (train_xs, _), (_, _) = cifar10.load_data() train_xs, mean, std = utils.preprocess_cifar10(train_xs) batch_size = 20 pcd_chain_size = 100 lr = 0.0001 if use_pcd else 1e-5 train_xs = train_xs[:, :, :, :1] #.reshape(-1, 32*32) rbm = GaussianCRBM((32, 32, 1), (12, 12, 1, 256), (2, 2), 'VALID', output_dir, {}) # rbm = GaussianRBM(32*32, 1000, output_dir) # drbn = DRBN([32*32], output_dir) # drbn.add_fc_layer(500, 'fc1', use_gaussian=True) # drbn.add_fc_layer(500, 'fc2') # drbn.add_fc_layer(1000, 'fc3') # drbn.print_network() train_rbm.train(rbm, train_xs, lr, 40, batch_size, use_pcd, cd_k, output_dir, pcd_chain_size=pcd_chain_size) # , mean, std)