Example #1
0
    def on_monitor(self, model, dataset, algorithm):

        if self.rescale == 'none':
            global_rescale = False
            patch_rescale = False
        elif self.rescale == 'global':
            global_rescale = True
            patch_rescale = False
        elif self.rescale == 'individual':
            global_rescale = False
            patch_rescale = True
        else:
            assert False

        # implement saving from show_examples
        examples = dataset.get_batch_topo(self.rows*self.cols)
        examples = dataset.adjust_for_viewer(examples)

        if global_rescale:
            examples /= np.abs(examples).max()

        if examples.shape[3] == 1:
            is_color = False
        elif examples.shape[3] == 3:
            is_color = True

        pv = patch_viewer.PatchViewer( (self.rows, self.cols), examples.shape[1:3], is_color=is_color)
        for i in xrange(self.rows*self.cols):
            pv.add_patch(examples[i,:,:,:], activation = 0.0, rescale = patch_rescale)

#        pv.show()

        pv.save(self.save_path % len(model.monitor.channels[model.monitor.channels.keys()[0]].val_record))
        print "dataset image saved"
Example #2
0
def show_facetube(x):
    # implement saving from show_examples
    examples = x
    rows = int(np.sqrt(examples.shape[0]))
    cols = int(examples.shape[0] / rows + 1)

    examples /= np.abs(examples).max()

    if examples.shape[3] == 1:
        is_color = False
    elif examples.shape[3] == 3:
        is_color = True
    else:
        print examples.shape[3], "?"

    #take 36 images
    if rows >= 6:
        rows = 6
        cols = 6
        idx = np.arange(0.0, examples.shape[0], examples.shape[0] / (0. + 36))
        idx = [int(id) for id in idx]
    else:
        idx = np.arange(examples.shape[0])

    pv = patch_viewer.PatchViewer((rows, cols),
                                  examples.shape[1:3],
                                  is_color=is_color)
    for i in idx:
        pv.add_patch(examples[i, :, :, :], activation=0.0, rescale=False)

    pv.show()
Example #3
0
    def on_monitor(self, model, dataset, algorithm):

        self.count += 1
        if self.count % self.freq != 0:
            return None

        import matplotlib.pyplot as plt

        if self.rescale == 'none':
            global_rescale = False
            patch_rescale = False
        elif self.rescale == 'global':
            global_rescale = True
            patch_rescale = False
        elif self.rescale == 'individual':
            global_rescale = False
            patch_rescale = True
        else:
            assert False

        # implement saving from show_examples
        examples = dataset.get_batch_design(self.rows*self.cols)
        X = examples
        input_space = dataset.transformer.input_space
        if input_space:
            # needs reshaping
            if len(input_space.axes) != len(X.shape):
                X = X.reshape([X.shape[0]]+list(input_space.shape)+[input_space.num_channels])

                # dimension transposition
            else:
                # How can we detect axes of X?
                X = input_space.convert_numpy(X,input_space.axes,default)

        examples = X#.astype('float32')

        if global_rescale:
            examples /= np.abs(examples).max()

        if examples.shape[3] == 1:
            is_color = False
        elif examples.shape[3] == 3:
            is_color = True

        pv = patch_viewer.PatchViewer( (self.rows, self.cols), examples.shape[1:3], is_color=is_color)
        for i in xrange(self.rows*self.cols):
            pv.add_patch(examples[i,:,:,:], activation = 0.0, rescale = patch_rescale)

        pv.save(self.save_path % self.count)

        print "dataset image saved"
Example #4
0
def show_random_examples():
    example_cnt = test.X.shape[0]
    pv = patch_viewer.PatchViewer(
        grid_shape=(1, test.clip_shape_ds[0]),
        patch_shape=[test.clip_shape_ds[1], test.clip_shape_ds[2]])
    while True:
        i = np.random.randint(example_cnt)
        print 'i =', i
        print 'rain_bits =', test.y[i]
        print 'pred_bits =', test.y_pred[i]
        track_frames = test.X[i].reshape((3, 12, 12))
        print 'track_frames[-1].sum() =', track_frames[-1].sum()
        for j in range(track_frames.shape[0]):
            pv.add_patch(track_frames[j].astype('float32'), activation=0)
        pv.show()
