def test_decompose_3d():
    """Tests computing a rank-1 approximation to a 3D filter.
    Note that this tests both ``filtertools.decompose()`` and
    ``filtertools.lowranksta()``.
    """
    np.random.seed(0)
    filter_length = 50
    nx, ny = 10, 10

    def gaussian(x, mu, sigma):
        return np.exp(-((x - mu) / sigma)**2) / np.sqrt(sigma * 2 * np.pi)

    temporal = gaussian(np.linspace(-3, 3, filter_length), -1, 1.5)
    spatial = gaussian(np.linspace(-3, 3, nx * ny), 0, 1.0).reshape(nx, ny)
    true_filter = np.outer(temporal, spatial.ravel())
    noise_std = 0.01 * (temporal.max() - temporal.min())
    true_filter += np.random.randn(*true_filter.shape) * noise_std

    s, t = flt.decompose(true_filter)

    # s/t are unit vectors, scale them and the inputs
    s -= s.min()
    s /= s.max()
    t -= t.min()
    t /= t.max()
    temporal -= temporal.min()
    temporal /= temporal.max()
    spatial -= spatial.min()
    spatial /= spatial.max()

    tol = 0.1
    assert np.allclose(temporal, t, atol=tol)
    assert np.allclose(spatial.ravel(), s.ravel(), atol=tol)
Exemple #2
0
def test_decompose_3d():
    """Tests computing a rank-1 approximation to a 3D filter.
    Note that this tests both ``filtertools.decompose()`` and
    ``filtertools.lowranksta()``.
    """
    np.random.seed(0)
    filter_length = 50
    nx, ny = 10, 10
    def gaussian(x, mu, sigma):
        return np.exp(-((x - mu) / sigma)**2) / np.sqrt(sigma * 2 * np.pi)
    temporal = gaussian(np.linspace(-3, 3, filter_length), -1, 1.5)
    spatial = gaussian(np.linspace(-3, 3, nx * ny), 0, 1.0).reshape(nx, ny)
    true_filter = np.outer(temporal, spatial.ravel())
    noise_std = 0.01 * (temporal.max() - temporal.min())
    true_filter += np.random.randn(*true_filter.shape) * noise_std

    s, t = flt.decompose(true_filter)
   
    # s/t are unit vectors, scale them and the inputs
    s -= s.min()
    s /= s.max()
    t -= t.min()
    t /= t.max()
    temporal -= temporal.min()
    temporal /= temporal.max()
    spatial -= spatial.min()
    spatial /= spatial.max()

    tol = 0.1
    assert np.allclose(temporal, t, atol=tol)
    assert np.allclose(spatial.ravel(), s.ravel(), atol=tol)
Exemple #3
0
def test_decompose():
    """Tests computing a rank-1 approximation to a filter.
    Note that this tests both filtertools.decompose() and filtertools.lowranksta().
    """
    np.random.seed(0)
    filter_length = 50
    nx, ny = 10, 10
    temporal, spatial, true_filter = utils.create_spatiotemporal_filter(nx, ny, filter_length)

    noise_std = 0.01
    true_filter += np.random.randn(*true_filter.shape) * noise_std

    s, t = flt.decompose(true_filter)

    tol = 0.1
    assert np.allclose(temporal, t, atol=tol)
    assert np.allclose(spatial, s, atol=tol)
Exemple #4
0
res_constrained = minimize(model_separation,
                           x0=initial_guess,
                           constraints=constraint,
                           method='COBYLA')

optimal_stimulus = res_constrained.x
constraint_violation = unit_norm_constraint(optimal_stimulus)

temporal_kernel = optimal_stimulus[:40]
spatial_profile = optimal_stimulus[40:]
low_rank_optimal_stimulus = np.outer(temporal_kernel, spatial_profile)
optimal_stimulus = low_rank_optimal_stimulus.reshape((1, 40, 50, 50))

ln_response = ln_model.predict(optimal_stimulus)[0][0]
convnet_response = naturalscenes_model.predict(optimal_stimulus)[0][0]
responses = np.array([ln_response, convnet_response])

## SAVE RESULT ##
save_dir = mksavedir(prefix='Maximal Differentiated Stimuli')
f = h5py.File(join(save_dir, 'differentiated_stimuli.h5'), 'w')
f.create_dataset('stimulus', data=optimal_stimulus)
f.create_dataset('responses', data=responses)
f.create_dataset('constraint', data=constraint_violation)
f.close()

spatial_profile, time = ft.decompose(optimal_stimulus[0])
fig_filename = 'differentiated_stimuli.png'
plt.imshow(spatial_profile, interpolation='nearest')
plt.grid('off')
plt.savefig(join(save_dir, fig_filename))
initial_guess = zscore(np.random.randint(0, 2, 40 + 50*50).astype('float32'))
res_constrained = minimize(model_separation, x0=initial_guess, constraints=constraint, method='COBYLA')

