Esempio n. 1
0
def test_lgnonoffcell():
    ffm = movie.FullFieldFlashMovie(range(120), range(240), 0.3, 0.7)
    mv = ffm.full(t_max=2.0)

    temporal_filter = TemporalFilterCosineBump(weights=[3.441, -2.115],
                                               kpeaks=[8.269, 19.991],
                                               delays=[0.0, 0.0])

    spatial_filter_on = GaussianSpatialFilter(sigma=(1.85, 1.85),
                                              origin=(0.0, 0.0),
                                              translate=(120.0, 60.0))
    on_linear_filter = SpatioTemporalFilter(spatial_filter_on,
                                            temporal_filter,
                                            amplitude=20)

    spatial_filter_off = GaussianSpatialFilter(sigma=(3.85, 3.85),
                                               origin=(0.0, 0.0),
                                               translate=(120.0, 60.0))
    off_linear_filter = SpatioTemporalFilter(spatial_filter_off,
                                             temporal_filter,
                                             amplitude=-20)

    cell = LGNOnOffCell(on_linear_filter, off_linear_filter)
    lgn = LGNModel([cell])
    results = lgn.evaluate(mv, downsample=10)
    assert (len(results) == 1)

    times = np.array(results[0][0], dtype=np.float64)
    rates = np.array(results[0][1], dtype=np.float64)

    assert (np.allclose(times, [0.0, 0.41666, 0.83333, 1.250, 1.6666],
                        atol=1.0e-4))
    assert (np.allclose(rates, [0.0, 3.7286, 0.0, 0.0, 0.0], atol=1.0e-3))
Esempio n. 2
0
def plot_sfilter_params():
    gm = movie.GratingMovie(200, 200)
    mv = gm.create_movie(t_max=2.0)

    fig, axes = plt.subplots(2, 2, figsize=(7, 7))
    rotations = [0.0, 45.0]
    sigmas = [(30.0, 20.0), (20.0, 30.0)]

    for r, sigma in enumerate(sigmas):
        for c, rot in enumerate(rotations):
            gsf = GaussianSpatialFilter(translate=(0, 0),
                                        sigma=sigma,
                                        rotation=rot)
            axes[r, c].imshow(gsf.get_kernel(mv.row_range,
                                             mv.col_range).full(),
                              extent=(0, 200, 0, 200),
                              origin='lower')

            if r == 0:
                axes[r, c].title.set_text('spatial_rotation={}'.format(rot))

            if c == 0:
                axes[r, c].set_ylabel('spatial_size={}'.format(sigma))

    plt.show()
Esempio n. 3
0
def test_spatialfilter_kernel():
    mv = movie.Movie(np.zeros((1001, 120, 240)),
                     t_range=np.linspace(0.0, 1.0, 1001, endpoint=True))

    gsf = GaussianSpatialFilter(translate=(-80, -20),
                                sigma=(30, 10),
                                rotation=15.0)
    kernel = gsf.get_kernel(row_range=mv.row_range, col_range=mv.col_range)
    assert (isinstance(kernel, Kernel2D))
    assert (kernel.full().shape == (120, 240))
    assert (np.isclose(np.sum(kernel.full()), 1.0))
Esempio n. 4
0
def test_offunit():
    ffm = movie.FullFieldFlashMovie(range(120), range(240), 0.3, 0.7)
    mv = ffm.full(t_max=2.0)

    spatial_filter = GaussianSpatialFilter(translate=(120.0, 60.0),
                                           sigma=(0.615, 0.615),
                                           origin=(0.0, 0.0))
    temporal_filter = TemporalFilterCosineBump(weights=[3.441, -2.115],
                                               kpeaks=[8.269, 19.991],
                                               delays=[0.0, 0.0])
    linear_filter = SpatioTemporalFilter(spatial_filter,
                                         temporal_filter,
                                         amplitude=-1.0)
    transfer_function = ScalarTransferFunction('Heaviside(s+1.05)*(s+1.05)')

    cell = OffUnit(linear_filter, transfer_function)
    lgn = LGNModel([cell])
    results = lgn.evaluate(mv, downsample=10)
    assert (len(results) == 1)

    times = np.array(results[0][0], dtype=np.float64)
    rates = np.array(results[0][1], dtype=np.float64)

    assert (np.allclose(times, [0.0, 0.41666, 0.83333, 1.250, 1.6666],
                        atol=1.0e-4))
    assert (np.allclose(rates, [1.05, 1.2364, 1.05, 1.05, 1.05], atol=1.0e-3))