Example #5
0
def get_weights_report(model_path=None,
                       model=None,
                       rescale='individual',
                       border=False,
                       norm_sort=False,
                       dataset=None):
    """
    Returns a PatchViewer displaying a grid of filter weights

    Parameters
    ----------
    model_path : str
        Filepath of the model to make the report on.
    rescale : str
        A string specifying how to rescale the filter images:
            - 'individual' (default) : scale each filter so that it
                  uses as much as possible of the dynamic range
                  of the display under the constraint that 0
                  is gray and no value gets clipped
            - 'global' : scale the whole ensemble of weights
            - 'none' :   don't rescale
    dataset : pylearn2.datasets.dataset.Dataset
        Dataset object to do view conversion for displaying the weights. If
        not provided one will be loaded from the model's dataset_yaml_src.

    Returns
    -------
    WRITEME
    """

    if model is None:
        logger.info('making weights report')
        logger.info('loading model')
        model = serial.load(model_path)
        logger.info('loading done')
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale=' + rescale +
                         ", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        keys = [key for key in model \
                if hasattr(model[key], 'ndim') and model[key].ndim == 2]
        if len(keys) > 2:
            key = None
            while key not in keys:
                logger.info('Which is the weights?')
                for key in keys:
                    logger.info('\t{0}'.format(key))
                key = input()
        else:
            key, = keys
        weights = model[key]

        norms = np.sqrt(np.square(weights).sum(axis=1))
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

        return patch_viewer.make_viewer(weights,
                                        is_color=weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            logger.info('loading dataset...')
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            logger.info('...done')

        try:
            W = model.get_weights()
        except AttributeError as e:
            reraise_as(AttributeError("""
Encountered an AttributeError while trying to call get_weights on a model.
This probably means you need to implement get_weights for this model class,
but look at the original exception to be sure.
If this is an older model class, it may have weights stored as weightsShared,
etc.
Original exception: """+str(e)))

    if W is None and weights_view is None:
        raise ValueError("model doesn't support any weights interfaces")

    if weights_view is None:
        weights_format = model.get_weights_format()
        assert hasattr(weights_format,'__iter__')
        assert len(weights_format) == 2
        assert weights_format[0] in ['v','h']
        assert weights_format[1] in ['v','h']
        assert weights_format[0] != weights_format[1]

        if weights_format[0] == 'v':
            W = W.T
        h = W.shape[0]

        if norm_sort:
            norms = np.sqrt(1e-8+np.square(W).sum(axis=1))
            norm_prop = norms / norms.max()


        weights_view = dataset.get_weights_view(W)
        assert weights_view.shape[0] == h
    try:
        hr, hc = model.get_weights_view_shape()
    except NotImplementedError:
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

    pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                  patch_shape=weights_view.shape[1:3],
            is_color = weights_view.shape[-1] == 3)

    if global_rescale:
        weights_view /= np.abs(weights_view).max()

    if norm_sort:
        logger.info('sorting weights by decreasing norm')
        idx = sorted( range(h), key=lambda l : - norm_prop[l] )
    else:
        idx = range(h)

    if border:
        act = 0
    else:
        act = None

    for i in range(0,h):
        patch = weights_view[idx[i],...]
        pv.add_patch(patch, rescale=patch_rescale, activation=act)

    abs_weights = np.abs(weights_view)
    logger.info('smallest enc weight magnitude: {0}'.format(abs_weights.min()))
    logger.info('mean enc weight magnitude: {0}'.format(abs_weights.mean()))
    logger.info('max enc weight magnitude: {0}'.format(abs_weights.max()))


    if W is not None:
        norms = np.sqrt(np.square(W).sum(axis=1))
        assert norms.shape == (h,)
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

    return pv
Example #6
0
    print 'this dataset has ' + str(len(examples.shape) -
                                    2) + ' topological dimensions'
    quit(-1)
#