optimal_stimulus = res_constrained.x
constraint_violation = unit_norm_constraint(optimal_stimulus)

temporal_kernel = optimal_stimulus[:40]
spatial_profile = optimal_stimulus[40:]
low_rank_optimal_stimulus = np.outer(temporal_kernel, spatial_profile)
optimal_stimulus = low_rank_optimal_stimulus.reshape((1,40,50,50))

ln_response = ln_model.predict(optimal_stimulus)[0][0]
convnet_response = naturalscenes_model.predict(optimal_stimulus)[0][0]
responses = np.array([ln_response, convnet_response])

## SAVE RESULT ##
save_dir = mksavedir(prefix='Maximal Differentiated Stimuli')
f = h5py.File(join(save_dir, 'differentiated_stimuli.h5'), 'w')
f.create_dataset('stimulus', data=optimal_stimulus)
f.create_dataset('responses', data=responses)
f.create_dataset('constraint', data=constraint_violation)
f.close()

spatial_profile, time = ft.decompose(optimal_stimulus[0])
fig_filename = 'differentiated_stimuli.png'
plt.imshow(spatial_profile, interpolation='nearest')
plt.grid('off')
plt.savefig(join(save_dir, fig_filename))


Exemple #6
0
def visualize_sta(sta,
                  fig_size=(8, 10),
                  display=True,
                  save=False,
                  normalize=True):
    '''
    Visualize one or many STAs of deep-retina interunits.

    Computes the spatial and temporal profiles by SVD.

    INPUTS:
    sta             weight array of shape (time, space, space)
                        or (num_units, time, space, space)
    fig_size        figure size in inches
    display         bool; display figure?
    save            bool; save figure?
    '''

    if len(sta) == 3:
        num_units = 1
    else:
        num_units = sta.shape[0]

    if normalize:
        colorlimit = [-np.max(abs(sta[:])), np.max(abs(sta[:]))]

    # plot space and time profiles together
    fig = plt.gcf()
    fig.set_size_inches(fig_size)
    plt.title('STA', fontsize=20)
    num_cols = int(np.sqrt(num_units))
    num_rows = int(np.ceil(num_units / num_cols))
    idxs = range(num_cols)
    for x in range(num_cols):
        for y in range(num_rows):
            plt_idx = y * num_cols + x + 1
            if num_units > 1:
                spatial, temporal = ft.decompose(sta[plt_idx - 1])
            else:
                spatial, temporal = ft.decompose(sta)
            #plt.subplot(num_rows, num_cols, plt_idx)
            ax = plt.subplot2grid((num_rows * 4, num_cols), (4 * y, x),
                                  rowspan=3)
            if not normalize:
                ax.imshow(spatial, interpolation='nearest', cmap='seismic')
            else:
                ax.imshow(spatial,
                          interpolation='nearest',
                          cmap='seismic',
                          clim=colorlimit)
            plt.grid('off')
            plt.axis('off')

            ax = plt.subplot2grid((num_rows * 4, num_cols), (4 * y + 3, x),
                                  rowspan=1)
            ax.plot(np.linspace(0,
                                len(temporal) * 10, len(temporal)),
                    temporal,
                    'k',
                    linewidth=2)
            plt.grid('off')
            plt.axis('off')
    if save:
        plt.savefig(fig_dir + title + '.png', dpi=300)
        plt.close()
    if display:
        plt.show()
