def plotImageData(data, imgShape, saveDir = None, prefix = 'imgdata', tileShape = (20,30), show = False, onlyRescaled = False):
    isColor = (len(imgShape) > 2)
    if not onlyRescaled:
        image = Image.fromarray(tile_raster_images(
            X = data.T, img_shape = imgShape,
            tile_shape = tileShape, tile_spacing=(1,1),
            scale_rows_to_unit_interval = False))
        if saveDir:
            image.save(os.path.join(saveDir, '%s.png' % prefix))
        if show:
            image.show()
    image = Image.fromarray(tile_raster_images(
        X = data.T, img_shape = imgShape,
        tile_shape = tileShape, tile_spacing=(1,1),
        scale_rows_to_unit_interval = True,
        scale_colors_together = True))
    if saveDir:
        image.save(os.path.join(saveDir, '%s_rescale.png' % prefix))
    if show:
        image.show()
    if isColor:
        image = Image.fromarray(tile_raster_images(
            X = data.T, img_shape = imgShape,
            tile_shape = tileShape, tile_spacing=(1,1),
            scale_rows_to_unit_interval = True,
            scale_colors_together = False))
        if saveDir:
            image.save(os.path.join(saveDir, '%s_rescale_indiv.png' % prefix))
        if show:
            image.show()
def demo():
    random.seed(0)
    Nw = 15
    xx, yy = randomCircleSampleMatrix(Nw=Nw, Nsamples=100)

    saveImage = False
    if saveImage:
        image_data = numpy.ones(
            ((Nw + 1) * 10 - 1,
             (Nw + 1) * 10 - 1), dtype='uint8') * 51  # dark gray
        tmp = tile_raster_images(X=xx,
                                 img_shape=(Nw, Nw),
                                 tile_shape=(10, 10),
                                 tile_spacing=(1, 1))
        image_data = tmp
        image = Image.fromarray(image_data)
        image.save(os.path.join('demoCircleSamples.png'))
    else:
        pyplot.figure()
        for ii in range(25):
            ax = pyplot.subplot(5, 5, ii)
            imagesc(xx[ii, :].reshape((Nw, Nw)), ax=ax)
            #pyplot.title(repr(yy[ii,:]))
            pyplot.title('%d, %d, %.2f' % tuple(yy[ii, :]))
        pyplot.show()
def plotTopActivations(activations,
                       data,
                       imgShape,
                       saveDir=None,
                       nActivations=50,
                       nSamples=20,
                       prefix='topact',
                       show=False):
    '''Plots top and bottom few activations for the first number activations.'''

    sortIdx = argsort(activations, 1)

    nActivations = min(nActivations, activations.shape[0])

    plotData = zeros((prod(imgShape), nActivations * nSamples))

    for ii in range(nActivations):
        idx = sortIdx[ii, -1:-(nSamples + 1):-1]
        plotData[:, (ii * nSamples):((ii + 1) * nSamples)] = data[:, idx]

    image = Image.fromarray(
        tile_raster_images(X=plotData.T,
                           img_shape=imgShape,
                           tile_shape=(nActivations, nSamples),
                           tile_spacing=(1, 1),
                           scale_rows_to_unit_interval=True))

    if saveDir:
        image.save(os.path.join(saveDir, '%s.png' % prefix))
    if show:
        image.show()
예제 #4
0
def makeData():
    for Nw in (2, 4, 10, 16, 28, 50):
        #for Nsamples in (50, 500, 5000):
        for Nsamples in (50, 500, 5000, 50000):
            random.seed(0)
            for string in [
                    'train_sc_p5', 'test_sc_p5', 'train_sc_p1', 'test_sc_p1',
                    'train_4sc_p5', 'test_4sc_p5', 'train_4sc_p1',
                    'test_4sc_p1'
            ]:

                probability = .5 if 'p5' in string else .1
                function = random4SquaresCircles if '4sc' in string else randomSquareCircle

                sampleXAndY = function(Nw, Nsamples, prob=probability)
                saveToFile(
                    '../data/simpleShapes/%s_%02d_%d.pkl.gz' %
                    (string, Nw, Nsamples), sampleXAndY)
                xx, yy = sampleXAndY
                if Nsamples == 5000:
                    image = Image.fromarray(
                        tile_raster_images(X=xx,
                                           img_shape=(Nw, Nw),
                                           tile_shape=(10, 15),
                                           tile_spacing=(1, 1),
                                           scale_rows_to_unit_interval=False))
                    image.save('../data/simpleShapes/%s_%02d.png' %
                               (string, Nw))
def plotImageRicaWW(WW, imgShape, saveDir, tileShape = None, prefix = 'WW'):
    imgIsColor = len(imgShape) > 2
    nOutputs, nInputs = WW.shape
    if tileShape is None: tileShape = getTileShape(nOutputs)

    if saveDir:
        image = Image.fromarray(tile_raster_images(
            X = WW,
            img_shape = imgShape, tile_shape = tileShape,
            tile_spacing=(1,1),
            scale_colors_together = True))
        image.save(os.path.join(saveDir, '%s.png' % prefix))
        if imgIsColor:
            image = Image.fromarray(tile_raster_images(
                X = WW,
                img_shape = imgShape, tile_shape = tileShape,
                tile_spacing=(1,1),
                scale_colors_together = False))
            image.save(os.path.join(saveDir, '%s_rescale_indiv.png' % prefix))
def plotImageData(data,
                  imgShape,
                  saveDir=None,
                  prefix='imgdata',
                  tileShape=(20, 30),
                  show=False,
                  onlyRescaled=False):
    isColor = (len(imgShape) > 2)
    if not onlyRescaled:
        image = Image.fromarray(
            tile_raster_images(X=data.T,
                               img_shape=imgShape,
                               tile_shape=tileShape,
                               tile_spacing=(1, 1),
                               scale_rows_to_unit_interval=False))
        if saveDir:
            image.save(os.path.join(saveDir, '%s.png' % prefix))
        if show:
            image.show()
    image = Image.fromarray(
        tile_raster_images(X=data.T,
                           img_shape=imgShape,
                           tile_shape=tileShape,
                           tile_spacing=(1, 1),
                           scale_rows_to_unit_interval=True,
                           scale_colors_together=True))
    if saveDir:
        image.save(os.path.join(saveDir, '%s_rescale.png' % prefix))
    if show:
        image.show()
    if isColor:
        image = Image.fromarray(
            tile_raster_images(X=data.T,
                               img_shape=imgShape,
                               tile_shape=tileShape,
                               tile_spacing=(1, 1),
                               scale_rows_to_unit_interval=True,
                               scale_colors_together=False))
        if saveDir:
            image.save(os.path.join(saveDir, '%s_rescale_indiv.png' % prefix))
        if show:
            image.show()