Esempio n. 5
0
def test_spatiotemporalfilter_kernel():
    mv = movie.Movie(np.zeros((1001, 120, 240)),
                     t_range=np.linspace(0.0, 1.0, 1001, endpoint=True))

    tf = TemporalFilterCosineBump(weights=[33.328, -2.10059],
                                  kpeaks=[59.0, 120.0],
                                  delays=[0.0, 0.0])
    sf = GaussianSpatialFilter(translate=(-80, -20),
                               sigma=(30, 10),
                               rotation=15.0)
    stf = SpatioTemporalFilter(sf, tf)
    kernel = stf.get_spatiotemporal_kernel(row_range=mv.row_range,
                                           col_range=mv.col_range,
                                           t_range=mv.t_range)
    assert (isinstance(kernel, Kernel3D))
    assert (kernel.full().shape == (987, 120, 240))
    kernel.normalize()
    assert (np.isclose(np.sum(kernel.full()), 1.0))
Esempio n. 6
0
def test_twosubfieldlinearcell():
    ffm = movie.FullFieldFlashMovie(range(120), range(240), 0.3, 0.7)
    mv = ffm.full(t_max=2.0)

    spatial_filter = GaussianSpatialFilter(translate=(120.0, 60.0),
                                           sigma=(0.615, 0.615),
                                           origin=(0.0, 0.0))

    son_tfiler = TemporalFilterCosineBump(
        [2.696143077048376, -1.8923936798453962],
        [37.993506826528716, 71.40822128514205], [42.0, 71.90456690180808])
    soff_tfilter = TemporalFilterCosineBump(
        [3.7309552296292257, -1.4209858354384888],
        [21.556972532016253, 51.56392683711558], [61.0, 74.85742945288372])

    linear_filter_son = SpatioTemporalFilter(spatial_filter,
                                             son_tfiler,
                                             amplitude=1.0)
    linear_filter_soff = SpatioTemporalFilter(spatial_filter,
                                              soff_tfilter,
                                              amplitude=-1.51426850536)

    two_sub_transfer_fn = MultiTransferFunction(
        (symbolic_x, symbolic_y),
        'Heaviside(x+2.0)*(x+2.0)+Heaviside(y+2.0)*(y+2.0)')

    cell = TwoSubfieldLinearCell(
        linear_filter_soff,
        linear_filter_son,
        subfield_separation=6.64946870229,
        onoff_axis_angle=249.09534316916634,
        dominant_subfield_location=(23.194207541958235, 49.44758663758982),
        transfer_function=two_sub_transfer_fn)

    lgn = LGNModel([cell])
    results = lgn.evaluate(mv, downsample=10)
    assert (len(results) == 1)

    times = np.array(results[0][0], dtype=np.float64)
    rates = np.array(results[0][1], dtype=np.float64)

    assert (np.allclose(times, [0.0, 0.41666, 0.83333, 1.250, 1.6666],
                        atol=1.0e-4))
    assert (np.allclose(rates, [4.0, 3.26931, 3.885, 4.0, 4.0], atol=1.0e-3))
