class GenerateAndSave(TrainExtension):
    """
    Keeps track of what the generator in a (vanilla) GAN returns for a
    particular set of noise values.
    """

    def __init__(self, generator, save_prefix, batch_size=20, grid_shape=(5, 4)):
        assert isinstance(generator, Generator)

        self.batch_sym = T.matrix('generate_batch')
        self.generate_f = theano.function([self.batch_sym],
                                          generator.dropout_fprop(self.batch_sym)[0])

        self.batch = generator.get_noise(batch_size).eval()
        self.save_prefix = save_prefix
        self.patch_viewer = PatchViewer(grid_shape=grid_shape, patch_shape=(32, 32),
                                        is_color=True)

    def on_monitor(self, model, dataset, algorithm):
        samples = self.generate_f(self.batch).swapaxes(0, 3)

        self.patch_viewer.clear()
        for sample in samples:
            self.patch_viewer.add_patch(sample, rescale=True)

        fname = self.save_prefix + '.%05i.png' % model.monitor.get_epochs_seen()
        self.patch_viewer.save(fname)
def plot(w):

    nblocks = int(model.n_g / model.sparse_gmask.bw_g)
    filters_per_block = model.sparse_gmask.bw_g * model.sparse_hmask.bw_h

    block_viewer = PatchViewer((model.sparse_gmask.bw_g, model.sparse_hmask.bw_h),
                               (opts.height, opts.width),
                               is_color = opts.color,
                               pad=(2,2))

    chan_viewer = PatchViewer(get_dims(nblocks),
                              (block_viewer.image.shape[0],
                              block_viewer.image.shape[1]),
                              is_color = opts.color,
                              pad=(5,5))

    main_viewer = PatchViewer(get_dims(nplots),
                              (chan_viewer.image.shape[0],
                               chan_viewer.image.shape[1]),
                              is_color = opts.color,
                              pad=(10,10))

    topo_shape = [opts.height, opts.width, opts.chans]
    view_converter = DefaultViewConverter(topo_shape)

    if opts.splitblocks:
        os.makedirs('filters/')

    for chan_i in xrange(nplots):

        viewer_dims = slice(0, None) if opts.color else chan_i

        for bidx in xrange(nblocks):

            for fidx in xrange(filters_per_block):
                fi = bidx * filters_per_block + fidx
                topo_view = view_converter.design_mat_to_topo_view(w[fi:fi+1,:])
                try:
                    block_viewer.add_patch(topo_view[0,:,:,viewer_dims])
                except:
                    import pdb; pdb.set_trace()

            if opts.splitblocks:
                pl.imshow(block_viewer.image, interpolation='nearest')
                pl.axis('off')
                pl.title('Wv - block %i, chan %i' % (bidx, chan_i))
                pl.savefig('filters/filters_chan%i_block%i.png' % (bidx, chan_i))

            chan_viewer.add_patch(block_viewer.image[:,:,viewer_dims] - 0.5)
            block_viewer.clear()

        main_viewer.add_patch(chan_viewer.image[:,:,viewer_dims] - 0.5)
        chan_viewer.clear()

    return copy.copy(main_viewer.image)
Exemple #3
0
        else: 
            new_w = w_di
    else:
        new_w = numpy.zeros((len(w_di), opts.height * opts.width)) if di else w_di

    for fi in xrange(len(w_di)):

        if opts.k != -1:
            # build "new_w" as a linear combination of "strongest" filters in layer below
            if di > 0:
                temp.fill(0.)
                idx = numpy.argsort(w_di[fi])[-opts.k:]
                for fi_m1 in idx:
                    new_w[fi:fi+1] += w_di[fi, fi_m1] * prev_w[fi_m1:fi_m1+1,:]
                #for fi_m1 in xrange(len(w_di[fi])):
            else:
                temp = w_di[fi:fi+1,:]

        topo_view = view_converter.design_mat_to_topo_view(new_w[fi:fi+1])
        block_viewer.add_patch(topo_view[0])

    main_viewer.add_patch(block_viewer.image[:,:,0] - 0.5)
    block_viewer.clear()
    
    prev_w = new_w

pl.imshow(main_viewer.image, interpolation=None)
pl.axis('off');
pl.savefig('weights.png')
pl.show()
Exemple #4
0
    # positive weights
    posw = copy.copy(W[1])
    posw[posw < 0] = 0.
    probs = posw / numpy.max(posw, axis=1)[:, None]
    r = numpy.random.random(probs.shape)
    pos_w1 = numpy.dot(posw * (probs > r), W[0])

    # negative weights
    negw = copy.copy(-W[1])
    negw[negw < 0] = 0.
    probs = negw / numpy.max(negw, axis=1)[:, None]
    r = numpy.random.random(probs.shape)
    neg_w1 = numpy.dot(-negw * (probs > r), W[0])

    block_viewer = make_viewer(pos_w1 + neg_w1, get_dims(max_filters),
                               (opts.height, opts.width))

    main_viewer.add_patch(block_viewer.image[:, :, 0] - 0.5)

    pl.imshow(main_viewer.image, interpolation=None)
    pl.axis('off')
    pl.savefig('weights.png')
    if k == 0:
        pl.show()
    else:
        pl.draw()

    block_viewer.clear()
    main_viewer.clear()
Exemple #5
0
    posw[posw < 0] = 0.
    probs = posw / numpy.max(posw, axis=1)[:, None]
    r = numpy.random.random(probs.shape)
    pos_w1 = numpy.dot(posw * (probs > r), W[0])

    # negative weights
    negw = copy.copy(-W[1])
    negw[negw < 0] = 0.
    probs = negw / numpy.max(negw, axis=1)[:, None]
    r = numpy.random.random(probs.shape)
    neg_w1 = numpy.dot(-negw * (probs > r), W[0])

    block_viewer = make_viewer(
            pos_w1 + neg_w1,
            get_dims(max_filters),
            (opts.height, opts.width))

    main_viewer.add_patch(block_viewer.image[:,:,0] - 0.5)


    pl.imshow(main_viewer.image, interpolation=None)
    pl.axis('off');
    pl.savefig('weights.png')
    if k == 0:
        pl.show()
    else:
        pl.draw()

    block_viewer.clear()
    main_viewer.clear()