def plotImageRicaWW(WW, imgShape, saveDir, tileShape=None, prefix='WW'):
    imgIsColor = len(imgShape) > 2
    nOutputs, nInputs = WW.shape
    if tileShape is None: tileShape = getTileShape(nOutputs)

    if saveDir:
        image = Image.fromarray(
            tile_raster_images(X=WW,
                               img_shape=imgShape,
                               tile_shape=tileShape,
                               tile_spacing=(1, 1),
                               scale_colors_together=True))
        image.save(os.path.join(saveDir, '%s.png' % prefix))
        if imgIsColor:
            image = Image.fromarray(
                tile_raster_images(X=WW,
                                   img_shape=imgShape,
                                   tile_shape=tileShape,
                                   tile_spacing=(1, 1),
                                   scale_colors_together=False))
            image.save(os.path.join(saveDir, '%s_rescale_indiv.png' % prefix))
def demo():
    random.seed(0)
    Nw = 2
    xx, yy = randomSquareCircle(Nw = Nw, Nsamples = 150)

    image = Image.fromarray(tile_raster_images(
        X = xx, img_shape = (Nw,Nw),
        tile_shape = (10, 15), tile_spacing=(1,1),
        scale_rows_to_unit_interval = False))
    image.save('demo.png')
    print 'saved as demo.png'
    image.show()
예제 #9
0
def demo():
    random.seed(0)
    Nw = 2
    xx, yy = randomSquareCircle(Nw=Nw, Nsamples=150)

    image = Image.fromarray(
        tile_raster_images(X=xx,
                           img_shape=(Nw, Nw),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1),
                           scale_rows_to_unit_interval=False))
    image.save('demo.png')
    print 'saved as demo.png'
    image.show()
def makeData():
    for Nw in (2, 4, 10, 16, 28, 50):
        #for Nsamples in (50, 500, 5000):
        for Nsamples in (50, 500, 5000, 50000):
            random.seed(0)
            for string in ['train_sc_p5', 'test_sc_p5', 'train_sc_p1', 'test_sc_p1',
                           'train_4sc_p5', 'test_4sc_p5', 'train_4sc_p1', 'test_4sc_p1']:

                probability = .5 if 'p5' in string else .1
                function = random4SquaresCircles if '4sc' in string else randomSquareCircle

                sampleXAndY = function(Nw, Nsamples, prob = probability)
                saveToFile('../data/simpleShapes/%s_%02d_%d.pkl.gz' % (string, Nw, Nsamples), sampleXAndY)
                xx, yy = sampleXAndY
                if Nsamples == 5000:
                    image = Image.fromarray(tile_raster_images(
                        X = xx, img_shape = (Nw,Nw),
                        tile_shape = (10, 15), tile_spacing=(1,1),
                        scale_rows_to_unit_interval = False))
                    image.save('../data/simpleShapes/%s_%02d.png' % (string, Nw))
def plotTopActivations(activations, data, imgShape, saveDir = None, nActivations = 50, nSamples = 20, prefix = 'topact', show = False):
    '''Plots top and bottom few activations for the first number activations.'''

    sortIdx = argsort(activations, 1)

    nActivations = min(nActivations, activations.shape[0])

    plotData = zeros((prod(imgShape), nActivations*nSamples))

    for ii in range(nActivations):
        idx = sortIdx[ii,-1:-(nSamples+1):-1]
        plotData[:,(ii*nSamples):((ii+1)*nSamples)] = data[:,idx]

    image = Image.fromarray(tile_raster_images(
        X = plotData.T, img_shape = imgShape,
        tile_shape = (nActivations, nSamples), tile_spacing=(1,1),
        scale_rows_to_unit_interval = True))

    if saveDir:
        image.save(os.path.join(saveDir, '%s.png' % prefix))
    if show:
        image.show()
def demo():
    random.seed(0)
    Nw = 15
    xx, yy = randomCircleSampleMatrix(Nw = Nw, Nsamples = 100)


    saveImage = False
    if saveImage:
        image_data = numpy.ones(((Nw+1)*10-1,(Nw+1)*10-1), dtype='uint8') * 51  # dark gray
        tmp = tile_raster_images(X = xx,
                                 img_shape = (Nw,Nw),
                                 tile_shape = (10,10),
                                 tile_spacing = (1,1))
        image_data = tmp
        image = Image.fromarray(image_data)
        image.save(os.path.join('demoCircleSamples.png'))
    else:
        pyplot.figure()
        for ii in range(25):
            ax = pyplot.subplot(5,5,ii)
            imagesc(xx[ii,:].reshape((Nw,Nw)), ax=ax)
            #pyplot.title(repr(yy[ii,:]))
            pyplot.title('%d, %d, %.2f' % tuple(yy[ii,:]))
        pyplot.show()