Esempio n. 7
0
    def _make_kernels(self,cell_type, num_cells):

        param_table, lgn_types_table = self._load_param_values()

        all_spatial_sizes = param_table['spatial_size'][param_table['model_id']==cell_type]
        all_kpeaks_dom_0s = param_table['kpeaks_dom_0'][param_table['model_id']==cell_type]
        all_kpeaks_dom_1s = param_table['kpeaks_dom_1'][param_table['model_id']==cell_type]
        all_weight_dom_0s = param_table['weight_dom_0'][param_table['model_id']==cell_type]
        all_weight_dom_1s = param_table['weight_dom_1'][param_table['model_id']==cell_type]
        all_delay_dom_0s = param_table['delay_dom_0'][param_table['model_id']==cell_type]
        all_delay_dom_1s = param_table['delay_dom_1'][param_table['model_id']==cell_type]
        all_kpeaks_non_dom_0s = param_table['kpeaks_non_dom_0'][param_table['model_id']==cell_type]
        all_kpeaks_non_dom_1s = param_table['kpeaks_non_dom_1'][param_table['model_id']==cell_type]
        all_weight_non_dom_0s = param_table['weight_non_dom_0'][param_table['model_id']==cell_type]
        all_weight_non_dom_1s = param_table['weight_non_dom_1'][param_table['model_id']==cell_type]
        all_delay_non_dom_0s = param_table['delay_non_dom_0'][param_table['model_id']==cell_type]
        all_delay_non_dom_1s = param_table['delay_non_dom_1'][param_table['model_id']==cell_type]
        all_sf_seps = param_table['sf_sep'][param_table['model_id']==cell_type]
        all_angles = param_table['tuning_angle'][param_table['model_id']==cell_type]

        # this needs to be corrected for sONsOFF/sONtOFF cells
        if (('sOFF' in cell_type) or ('tOFF' in cell_type)) and (cell_type != 'sONsOFF_001') and (cell_type != 'sONtOFF_001'):
            amplitude = -1.0
        elif (cell_type == 'sONsOFF_001') or (cell_type == 'sONtOFF_001'):
            amplitude = 1.0
            amplitude_2 = -1.0
        else:
            amplitude = 1.0
            

        kdom_data = torch.empty((num_cells,3,*self.kernel_size))
        k_dom_nondom_data = torch.empty((num_cells,3,*self.kernel_size))
        kernels = dict()

        for cellcount in range(0,num_cells):
            
            sampled_cell_idx = int(torch.randint(low=min(all_kpeaks_dom_0s.keys()),high=max(all_kpeaks_dom_0s.keys()),size=(1,1)))

            Tdom = TemporalFilterCosineBump(weights=(all_weight_dom_0s[sampled_cell_idx],all_weight_dom_1s[sampled_cell_idx]), 
                                            kpeaks=(all_kpeaks_dom_0s[sampled_cell_idx],all_kpeaks_dom_1s[sampled_cell_idx]), 
                                            delays=(all_delay_dom_0s[sampled_cell_idx],all_delay_dom_1s[sampled_cell_idx]))
            

            this_sigma = all_spatial_sizes[sampled_cell_idx]
            this_sf_sep = all_sf_seps[sampled_cell_idx]
            this_angle = all_angles[sampled_cell_idx]

            Sdom = GaussianSpatialFilter(translate=(0.0, 0.0), 
                                        sigma=(this_sigma, this_sigma), 
                                        rotation=0, 
                                        origin='center')

            Kerneldom = SpatioTemporalFilter(spatial_filter = Sdom, temporal_filter = Tdom, amplitude=amplitude)
            # Kerneldom.show_temporal_filter(show=True)
            # Kerneldom.show_spatial_filter(row_range=range(0,10),col_range=range(0,10),show=True)
            
            
            kdom = Kerneldom.get_spatiotemporal_kernel(row_range=range(0,self.kernel_size[1]),col_range=range(0,self.kernel_size[2]))
            temporal_ds_rate = (kdom.full().shape[0]-2)//(self.kernel_size[0]-1)

            if cell_type != 'sONsOFF_001' and cell_type != 'sONtOFF_001':
                kdom_data[cellcount,:,:,:,:] = torch.Tensor(kdom.full())[::temporal_ds_rate,:,:].repeat([3,1,1,1]) 
                kernels['dom'] = kdom_data

            elif cell_type == 'sONsOFF_001' or cell_type == 'sONtOFF_001':
                Tnondom = TemporalFilterCosineBump(weights=(all_weight_non_dom_0s[sampled_cell_idx],all_weight_non_dom_1s[sampled_cell_idx]), 
                                            kpeaks=(all_kpeaks_non_dom_0s[sampled_cell_idx],all_kpeaks_non_dom_1s[sampled_cell_idx]), 
                                            delays=(all_delay_non_dom_0s[sampled_cell_idx],all_delay_non_dom_1s[sampled_cell_idx]))

            
                Snondom = GaussianSpatialFilter(translate=(0.0, 0.0), 
                                            sigma=(this_sigma, this_sigma), 
                                            rotation=0, 
                                            origin='center')
                
                Kernelnondom = SpatioTemporalFilter(spatial_filter = Snondom, temporal_filter = Tnondom, amplitude=amplitude_2)

                KernelOnOff = TwoSubfieldLinearCell(dominant_filter = Kerneldom, 
                                                    nondominant_filter = Kernelnondom, 
                                                    subfield_separation = this_sf_sep,
                                                    onoff_axis_angle = this_angle,
                                                    dominant_subfield_location = (0, 0),
                                                    transfer_function = MultiTransferFunction((symbolic_x, symbolic_y),'Heaviside(x)*(x)+Heaviside(y)*(y)'))
                k_dom_nondom = KernelOnOff.get_spatiotemporal_kernel(row_range=range(0,self.kernel_size[1]),col_range=range(0,self.kernel_size[2]))
                k_dom_nondom_data[cellcount,:,:,:,:] = torch.Tensor(k_dom_nondom.full())[::temporal_ds_rate,:,:].repeat([3,1,1,1]) 
                kernels['dom_nondom'] = k_dom_nondom_data

        return kernels