def get_fisher_info(activ_path, save_path, model_name, dataset_name):

    #%% get information for this model and dataset
    info = load_activations.get_info(model_name, dataset_name)

    # extract some things from info
    layers2load = info['layer_labels_full']
    nLayers = info['nLayers']

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    orilist = info['orilist']
    sflist = info['sflist']
    if info['nPhase'] == 1:
        nOri = 180
        orilist_adj = orilist
    else:
        nOri = 360
        orilist_adj = deepcopy(orilist)
        phaselist = info['phaselist']
        orilist_adj[phaselist == 1] = orilist_adj[phaselist == 1] + 180

    nSF = np.size(np.unique(sflist))

    fisher_info = np.zeros(
        [nLayers, nSF, 180,
         len(ncomp2do), np.size(delta_vals)])

    #%% loop over layers
    for ll in range(nLayers):

        #%% load all activations

        fn = os.path.join(activ_path,
                          'allStimsReducedWts_%s.npy' % layers2load[ll])
        print('loading reduced activations from %s\n' % fn)
        allw = np.load(fn)
        # allw is nIms x nFeatures

        #%% make sure no bad units that will mess up calculations
        constant_inds = np.all(np.equal(
            allw,
            np.tile(np.expand_dims(allw[0, :], axis=0),
                    [np.shape(allw)[0], 1])),
                               axis=0)
        assert (np.sum(constant_inds) == 0)
        is_inf = np.any(allw == np.inf, axis=0)
        assert (np.sum(is_inf) == 0)

        #%% get fisher information for this layer, each spatial frequency
        print(
            'calculating fisher information across all remaining units [%d x %d]...\n'
            % (np.shape(allw)[0], np.shape(allw)[1]))
        for sf in range(nSF):

            inds = np.where(sflist == sf)[0]

            for dd in range(np.size(delta_vals)):

                for nn in range(len(ncomp2do)):

                    ori_axis, fi = classifiers.get_fisher_info_cov(
                        allw[inds, 0:ncomp2do[nn]],
                        orilist_adj[inds],
                        delta=delta_vals[dd])
                    if nOri == 360:
                        fi = np.reshape(fi, [180, 2], order='F')
                        fi = np.mean(fi, axis=1)

                    if np.any(np.isnan(fi)):
                        print(
                            'warning: there are some nan elements in fisher information matrix for %s'
                            % layers2load[ll])

                    fisher_info[ll, sf, :, nn, dd] = np.squeeze(fi)

    #%% save everything into one big file
    save_name = os.path.join(save_path, 'Fisher_info_cov_vary_ncomps.npy')
    print('saving to %s\n' % save_name)
    np.save(save_name, fisher_info)
예제 #2
0
def get_fisher_info(activ_path, save_path, model_name, dataset_name,
                    num_batches):

    #%% get information for this model and dataset
    info = load_activations.get_info(model_name, dataset_name)

    # extract some things from info
    layers2load = info['layer_labels_full']
    nLayers = info['nLayers']

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    orilist = info['orilist']
    sflist = info['sflist']
    if info['nPhase'] == 2:
        nOri = 360
        orilist_adj = deepcopy(orilist)
        phaselist = info['phaselist']
        orilist_adj[phaselist == 1] = orilist_adj[phaselist == 1] + 180
    else:
        nOri = 180
        orilist_adj = orilist

    nSF = np.size(np.unique(sflist))

    print(
        'dataset %s has %d spatial freqs, orients from 0-%d deg, %d unique phases\n'
        % (dataset_name, nSF, nOri, info['nPhase']))

    fisher_info = np.zeros([nLayers, nSF, 180, np.size(delta_vals)])
    deriv2 = np.zeros([nLayers, nSF, 180, np.size(delta_vals)])
    varpooled = np.zeros([nLayers, nSF, 180, np.size(delta_vals)])

    #%% loop over layers
    for ll in range(nLayers):

        #%% load all activations
        allw = None

        for bb in np.arange(0, num_batches, 1):

            #          file = os.path.join(activ_path, 'batch' + str(0) +'_' + layers2load[ll] +'.npy')
            file = os.path.join(
                activ_path, 'batch' + str(bb) + '_' + layers2load[ll] + '.npy')
            print('loading from %s\n' % file)
            w = np.squeeze(np.load(file))
            # w will be nIms x nFeatures
            w = np.reshape(w, [np.shape(w)[0], np.prod(np.shape(w)[1:])])

            if bb == 0:
                allw = w