def test_rbm(learning_rate=0.1, training_epochs = 15,
             datasets = None, batch_size = 20,
             n_chains = 20, n_samples = 14, output_dir = 'rbm_plots',
             img_dim = 28, n_input = None, n_hidden = 500, quickHack = False,
             visibleModel = 'binary', initWfactor = 1.0,
             imgPlotFunction = None):
    '''
    Demonstrate how to train an RBM.

    This is demonstrated on MNIST.

    :param learning_rate: learning rate used for training the RBM

    :param training_epochs: number of epochs used for training

    :param dataset: path the the pickled dataset

    :param batch_size: size of a batch used to train the RBM

    :param n_chains: number of parallel Gibbs chains to be used for sampling

    :param n_samples: number of samples to plot for each chain


    :param visibleModel: 'real' or 'binary'

    :param initWfactor: Typicaly 1 for binary or .01 for real

    XXX:param pcaDims: None to skip PCA or >0 to use PCA to reduce dimensionality of data first.

    '''

    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x,  test_set_y  = datasets[2]

    if quickHack:
        train_set_x = train_set_x[:2500,:]
        if train_set_y is not None:
            train_set_y = train_set_y[:2500]

    print ('(%d, %d, %d) %d dimensional examples in (train, valid, test)' % 
           (train_set_x.shape[0], valid_set_x.shape[0], test_set_x.shape[0], train_set_x.shape[1]))

    
    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.shape[0] / batch_size
    print 'n_train_batches is', n_train_batches

    rng        = numpy.random.RandomState(1)

    if n_input is None:
        n_input = train_set_x.shape[1]

    # construct the RBM class
    rbm = RBM(nVisible=n_input, nHidden = n_hidden, numpyRng = rng,
              visibleModel = visibleModel, initWfactor = initWfactor)


    #################################
    #     Training the RBM          #
    #################################

    print 'starting training.'

    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    plotting_time = 0.
    start_time = time.clock()

    # go through training epochs
    meanCosts = []
    ii = -1
    metrics = array([])
    plotEvery = 100
    for epoch in xrange(training_epochs):
        # go through the training set
        for batch_index in xrange(n_train_batches):
            #print 'about to train using exemplars %d to %d.' % (batch_index*batch_size, (batch_index+1)*batch_size)

            ii += 1
            if ii % plotEvery == 0:
                plotWeights = '%03i_%05i' % (epoch, batch_index)
                calcMetrics = True
            else:
                plotWeights = False
                calcMetrics = False

            # metric is xEntropyCost, reconError
            metric = rbm.train(train_set_x[batch_index*batch_size:(batch_index+1)*batch_size],
                               lr = learning_rate, metrics = calcMetrics, plotWeights = plotWeights,
                               output_dir = output_dir)

            if calcMetrics:
                if len(metrics) == 0:
                    metrics = array([metric])
                else:
                    metrics = vstack((metrics, metric))

            if ii % plotEvery == 0:
                # Plot filters after each single step
                plotting_start = time.clock()
                # Construct image from the weight matrix
                image = Image.fromarray(tile_raster_images(
                         X = imgPlotFunction(rbm.W.T) if imgPlotFunction else rbm.W.T,
                         img_shape = (img_dim,img_dim),tile_shape = (10,10),
                         tile_spacing=(1,1)))
                image.save(os.path.join(output_dir, 'filters_at_epoch_batch_%03i_%05i.png' % (epoch, batch_index)))
                plotting_stop = time.clock()
                plotting_time += (plotting_stop - plotting_start)

                #print '  Training epoch %d batch %d, xEntropyCost is ' % (epoch, batch_index), numpy.mean(mean_cost),
                print '  Training epoch %d batch %d, xEntropyCost is ' % (epoch, batch_index), metrics[-1,0],
                print '\trecon error ', metrics[-1,1]

        thisEpochStart =  epoch   *n_train_batches/plotEvery
        thisEpochEnd   = (epoch+1)*n_train_batches/plotEvery
        epochMeanXEnt  = mean(metrics[thisEpochStart:thisEpochEnd,0])
        epochMeanRecon = mean(metrics[thisEpochStart:thisEpochEnd,1])
        print 'Training epoch %d mean xEntropyCost is ' % (epoch), epochMeanXEnt, '\trecon error ', epochMeanRecon

        meanCosts.append(epochMeanXEnt)

        # Plot filters after each training epoch
        plotting_start = time.clock()
        # Construct image from the weight matrix
        image = Image.fromarray(tile_raster_images(
                 X = imgPlotFunction(rbm.W.T) if imgPlotFunction else rbm.W.T,
                 img_shape = (img_dim,img_dim),tile_shape = (10,10),
                 tile_spacing=(1,1)))
        image.save(os.path.join(output_dir, 'filters_at_epoch_%03i.png' % epoch))
        plotting_stop = time.clock()
        plotting_time += (plotting_stop - plotting_start)

    plotting_start = time.clock()
    pyplot.plot(metrics)
    pyplot.savefig(os.path.join(output_dir, 'reconErr.png'))
    plotting_time += (time.clock() - plotting_start)
    
    end_time = time.clock()

    pretraining_time = (end_time - start_time) - plotting_time

    print ('Training took %f minutes' %(pretraining_time/60.))
    print ('Plotting took %f minutes' %(plotting_time/60.))


    #################################
    #   Plot some samples from RBM  #
    #################################


    # find out the number of test samples
    number_of_test_samples = test_set_x.shape[0]

    plot_every = 1

    # if imgPlotFunction is defined, then also plot before function if
    # the data is of the same dimension (e.g. for ZCA, but not for
    # PCA).
    plotRawAlso = (imgPlotFunction and train_set_x.shape[0] == img_dim * img_dim)
        
    # create a space to store the image for plotting ( we need to leave
    # room for the tile_spacing as well)
    image_data = numpy.ones(((img_dim+1)*n_samples-1,(img_dim+1)*n_chains-1), dtype='uint8') * 51  # dark gray
    if plotRawAlso:
        image_data_raw = numpy.ones(((img_dim+1)*n_samples-1,(img_dim+1)*n_chains-1), dtype='uint8') * 51  # dark gray
    
    for ii in xrange(n_chains):
        # generate `plot_every` intermediate samples that we discard, because successive samples in the chain are too correlated
        test_idx = rng.randint(number_of_test_samples)
        
        samples = numpy.zeros((n_chains, n_input))

        visMean = test_set_x[test_idx,:]
        visSample = visMean
        for jj in xrange(n_samples):
            samples[jj,:] = visMean # show the mean, but use the sample for gibbs steps
            if jj == n_samples-1: break  # skip the last for speed
            plot_every = 2**jj  # exponentially increasing number of gibbs samples. max for n_samples=14 is 2^12
            for ss in xrange(plot_every):
                visMean, visSample = rbm.gibbs_vhv(visSample)[4:6]   # 4 for mean, 5 for sample

        print ' ... plotting sample ', ii
        image_data[:,(img_dim+1)*ii:(img_dim+1)*ii+img_dim] = tile_raster_images(
                X = imgPlotFunction(samples) if imgPlotFunction else samples,
                img_shape = (img_dim,img_dim),
                tile_shape = (n_samples, 1),
                tile_spacing = (1,1))
        if plotRawAlso:
            image_data_raw[:,(img_dim+1)*ii:(img_dim+1)*ii+img_dim] = tile_raster_images(
                    X = samples,
                    img_shape = (img_dim,img_dim),
                    tile_shape = (n_samples, 1),
                    tile_spacing = (1,1))

    image = Image.fromarray(image_data)
    image.save(os.path.join(output_dir, 'samples.png'))
    if plotRawAlso:
        image = Image.fromarray(image_data)
        image.save(os.path.join(output_dir, 'samplesRaw.png'))
    
    saveToFile(os.path.join(output_dir, 'rbm.pkl.gz'), rbm)
    
    return rbm, meanCosts