Exemple #7
0
def visualize_convnet_weights(weights,
                              title='convnet',
                              layer_name='layer_0',
                              fig_dir=None,
                              fig_size=(8, 10),
                              dpi=300,
                              space=True,
                              time=True,
                              display=True,
                              save=False,
                              cmap='seismic',
                              normalize=True):
    '''
    Visualize convolutional spatiotemporal filters in a convolutional neural
    network.

    Computes the spatial and temporal profiles by SVD.

    INPUTS:
    weights         weight array of shape (num_filters, history, space, space)
                        or full path to weight file
    title           title of plots; also the saved plot file base name
    fig_dir         where to save figures
    fig_size        figure size in inches
    dpi             resolution in dots per inch
    space           bool; if True plots the spatial profiles of weights
    time            bool; if True plots the temporal profiles of weights
                    NOTE: if space and time are both False, function returns
                    spatial and temporal profiles instead of plotting
    display         bool; display figure?
    save            bool; save figure?

    OUTPUT:
    When space or time are true, ouputs are plots saved to fig_dir.
    When neither space nor time are true, output is:
        spatial_profiles        list of spatial profiles of filters
        temporal_profiles       list of temporal profiles of filters
    '''

    if fig_dir is None:
        fig_dir = os.getcwd()

    # if user supplied path instead of array of weights
    if type(weights) is str:
        weight_file = h5py.File(weights, 'r')
        weights = weight_file[layer_name]['param_0']

    num_filters = weights.shape[0]

    if normalize:
        max_val = np.max(abs(weights[:]))
        colorlimit = [-max_val, max_val]

    # plot space and time profiles together
    if space and time:
        fig = plt.gcf()
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_rows = int(np.sqrt(num_filters))
        num_cols = int(np.ceil(num_filters / num_rows))
        idxs = range(num_cols)
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                # in case fewer weights than fit neatly in rows and cols
                if plt_idx <= len(weights):
                    spatial, temporal = ft.decompose(weights[plt_idx - 1])
                    #plt.subplot(num_rows, num_cols, plt_idx)
                    ax = plt.subplot2grid((num_rows * 4, num_cols), (4 * y, x),
                                          rowspan=3)
                    if normalize:
                        ax.imshow(spatial,
                                  interpolation='nearest',
                                  cmap=cmap,
                                  clim=colorlimit)
                    else:
                        ax.imshow(spatial, interpolation='nearest', cmap=cmap)
                    plt.title('Subunit %i' % plt_idx)
                    plt.grid('off')
                    plt.axis('off')

                    ax = plt.subplot2grid((num_rows * 4, num_cols),
                                          (4 * y + 3, x),
                                          rowspan=1)
                    ax.plot(np.linspace(0,
                                        len(temporal) * 10, len(temporal)),
                            temporal,
                            'k',
                            linewidth=2)
                    plt.grid('off')
                    plt.axis('off')
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png',
                        dpi=dpi)
            plt.close()
        if display:
            plt.show()

    # plot just spatial profile
    elif space and not time:
        fig = plt.gcf()
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_cols = int(np.sqrt(num_filters))
        num_rows = int(np.ceil(num_filters / num_cols))
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                # in case fewer weights than fit neatly in rows and cols
                if plt_idx <= len(weights):
                    spatial, temporal = ft.decompose(weights[plt_idx - 1])
                    plt.subplot(num_rows, num_cols, plt_idx)
                    plt.imshow(spatial,
                               interpolation='nearest',
                               cmap=cmap,
                               clim=colorlimit)
                    plt.grid('off')
                    plt.axis('off')
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png',
                        dpi=dpi)
            plt.close()
        if display:
            plt.show()

    # plot just temporal profile
    elif time and not space:
        fig = plt.gcf()
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_cols = int(np.sqrt(num_filters))
        num_rows = int(np.ceil(num_filters / num_cols))
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                # in case fewer weights than fit neatly in rows and cols
                if plt_idx <= len(weights):
                    spatial, temporal = ft.decompose(weights[plt_idx - 1])
                    plt.subplot(num_rows, num_cols, plt_idx)
                    plt.plot(np.linspace(0,
                                         len(temporal) * 10, len(temporal)),
                             temporal,
                             'k',
                             linewidth=2)
                    plt.grid('off')
                    plt.axis('off')
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png',
                        dpi=dpi)
            plt.close()
        if display:
            plt.show()

    # don't plot anything, just return spatial and temporal profiles
    else:
        spatial_profiles = []
        temporal_profiles = []
        for f in weights:
            spatial, temporal = ft.decompose(f)
            spatial_profiles.append(spatial)
            temporal_profiles.append(temporal)
        return spatial, temporal
Exemple #8
0
def plot_filters(weights, cmap='seismic', normalize=True):
    """Plots an array of spatiotemporal filters

    Parameters
    ----------
    weights : array_like
        Must have shape (num_conv_filters, num_temporal, num_spatial, num_spatial)

    cmap : str, optional
        A matplotlib colormap (Default: 'seismic')

    normalize : boolean, optional
        Whether or not to scale the color limit according to the minimum and maximum
        values in the weights array

    Returns
    -------
    fig : a matplotlib figure handle
    """

    # create the figure
    fig = plt.figure(figsize=(12, 8))

    # number of convolutional filters
    num_filters = weights.shape[0]

    # get the number of rows and columns in the grid
    nrows, ncols = gridshape(num_filters, tol=2.0)

    # build the grid for all of the filters
    outer_grid = gridspec.GridSpec(nrows, ncols)

    # normalize to the maximum weight in the array
    if normalize:
        max_val = np.max(abs(weights.ravel()))
        vmin, vmax = -max_val, max_val
    else:
        vmin = np.min(weights.ravel())
        vmax = np.max(weights.ravel())

    # loop over each convolutional filter
    for w, og in zip(weights, outer_grid):

        # get the spatial and temporal frame
        spatial, temporal = ft.decompose(w)

        # build the gridspec (spatial and temporal subplots) for this filter
        inner_grid = gridspec.GridSpecFromSubplotSpec(2,
                                                      1,
                                                      subplot_spec=og,
                                                      height_ratios=(4, 1),
                                                      hspace=0.0)

        # plot the spatial frame
        ax = plt.Subplot(fig, inner_grid[0])
        ax.imshow(spatial,
                  interpolation='nearest',
                  cmap=cmap,
                  vmin=vmin,
                  vmax=vmax)
        fig.add_subplot(ax)
        plt.grid('off')
        plt.axis('off')

        ax = plt.Subplot(fig, inner_grid[1])
        ax.plot(temporal, 'k', lw=2)
        fig.add_subplot(ax)
        plt.grid('off')
        plt.axis('off')

    plt.show()
    plt.draw()
    return fig