#              allw = w[:,0:3]
            else:
                allw = np.concatenate((allw, w), axis=0)
#              allw = np.concatenate((allw, w[:,0:3]), axis=0)

#%% remove any bad units that will mess up Fisher information calculation
# first take out all the constant units, leaving only units with variance over images
        constant_inds = np.all(np.equal(
            allw,
            np.tile(np.expand_dims(allw[0, :], axis=0),
                    [np.shape(allw)[0], 1])),
                               axis=0)
        print('removing units with no variance (%d out of %d)...\n' %
              (np.sum(constant_inds), np.size(constant_inds)))
        allw = allw[:, ~constant_inds]
        # also remove any units with infinite responses (usually none)
        is_inf = np.any(allw == np.inf, axis=0)
        print('removing units with inf response (%d out of %d)...\n' %
              (np.sum(is_inf), np.size(is_inf)))
        allw = allw[:, ~is_inf]
        if np.shape(allw)[1] > 0:

            #%% get fisher information for this layer, each spatial frequency
            print(
                'calculating fisher information across all remaining units [%d x %d]...\n'
                % (np.shape(allw)[0], np.shape(allw)[1]))
            for sf in range(nSF):

                inds = np.where(sflist == sf)[0]

                for dd in range(np.size(delta_vals)):

                    ori_axis, fi, d, v = classifiers.get_fisher_info(
                        allw[inds, :], orilist_adj[inds], delta=delta_vals[dd])
                    if nOri == 360:
                        fi = np.reshape(fi, [2, 180])
                        fi = np.mean(fi, axis=0)
                        d = np.reshape(d, [2, 180])
                        d = np.mean(d, axis=0)
                        v = np.reshape(v, [2, 180])
                        v = np.mean(v, axis=0)

                    if np.any(np.isnan(fi)):
                        print(
                            'warning: there are some nan elements in fisher information matrix for %s'
                            % layers2load[ll])
    #          assert(not np.any(np.isnan(fi)))
    #          assert(not np.any(np.isnan(d)))
    #          assert(not np.any(np.isnan(v)))
                    fisher_info[ll, sf, :, dd] = np.squeeze(fi)
                    deriv2[ll, sf, :, dd] = np.squeeze(d)
                    varpooled[ll, sf, :, dd] = np.squeeze(v)
        else:
            print('skipping %s, no units left\n' % layers2load[ll])
    #%% save everything into one big file
    save_name = os.path.join(save_path, 'Fisher_info_all_units.npy')
    print('saving to %s\n' % save_name)
    np.save(save_name, fisher_info)

    # also saving these intermediate values (numerator and denominator of FI expressions)
    save_name = os.path.join(save_path, 'Deriv_sq_all_units.npy')
    print('saving to %s\n' % save_name)
    np.save(save_name, deriv2)

    save_name = os.path.join(save_path, 'Pooled_var_all_units.npy')
    print('saving to %s\n' % save_name)
    np.save(save_name, varpooled)