def plotRicaReconstructions(rica,
                            data,
                            imgShape,
                            saveDir=None,
                            unwhitener=None,
                            tileShape=None,
                            number=50,
                            prefix='recon',
                            onlyHilights=False,
                            hilightCmap=None):
    '''Plots reconstructions for some randomly chosen data points.'''

    if saveDir:
        print 'Plotting %d recon plots...' % number,
        sys.stdout.flush()
        imgIsColor = len(imgShape) > 2
        nOutputs, nInputs = rica.WW.shape
        if tileShape is None: tileShape = getTileShape(nOutputs)
        tileRescaleFactor = 2
        reconRescaleFactor = 3

        font = ImageFont.load_default()

        hidden = dot(rica.WW, data[:, :number])
        reconstruction = dot(rica.WW.T, hidden)

        if unwhitener:
            #pdb.set_trace() DEBUG?
            dataOrig = unwhitener(data[:, :number])
            reconstructionOrig = unwhitener(reconstruction[:, :number])
        for ii in xrange(number):
            # Hilighted tiled image
            hilightAmount = abs(hidden[:, ii])
            maxHilight = hilightAmount.max()
            #hilightAmount -= hilightAmount.min()   # Don't push to 0
            hilightAmount /= maxHilight + 1e-6

            if hilightCmap:
                cmap = cm.get_cmap(hilightCmap)
                hilights = cmap(hilightAmount)[:, :3]  # chop off alpha channel
            else:
                # default black -> red colormap
                hilights = outer(hilightAmount, array([1, 0, 0]))

            tileImg = Image.fromarray(
                tile_raster_images(X=rica.WW,
                                   img_shape=imgShape,
                                   tile_shape=tileShape,
                                   tile_spacing=(2, 2),
                                   scale_colors_together=True,
                                   hilights=hilights,
                                   onlyHilights=onlyHilights))
            tileImg = tileImg.resize(
                [x * tileRescaleFactor for x in tileImg.size])

            # Input / Reconstruction image
            if unwhitener:
                rawReconErr = array([
                    dataOrig[:, ii], data[:, ii], reconstruction[:, ii],
                    reconstructionOrig[:, ii],
                    reconstruction[:, ii] - data[:, ii],
                    reconstructionOrig[:, ii] - dataOrig[:, ii]
                ])
                # Scale data-raw and recon-raw together between 0 and 1
                rawReconErr = scale_rows_together_to_unit_interval(
                    rawReconErr, [0, 3], anchor0=False)
                # Scale data-white and recon-white together, map 0 -> 50% gray
                rawReconErr = scale_rows_together_to_unit_interval(
                    rawReconErr, [1, 2], anchor0=True)
                # Scale diffs independently to [0,1]
                rawReconErr = scale_some_rows_to_unit_interval(
                    rawReconErr, [4, 5])
            else:
                rawReconErr = array([
                    data[:, ii], reconstruction[:, ii],
                    reconstruction[:, ii] - data[:, ii]
                ])
                # Scale data-raw and recon-raw together between 0 and 1
                rawReconErr = scale_rows_together_to_unit_interval(
                    rawReconErr, [0, 1], anchor0=False)
                # Scale diffs independently to [0,1]
                rawReconErr = scale_some_rows_to_unit_interval(
                    rawReconErr, [2])
            rawReconErrImg = Image.fromarray(
                tile_raster_images(X=rawReconErr,
                                   img_shape=imgShape,
                                   tile_shape=(rawReconErr.shape[0], 1),
                                   tile_spacing=(1, 1),
                                   scale_rows_to_unit_interval=False))
            rawReconErrImg = rawReconErrImg.resize(
                [x * reconRescaleFactor for x in rawReconErrImg.size])

            # Add Red activation limit
            redString = '%g' % maxHilight
            fontSize = font.font.getsize(redString)
            size = (max(tileImg.size[0],
                        fontSize[0]), tileImg.size[1] + fontSize[1])
            tempImage = Image.new('RGBA', size, (51, 51, 51))
            tempImage.paste(tileImg, (0, 0))
            draw = ImageDraw.Draw(tempImage)
            draw.text(((size[0] - fontSize[0]) / 2, size[1] - fontSize[1]),
                      redString,
                      font=font)
            tileImg = tempImage

            # Combined
            costEtc = rica.cost(rica.WW, data[:, ii:ii + 1])
            costString = rica.getReconPlotString(costEtc)
            fontSize = font.font.getsize(costString)
            size = (max(
                tileImg.size[0] + rawReconErrImg.size[0] + reconRescaleFactor,
                fontSize[0]),
                    max(tileImg.size[1], rawReconErrImg.size[1]) + fontSize[1])
            wholeImage = Image.new('RGBA', size, (51, 51, 51))
            wholeImage.paste(tileImg, (0, 0))
            wholeImage.paste(rawReconErrImg,
                             (tileImg.size[0] + reconRescaleFactor, 0))
            draw = ImageDraw.Draw(wholeImage)
            draw.text(((size[0] - fontSize[0]) / 2, size[1] - fontSize[1]),
                      costString,
                      font=font)
            wholeImage.save(os.path.join(saveDir,
                                         '%s_%04d.png' % (prefix, ii)))

        print 'done.'