if examples.shape[3] == 1:
    is_color = False
elif examples.shape[3] == 3:
    is_color = True
else:
    print 'got unknown image format with ' + str(
        examples.shape[3]) + ' channels'
    print 'supported formats are 1 channel greyscale or three channel RGB'
    quit(-1)
#

print examples.shape[1:3]

pv = patch_viewer.PatchViewer((rows, cols),
                              examples.shape[1:3],
                              is_color=is_color)

for i in xrange(rows * cols):
    pv.add_patch(examples[i, :, :, :], activation=0.0, rescale=patch_rescale)
#

if out is None:
    pv.show()
else:
    pv.save(out)
Example #7
0
def show_examples(path, rows, cols, rescale='global', out=None):
    """
    .. todo::

        WRITEME

    Parameters
    ----------
    path : string
        The pickle or YAML file to show examples of
    rows : int
        WRITEME
    cols : int
        WRITEME
    rescale : {'rescale', 'global', 'individual'}
        Default is 'rescale', WRITEME
    out : string, optional
        WRITEME
    """

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True

    if path.endswith('.pkl'):
        from pylearn2.utils import serial
        obj = serial.load(path)
    elif path.endswith('.yaml'):
        print 'Building dataset from yaml...'
        obj = yaml_parse.load_path(path)
        print '...done'
    else:
        obj = yaml_parse.load(path)

    if hasattr(obj, 'get_batch_topo'):
        # obj is a Dataset
        dataset = obj

        examples = dataset.get_batch_topo(rows * cols)
    else:
        # obj is a Model
        model = obj
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        theano_rng = RandomStreams(42)
        design_examples_var = model.random_design_matrix(batch_size=rows *
                                                         cols,
                                                         theano_rng=theano_rng)
        from theano import function
        print 'compiling sampling function'
        f = function([], design_examples_var)
        print 'sampling'
        design_examples = f()
        print 'loading dataset'
        dataset = yaml_parse.load(model.dataset_yaml_src)
        examples = dataset.get_topological_view(design_examples)

    norms = np.asarray([
        np.sqrt(np.sum(np.square(examples[i, :])))
        for i in xrange(examples.shape[0])
    ])
    print 'norms of examples: '
    print '\tmin: ', norms.min()
    print '\tmean: ', norms.mean()
    print '\tmax: ', norms.max()

    print 'range of elements of examples', (examples.min(), examples.max())
    print 'dtype: ', examples.dtype

    examples = dataset.adjust_for_viewer(examples)

    if global_rescale:
        examples /= np.abs(examples).max()

    if len(examples.shape) != 4:
        print 'sorry, view_examples.py only supports image examples for now.'
        print 'this dataset has ' + str(len(examples.shape) - 2),
        print 'topological dimensions'
        quit(-1)

    if examples.shape[3] == 1:
        is_color = False
    elif examples.shape[3] == 3:
        is_color = True
    else:
        print 'got unknown image format with ' + str(examples.shape[3]),
        print 'channels'
        print 'supported formats are 1 channel greyscale or three channel RGB'
        quit(-1)

    print examples.shape[1:3]

    pv = patch_viewer.PatchViewer((rows, cols),
                                  examples.shape[1:3],
                                  is_color=is_color)

    for i in xrange(rows * cols):
        pv.add_patch(examples[i, :, :, :],
                     activation=0.0,
                     rescale=patch_rescale)

    if out is None:
        pv.show()
    else:
        pv.save(out)
