def test_bernoulli_csl(): print 'loading MNIST test set' # (train_x, _, # valid_x, _, # test_x, _) = data_provider.load_mnist(binary=False, standard_split=False) # (train_x_b, _, # valid_x_b, _, # test_x_b, _) = data_provider.load_mnist(binary=True, standard_split=False) (_, _), (_, _), (test_x, _) = data_tools.load_mnist('../../data') (_, _), (_, _), (test_x_b, _) = data_tools.load_mnist_binary('../../data') means = test_x.astype('float32') means = numpy.clip(means, 1e-10, (1 - (1e-5))) #means = numpy.random.uniform(size=(10000,784)).astype('float32') * 0 + 0.5 csl = CSL() minibatches = test_x_b.reshape((1000, 10, 784)).astype('float32') if 0: # when means is a matrix of (N,D), representing only 1 chain csl_fn = csl.get_CSL_fn_independent_Bernoulli_v2(means) csl.compute_CSL_with_minibatches_one_chain(csl_fn, minibatches) else: # when means is a 3D tensor (N, K, D) # When there are N chains, each chain having K samples of dimension D chains = means.reshape(10, 100, 10, 784) csl_fn = csl.get_CSL_fn_independent_Bernoulli() csl.compute_CSL_with_minibatches(csl_fn, minibatches, chains)
def test_bernoulli_csl(): print 'loading MNIST test set' # (train_x, _, # valid_x, _, # test_x, _) = data_provider.load_mnist(binary=False, standard_split=False) # (train_x_b, _, # valid_x_b, _, # test_x_b, _) = data_provider.load_mnist(binary=True, standard_split=False) (_,_), (_,_), (test_x,_) = data_tools.load_mnist('../../data') (_,_), (_,_), (test_x_b,_) = data_tools.load_mnist_binary('../../data') means = test_x.astype('float32') means = numpy.clip(means,1e-10,(1-(1e-5))) #means = numpy.random.uniform(size=(10000,784)).astype('float32') * 0 + 0.5 csl = CSL() minibatches = test_x_b.reshape((1000,10,784)).astype('float32') if 0: # when means is a matrix of (N,D), representing only 1 chain csl_fn = csl.get_CSL_fn_independent_Bernoulli_v2(means) csl.compute_CSL_with_minibatches_one_chain(csl_fn, minibatches) else: # when means is a 3D tensor (N, K, D) # When there are N chains, each chain having K samples of dimension D chains = means.reshape(10,100,10,784) csl_fn = csl.get_CSL_fn_independent_Bernoulli() csl.compute_CSL_with_minibatches(csl_fn, minibatches, chains)
def visualize_mnist(): (train_X, train_Y), (valid_X, valid_Y), (test_X, test_Y) = data.load_mnist('../data') design_matrix = train_X images = design_matrix[0:2500, :] channel_length = 28 * 28 to_visualize = images image_data = tile_raster_images(to_visualize, img_shape=[28,28], tile_shape=[50,50], tile_spacing=(2,2)) im_new = Image.fromarray(numpy.uint8(image_data)) im_new.save('samples_mnist.png') os.system('eog samples_mnist.png')