def plotRicaReconstructions(rica, data, imgShape, saveDir = None, unwhitener = None, tileShape = None, number = 50, prefix = 'recon', onlyHilights = False, hilightCmap = None):
    '''Plots reconstructions for some randomly chosen data points.'''

    if saveDir:
        print 'Plotting %d recon plots...' % number,
        sys.stdout.flush()
        imgIsColor = len(imgShape) > 2
        nOutputs, nInputs = rica.WW.shape
        if tileShape is None: tileShape = getTileShape(nOutputs)
        tileRescaleFactor  = 2
        reconRescaleFactor = 3
        
        font = ImageFont.load_default()

        hidden = dot(rica.WW, data[:,:number])
        reconstruction = dot(rica.WW.T, hidden)

        if unwhitener:
            #pdb.set_trace() DEBUG?
            dataOrig = unwhitener(data[:,:number])
            reconstructionOrig = unwhitener(reconstruction[:,:number])
        for ii in xrange(number):
            # Hilighted tiled image
            hilightAmount = abs(hidden[:,ii])
            maxHilight = hilightAmount.max()
            #hilightAmount -= hilightAmount.min()   # Don't push to 0
            hilightAmount /= maxHilight + 1e-6

            if hilightCmap:
                cmap = cm.get_cmap(hilightCmap)
                hilights = cmap(hilightAmount)[:,:3]  # chop off alpha channel
            else:
                # default black -> red colormap
                hilights = outer(hilightAmount, array([1,0,0]))
            
            tileImg = Image.fromarray(tile_raster_images(
                X = rica.WW,
                img_shape = imgShape, tile_shape = tileShape,
                tile_spacing=(2,2),
                scale_colors_together = True,
                hilights = hilights,
                onlyHilights = onlyHilights))
            tileImg = tileImg.resize([x*tileRescaleFactor for x in tileImg.size])

            # Input / Reconstruction image
            if unwhitener:
                rawReconErr = array([dataOrig[:,ii], data[:,ii], reconstruction[:,ii], reconstructionOrig[:,ii],
                                     reconstruction[:,ii]-data[:,ii], reconstructionOrig[:,ii]-dataOrig[:,ii]])
                # Scale data-raw and recon-raw together between 0 and 1
                rawReconErr = scale_rows_together_to_unit_interval(rawReconErr, [0, 3], anchor0 = False)
                # Scale data-white and recon-white together, map 0 -> 50% gray
                rawReconErr = scale_rows_together_to_unit_interval(rawReconErr, [1, 2], anchor0 = True)
                # Scale diffs independently to [0,1]
                rawReconErr = scale_some_rows_to_unit_interval(rawReconErr, [4, 5])
            else:
                rawReconErr = array([data[:,ii], reconstruction[:,ii],
                                     reconstruction[:,ii]-data[:,ii]])
                # Scale data-raw and recon-raw together between 0 and 1
                rawReconErr = scale_rows_together_to_unit_interval(rawReconErr, [0, 1], anchor0 = False)
                # Scale diffs independently to [0,1]
                rawReconErr = scale_some_rows_to_unit_interval(rawReconErr, [2])
            rawReconErrImg = Image.fromarray(tile_raster_images(
                X = rawReconErr,
                img_shape = imgShape, tile_shape = (rawReconErr.shape[0], 1),
                tile_spacing=(1,1),
                scale_rows_to_unit_interval = False))
            rawReconErrImg = rawReconErrImg.resize([x*reconRescaleFactor for x in rawReconErrImg.size])

            # Add Red activation limit
            redString = '%g' % maxHilight
            fontSize = font.font.getsize(redString)
            size = (max(tileImg.size[0], fontSize[0]), tileImg.size[1] + fontSize[1])
            tempImage = Image.new('RGBA', size, (51, 51, 51))
            tempImage.paste(tileImg, (0, 0))
            draw = ImageDraw.Draw(tempImage)
            draw.text(((size[0]-fontSize[0])/2, size[1]-fontSize[1]), redString, font=font)
            tileImg = tempImage

            # Combined
            costEtc = rica.cost(rica.WW, data[:,ii:ii+1])
            costString = rica.getReconPlotString(costEtc)
            fontSize = font.font.getsize(costString)
            size = (max(tileImg.size[0] + rawReconErrImg.size[0] + reconRescaleFactor, fontSize[0]),
                    max(tileImg.size[1], rawReconErrImg.size[1]) + fontSize[1])
            wholeImage = Image.new('RGBA', size, (51, 51, 51))
            wholeImage.paste(tileImg, (0, 0))
            wholeImage.paste(rawReconErrImg, (tileImg.size[0] + reconRescaleFactor, 0))
            draw = ImageDraw.Draw(wholeImage)
            draw.text(((size[0]-fontSize[0])/2, size[1]-fontSize[1]), costString, font=font)
            wholeImage.save(os.path.join(saveDir, '%s_%04d.png' % (prefix, ii)))

        print 'done.'