예제 #3
0
def reduce_activ(raw_path,
                 reduced_path,
                 model_name,
                 dataset_name,
                 num_batches,
                 min_var_expl,
                 max_comp_keep=1000,
                 min_comp_keep=10):

    #%% get information for this model and dataset
    info = load_activations.get_info(model_name, dataset_name)

    # extract some things from info
    layers2load = info['layer_labels_full']
    nLayers = info['nLayers']

    if not os.path.exists(reduced_path):
        os.makedirs(reduced_path)

    #%% loop over layers
    for ll in range(nLayers):

        #%% first loading all activations, raw (large) format
        allw = None

        for bb in np.arange(0, num_batches, 1):

            file = os.path.join(
                raw_path, 'batch' + str(bb) + '_' + layers2load[ll] + '.npy')
            print('loading from %s\n' % file)
            w = np.squeeze(np.load(file))
            # full W here is NHWC format: Number of images x Height (top to bottom) x Width (left ro right) x Channels.
            # new w will be nIms x nFeatures
            w = np.reshape(w, [np.shape(w)[0], np.prod(np.shape(w)[1:])])

            if bb == 0:
                allw = w
            else:
                allw = np.concatenate((allw, w), axis=0)

        #%% now use PCA to reduce dimensionality

        pca = decomposition.PCA()
        print('size of allw before reducing is %d by %d' %
              (np.shape(allw)[0], np.shape(allw)[1]))
        print('\n STARTING PCA\n')
        weights_reduced = pca.fit_transform(allw)

        # decide how many components needed
        var_expl = pca.explained_variance_ratio_
        ncomp2keep = np.where(np.cumsum(var_expl) > min_var_expl / 100)

        if np.size(ncomp2keep) == 0:
            ncomp2keep = max_comp_keep
            print(
                'need all the data to capture %d percent of variance, but max components is set to %d so keeping that many only!\n'
                % (min_var_expl, max_comp_keep))
        elif ncomp2keep[0][0] > max_comp_keep:
            ncomp2keep = max_comp_keep
            print(
                'need >%d components to capture %d percent of variance, but max components is set to %d so keeping that many only!'
                % (max_comp_keep, min_var_expl, max_comp_keep))
        else:
            ncomp2keep = ncomp2keep[0][0]
            print('need %d components to capture %d percent of variance' %
                  (ncomp2keep, min_var_expl))

        if ncomp2keep < min_comp_keep:
            ncomp2keep = min_comp_keep
            print(
                'need only %d components to capture %d percent of variance, but keeping first %d!'
                % (ncomp2keep, min_var_expl, min_comp_keep))

        weights_reduced = weights_reduced[:, 0:ncomp2keep]

        print('saving %d components\n' % np.shape(weights_reduced)[1])
        #%% Save the result as a single file

        fn2save = os.path.join(
            reduced_path, 'allStimsReducedWts_' + layers2load[ll] + '.npy')
        print('saving to %s\n' % (fn2save))
        np.save(fn2save, weights_reduced)

        fn2save = os.path.join(reduced_path,
                               'allStimsVarExpl_' + layers2load[ll] + '.npy')
        print('saving to %s\n' % (fn2save))
        np.save(fn2save, var_expl)