def main(options, positional_args):
    """
    .. todo::

        WRITEME
    """
    assert len(positional_args) == 1

    path, = positional_args

    out = options.out
    rescale = options.rescale

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        assert False

    if path.endswith('.pkl'):
        from pylearn2.utils import serial
        obj = serial.load(path)
    elif path.endswith('.yaml'):
        print 'Building dataset from yaml...'
        obj = yaml_parse.load_path(path)
        print '...done'
    else:
        obj = yaml_parse.load(path)

    rows = options.rows
    cols = options.cols

    if hasattr(obj, 'get_batch_topo'):
        #obj is a Dataset
        dataset = obj

        examples = dataset.get_batch_topo(rows * cols)
    else:
        #obj is a Model
        model = obj
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        theano_rng = RandomStreams(42)
        design_examples_var = model.random_design_matrix(batch_size=rows *
                                                         cols,
                                                         theano_rng=theano_rng)
        from theano import function
        print 'compiling sampling function'
        f = function([], design_examples_var)
        print 'sampling'
        design_examples = f()
        print 'loading dataset'
        dataset = yaml_parse.load(model.dataset_yaml_src)
        examples = dataset.get_topological_view(design_examples)

    norms = N.asarray([
        N.sqrt(N.sum(N.square(examples[i, :])))
        for i in xrange(examples.shape[0])
    ])
    print 'norms of examples: '
    print '\tmin: ', norms.min()
    print '\tmean: ', norms.mean()
    print '\tmax: ', norms.max()

    print 'range of elements of examples', (examples.min(), examples.max())
    print 'dtype: ', examples.dtype

    examples = dataset.adjust_for_viewer(examples)

    if global_rescale:
        examples /= N.abs(examples).max()

    if len(examples.shape) != 4:
        print 'sorry, view_examples.py only supports image examples for now.'
        print 'this dataset has ' + str(len(examples.shape) -
                                        2) + ' topological dimensions'
        quit(-1)

    is_color = False
    assert examples.shape[3] == 2

    print examples.shape[1:3]

    pv = patch_viewer.PatchViewer((rows, cols * 2),
                                  examples.shape[1:3],
                                  is_color=is_color)

    for i in xrange(rows * cols):
        # Load patches in backwards order for easier cross-eyed viewing
        # (Ian can't do the magic eye thing where you focus your eyes past the screen, must
        # focus eyes in front of screen)
        pv.add_patch(examples[i, :, :, 1],
                     activation=0.0,
                     rescale=patch_rescale)
        pv.add_patch(examples[i, :, :, 0],
                     activation=0.0,
                     rescale=patch_rescale)

    if out is None:
        pv.show()
    else:
        pv.save(out)
Example #9
0
        if norm_sort:
            norms = np.sqrt(1e-8 + np.square(W).sum(axis=1))
            norm_prop = norms / norms.max()

        weights_view = dataset.get_weights_view(W)
        assert weights_view.shape[0] == h
    try:
        hr, hc = model.get_weights_view_shape()
    except NotImplementedError:
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

    pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                  patch_shape=weights_view.shape[1:3],
                                  is_color=weights_view.shape[-1] == 3)

    if global_rescale:
        weights_view /= np.abs(weights_view).max()

    if norm_sort:
        print 'sorting weights by decreasing norm'
        idx = sorted(range(h), key=lambda l: -norm_prop[l])
    else:
        idx = range(h)

    if border:
        act = 0
    else:
        act = None