def testIca(datasets, savedir = None, smallImgHack = False, quickHack = False):
    '''Test ICA on a given dataset.'''

    random.seed(1)

    # 0. Get data
    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x,  test_set_y  = datasets[2]

    if quickHack:
        print '!!! Using quickHack !!!'
        train_set_x = train_set_x[:2500,:]
        if train_set_y is not None:
            train_set_y = train_set_y[:2500]
    if smallImgHack:
        print '!!! Using smallImgHack !!! (images will be misaligned)'
        train_set_x = train_set_x[:,:100]

    print ('(%d, %d, %d) %d dimensional examples in (train, valid, test)' % 
           (train_set_x.shape[0], valid_set_x.shape[0], test_set_x.shape[0], train_set_x.shape[1]))

    nDim = train_set_x.shape[1]
    imgDim = int(round(sqrt(nDim)))    # Might not always be true...

    randIdxRaw    = random.randint(0, nDim, 100)
    randIdxWhite  = random.randint(0, nDim, 100)
    randIdxSource = random.randint(0, nDim, 100)

    image = Image.fromarray(tile_raster_images(
             X = train_set_x,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'data_raw.png'))
    image.show()

    pyplot.figure()
    for ii in range(20):
        idx = randIdxRaw[ii]
        pyplot.subplot(4,5,ii+1)
        pyplot.title('raw dim %d' % idx)
        pyplot.hist(train_set_x[:,idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_raw_hist.png'))

    # 1. Whiten data
    print 'Whitening data with pca...'
    pca = PCA(train_set_x)
    xWhite = pca.toZca(train_set_x)
    print '  done.'

    pyplot.figure()
    for ii in range(20):
        idx = randIdxWhite[ii]
        pyplot.subplot(4,5,ii+1)
        pyplot.title('data white dim %d' % idx)
        pyplot.hist(xWhite[:,idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_white_hist.png'))

    image = Image.fromarray(tile_raster_images(
             X = xWhite,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'data_white.png'))
    image.show()

    # 1.1 plot hist
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('data white 20 random dims')
    histMax = 0
    histMin = 1e10
    for ii in range(20):
        idx = randIdxWhite[ii]
        hist, binEdges = histogram(xWhite[:,idx], bins = 20, density = True)
        histMax = max(histMax, max(hist))
        histMin = min(histMin, min(hist[hist != 0]))   # min non-zero entry
        binMiddles = binEdges[:-1] + (binEdges[1] - binEdges[0])/2
        #print ' %d from %f to %f' % (ii, min(binMiddles), max(binMiddles))
        pyplot.semilogy(binMiddles, hist, '.-')
    pyplot.axis('tight')
    ax = looser(pyplot.axis(), semilogy = True)
    xAbsMax = max(fabs(ax[0:2]))
    xx = linspace(-xAbsMax, xAbsMax, 100)
    pyplot.semilogy(xx, mlab.normpdf(xx, 0, 1), 'k', linewidth = 3)
    pyplot.axis((-xAbsMax, xAbsMax, ax[2], ax[3]))
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_white_log_hist.png'))

    # 1.2 plot points
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('data white 20 random dims')
    nSamples = min(xWhite.shape[0], 1000)
    print 'data_white_log_points plotted with', nSamples, 'samples.'
    for ii in range(10):
        idx = randIdxWhite[ii]
        pyplot.plot(xWhite[:nSamples,idx],
                    ii + random.uniform(-.25, .25, nSamples), 'o')
    pyplot.axis('tight')
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_white_log_points.png'))

    # 2. Fit ICA
    rng = random.RandomState(1)
    ica = FastICA(random_state = rng, whiten = False)
    print 'Fitting ICA...'
    ica.fit(xWhite)
    print '  done.'
    if savedir:  saveToFile(os.path.join(savedir, 'ica.pkl.gz'), ica)

    print 'Geting sources and mixing matrix...'
    sourcesWhite = ica.transform(xWhite)  # Estimate the sources
    #S_fica /= S_fica.std(axis=0)   # (should already be done)
    mixingMatrix = ica.get_mixing_matrix()
    print '  done.'

    sources = pca.fromZca(sourcesWhite)
    

    # 3. Show independent components and inferred sources
    image = Image.fromarray(tile_raster_images(
             X = mixingMatrix,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'ic_white.png'))
    image.show()
    image = Image.fromarray(tile_raster_images(
             X = mixingMatrix.T,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'ic_white.T.png'))
    image.show()
    image = Image.fromarray(tile_raster_images(
             X = pca.fromZca(mixingMatrix),
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'ic_raw.png'))
    image.show()
    image = Image.fromarray(tile_raster_images(
             X = pca.fromZca(mixingMatrix.T),
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'ic_raw.T.png'))
    image.show()

    pyplot.figure()
    for ii in range(20):
        idx = randIdxSource[ii]
        pyplot.subplot(4,5,ii+1)
        pyplot.title('sourceWhite %d' % idx)
        pyplot.hist(sourcesWhite[:,idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'sources_white_hist.png'))

    image = Image.fromarray(tile_raster_images(
             X = sourcesWhite,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'sources_white.png'))
    image.show()

    # 3.1 plot hist
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('sources white 20 random dims')
    histMax = 0
    histMin = 1e10
    for ii in range(20):
        idx = randIdxSource[ii]
        hist, binEdges = histogram(sourcesWhite[:,idx], bins = 20, density = True)
        histMax = max(histMax, max(hist))
        histMin = min(histMin, min(hist[hist != 0]))   # min non-zero entry
        binMiddles = binEdges[:-1] + (binEdges[1] - binEdges[0])/2
        #print ' %d from %f to %f' % (ii, min(binMiddles), max(binMiddles))
        pyplot.semilogy(binMiddles, hist, '.-')
    pyplot.axis('tight')
    ax = looser(pyplot.axis(), semilogy = True)
    xAbsMax = max(fabs(ax[0:2]))
    xx = linspace(-xAbsMax, xAbsMax, 100)
    pyplot.semilogy(xx, mlab.normpdf(xx, 0, 1), 'k', linewidth = 3)
    pyplot.axis((-xAbsMax, xAbsMax, ax[2], ax[3]))
    if savedir: pyplot.savefig(os.path.join(savedir, 'sources_white_log_hist.png'))

    # 3.2 plot points
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('sources white 20 random dims')
    nSamples = min(sourcesWhite.shape[0], 1000)
    print 'sources_white_log_points plotted with', nSamples, 'samples.'
    for ii in range(10):
        idx = randIdxWhite[ii]
        pyplot.plot(sourcesWhite[:nSamples,idx],
                    ii + random.uniform(-.25, .25, nSamples), 'o')
    pyplot.axis('tight')
    if savedir: pyplot.savefig(os.path.join(savedir, 'sources_white_log_points.png'))


    image = Image.fromarray(tile_raster_images(
             X = sources,
             img_shape = (imgDim,imgDim), tile_shape = (10,15),
             tile_spacing=(1,1)))
    if savedir:  image.save(os.path.join(savedir, 'sources_raw.png'))
    image.show()


    
    if savedir:
        print 'plots saved in', savedir
    else:
        import ipdb; ipdb.set_trace()
def testIca(datasets, savedir=None, smallImgHack=False, quickHack=False):
    '''Test ICA on a given dataset.'''

    random.seed(1)

    # 0. Get data
    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]

    if quickHack:
        print '!!! Using quickHack !!!'
        train_set_x = train_set_x[:2500, :]
        if train_set_y is not None:
            train_set_y = train_set_y[:2500]
    if smallImgHack:
        print '!!! Using smallImgHack !!! (images will be misaligned)'
        train_set_x = train_set_x[:, :100]

    print('(%d, %d, %d) %d dimensional examples in (train, valid, test)' %
          (train_set_x.shape[0], valid_set_x.shape[0], test_set_x.shape[0],
           train_set_x.shape[1]))

    nDim = train_set_x.shape[1]
    imgDim = int(round(sqrt(nDim)))  # Might not always be true...

    randIdxRaw = random.randint(0, nDim, 100)
    randIdxWhite = random.randint(0, nDim, 100)
    randIdxSource = random.randint(0, nDim, 100)

    image = Image.fromarray(
        tile_raster_images(X=train_set_x,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'data_raw.png'))
    image.show()

    pyplot.figure()
    for ii in range(20):
        idx = randIdxRaw[ii]
        pyplot.subplot(4, 5, ii + 1)
        pyplot.title('raw dim %d' % idx)
        pyplot.hist(train_set_x[:, idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_raw_hist.png'))

    # 1. Whiten data
    print 'Whitening data with pca...'
    pca = PCA(train_set_x)
    xWhite = pca.toZca(train_set_x)
    print '  done.'

    pyplot.figure()
    for ii in range(20):
        idx = randIdxWhite[ii]
        pyplot.subplot(4, 5, ii + 1)
        pyplot.title('data white dim %d' % idx)
        pyplot.hist(xWhite[:, idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'data_white_hist.png'))

    image = Image.fromarray(
        tile_raster_images(X=xWhite,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'data_white.png'))
    image.show()

    # 1.1 plot hist
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('data white 20 random dims')
    histMax = 0
    histMin = 1e10
    for ii in range(20):
        idx = randIdxWhite[ii]
        hist, binEdges = histogram(xWhite[:, idx], bins=20, density=True)
        histMax = max(histMax, max(hist))
        histMin = min(histMin, min(hist[hist != 0]))  # min non-zero entry
        binMiddles = binEdges[:-1] + (binEdges[1] - binEdges[0]) / 2
        #print ' %d from %f to %f' % (ii, min(binMiddles), max(binMiddles))
        pyplot.semilogy(binMiddles, hist, '.-')
    pyplot.axis('tight')
    ax = looser(pyplot.axis(), semilogy=True)
    xAbsMax = max(fabs(ax[0:2]))
    xx = linspace(-xAbsMax, xAbsMax, 100)
    pyplot.semilogy(xx, mlab.normpdf(xx, 0, 1), 'k', linewidth=3)
    pyplot.axis((-xAbsMax, xAbsMax, ax[2], ax[3]))
    if savedir:
        pyplot.savefig(os.path.join(savedir, 'data_white_log_hist.png'))

    # 1.2 plot points
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('data white 20 random dims')
    nSamples = min(xWhite.shape[0], 1000)
    print 'data_white_log_points plotted with', nSamples, 'samples.'
    for ii in range(10):
        idx = randIdxWhite[ii]
        pyplot.plot(xWhite[:nSamples, idx],
                    ii + random.uniform(-.25, .25, nSamples), 'o')
    pyplot.axis('tight')
    if savedir:
        pyplot.savefig(os.path.join(savedir, 'data_white_log_points.png'))

    # 2. Fit ICA
    rng = random.RandomState(1)
    ica = FastICA(random_state=rng, whiten=False)
    print 'Fitting ICA...'
    ica.fit(xWhite)
    print '  done.'
    if savedir: saveToFile(os.path.join(savedir, 'ica.pkl.gz'), ica)

    print 'Geting sources and mixing matrix...'
    sourcesWhite = ica.transform(xWhite)  # Estimate the sources
    #S_fica /= S_fica.std(axis=0)   # (should already be done)
    mixingMatrix = ica.get_mixing_matrix()
    print '  done.'

    sources = pca.fromZca(sourcesWhite)

    # 3. Show independent components and inferred sources
    image = Image.fromarray(
        tile_raster_images(X=mixingMatrix,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'ic_white.png'))
    image.show()
    image = Image.fromarray(
        tile_raster_images(X=mixingMatrix.T,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'ic_white.T.png'))
    image.show()
    image = Image.fromarray(
        tile_raster_images(X=pca.fromZca(mixingMatrix),
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'ic_raw.png'))
    image.show()
    image = Image.fromarray(
        tile_raster_images(X=pca.fromZca(mixingMatrix.T),
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'ic_raw.T.png'))
    image.show()

    pyplot.figure()
    for ii in range(20):
        idx = randIdxSource[ii]
        pyplot.subplot(4, 5, ii + 1)
        pyplot.title('sourceWhite %d' % idx)
        pyplot.hist(sourcesWhite[:, idx])
    if savedir: pyplot.savefig(os.path.join(savedir, 'sources_white_hist.png'))

    image = Image.fromarray(
        tile_raster_images(X=sourcesWhite,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'sources_white.png'))
    image.show()

    # 3.1 plot hist
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('sources white 20 random dims')
    histMax = 0
    histMin = 1e10
    for ii in range(20):
        idx = randIdxSource[ii]
        hist, binEdges = histogram(sourcesWhite[:, idx], bins=20, density=True)
        histMax = max(histMax, max(hist))
        histMin = min(histMin, min(hist[hist != 0]))  # min non-zero entry
        binMiddles = binEdges[:-1] + (binEdges[1] - binEdges[0]) / 2
        #print ' %d from %f to %f' % (ii, min(binMiddles), max(binMiddles))
        pyplot.semilogy(binMiddles, hist, '.-')
    pyplot.axis('tight')
    ax = looser(pyplot.axis(), semilogy=True)
    xAbsMax = max(fabs(ax[0:2]))
    xx = linspace(-xAbsMax, xAbsMax, 100)
    pyplot.semilogy(xx, mlab.normpdf(xx, 0, 1), 'k', linewidth=3)
    pyplot.axis((-xAbsMax, xAbsMax, ax[2], ax[3]))
    if savedir:
        pyplot.savefig(os.path.join(savedir, 'sources_white_log_hist.png'))

    # 3.2 plot points
    pyplot.figure()
    pyplot.hold(True)
    pyplot.title('sources white 20 random dims')
    nSamples = min(sourcesWhite.shape[0], 1000)
    print 'sources_white_log_points plotted with', nSamples, 'samples.'
    for ii in range(10):
        idx = randIdxWhite[ii]
        pyplot.plot(sourcesWhite[:nSamples, idx],
                    ii + random.uniform(-.25, .25, nSamples), 'o')
    pyplot.axis('tight')
    if savedir:
        pyplot.savefig(os.path.join(savedir, 'sources_white_log_points.png'))

    image = Image.fromarray(
        tile_raster_images(X=sources,
                           img_shape=(imgDim, imgDim),
                           tile_shape=(10, 15),
                           tile_spacing=(1, 1)))
    if savedir: image.save(os.path.join(savedir, 'sources_raw.png'))
    image.show()

    if savedir:
        print 'plots saved in', savedir
    else:
        import ipdb
        ipdb.set_trace()
예제 #18
0
def test_rbm(learning_rate=0.1,
             training_epochs=15,
             datasets=None,
             batch_size=20,
             n_chains=20,
             n_samples=14,
             output_dir='rbm_plots',
             img_dim=28,
             n_input=None,
             n_hidden=500,
             quickHack=False,
             visibleModel='binary',
             initWfactor=1.0,
             imgPlotFunction=None):
    '''
    Demonstrate how to train an RBM.

    This is demonstrated on MNIST.

    :param learning_rate: learning rate used for training the RBM

    :param training_epochs: number of epochs used for training

    :param dataset: path the the pickled dataset

    :param batch_size: size of a batch used to train the RBM

    :param n_chains: number of parallel Gibbs chains to be used for sampling

    :param n_samples: number of samples to plot for each chain


    :param visibleModel: 'real' or 'binary'

    :param initWfactor: Typicaly 1 for binary or .01 for real

    XXX:param pcaDims: None to skip PCA or >0 to use PCA to reduce dimensionality of data first.

    '''

    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]

    if quickHack:
        train_set_x = train_set_x[:2500, :]
        if train_set_y is not None:
            train_set_y = train_set_y[:2500]

    print('(%d, %d, %d) %d dimensional examples in (train, valid, test)' %
          (train_set_x.shape[0], valid_set_x.shape[0], test_set_x.shape[0],
           train_set_x.shape[1]))

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.shape[0] / batch_size
    print 'n_train_batches is', n_train_batches

    rng = numpy.random.RandomState(1)

    if n_input is None:
        n_input = train_set_x.shape[1]

    # construct the RBM class
    rbm = RBM(nVisible=n_input,
              nHidden=n_hidden,
              numpyRng=rng,
              visibleModel=visibleModel,
              initWfactor=initWfactor)

    #################################
    #     Training the RBM          #
    #################################

    print 'starting training.'

    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    plotting_time = 0.
    start_time = time.clock()

    # go through training epochs
    meanCosts = []
    ii = -1
    metrics = array([])
    plotEvery = 100
    for epoch in xrange(training_epochs):
        # go through the training set
        for batch_index in xrange(n_train_batches):
            #print 'about to train using exemplars %d to %d.' % (batch_index*batch_size, (batch_index+1)*batch_size)

            ii += 1
            if ii % plotEvery == 0:
                plotWeights = '%03i_%05i' % (epoch, batch_index)
                calcMetrics = True
            else:
                plotWeights = False
                calcMetrics = False

            # metric is xEntropyCost, reconError
            metric = rbm.train(
                train_set_x[batch_index * batch_size:(batch_index + 1) *
                            batch_size],
                lr=learning_rate,
                metrics=calcMetrics,
                plotWeights=plotWeights,
                output_dir=output_dir)

            if calcMetrics:
                if len(metrics) == 0:
                    metrics = array([metric])
                else:
                    metrics = vstack((metrics, metric))

            if ii % plotEvery == 0:
                # Plot filters after each single step
                plotting_start = time.clock()
                # Construct image from the weight matrix
                image = Image.fromarray(
                    tile_raster_images(X=imgPlotFunction(rbm.W.T)
                                       if imgPlotFunction else rbm.W.T,
                                       img_shape=(img_dim, img_dim),
                                       tile_shape=(10, 10),
                                       tile_spacing=(1, 1)))
                image.save(
                    os.path.join(
                        output_dir, 'filters_at_epoch_batch_%03i_%05i.png' %
                        (epoch, batch_index)))
                plotting_stop = time.clock()
                plotting_time += (plotting_stop - plotting_start)

                #print '  Training epoch %d batch %d, xEntropyCost is ' % (epoch, batch_index), numpy.mean(mean_cost),
                print '  Training epoch %d batch %d, xEntropyCost is ' % (
                    epoch, batch_index), metrics[-1, 0],
                print '\trecon error ', metrics[-1, 1]

        thisEpochStart = epoch * n_train_batches / plotEvery
        thisEpochEnd = (epoch + 1) * n_train_batches / plotEvery
        epochMeanXEnt = mean(metrics[thisEpochStart:thisEpochEnd, 0])
        epochMeanRecon = mean(metrics[thisEpochStart:thisEpochEnd, 1])
        print 'Training epoch %d mean xEntropyCost is ' % (
            epoch), epochMeanXEnt, '\trecon error ', epochMeanRecon

        meanCosts.append(epochMeanXEnt)

        # Plot filters after each training epoch
        plotting_start = time.clock()
        # Construct image from the weight matrix
        image = Image.fromarray(
            tile_raster_images(
                X=imgPlotFunction(rbm.W.T) if imgPlotFunction else rbm.W.T,
                img_shape=(img_dim, img_dim),
                tile_shape=(10, 10),
                tile_spacing=(1, 1)))
        image.save(
            os.path.join(output_dir, 'filters_at_epoch_%03i.png' % epoch))
        plotting_stop = time.clock()
        plotting_time += (plotting_stop - plotting_start)

    plotting_start = time.clock()
    pyplot.plot(metrics)
    pyplot.savefig(os.path.join(output_dir, 'reconErr.png'))
    plotting_time += (time.clock() - plotting_start)

    end_time = time.clock()

    pretraining_time = (end_time - start_time) - plotting_time

    print('Training took %f minutes' % (pretraining_time / 60.))
    print('Plotting took %f minutes' % (plotting_time / 60.))

    #################################
    #   Plot some samples from RBM  #
    #################################

    # find out the number of test samples
    number_of_test_samples = test_set_x.shape[0]

    plot_every = 1

    # if imgPlotFunction is defined, then also plot before function if
    # the data is of the same dimension (e.g. for ZCA, but not for
    # PCA).
    plotRawAlso = (imgPlotFunction
                   and train_set_x.shape[0] == img_dim * img_dim)

    # create a space to store the image for plotting ( we need to leave
    # room for the tile_spacing as well)
    image_data = numpy.ones(
        ((img_dim + 1) * n_samples - 1, (img_dim + 1) * n_chains - 1),
        dtype='uint8') * 51  # dark gray
    if plotRawAlso:
        image_data_raw = numpy.ones(
            ((img_dim + 1) * n_samples - 1, (img_dim + 1) * n_chains - 1),
            dtype='uint8') * 51  # dark gray

    for ii in xrange(n_chains):
        # generate `plot_every` intermediate samples that we discard, because successive samples in the chain are too correlated
        test_idx = rng.randint(number_of_test_samples)

        samples = numpy.zeros((n_chains, n_input))

        visMean = test_set_x[test_idx, :]
        visSample = visMean
        for jj in xrange(n_samples):
            samples[
                jj, :] = visMean  # show the mean, but use the sample for gibbs steps
            if jj == n_samples - 1: break  # skip the last for speed
            plot_every = 2**jj  # exponentially increasing number of gibbs samples. max for n_samples=14 is 2^12
            for ss in xrange(plot_every):
                visMean, visSample = rbm.gibbs_vhv(visSample)[
                    4:6]  # 4 for mean, 5 for sample

        print ' ... plotting sample ', ii
        image_data[:, (img_dim + 1) * ii:(img_dim + 1) * ii +
                   img_dim] = tile_raster_images(X=imgPlotFunction(samples) if
                                                 imgPlotFunction else samples,
                                                 img_shape=(img_dim, img_dim),
                                                 tile_shape=(n_samples, 1),
                                                 tile_spacing=(1, 1))
        if plotRawAlso:
            image_data_raw[:, (img_dim + 1) * ii:(img_dim + 1) * ii +
                           img_dim] = tile_raster_images(X=samples,
                                                         img_shape=(img_dim,
                                                                    img_dim),
                                                         tile_shape=(n_samples,
                                                                     1),
                                                         tile_spacing=(1, 1))

    image = Image.fromarray(image_data)
    image.save(os.path.join(output_dir, 'samples.png'))
    if plotRawAlso:
        image = Image.fromarray(image_data)
        image.save(os.path.join(output_dir, 'samplesRaw.png'))

    saveToFile(os.path.join(output_dir, 'rbm.pkl.gz'), rbm)

    return rbm, meanCosts