예제 #4
0
def analyze_orient_tuning(root, model, training_str, dataset_all, nSamples,
                          param_str, ckpt_str):

    if ckpt_str == '0':
        ckpt_str_print = '00000'
    else:
        ckpt_str_print = '%s' % (np.round(int(ckpt_str), -4))

    save_path = os.path.join(root, 'code', 'unit_tuning', model, training_str,
                             param_str, dataset_all)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # get info about the model/dataset we're using here
    info = load_activations.get_info(model, dataset_all)
    nLayers = info['nLayers']
    layers2load = info['layer_labels_full']
    layer_labels = info['layer_labels']

    if debug == 1:
        nLayers = 2

    #%% loop over layers and do processing for each
    for ll in range(nLayers):

        # check if this has been done yet, skip if so
        files_done = os.listdir(save_path)
        fit_files_this_layer = [
            ff for ff in files_done
            if layer_labels[ll] in ff and 'fastpars_avgspace' in ff
        ]
        n_done = np.size(fit_files_this_layer)
        if n_done >= 1:
            continue

        #%% loop over different versions of the evaluation image set (samples)
        # load the full activation patterns
        for kk in range(nSamples):

            if kk == 0 and 'FiltIms' not in dataset_all:
                dataset = dataset_all
            elif 'FiltIms' in dataset_all:
                dataset = '%s_rand%d' % (dataset_all, kk + 1)
            else:
                if kk == 0:
                    dataset = dataset_all
                else:
                    dataset = '%s%d' % (dataset_all, kk)

            file_path = os.path.join(
                root, 'activations', model, training_str, param_str, dataset,
                'eval_at_ckpt-%s_orient_tuning' % (ckpt_str))

            file_name = os.path.join(
                file_path,
                'AllUnitsOrientTuningAvgSpace_%s.npy' % (layers2load[ll]))
            print('loading from %s\n' % file_name)

            # [nUnits x nSf x nOri]
            t = np.load(file_name)

            if kk == 0:
                nUnits = np.shape(t)[0]
                nSF = np.shape(t)[1]
                nOri = np.shape(t)[2]
                #        ori_axis = np.arange(0, nOri,1)
                all_units = np.zeros([nSamples, nUnits, nSF, nOri])

            print(
                'analyzing tuning curves for %d units, %d spatial frequencies, orientations 0-%d deg\n'
                % (nUnits, nSF, nOri))

            # [nSamples x nUnits x nSF x nOri]
            all_units[kk, :, :, :] = t

        #%% now identify the non-responsive units in this layer.
        nUnits = np.shape(all_units)[1]

        resp_units = all_units

        mean_TF = np.mean(resp_units, 1)

        # some parameters that can be estimated quickly - no fitting needed
        maxori = np.argmax(resp_units, axis=3)
        minori = np.argmin(resp_units, axis=3)
        sqslopevals = np.diff(
            resp_units, axis=3
        )**2  # finding most slopey region of each unit's orientaton tuning function
        maxsqslopeori = np.argmax(sqslopevals, axis=3)
        maxsqslopeval = np.max(sqslopevals, axis=3)
        meanresp = np.mean(resp_units, axis=3)
        maxresp = np.max(resp_units, axis=3)
        minresp = np.min(resp_units, axis=3)

        fastpars = dict()
        fastpars['maxori'] = maxori
        fastpars['minori'] = minori
        fastpars['maxsqslopeori'] = maxsqslopeori
        fastpars['maxsqslopeval'] = maxsqslopeval
        fastpars['meanresp'] = meanresp
        fastpars['maxresp'] = maxresp
        fastpars['minresp'] = minresp

        #%% Save the responsive units
        save_name = os.path.join(
            save_path, '%s_mean_resp_avgspace_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, mean_TF)

        save_name = os.path.join(
            save_path, '%s_fastpars_avgspace_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, fastpars)

        resp_units = None
        all_units = None
def analyze_orient_tuning(root, model, training_str, dataset_all, nSamples, param_str, ckpt_str, rand_seed):
  
  save_path = os.path.join(root,'code','unit_tuning',model,training_str,param_str,dataset_all)
  if not os.path.exists(save_path):
    os.makedirs(save_path)
    
  # get info about the model/dataset we're using here
  info = load_activations.get_info(model,dataset_all)
  nLayers = info['nLayers']
  layers2load = info['layer_labels_full']     
  layer_labels = info['layer_labels']
  
  if debug==1:
    nLayers=2
    
  # going to count units with zero response
  # going to get rid of the totally non responsive ones to reduce the size of this big matrix.  
  # also going to get rid of any units with no response variance (exactly same response for all stims)   
  nTotalUnits = np.zeros([nLayers,1])
  propZeroUnits = np.zeros([nLayers,1]) # how many units had zero resp for all stims?
  propConstUnits = np.zeros([nLayers,1])  # how many units had some constant, nonzero resp for all stims?
  
  #%% loop over layers and do processing for each
  for ll in range(nLayers):
    
    # check if this has been done yet, skip if so
    files_done = os.listdir(save_path)
    fit_files_this_layer = [ff for ff in files_done if layer_labels[ll] in ff and 'jitter_%d'%rand_seed in ff]
    n_done = np.size(fit_files_this_layer)
    if n_done>=3:
      continue
    
    #%% loop over different versions of the evaluation image set (samples)
    # load the full activation patterns
    for kk in range(nSamples):
    
      if kk==0 and 'FiltIms' not in dataset_all:
        dataset = dataset_all
      elif 'FiltIms' in dataset_all:
        dataset = '%s_rand%d'%(dataset_all,kk+1)
      else:
        if kk==0:
          dataset=dataset_all
        else:
          dataset = '%s%d'%(dataset_all,kk)
          
      file_path = os.path.join(root,'activations',model,training_str,param_str,dataset,
                               'eval_at_ckpt-%s_orient_tuning'%(ckpt_str))
      
      file_name = os.path.join(file_path,'AllUnitsOrientTuning_%s.npy'%(layers2load[ll]))
      print('loading from %s\n'%file_name)
    
      # [nUnits x nSf x nOri]    
      t = np.load(file_name)
            
      if kk==0:
        nUnits = np.shape(t)[0]
        nSF = np.shape(t)[1]
        nOri = np.shape(t)[2]
        ori_axis = np.arange(0, nOri,1) 
        all_units = np.zeros([nSamples, nUnits, nSF, nOri])
      
      print('analyzing tuning curves for %d spatial frequencies, orientations 0-%d deg\n'%(nSF,nOri))
      
      # [nSamples x nUnits x nSF x nOri]
      all_units[kk,:,:,:] = t
    
      #%% For each unit - label it according to its spatial position and channel.
      # full W matrix is in NHWC format: Number of images x Height (top to bottom) x Width (left ro right) x Channels.
      # First dim of t matrix is currently [H x W x C]  (reshaped from W)
      H = info['activ_dims'][ll]
      W = info['activ_dims'][ll]
      C = int(np.shape(t)[0]/H/W)        
      clabs = np.tile(np.expand_dims(np.arange(0,C),axis=1),[H*W,1])
      wlabs = np.expand_dims(np.repeat(np.tile(np.expand_dims(np.arange(0,W),axis=1),[H,1]), C),axis=1)
      hlabs = np.expand_dims(np.repeat(np.arange(0,H),W*C),axis=1)
      # the coords matrix goes [nUnits x 3] where columns are [H,W,C]
      coords = np.concatenate((hlabs,wlabs,clabs),axis=1)      
      assert np.array_equal(coords, np.unique(coords, axis=0))
  
    #%% now identify the non-responsive units in this layer.
    nUnits = np.shape(all_units)[1]
    is_zero = np.zeros([nUnits,nSamples,nSF])
    is_constant_nonzero = np.zeros([nUnits, nSamples, nSF])
    
    for kk in range(nSamples):
      print('identifying nonresponsive units in %s, sample %d'%(layer_labels[ll],kk))
      for sf in range(nSF):
        # take out data, [nUnits x nOri]        
        vals = all_units[kk,:,sf,:]
        
        # find units where signal is zero for all ori
        # add these to a running list of which units were zero for any sample and spatial frequency.
        is_zero[:,kk,sf] = np.all(vals==0,axis=1)
        
        # find units where signal is constant for all images (no variance)
        constval = all_units[0,:,0,0]
        const = np.all(np.equal(vals, np.tile(np.expand_dims(constval, axis=1), [1,nOri])),axis=1)
        # don't count the zero units here so we can see how many of each...
        is_constant_nonzero[:,kk,sf] = np.logical_and(const, ~np.all(vals==0,axis=1))
        
    is_zero_any = np.any(np.any(is_zero,axis=2),axis=1)  
    propZeroUnits[ll] = np.sum(is_zero_any==True)/nUnits
    
    is_constant_any = np.any(np.any(is_constant_nonzero,axis=2),axis=1)
    propConstUnits[ll] = np.sum(is_constant_any==True)/nUnits

    nTotalUnits[ll] = nUnits
    
    # make a new matrix with only the good units in it.
    units2use = np.logical_and(~is_zero_any, ~is_constant_any)
    resp_units = all_units[:,units2use,:,:]
    
    # record the spatial position and channel corresponding to each unit.
    coords_good = coords[units2use,:]
    
    #%% Save the responsive units     
    save_name =os.path.join(save_path,'%s_all_responsive_units_eval_at_ckpt_%d.npy'%(layer_labels[ll],np.round(int(ckpt_str),-4)))
    print('saving to %s\n'%save_name)
    np.save(save_name,resp_units)
  
    save_name =os.path.join(save_path,'%s_coordsHWC_all_responsive_units_eval_at_ckpt_%d.npy'%(layer_labels[ll],np.round(int(ckpt_str),-4)))
    print('saving to %s\n'%save_name)
    np.save(save_name,coords_good)
    
    #%% Fit von mises functions to each unit, save parameters
    
    # random seed for generating random x-axis jitter
    np.random.seed(rand_seed+ll)  
    jitter_by = np.random.randint(0,nOri,size=[nSF,nUnits])
    
    nPars = 4;  #center,k,amplitude,baseline
    nUnits = np.shape(resp_units)[1]  # note that nUnits is smaller now than it was above, counting only responsive ones.
    # record r2 and parameters of the final fit
    r2 = np.zeros((nUnits,nSF))
    fit_pars = np.zeros((nUnits,nSF,nPars+1)) # last element here is FWHM
    # as a check for tuning consistency - calculate r2 for the fit to each individual sample.
    r2_each_sample = np.zeros((nUnits, nSF, nSamples))
    
    for sf in range(nSF):
      
      # get the data for this spatial frequency and this unit
      
      # dat is nSamples x nUnits x nOri
      dat = resp_units[:,:,sf,:]
      # meandat is nUnits x nOri
      meandat = np.mean(dat,axis=0)
          
      if debug==0:
        units2do = np.arange(0,nUnits)
      else:
        units2do = np.arange(0,100)
        
      
      # loop over units and fit each one
      for uu in units2do:

        print('fitting %s, sf %d, unit %d \ %d, jitter by %d deg'%(layer_labels[ll],sf,uu, nUnits,jitter_by[sf,uu]))
    
        real_y = meandat[uu,:]
        # circularly shift the function over before fitting. Make sure that we don't end up with a bunch of units tuned to edges of orient space by accident.
        real_y = np.roll(real_y, jitter_by[sf,uu])
        # estimate amplitude and baseline from the max and min 
        init_b = np.min(real_y)
        init_a = np.max(real_y)-init_b
        # initialize the mu parameter with the max of the curve
        init_mu = ori_axis[np.argmax(real_y)]
        init_k = 1
        params_start = (init_mu,init_k,init_a,init_b)
       
        # do the fitting with scipy curve fitting toolbox
        try:
          # constrain the center to 0-179, constrain k to positive values that aren't too small, let a and b vary freely.
          params, params_covar = scipy.optimize.curve_fit(von_mises_deg, ori_axis, real_y, p0=params_start, bounds=((-0.0001,10**(-15),-np.inf,-np.inf),(nOri+0.0001,np.inf,np.inf,np.inf)))
        except:
          print('fitting failed for unit %d'%uu)
          r2[uu,sf] = np.nan
          fit_pars[uu,sf,:] = np.nan
          continue
        
       
        pred_y = von_mises_deg(ori_axis,params[0],params[1],params[2],params[3])
              
        # get r2 for this fit
        r2[uu,sf] = get_r2(real_y, pred_y)
        
        # get r2 for how well this fit captures individual samples
        for ss in range(nSamples):
          real_y_this_sample = dat[ss,uu,:]
          # circularly shift the function over to match the mean tuning function
          real_y_this_sample = np.roll(real_y_this_sample, jitter_by[sf,uu])
          r2_each_sample[uu,sf,ss] = get_r2(real_y_this_sample, pred_y)
          
        # record the final fit parameters.
        # finally, re-adjust the center to where it would have been before shifting over.
        params[0] = np.mod(params[0]-jitter_by[sf,uu],nOri)
        fit_pars[uu,sf,0:nPars] = params
        
        # also calculate fwhm for the von mises function - more interpretable than k.        
        if params[2]<0:
          # if the von mises had a negative amplitude, we actually want the width of the negative peak.
          fwhm = get_fwhm((-1)*pred_y,ori_axis)
        else:
          fwhm = get_fwhm(pred_y,ori_axis)
        fit_pars[uu,sf,nPars] = fwhm

    #%% save the results of the fitting, all units this layer
    save_name =os.path.join(save_path,'%s_fit_jitter_%d_r2_eval_at_ckpt_%d.npy'%(layer_labels[ll], rand_seed, np.round(int(ckpt_str),-4)))
    print('saving to %s\n'%save_name)
    np.save(save_name,r2)
    save_name =os.path.join(save_path,'%s_fit_jitter_%d_r2_each_sample_eval_at_ckpt_%d.npy'%(layer_labels[ll], rand_seed, np.round(int(ckpt_str),-4)))
    print('saving to %s\n'%save_name)
    np.save(save_name,r2_each_sample)
    save_name =os.path.join(save_path,'%s_fit_jitter_%d_pars_eval_at_ckpt_%d.npy'%(layer_labels[ll], rand_seed, np.round(int(ckpt_str),-4)))
    print('saving to %s\n'%save_name)
    np.save(save_name,fit_pars)


  #%% save the proportion of non-responsive units in each layer
 
  save_name =os.path.join(save_path,'PropZeroUnits_eval_at_ckpt_%d.npy'%(np.round(int(ckpt_str),-4)))
  print('saving to %s\n'%save_name)
  np.save(save_name,propZeroUnits)
  save_name =os.path.join(save_path,'PropConstUnits_eval_at_ckpt_%d.npy'%(np.round(int(ckpt_str),-4)))
  print('saving to %s\n'%save_name)
  np.save(save_name,propConstUnits)
  save_name =os.path.join(save_path,'TotalUnits_eval_at_ckpt_%d.npy'%(np.round(int(ckpt_str),-4)))
  print('saving to %s\n'%save_name)
  np.save(save_name,nTotalUnits)
def analyze_orient_tuning(root, model, training_str, dataset_all, nSamples,
                          param_str, ckpt_str):

    if ckpt_str == '0':
        ckpt_str_print = '00000'
    else:
        ckpt_str_print = '%s' % (np.round(int(ckpt_str), -4))

    save_path = os.path.join(root, 'code', 'unit_tuning', model, training_str,
                             param_str, dataset_all)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # get info about the model/dataset we're using here
    info = load_activations.get_info(model, dataset_all)
    nLayers = info['nLayers']
    layers2load = info['layer_labels_full']
    layer_labels = info['layer_labels']

    if debug == 1:
        nLayers = 2

    # going to count units with zero response
    # going to get rid of the totally non responsive ones to reduce the size of this big matrix.
    # also going to get rid of any units with no response variance (exactly same response for all stims)
    nTotalUnits = np.zeros([nLayers, 1])
    propZeroUnits = np.zeros([nLayers, 1
                              ])  # how many units had zero resp for all stims?
    propConstUnits = np.zeros(
        [nLayers,
         1])  # how many units had some constant, nonzero resp for all stims?

    #%% loop over layers and do processing for each
    for ll in range(nLayers):

        # check if this has been done yet, skip if so
        files_done = os.listdir(save_path)
        fit_files_this_layer = [
            ff for ff in files_done
            if layer_labels[ll] in ff and 'fastpars_eval' in ff
        ]
        n_done = np.size(fit_files_this_layer)
        if n_done >= 1:
            continue

        #%% loop over different versions of the evaluation image set (samples)
        # load the full activation patterns
        for kk in range(nSamples):

            if kk == 0 and 'FiltIms' not in dataset_all:
                dataset = dataset_all
            elif 'FiltIms' in dataset_all:
                dataset = '%s_rand%d' % (dataset_all, kk + 1)
            else:
                if kk == 0:
                    dataset = dataset_all
                else:
                    dataset = '%s%d' % (dataset_all, kk)

            file_path = os.path.join(
                root, 'activations', model, training_str, param_str, dataset,
                'eval_at_ckpt-%s_orient_tuning' % (ckpt_str))

            file_name = os.path.join(
                file_path, 'AllUnitsOrientTuning_%s.npy' % (layers2load[ll]))
            print('loading from %s\n' % file_name)

            # [nUnits x nSf x nOri]
            t = np.load(file_name)

            if kk == 0:
                nUnits = np.shape(t)[0]
                nSF = np.shape(t)[1]
                nOri = np.shape(t)[2]
                #        ori_axis = np.arange(0, nOri,1)
                all_units = np.zeros([nSamples, nUnits, nSF, nOri])

            print(
                'analyzing tuning curves for %d spatial frequencies, orientations 0-%d deg\n'
                % (nSF, nOri))

            # [nSamples x nUnits x nSF x nOri]
            all_units[kk, :, :, :] = t

            #%% For each unit - label it according to its spatial position and channel.
            # full W matrix is in NHWC format: Number of images x Height (top to bottom) x Width (left ro right) x Channels.
            # First dim of t matrix is currently [H x W x C]  (reshaped from W)
            H = info['activ_dims'][ll]
            W = info['activ_dims'][ll]
            C = int(np.shape(t)[0] / H / W)
            clabs = np.tile(np.expand_dims(np.arange(0, C), axis=1),
                            [H * W, 1])
            wlabs = np.expand_dims(np.repeat(
                np.tile(np.expand_dims(np.arange(0, W), axis=1), [H, 1]), C),
                                   axis=1)
            hlabs = np.expand_dims(np.repeat(np.arange(0, H), W * C), axis=1)
            # the coords matrix goes [nUnits x 3] where columns are [H,W,C]
            coords = np.concatenate((hlabs, wlabs, clabs), axis=1)
            assert np.array_equal(coords, np.unique(coords, axis=0))

        #%% now identify the non-responsive units in this layer.
        nUnits = np.shape(all_units)[1]
        is_zero = np.zeros([nUnits, nSamples, nSF])
        is_constant_nonzero = np.zeros([nUnits, nSamples, nSF])

        for kk in range(nSamples):
            print('identifying nonresponsive units in %s, sample %d' %
                  (layer_labels[ll], kk))
            for sf in range(nSF):
                # take out data, [nUnits x nOri]
                vals = all_units[kk, :, sf, :]

                # find units where signal is zero for all ori
                # add these to a running list of which units were zero for any sample and spatial frequency.
                is_zero[:, kk, sf] = np.all(vals == 0, axis=1)

                # find units where signal is constant for all images (no variance)
                constval = all_units[0, :, 0, 0]
                const = np.all(np.equal(
                    vals, np.tile(np.expand_dims(constval, axis=1),
                                  [1, nOri])),
                               axis=1)
                # don't count the zero units here so we can see how many of each...
                is_constant_nonzero[:, kk, sf] = np.logical_and(
                    const, ~np.all(vals == 0, axis=1))

        is_zero_any = np.any(np.any(is_zero, axis=2), axis=1)
        propZeroUnits[ll] = np.sum(is_zero_any == True) / nUnits

        is_constant_any = np.any(np.any(is_constant_nonzero, axis=2), axis=1)
        propConstUnits[ll] = np.sum(is_constant_any == True) / nUnits

        nTotalUnits[ll] = nUnits

        # make a new matrix with only the good units in it.
        units2use = np.logical_and(~is_zero_any, ~is_constant_any)
        resp_units = all_units[:, units2use, :, :]

        # record the spatial position and channel corresponding to each unit.
        coords_good = coords[units2use, :]

        mean_TF = np.mean(resp_units, 1)

        # some parameters that can be estimated quickly - no fitting needed
        maxori = np.argmax(resp_units, axis=3)
        minori = np.argmin(resp_units, axis=3)
        sqslopevals = np.diff(
            resp_units, axis=3
        )**2  # finding most slopey region of each unit's orientaton tuning function
        maxsqslopeori = np.argmax(sqslopevals, axis=3)
        maxsqslopeval = np.max(sqslopevals, axis=3)
        meanresp = np.mean(resp_units, axis=3)
        maxresp = np.max(resp_units, axis=3)
        minresp = np.min(resp_units, axis=3)

        fastpars = dict()
        fastpars['maxori'] = maxori
        fastpars['minori'] = minori
        fastpars['maxsqslopeori'] = maxsqslopeori
        fastpars['maxsqslopeval'] = maxsqslopeval
        fastpars['meanresp'] = meanresp
        fastpars['maxresp'] = maxresp
        fastpars['minresp'] = minresp

        #%% Save the responsive units
        save_name = os.path.join(
            save_path, '%s_all_responsive_units_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, resp_units)

        save_name = os.path.join(
            save_path, '%s_mean_resp_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, mean_TF)

        save_name = os.path.join(
            save_path,
            '%s_coordsHWC_all_responsive_units_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, coords_good)

        save_name = os.path.join(
            save_path, '%s_fastpars_eval_at_ckpt_%s.npy' %
            (layer_labels[ll], ckpt_str_print))
        print('saving to %s\n' % save_name)
        np.save(save_name, fastpars)

        resp_units = None
        coords_good = None
        is_zero = None
        is_constant_nonzero = None
        all_units = None

    #%% save the proportion of non-responsive units in each layer

    save_name = os.path.join(
        save_path, 'PropZeroUnits_eval_at_ckpt_%s0000.npy' % (ckpt_str[0:2]))
    print('saving to %s\n' % save_name)
    np.save(save_name, propZeroUnits)
    save_name = os.path.join(
        save_path, 'PropConstUnits_eval_at_ckpt_%s0000.npy' % (ckpt_str[0:2]))
    print('saving to %s\n' % save_name)
    np.save(save_name, propConstUnits)
    save_name = os.path.join(
        save_path, 'TotalUnits_eval_at_ckpt_%s0000.npy' % (ckpt_str[0:2]))
    print('saving to %s\n' % save_name)
    np.save(save_name, nTotalUnits)