Exemple #1
0
def get_gabor_filters(params):
    """ Return a Gabor filterbank (generate it if needed)

    Inputs:
    params -- filters parameters (dict)

    Outputs:
    filt_l -- filterbank (list)

    """

    global filt_l

    if filt_l is not None:
        return filt_l

    # -- get parameters
    fh, fw = params['kshape']
    orients = params['orients']
    freqs = params['freqs']
    phases = params['phases']
    nf =  len(orients) * len(freqs) * len(phases)
    fbshape = nf, fh, fw
    xc = fw/2
    yc = fh/2
    filt_l = []

    # -- build the filterbank
    for freq in freqs:
        for orient in orients:
            for phase in phases:
                # create 2d gabor
                filt = gabor2d(xc,yc,xc,yc,
                               freq,orient,phase,
                               (fw,fh))
                filt_l += [filt]

    return filt_l
def get_filterbank(config):

    model_config = config
    config = config['filter']
    model_name = config['model_name']
    fh, fw = config.get('kshape',config.get('ker_shape'))
    
    if model_name == 'really_random':
        num_filters = config['num_filters']
        filterbank = get_random_filterbank((num_filters,fh,fw),normalization=config.get('normalize',True))

    elif model_name == 'random_gabor':
        num_filters = config['num_filters']
        xc = fw/2
        yc = fh/2
        filterbank = np.empty((num_filters,fh,fw))
        orients = []
        freqs = []
        phases = []
        df = config.get('divfreq')
        for i in range(num_filters):
            orient = config.get('orient',2*np.pi*np.random.random())
            orients.append(orient)
            if not df:
                freq = 1./np.random.randint(config['min_wavelength'],high = config['max_wavelength'])
            else:
                freq = 1./df
            freqs.append(freq)
            phase = config.get('phase',2*np.pi*np.random.random())
            phases.append(phase)
            
            filterbank[i,:,:] = v1m.gabor2d(xc,yc,xc,yc,
                               freq,orient,phase,
                               (fw,fh))   
        
        #return SON([('filterbank',filterbank),('orients',orients),('phases',phases),('freqs',freqs)])
        return filterbank
                               
    elif model_name == 'gridded_gabor':
        norients = config['norients']
        orients = [ o*np.pi/norients for o in xrange(norients) ]
        divfreqs = config['divfreqs']
        freqs = [1./d for d in divfreqs]
        phases = config['phases']       
        xc = fw/2
        yc = fh/2
        values = list(itertools.product(freqs,orients,phases))
        num_filters = len(values)
        filterbank = np.empty((num_filters,fh,fw))
        for (i,(freq,orient,phase)) in enumerate(values):
            filterbank[i,:,:] = v1m.gabor2d(xc,yc,xc,yc,
                               freq,orient,phase,
                               (fw,fh)) 
                               
    elif model_name == 'pixels':
        return np.ones((1,fh,fw))

    elif model_name == 'specific_gabor':
        orients = config['orients']
        divfreqs = config['divfreqs']
        phases = config['phases']
        xc = fw/2
        yc = fh/2
        freqs = [1./d for d in divfreqs]
        values = zip(freqs,orients,phases)
        num_filters = len(values)
        filterbank = np.empty((num_filters,fh,fw))
        for (i,(freq,orient,phase)) in enumerate(values):
            filterbank[i,:,:] = v1m.gabor2d(xc,yc,xc,yc,
                               freq,orient,phase,
                               (fw,fh)) 
        
    
    elif model_name == 'cairo_generated':
        specs = config.get('specs')
        if not specs:
            specs = [spec['image'] for spec in rendering.cairo_config_gen(config['spec_gen'])]
        filterbank = np.empty((len(specs),fh,fw))
        for (i,spec) in enumerate(specs):
            im_fh = rendering.cairo_render(spec,returnfh=True)
            arr = processing.image2array({'color_space':'rgb'},im_fh).astype(np.int32)
            arr = arr[:,:,0] - arr[:,:,1]
            arrx0 = arr.shape[0]/2
            arry0 = arr.shape[1]/2
            dh = fh/2; dw = fw/2
            filterbank[i,:,:] = normalize(arr[arrx0-dh:arrx0+(fh - dh),arry0-dw:arry0+(fw-dw)])
    
    elif model_name == 'center_surround':
        
        if config.get('orth',True):
            return center_surround_orth(model_config)
        else:
            return center_surround(model_config)

    filterbank = np.cast[np.float32](filterbank) 
    return filterbank