def visualize_convnet_weights(weights, title='convnet', fig_dir=pwd, 
        fig_size=(8,10), dpi=500, space=True, time=True, display=False,
        save=True):
    '''
    Visualize convolutional spatiotemporal filters in a convolutional neural 
    network.

    Computes the spatial and temporal profiles by SVD.

    INPUTS:
    weights         weight array of shape (num_filters, history, space, space)
    title           title of plots; also the saved plot file base name
    fig_dir         where to save figures
    fig_size        figure size in inches
    dpi             resolution in dots per inch
    space           bool; if True plots the spatial profiles of weights
    time            bool; if True plots the temporal profiles of weights
                    NOTE: if space and time are both False, function returns
                    spatial and temporal profiles instead of plotting
    display         bool; display figure?
    save            bool; save figure?
    
    OUTPUT:
    When space or time are true, ouputs are plots saved to fig_dir.
    When neither space nor time are true, output is:
        spatial_profiles        list of spatial profiles of filters
        temporal_profiles       list of temporal profiles of filters
    '''
    
    num_filters = weights.shape[0]

    # plot space and time profiles together
    if space and time:
        fig = plt.gcf()
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_cols = int(np.sqrt(num_filters))
        num_rows = int(np.ceil(num_filters/num_cols))
        idxs = range(num_cols)
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                spatial,temporal = ft.decompose(weights[plt_idx-1])
                #plt.subplot(num_rows, num_cols, plt_idx)
                ax = plt.subplot2grid((num_rows*4, num_cols), (4*y, x), rowspan=3)
                ax.imshow(spatial, interpolation='nearest', cmap='gray') #, clim=[np.min(W0), np.max(W0)])
                plt.grid('off')
                plt.axis('off')
                
                ax = plt.subplot2grid((num_rows*4, num_cols), (4*y+3, x), rowspan=1)
                ax.plot(np.linspace(0,400,40), temporal, 'k', linewidth=2)
                plt.grid('off')
                plt.axis('off')
        if display:
            plt.show()
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png', dpi=dpi)

    # plot just spatial profile
    elif space and not time:
        fig = plt.gcf
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_cols = int(np.sqrt(num_filters))
        num_rows = int(np.ceil(num_filters/num_cols))
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                spatial, temporal = ft.decompose(weights[plt_idx-1])
                plt.subplot(num_rows, num_cols, plt_idx)
                plt.imshow(spatial, interpolation='nearest', cmap='gray')
                plt.colorbar()
                plt.grid('off')
                plt.axis('off')
        if display:
            plt.show()
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png', dpi=dpi)

    # plot just temporal profile
    elif time and not space:
        fig = plt.gcf
        fig.set_size_inches(fig_size)
        plt.title(title, fontsize=20)
        num_cols = int(np.sqrt(num_filters))
        num_rows = int(np.ceil(num_filters/num_cols))
        for x in range(num_cols):
            for y in range(num_rows):
                plt_idx = y * num_cols + x + 1
                spatial, temporal = ft.decompose(weights[plt_idx-1])
                plt.subplot(num_rows, num_cols, plt_idx)
                plt.plot(np.linspace(0,400,40), temporal, 'k', linewidth=2)
                plt.grid('off')
                plt.axis('off')
        if display:
            plt.show()
        if save:
            plt.savefig(fig_dir + title + '_spatiotemporal_profiles.png', dpi=dpi)

    # don't plot anything, just return spatial and temporal profiles
    else:
        spatial_profiles = []
        temporal_profiles = []
        for f in weights:
            spatial, temporal = ft.decompose(f)
            spatial_profiles.append(spatial)
            temporal_profiles.append(temporal)
        return spatial, temporal