Example #10
0
def main(options, positional_args):
    assert len(positional_args) == 1

    path ,= positional_args

    out = options.out
    rescale = options.rescale

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        assert False

    if path.endswith('.pkl'):
        from pylearn2.utils import serial
        obj = serial.load(path)
    elif path.endswith('.yaml'):
        print 'Building dataset from yaml...'
        obj =yaml_parse.load_path(path)
        print '...done'
    else:
        obj = yaml_parse.load(path)

    rows = options.rows
    cols = options.cols

    if hasattr(obj,'get_batch_topo'):
        #obj is a Dataset
        dataset = obj

        examples = dataset.get_batch_topo(rows*cols)
    else:
        #obj is a Model
        model = obj
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        theano_rng = RandomStreams(42)
        design_examples_var = model.random_design_matrix(batch_size = rows * cols, theano_rng = theano_rng)
        from theano import function
        print 'compiling sampling function'
        f = function([],design_examples_var)
        print 'sampling'
        design_examples = f()
        print 'loading dataset'
        dataset = yaml_parse.load(model.dataset_yaml_src)
        examples = dataset.get_topological_view(design_examples)

    norms = N.asarray( [
            N.sqrt(N.sum(N.square(examples[i,:])))
                        for i in xrange(examples.shape[0])
                        ])
    print 'norms of examples: '
    print '\tmin: ',norms.min()
    print '\tmean: ',norms.mean()
    print '\tmax: ',norms.max()

    print 'range of elements of examples',(examples.min(),examples.max())
    print 'dtype: ', examples.dtype

    examples = dataset.adjust_for_viewer(examples)

    if global_rescale:
        examples /= N.abs(examples).max()

    if len(examples.shape) != 4:
        print 'sorry, view_examples.py only supports image examples for now.'
        print 'this dataset has '+str(len(examples.shape)-2)+' topological dimensions'
        quit(-1)
    #

    if examples.shape[3] == 1:
        is_color = False
    elif examples.shape[3] == 3:
        is_color = True
    else:
        print 'got unknown image format with '+str(examples.shape[3])+' channels'
        print 'supported formats are 1 channel greyscale or three channel RGB'
        quit(-1)
    #

    print examples.shape[1:3]

    pv = patch_viewer.PatchViewer( (rows, cols), examples.shape[1:3], is_color = is_color)

    for i in xrange(rows*cols):
        pv.add_patch(examples[i,:,:,:], activation = 0.0, rescale = patch_rescale)
    #

    if out is None:
        pv.show()
    else:
        pv.save(out)
Example #11
0
                        num_pieces = 1,
                        kernel_shape = (4, 4),
                        pool_shape = (1, 1),
                        pool_stride=(1, 1),
                        irange = 0.05)
deconv = Deconv(layer_name = 'deconv',
                num_channels = 1,
                kernel_shape = (4, 4),
                irange = 0.05)

mlp = MLP(input_space =input_space,
        layers = [conv, deconv])

mlp.layers[1].transformer._filters.set_value(mlp.layers[0].transformer._filters.get_value())

x = input_space.get_theano_batch()
out = mlp.fprop(x)
f = theano.function([x], out)

data = MNIST('test')
data_specs = (input_space, 'features')
iter = data.iterator(mode = 'sequential', batch_size = 2, data_specs = data_specs)
pv = patch_viewer.PatchViewer((10, 10), (28, 28))
for item in iter:
    res = f(item)
    pv.add_patch(item[0,:,:,0])
    pv.add_patch(res[0,:,:,0])
    pv.show()
    break

Example #12
0
def get_weights_report(model_path, rescale = True):
    print 'making weights report'
    print 'loading model'
    p = serial.load(model_path)
    print 'loading done'

    dataset = yaml_parse.load(p.dataset_yaml_src)

    if hasattr(p,'get_weights'):
        p.weights = p.get_weights()

    if 'weightsShared' in dir(p):
        p.weights = p.weightsShared.get_value()

    if 'W' in dir(p):
        if hasattr(p.W,'__array__'):
            warnings.warn('model.W is an ndarray; I can figure out how to display this but that seems like a sign of a bad bug')
            p.weights = p.W
            from theano import shared
            p.W = shared(p.W)
        else:
            p.weights = p.W.get_value()

    if 'D' in dir(p):
        p.decWeightsShared = p.D

    if 'enc_weights_shared' in dir(p):
        p.weights = p.enc_weights_shared.get_value()


    if 'W' in dir(p) and len(p.W.get_value().shape) == 3:
        W = p.W.get_value()
        nh , nv, ns = W.shape
        pv = patch_viewer.PatchViewer(grid_shape=(nh,ns), patch_shape= dataset.view_shape()[0:2])

        for i in range(0,nh):
            for k in range(0,ns):
                patch = W[i,:,k]
                patch = dataset.vec_to_view(patch, weights = True)
                pv.add_patch( patch, rescale = rescale)
            #
        #
    elif len(p.weights.shape) == 2:
        if hasattr(p,'get_weights_format'):
            p.weights_format = p.get_weights_format

        assert type(p.weights_format()) == type([])
        assert len(p.weights_format()) == 2
        assert p.weights_format()[0] in ['v','h']
        assert p.weights_format()[1] in ['v','h']
        assert p.weights_format()[0] != p.weights_format()[1]

        if p.weights_format()[0] == 'v':
            p.weights = p.weights.transpose()
        h = p.weights.shape[0]


        hr = int(N.ceil(N.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(p):
            hr, hc = p.hidShape

        pv = patch_viewer.PatchViewer(grid_shape=(hr,hc), patch_shape=dataset.view_shape()[0:2],
                is_color = dataset.view_shape()[2] == 3)
        weights_mat = p.weights

        assert weights_mat.shape[0] == h
        weights_view = dataset.get_weights_view(weights_mat)
        assert weights_view.shape[0] == h
        #print 'weights_view shape '+str(weights_view.shape)
        for i in range(0,h):
            patch = weights_view[i,...]
            pv.add_patch( patch, rescale   = rescale)
    else:
        e = p.weights
        d = p.dec_weights_shared.value

        h = e.shape[0]

        if len(e.shape) == 8:
            raise Exception("get_weights_report doesn't support tiled convolution yet, use the show_weights8 app")

        if e.shape[4] != 1:
            raise Exception('weights shape: '+str(e.shape))
        shape = e.shape[1:3]
        dur = e.shape[3]

        show_dec = id(e) != id(d)

        pv = PatchViewer.PatchViewer( grid_shape = ((1+show_dec)*h,dur), patch_shape=shape)
        for i in range(0,h):
            pv.addVid( e[i,:,:,:,0], rescale = rescale)
            if show_dec:
                pv.addVid( d[i,:,:,:,0], rescale = rescale)

    print 'smallest enc weight magnitude: '+str(N.abs(p.weights).min())
    print 'mean enc weight magnitude: '+str(N.abs(p.weights).mean())
    print 'max enc weight magnitude: '+str(N.abs(p.weights).max())


    return pv
        weights_view = dataset.get_weights_view(W)
        assert weights_view.shape[0] == h
    try:
        hr, hc = model.get_weights_view_shape()
    except NotImplementedError:
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

    if (weights_view.shape[-1] == 1 and len(channels) != 1):
        channels = [0]

    pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                  patch_shape=weights_view.shape[1:3],
                                  is_color=len(channels) == 3)

    if global_rescale:
        weights_view /= np.abs(weights_view).max()

    if norm_sort:
        logger.info('sorting weights by decreasing norm')
        idx = sorted(range(h), key=lambda l: -norm_prop[l])
    else:
        idx = range(h)

    if border:
        act = 0
    else:
        act = None
Example #14
0
import sys
from pylearn2.utils import serial
from pylearn2.gui import patch_viewer

model_path = sys.argv[1]
height = int(sys.argv[2])
width = int(sys.argv[3])

model = serial.load(model_path)
W, = model.layers[0].transformer.get_params()
W = W.get_value()

assert height * width == W.shape[0] * W.shape[3]
pv = patch_viewer.PatchViewer(grid_shape=(height, width),
                              patch_shape=[W.shape[1], W.shape[2]])
for i in range(W.shape[3]):
    for j in range(W.shape[0]):
        pv.add_patch(W[j, :, :, i])
pv.show()
Example #15
0
            h = W.shape[0]

            if norm_sort:
                norms = np.sqrt(1e-8 + np.square(W).sum(axis=1))
                norm_prop = norms / norms.max()

            weights_view = dataset.get_weights_view(W)
            assert weights_view.shape[0] == h
        #print 'weights_view shape '+str(weights_view.shape)
        hr = int(np.ceil(np.sqrt(h)))
        hc = hr
        if 'hidShape' in dir(model):
            hr, hc = model.hidShape

        pv = patch_viewer.PatchViewer(grid_shape=(hr, hc),
                                      patch_shape=weights_view.shape[1:3],
                                      is_color=weights_view.shape[-1] == 3)

        if global_rescale:
            weights_view /= np.abs(weights_view).max()

        if norm_sort:
            print 'sorting weights by decreasing norm'
            idx = sorted(range(h), key=lambda l: -norm_prop[l])
        else:
            idx = range(h)

        if border:
            act = 0
        else:
            act = None