Beispiel #1
0
def stimulate(sr_graph, sess, file_name, dimx, dimy):

    piece = '2015-11-09-3.mat'

    # saving filename.
    save_analysis_filename = os.path.join(FLAGS.save_folder,
                                          file_name + '_prosthesis.pkl')
    save_dict = {}

    # load dictionary.
    dict_dir = '/home/bhaishahster/Downloads/dictionary'
    dict_src = os.path.join(dict_dir, piece)
    # _, dictionary, cellID_list, EA, elec_loc = load_single_elec_stim_data(gfile.Open(dict_src, 'r'))
    _, dictionary, cellID_list, EA, elec_loc = load_single_elec_stim_data(
        dict_src)
    dictionary = dictionary.T

    # Load cell properties
    cell_data_dir = '/home/bhaishahster/Downloads/rgb-8-1-0.48-11111'
    cell_file = os.path.join(cell_data_dir, piece)
    data_cell = sio.loadmat(gfile.Open(cell_file, 'r'))
    data_util.process_dataset(data_cell, dimx=80, dimy=40, num_cell_types=2)

    # Load stimulus
    data = h5py.File(os.path.join(cell_data_dir, 'stimulus.mat'))
    stimulus = np.array(data.get('stimulus')) - 0.5

    # Generate targets

    # random 100 samples
    t_len = 100
    stim_history = 30
    stim_batch = np.zeros(
        (t_len, stimulus.shape[1], stimulus.shape[2], stim_history))
    for isample, itime in enumerate(
            np.random.randint(0, stimulus.shape[0], t_len)):
        stim_batch[isample, :, :, :] = np.transpose(
            stimulus[itime:itime - stim_history:-1, :, :], [1, 2, 0])

    from IPython import embed
    embed()

    # Use regression to decide dictionary elements
    regress_dictionary(sr_graph, stim_batch, dictionary, 10, dimx, dimy,
                       data_cell)

    # Select stimulation pattern
    dict_sel_np_logit, r_s, dictionary, d_log = get_optimal_stimulation(
        stim_batch, sr_graph, dictionary, data_cell, sess)

    save_dict.update({
        'dict_sel': dict_sel_np_logit,
        'resp_sample': r_s,
        'dictionary': dictionary,
        'd_log': d_log
    })

    pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
def compute_sta(cell_id, response, cid_idx, delay=4):

    # compute STA for one cell - load a chunk, compute STA, repeat
    sta_frame = np.zeros((640, 320, 3))
    for ichunk in np.arange(1000) + 1:
        print('Loading chunk: %d' % ichunk)
        # get a stimulus chunk
        stim_path = FLAGS.data_location + 'Stimulus/'
        stim_file = sio.loadmat(
            gfile.Open(stim_path + 'stim_chunk_' + str(ichunk) + '.mat'))
        chunk_start = np.squeeze(stim_file['chunk_start'])
        chunk_end = np.squeeze(stim_file['chunk_end'])
        jump = stim_file['jump']
        stim_chunk = stim_file['stimulus_chunk']
        stim_chunk = np.transpose(stim_chunk, [3, 0, 1, 2])
        print(np.shape(stim_chunk))

        print(chunk_start, chunk_end)
        # find non-zero time points and
        for itime in np.arange(29, chunk_end - chunk_start + 1):
            if response[chunk_start + itime - 1, cid_idx] > 0:
                sta_frame = sta_frame + np.squeeze(
                    stim_chunk[itime - delay, :, :, :])

    # compute STA
    sta_frame = sta_frame / np.sum(response[:, cid_idx])
    import pdb
    pdb.set_trace()

    # plot STA
    plt.imshow(sta_frame[:, :, 2])
    plt.show()
    plt.draw()
Beispiel #3
0
def get_partitions(partition_file, taskid):
  """Reads partition file and return training/testing paritions for the taskid.

  The partition file has each task in differnet rows, with each row format -
  taskid: training dataset list: testing dataset list
  Args:
    partition_file : File containing all the partitions.
    taskid : Index of partition to load.

  Returns:
    training_datasets : dataset ids to train on.
    testing_datasets: dataset ids to test on.
  """
  # Load partitions
  with gfile.Open(partition_file, 'r') as f:
    content = f.readlines()

  for iline in content:
    tokens = iline.split(':')
    tokens = [i.replace(' ', '') for i in tokens]
    if taskid == int(tokens[0]):
      training_datasets = tokens[1].split(',')
      testing_datasets = tokens[2].split(',')

      training_datasets = [int(i) for i in training_datasets]
      testing_datasets = [int(i) for i in testing_datasets]
      break

  return training_datasets, testing_datasets
Beispiel #4
0
def setup_dataset():
    # initialize paths, get dataset properties, etc
    path = FLAGS.data_location

    # load cell response
    response_path = path + 'response.mat'
    response_file = sio.loadmat(gfile.Open(response_path))
    resp_mat = response_file['binned_spikes']
    resp_file_cids = np.squeeze(response_file['cell_ids'])

    # load off parasol cell IDs
    cids_path = path + 'cell_ids/cell_ids_OFF parasol.mat'
    cids_file = sio.loadmat(gfile.Open(cids_path))
    cids_select = np.squeeze(cids_file['cids'])

    # find index of cells to choose from resp_mat
    resp_file_choose_idx = np.array([])
    for icell in np.array(cids_select):
        idx = np.where(resp_file_cids == icell)
        resp_file_choose_idx = np.append(resp_file_choose_idx, idx[0])

    # finally, get selected cells from resp_mat
    global response
    response = resp_mat[resp_file_choose_idx.astype('int'), :].T
    print(cids_select)
    print(resp_file_choose_idx.astype('int'))

    # load population time courses
    time_c_file_path = path + 'cell_ids/time_courses.mat'
    time_c_file = sio.loadmat(gfile.Open(time_c_file_path))
    tc_mat = time_c_file['time_courses']
    tm_cids = np.squeeze(time_c_file['cids'])

    # find average time course of cells of interest
    tc_file_choose_idx = np.array([])
    for icell in np.array(cids_select):
        idx = np.where(tm_cids == icell)
        tc_file_choose_idx = np.append(tc_file_choose_idx, idx[0])
    tc_select = tc_mat[tc_file_choose_idx.astype('int'), :, :]
    tc_mean = np.squeeze(np.mean(tc_select, axis=0))
    n_cells = cids_select.shape[0]
    FLAGS.n_cells = n_cells

    #  'response', cell ids are 'cids_select' with 'n_cells' cells, 'tc_select' are timecourses, 'tc_mean' for mean time course
    return response, cids_select, n_cells, tc_select, tc_mean
Beispiel #5
0
 def get_stimulus_batch(ichunk):
   stim_path = FLAGS.data_location + 'Stimulus/'
   stim_file = sio.loadmat(gfile.Open(stim_path+'stim_chunk_' + str(ichunk) + '.mat'))
   chunk_start = np.squeeze(stim_file['chunk_start'])
   chunk_end = np.squeeze(stim_file['chunk_end'])
   jump = stim_file['jump']
   stim_chunk = stim_file['stimulus_chunk']
   stim_chunk = np.transpose(stim_chunk, [3,0,1,2])
   return stim_chunk, chunk_start, chunk_end
def get_test_data():
    # the last chunk of data is test data
    test_data_chunks = [FLAGS.n_chunks]
    for ichunk in test_data_chunks:
        filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat'
        file_r = gfile.Open(filename, 'r')
        data = sio.loadmat(file_r)
        stim_part = data['maskedMovdd_part'].T
        resp_part = data['Y_part'].T
        test_len = stim_part.shape[0]
    stim_part = stim_part[:, chosen_mask]
    resp_part = resp_part[:, cells_choose]
    return stim_part, resp_part, test_len
def get_su_spks(subunit_fit_path, stimulus_test, resp_ret):
    '''Predict spikes for each cell and each number of subunits. '''

    # Predict responses to 'stimulus_test'
    n_valid_cells = resp_ret['valid_cells'].sum()
    resp_su = np.zeros((10, stimulus_test.shape[0], n_valid_cells))
    cell_ids = resp_ret['cellID_list'].squeeze()

    # time x cells x subunits
    for Nsub in range(1, 11):
        for icell in range(n_valid_cells):

            # Get subunits.
            fit_file = os.path.join(
                subunit_fit_path,
                'Cell_%d_Nsub_%d.pkl' % (cell_ids[icell], Nsub))
            try:
                su_fit = pickle.load(gfile.Open(fit_file, 'r'))
            except:
                print('Cell %d not loaded ' % cell_ids[icell])
                continue

            print(Nsub, '%d' % cell_ids[icell])

            # Get window to extract stimulus around RF.
            windx = su_fit['windx']
            windy = su_fit['windy']
            stim_cell = np.reshape(
                stimulus_test[:, windx[0]:windx[1], windy[0]:windy[1]],
                [stimulus_test.shape[0], -1])

            # Filter in time.
            ttf = su_fit['ttf']
            stim_filter = np.zeros_like(stim_cell)
            for idelay in range(30):
                length = stim_filter[idelay:, :].shape[0]
                stim_filter[idelay:, :] += stim_cell[:length, :] * ttf[idelay]

            stim_filter -= np.mean(stim_filter)
            stim_filter /= np.sqrt(np.var(stim_filter))

            # Compute firing rate.
            K = su_fit['K']
            b = su_fit['b']
            firing_rate = np.exp(stim_filter.dot(K) + b[:, 0]).sum(-1)

            # Sample spikes.
            resp_su[Nsub - 1, :, icell] = np.random.poisson(firing_rate)

    return resp_su
def get_test_data():
    # stimulus.astype('float32')[216000-1000: 216000-1, :]
    # response.astype('float32')[216000-1000: 216000-1, :]
    # length
    test_data_chunks = [FLAGS.n_chunks]
    for ichunk in test_data_chunks:
        filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat'
        file_r = gfile.Open(filename, 'r')
        data = sio.loadmat(file_r)
        stim_part = data['maskedMovdd_part'].T
        resp_part = data['Y_part'].T
        test_len = stim_part.shape[0]
    #logfile.write('\nReturning test data')
    stim_part = stim_part[:, chosen_mask]
    resp_part = resp_part[:, cells_choose]
    return stim_part, resp_part, test_len
Beispiel #9
0
def get_data_retina(piece='2005-04-06-4'):
  """Load data for 1 piece."""

  # Load data
  file_d = h5py.File('/home/bhaishahster/'
                   'Downloads/%s.mat' % piece, 'r')

  stimulus = np.array(file_d.get('stimulus'))
  responses = np.array(file_d.get('response'))
  ei = np.array(file_d.get('ei'))
  cell_type = np.array(file_d.get('cell_type'))

  elec_loc_file = gfile.Open('/home/bhaishahster/'
                             'Downloads/Elec_loc512.mat', 'r')
  data = sio.loadmat(elec_loc_file)
  elec_loc = np.squeeze(np.array([data['elec_locx'], data['elec_locy']])).T

  #########################################################################
  # Process
  stimulus = np.mean(stimulus, 1) - 0.5
  stimulus = stimulus[:-1, :, :]
  stimx = stimulus.shape[1]
  stimy = stimulus.shape[2]
  ei_magnitude = np.sqrt(np.sum(ei**2, 0))
  rfs = np.reshape(stimulus[:-4, :],
                   [-1, stimulus.shape[1]*stimulus.shape[2]]).T.dot(responses[4:, :])
  rfs = np.reshape(rfs, [stimulus.shape[1], stimulus.shape[2], -1])
  rfs = clean_rfs(rfs, nbd=1)

  ei_embedding_matrix = embed_ei_grid(elec_loc, smoothness_sigma=15)
  eix, eiy, _ = ei_embedding_matrix.shape
  n_elec = 512

  # compile data
  data = {'stimulus': stimulus,
          'responses': responses,
          'ei_magnitude': ei_magnitude,
          'rfs': rfs,
          'elec_loc': elec_loc,
          'stimx': stimx,
          'stimy': stimy,
          'eix': eix,
          'eiy': eiy,
          'ei_embedding_matrix': ei_embedding_matrix, 'n_elec': n_elec,
          'cell_type': cell_type}

  return data
def get_null_stimulus(resp_ret, subunit_fit_path, stimulus_test):
    """Project the sitmulus into null space."""

    A = []

    # Collect RF for all cells
    n_valid_cells = resp_ret['valid_cells'].sum()
    cell_ids = resp_ret['cellID_list'].squeeze()

    # time x cells x subunits
    Nsub = 1
    for icell in range(n_valid_cells):

        # Get subunits.
        fit_file = os.path.join(
            subunit_fit_path, 'Cell_%d_Nsub_%d.pkl' % (cell_ids[icell], Nsub))
        try:
            su_fit = pickle.load(gfile.Open(fit_file, 'r'))
        except:
            print('Cell %d not loaded ' % cell_ids[icell])
            continue

        print(Nsub, '%d' % cell_ids[icell])

        # Get window to extract stimulus around RF.
        windx = su_fit['windx']
        windy = su_fit['windy']

        sta = np.zeros((stimulus_test.shape[1], stimulus_test.shape[2]))
        sta[windx[0]:windx[1], windy[0]:windy[1]] = np.reshape(
            su_fit['K'].squeeze(), (windx[1] - windx[0], windy[1] - windy[0]))
        A += [sta]

    A = np.array(A)
    A_2d = np.reshape(A, (A.shape[0], -1))
    stim_test_2d = np.reshape(stimulus_test, (stimulus_test.shape[0], -1))

    stimulus_test_null = stim_test_2d.T - A_2d.T.dot(
        np.linalg.solve(A_2d.dot(A_2d.T), A_2d.dot(stim_test_2d.T)))
    stimulus_test_null = stimulus_test_null.T

    stimulus_test_null = np.reshape(
        stimulus_test_null,
        [-1, stimulus_test.shape[1], stimulus_test.shape[2]])
    return stimulus_test_null
Beispiel #11
0
def main(unused_argv=()):

  # Load stimulus-response data
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  responses = []
  print(datasets)
  for icnt, idataset in enumerate([datasets]): #TODO(bhaishahster): HACK.

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key)
      stimulus, resp, dimx, dimy, num_cell_types = op

      responses += resp

    for idataset in range(len(responses)):
      k, b, ttf = fit_ln_population(responses[idataset]['responses'], stimulus)  # Use FLAGS.taskID
      save_dict = {'k': k, 'b': b, 'ttf': ttf}

      save_analysis_filename = os.path.join(FLAGS.save_folder,
                                            responses[idataset]['piece']
                                            + '_ln_model.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
Beispiel #12
0
def get_test_data():
    # stimulus.astype('float32')[216000-1000: 216000-1, :]
    # response.astype('float32')[216000-1000: 216000-1, :]
    # length
    global chosen_mask
    global cells_choose
    stim_part = np.array([]).reshape(0, np.sum(chosen_mask))
    resp_part = np.array([]).reshape(0, np.sum(cells_choose))

    test_data_chunks = np.arange(FLAGS.n_chunks - 20, FLAGS.n_chunks + 1)
    for ichunk in test_data_chunks:
        filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat'
        file_r = gfile.Open(filename, 'r')
        data = sio.loadmat(file_r)

        s = data['maskedMovdd_part'].T
        r = data['Y_part'].T
        print(np.shape(s))
        print(np.shape(stim_part))
        stim_part = np.append(stim_part, s[:, chosen_mask], axis=0)
        resp_part = np.append(resp_part, r[:, cells_choose], axis=0)

        test_len = stim_part.shape[0]
    return stim_part, resp_part, test_len
Beispiel #13
0
def main(argv):
    print('\nCode started')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = tot_spks[cells_choose]
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    #short_filename = 'data_model=ASM_pop_bg'
    #  short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) +
    #    '_step_sz'+ str(FLAGS.step_sz)+'_bg')

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'logistic':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'poisson_full':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        # variables
        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.1 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))
        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.random.randn(
                n_cells
            )  #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0)))
            b = tf.Variable(b_init, dtype='float32')

    # get relevant files
        file_list = gfile.ListDirectory(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename
        print('\nLoading: ', save_filename)
        bin_files = []
        meta_files = []
        for file_n in file_list:
            if re.search(short_filename + '.', file_n):
                if re.search('.meta', file_n):
                    meta_files += [file_n]
                else:
                    bin_files += [file_n]
        #print(bin_files)
        print(len(meta_files), len(bin_files), len(file_list))

        # get iteration numbers
        iterations = np.array([])
        for file_name in bin_files:
            try:
                iterations = np.append(
                    iterations, int(file_name.split('/')[-1].split('-')[-1]))
            except:
                print('Could not load filename: ' + file_name)
        iterations.sort()
        print(iterations)

        iter_plot = iterations[-1]
        print(int(iter_plot))

        # load tensorflow variables
        saver_var = tf.train.Saver(tf.all_variables())

        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        a_eval = a.eval()
        print(np.exp(np.squeeze(a_eval)))
        #print(np.shape(a_eval))

        # get 2D region to plot
        mask2D = np.reshape(chosen_mask, [40, 80])
        nz_idx = np.nonzero(mask2D)
        np.shape(nz_idx)
        print(nz_idx)
        ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
        xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])
        w_eval = w.eval()

        plt.figure()
        n_su = w_eval.shape[1]
        for isu in np.arange(n_su):
            xx = np.zeros((3200))
            xx[chosen_mask] = w_eval[:, isu]
            fig = plt.subplot(np.ceil(np.sqrt(n_su)), np.ceil(np.sqrt(n_su)),
                              isu + 1)
            plt.imshow(np.reshape(xx, [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            #if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge':
            #  plt.title(str(a_eval[isu, :]))
            #else:
            #  plt.title(str(np.squeeze(np.exp(a_eval[:, 0, isu]))), fontsize=12)

        plt.suptitle('Iteration:' + str(int(iter_plot)) + ' batchSz:' +
                     str(FLAGS.batchsz) + ' step size:' + str(FLAGS.step_sz),
                     fontsize=18)
        plt.show()
        plt.draw()
def subunit_discriminability(dataset_dict,
                             stimuli,
                             responses,
                             sr_graph,
                             num_examples=1000):
    ## compute distances between s-r pairs - pos and neg.
    ## negative_stim - if the negative is a stimulus or a response
    if num_examples % 100 != 0:
        raise ValueError('Only supports examples which are multiples of 100.')

    subunit_fit_loc = '/home/bhaishahster/stim-resp_collection_big_wn_retina_subunit_properties_train'
    subunits_datasets = gfile.ListDirectory(subunit_fit_loc)

    save_dict = {}

    datasets_log = {}
    for dat_key, datasets in dataset_dict.items():

        distances_log = {}
        distances_retina_sr_log = []
        distances_retina_rr_log = []
        for iretina in range(len(datasets)):
            # Find the relevant subunit fit
            piece = responses[iretina]['piece']
            matched_dataset = [
                ifit for ifit in subunits_datasets if piece[:12] == ifit[:12]
            ]
            if matched_dataset == []:
                raise ValueError('Could not find subunit fit')

            subunit_fit_path = os.path.join(subunit_fit_loc,
                                            matched_dataset[0])

            # Get predicted spikes.
            dat_resp_su = pickle.load(
                gfile.Open(
                    os.path.join(subunit_fit_path, 'response_prediction.pkl'),
                    'r'))
            resp_su = dat_resp_su[
                'resp_su']  # it has non-rejected cells as well.

            # Remove some cells.
            select_cells = [
                icell for icell in range(resp_su.shape[2])
                if dat_resp_su['cell_ids'][icell] in responses[iretina]
                ['cellID_list'].squeeze()
            ]

            select_cells = np.array(select_cells)
            resp_su = resp_su[:, :, select_cells].astype(np.float32)

            # Get stimulus
            stimulus = stimuli[responses[iretina]['stimulus_key']]
            stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :]
            responses_recorded_test = responses[iretina]['responses'][
                FLAGS.test_min:FLAGS.test_max, :]

            # Sample stimuli and responses.
            random_times = np.random.randint(40, stimulus_test.shape[0],
                                             num_examples)
            batch_size = 100

            # Recorded response - predicted response distances.
            distances_retina = np.zeros((num_examples, 10)) + np.nan
            for Nsub in range(1, 11):
                for ibatch in range(
                        np.floor(num_examples / batch_size).astype(np.int)):

                    # construct stimulus tensor.
                    stim_history = 30
                    resp_pred_batch = np.zeros((batch_size, resp_su.shape[2]))
                    resp_rec_batch = np.zeros((batch_size, resp_su.shape[2]))

                    for isample in range(batch_size):
                        itime = random_times[batch_size * ibatch + isample]
                        resp_pred_batch[isample, :] = resp_su[Nsub - 1,
                                                              itime, :]
                        resp_rec_batch[isample, :] = responses_recorded_test[
                            itime, :]

                    # Embed predicted responses
                    feed_dict = make_feed_dict(sr_graph,
                                               responses[iretina],
                                               responses=resp_pred_batch)
                    embed_predicted = sr_graph.sess.run(
                        sr_graph.anchor_model.responses_embed,
                        feed_dict=feed_dict)

                    # Embed recorded responses
                    feed_dict = make_feed_dict(sr_graph,
                                               responses[iretina],
                                               responses=resp_rec_batch)
                    embed_recorded = sr_graph.sess.run(
                        sr_graph.anchor_model.responses_embed,
                        feed_dict=feed_dict)

                    dd = sr_graph.sess.run(sr_graph.distances_arbitrary,
                                           feed_dict={
                                               sr_graph.arbitrary_embedding_1:
                                               embed_predicted,
                                               sr_graph.arbitrary_embedding_2:
                                               embed_recorded
                                           })

                    distances_retina[batch_size * ibatch:batch_size *
                                     (ibatch + 1), Nsub - 1] = dd
                    print(iretina, Nsub, ibatch)

            distances_retina_rr_log += [distances_retina]

            # Stimulus - predicted response distances.
            distances_retina = np.zeros((num_examples, 10)) + np.nan
            for Nsub in range(1, 11):
                for ibatch in range(
                        np.floor(num_examples / batch_size).astype(np.int)):

                    # construct stimulus tensor.
                    stim_history = 30
                    stim_batch = np.zeros(
                        (batch_size, stimulus_test.shape[1],
                         stimulus_test.shape[2], stim_history))
                    resp_batch = np.zeros((batch_size, resp_su.shape[2]))

                    for isample in range(batch_size):
                        itime = random_times[batch_size * ibatch + isample]
                        stim_batch[isample, :, :, :] = np.transpose(
                            stimulus_test[itime:itime - stim_history:-1, :, :],
                            [1, 2, 0])
                        resp_batch[isample, :] = resp_su[Nsub - 1, itime, :]

                    feed_dict = make_feed_dict(sr_graph, responses[iretina],
                                               resp_batch, stim_batch)

                    # Get distances
                    d_pos = sr_graph.sess.run(sr_graph.d_s_r_pos,
                                              feed_dict=feed_dict)

                    distances_retina[batch_size * ibatch:batch_size *
                                     (ibatch + 1), Nsub - 1] = d_pos
                    print(iretina, Nsub, ibatch)

            distances_retina_sr_log += [distances_retina]

        distances_log.update({'rr': distances_retina_rr_log})
        distances_log.update({'sr': distances_retina_sr_log})
        datasets_log.update({dat_key: distances_log})
    save_dict.update({
        'datasets_log': datasets_log,
        'dataset_dict': dataset_dict
    })

    return save_dict
def main(argv):

    # global variables will be used for getting training data
    global cells_choose
    global chosen_mask
    global chunk_order

    # set random seeds: when same algorithm run with different FLAGS,
    # the sequence of random data is same.
    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    # initial chunk order (will be re-shuffled everytime we go over a chunk)
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    # Load data summary
    data_filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(data_filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])

    # which cells to train subunits for
    if FLAGS.all_cells == 'True':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    else:
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    n_cells = np.sum(cells_choose)  # number of cells

    # load spikes and relevant stimulus pixels for chosen cells
    tot_spks = np.squeeze(data_summary['tot_spks'])
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    # chosen_mask = which pixels to learn subunits over
    if FLAGS.masked_stimulus == 'True':
        chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                               dtype='bool')
    else:
        chosen_mask = np.array(np.ones(3200).astype('bool'))
    stim_dim = np.sum(chosen_mask)  # stimulus dimensions
    print('\ndataset summary loaded')

    # print parameters
    print('Save folder name: ' + str(FLAGS.folder_name) + '\nmodel:' +
          str(FLAGS.model_id) + '\nLoss:' + str(FLAGS.loss) +
          '\nmasked stimulus:' + str(FLAGS.masked_stimulus) + '\nall_cells?' +
          str(FLAGS.all_cells) + '\nbatch size' + str(FLAGS.batchsz) +
          '\nstep size' + str(FLAGS.step_sz) + '\ntraining length: ' +
          str(FLAGS.train_len) + '\nn_cells: ' + str(n_cells))

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # filename for saving file
    short_filename = ('_masked_stim=' + str(FLAGS.masked_stimulus) +
                      '_all_cells=' + str(FLAGS.all_cells) + '_loss=' +
                      str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) +
                      '_step_sz' + str(FLAGS.step_sz) + '_tlen=' +
                      str(FLAGS.train_len) + '_bg')

    with tf.Session() as sess:
        # set up stimulus and response palceholders
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.loss == 'poisson':
            b_init = np.array(
                0.000001 * np.ones(n_cells)
            )  # a very small positive bias needed to avoid log(0) in poisson loss
        else:
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)
            )  # log-odds, a good initialization for some losses (like logistic)

        # different firing rate models
        if FLAGS.model_id == 'exp_additive':
            # This model was implemented for earlier work.
            # firing rate for cell c: lam_c = sum_s exp(w_s.x + a_sc)

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                                     dtype='float32'),
                            name='a')
            # firing rate model
            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            regularization = 0
            vars_fit = [w, a]

            def proj(
            ):  # called after every training step - to project to parameter constraints
                pass

        if FLAGS.model_id == 'relu':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                              '_nsu=' + str(n_su) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_su, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            # firing rate model
            lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # regularization of parameters
            regularization = (FLAGS.lam_w * tf.reduce_sum(tf.abs(w)) +
                              FLAGS.lam_a * tf.reduce_sum(tf.abs(a)))
            # projection to satisfy constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows(
            )  # get convolutional windows

            # variables
            w = tf.Variable(np.array(0.1 +
                                     0.05 * np.random.rand(dimx, dimy, n_pix),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_wts = tf.nn.relu(tf.reduce_sum(tf.mul(stim_masked, w), 3))
            # get firing rate
            lam = tf.matmul(tf.reshape(stim_wts, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to learn
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            #  stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm':
            # firing rate for cell c: lam_c = a_sfm_c'.relu(w.x) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to fit
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm_exp':
            # firing rate for cell c: lam_c = exp(a_sfm_c'.relu(w.x)) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.
            # exponential output NL would cancel the log() in poisson and might get better estimation properties.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            # get windows
            mask_tf, dimx, dimy, n_pix = get_windows()

            # declare variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # filter stimulus
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # get subunit activation
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get cell firing rates
            lam = tf.exp(
                tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm)) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        # different loss functions
        if FLAGS.loss == 'poisson':
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len

        if FLAGS.loss == 'logistic':
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * lam)) / data_len

        if FLAGS.loss == 'hinge':
            loss_inter = tf.reduce_sum(
                tf.nn.relu(1 - 2 * (resp - 0.5) * lam)) / data_len

        loss = loss_inter + regularization  # add regularization to get final loss function

        # training consists of calling training()
        # which performs a train step and
        # project parameters to model specific constraints using proj()
        train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
            loss, var_list=vars_fit)

        def training(inp_dict):
            sess.run(train_step,
                     feed_dict=inp_dict)  # one step of gradient descent
            proj()  # model specific projection operations

        # evaluate loss on given data.
        def get_loss(inp_dict):
            ls = sess.run(loss, feed_dict=inp_dict)
            return ls

        # saving details
        # make a folder with name derived from parameters of the algorithm
        # - it saves checkpoint files and summaries used in tensorboard
        parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
        # make folder if it does not exist
        if not gfile.IsDirectory(parent_folder):
            gfile.MkDir(parent_folder)
        FLAGS.save_location = parent_folder + short_filename + '/'
        if not gfile.IsDirectory(FLAGS.save_location):
            gfile.MkDir(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename

        # create summary writers
        # create histogram summary for all parameters which are learnt
        for ivar in vars_fit:
            tf.histogram_summary(ivar.name, ivar)
        # loss summary
        l_summary = tf.scalar_summary('loss', loss)
        # loss without regularization summary
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)
        # Merge all the summary writer ops into one op (this way,
        # calling one op stores all summaries)
        merged = tf.merge_all_summaries()
        # training and testing has separate summary writers
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        ## Fitting procedure
        print('Start fitting')
        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            # restore previous fits if they are available
            # - useful when programs are preempted frequently on .
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            # restore previous iteration count and start from there.
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)  # restore variables
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('Previous results loaded')
        else:
            print('Variables initialized')

        # Finally, do fitting
        icnt = 0
        # get test data and make test dictionary
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        for istep in np.arange(start_iter, 400000):
            print(istep)
            # get training data and make test dictionary
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }

            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)

                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
def get_next_training_batch(iteration):
    # Returns a new batch of training data : stimulus and response arrays
    # we will use global stimulus and response variables to permute training data
    # chunks and store where we are in list of training data

    # each chunk might have multiple training batches.
    # So go through all batches in a 'chunk' before moving on to the next chunk
    global stim_train_part
    global resp_train_part
    global chunk_order

    togo = True
    while togo:
        if (iteration % FLAGS.n_b_in_c == 0):
            # iteration is multiple of number of batches in a chunk means
            # finished going through a chunk, load new chunk of data
            ichunk = (iteration / FLAGS.n_b_in_c) % (
                FLAGS.train_len - 1)  # -1 as last one chunk used for testing
            if (ichunk == 0):
                # if starting over the chunks again, shuffle the chunks
                chunk_order = np.random.permutation(np.arange(
                    FLAGS.train_len))  # remove first chunk - weired?
            if chunk_order[
                    ichunk] + 1 != 1:  # 1st chunk was weired for the dataset used
                filename = FLAGS.data_location + 'Off_par_data_' + str(
                    chunk_order[ichunk] + 1) + '.mat'
                file_r = gfile.Open(filename, 'r')
                data = sio.loadmat(file_r)
                stim_train_part = data['maskedMovdd_part']  # stimulus
                resp_train_part = data['Y_part']  # response

                ichunk = chunk_order[ichunk] + 1
                while stim_train_part.shape[1] < FLAGS.batchsz:
                    # if the current loaded data is smaller than batch size, load more chunks
                    if (ichunk > FLAGS.n_chunks):
                        ichunk = 2
                    filename = FLAGS.data_location + 'Off_par_data_' + str(
                        ichunk) + '.mat'
                    file_r = gfile.Open(filename, 'r')
                    data = sio.loadmat(file_r)
                    stim_train_part = np.append(stim_train_part,
                                                data['maskedMovdd_part'],
                                                axis=1)
                    resp_train_part = np.append(resp_train_part,
                                                data['Y_part'],
                                                axis=1)
                    ichunk = ichunk + 1

        ibatch = iteration % FLAGS.n_b_in_c  # which section of current chunk to use
        try:
            stim_train = np.array(stim_train_part[:, ibatch:ibatch +
                                                  FLAGS.batchsz],
                                  dtype='float32').T
            resp_train = np.array(resp_train_part[:, ibatch:ibatch +
                                                  FLAGS.batchsz],
                                  dtype='float32').T
            togo = False
        except:
            iteration = np.random.randint(1, 100000)
            print('Load exception iteration: ' + str(iteration) + 'chunk: ' +
                  str(chunk_order[ichunk]) + 'batch: ' + str(ibatch))
            togo = True

    stim_train = stim_train[:, chosen_mask]
    resp_train = resp_train[:, cells_choose]
    return stim_train, resp_train, FLAGS.batchsz
def linear_rank1_models(sr_model,
                        sess,
                        dimx,
                        dimy,
                        n_cells,
                        center_locations,
                        cell_masks,
                        firing_rates,
                        cell_type,
                        time_window=30):
    """Learn a linear rank 1 (space time separable) stimulus response metric.

  The metric d(s, r) = -r' F s g. with r=response, s = space x time,
    g is time filter and F is (# cells x space) is spatial filter.

  g is initialized with precomputed average time filter
      across multiple retinas.
  F is initialized with signed gaussians at the center locations
      such that it forms a mosaic.

  g, F are learned by minimzing distance between positive examples and
    all the negatives in a batch.

  Since we think F must be spatially localized, we project to
    minimize locally normalized L1 regularization at each step.


  Args:
    sr_model : The variant of linear model ('lin_rank1' or 'lin_rank1_blind')
    sess : Tensorflow session.
    dimx: X dimension of the stimulus.
    dimy: Y dimension of the stimulus.
    n_cells : Number of cells
    center_locations : Location of centers for cell in the
                        2D grid (Dimx X Dimy X n_cells).
    cell_masks : Masks for spatial filter for each cell (Dimx x Dimy x n_cells).
    firing_rates : Mean firing rate of different cells ( # cells)
    cell_type : Cell type of each cell. (# cells)
    time_window : Length of stimulus (in time).

  Returns:
    sr_graph : Container of the embedding parameters and losses.

  Raises:
    ValueError : If the model type is not supported.
  """

    # Embed stimulus.
    stim_tf = tf.placeholder(tf.float32,
                             shape=[None, dimx, dimy, time_window
                                    ])  # batch x X x Y x time_window

    pos_resp = tf.placeholder(dtype=tf.float32,
                              shape=[None, n_cells, 1],
                              name='pos')

    neg_resp = tf.placeholder(dtype=tf.float32,
                              shape=[None, n_cells, 1],
                              name='neg')

    # n_cells x number of selected cells
    select_cells_mat = tf.placeholder(dtype=tf.float32,
                                      shape=[n_cells, None],
                                      name='selected_cells')

    # Declare spatial and temporal filter variables.
    k_init = approximate_rfs_from_centers(center_locations, cell_type,
                                          firing_rates, dimx, dimy, n_cells)
    ttf_data = pickle.load(gfile.Open(FLAGS.ttf_file, 'r'))
    ttf_init = ttf_data['ttf'].astype(np.float32)

    k_tf = tf.Variable(k_init.astype(np.float32))
    ttf_tf = tf.Variable(ttf_init.astype(np.float32))

    # Do filtering in space and time
    tfd = tf.expand_dims
    ttf_4d = tfd(tfd(tfd(ttf_tf, 0), 0), 3)
    stim_time_filtered = tf.nn.conv2d(stim_tf,
                                      ttf_4d,
                                      strides=[1, 1, 1, 1],
                                      padding='VALID')  # Batch x X x Y x 1
    stim_time_filtered_3d = stim_time_filtered[:, :, :, 0]
    stim_time_filtered_2d = tf.reshape(stim_time_filtered_3d,
                                       [-1, dimx * dimy])
    k_tf_flat = tf.reshape(k_tf, [dimx * dimy, n_cells])
    lam_raw = tf.matmul(stim_time_filtered_2d, k_tf_flat)  # Batch x n_cells

    # select_cells - all outputs of size Batch x n_selected_cells
    pos_resp_sel = tf.matmul(pos_resp[:, :, 0], select_cells_mat)
    neg_resp_sel = tf.matmul(neg_resp[:, :, 0], select_cells_mat)
    lam_raw_sel = tf.matmul(lam_raw, select_cells_mat)

    # Loss
    beta = FLAGS.beta
    d_pos = -tf.reduce_mean(pos_resp_sel * lam_raw_sel, 1)  # Batch
    d_pairwise_neg = -tf.reduce_mean(
        tfd(neg_resp_sel, 0) * tfd(lam_raw_sel, 1), 2)  # Batch

    difference = (tf.expand_dims(d_pos / beta, 1) - d_pairwise_neg / beta
                  )  # postives x negatives

    # Option 2
    # log(1 + \sum_j(exp(d+ - dj-)))
    difference_padded = tf.pad(difference, [[0, 0], [0, 1]])
    loss = tf.reduce_sum(beta * tf.reduce_logsumexp(difference_padded, 1), 0)

    accuracy_tf = tf.reduce_mean(
        tf.sign(-tf.expand_dims(d_pos, 1) + d_pairwise_neg))

    ## Train model
    if sr_model == 'lin_rank1':
        train_op_part = tf.train.AdagradOptimizer(
            FLAGS.learning_rate).minimize(loss)
        # Project K
        with tf.control_dependencies([train_op_part]):
            # Locally reweighted L1 for sptial locality.
            # proj_k = project_spatially_locaized_l1(k_tf, dimx, dimy, n_cells,
            #                                       lnl1_reg=0.00001, eps_neigh=0.01)

            # Project K to the mask
            cell_masks_tf = tf.constant(cell_masks.astype(np.float32))
            proj_k = tf.assign(k_tf, k_tf * cell_masks_tf)
        train_op = tf.group(train_op_part, proj_k)

    elif sr_model == 'lin_rank1_blind':
        # Only time filter is trained in blind model.
        # train_op = tf.train.AdagradOptimizer(FLAGS.learning_rate
        #                                    ).minimize(loss, var_list=[ttf_tf])
        train_op = []  # Blind model not trained.

    else:
        raise ValueError('Only lin_rank1 and lin_rank1_blind supported')

    # Store everything in a graph.
    sr_graph = collections.namedtuple(
        'SR_Graph', 'sess train_op '
        'select_cells_mat '
        'd_s_r_pos d_pairwise_s_rneg '
        'loss accuracy_tf stim_tf '
        'pos_resp neg_resp ttf_tf k_tf ')

    sr_graph = sr_graph(sess=sess,
                        train_op=train_op,
                        select_cells_mat=select_cells_mat,
                        d_s_r_pos=d_pos,
                        d_pairwise_s_rneg=d_pairwise_neg,
                        loss=loss,
                        accuracy_tf=accuracy_tf,
                        stim_tf=stim_tf,
                        pos_resp=pos_resp,
                        neg_resp=neg_resp,
                        ttf_tf=ttf_tf,
                        k_tf=k_tf)

    return sr_graph
Beispiel #18
0
def test_encoding(training_datasets, testing_datasets, responses, stimuli,
                  sr_graph, sess, file_name, sample_fcn):

    print('Testing for encoding model')
    tf.logging.info('Testing for encoding model')
    #from IPython import embed; embed()

    # saving filename.
    save_analysis_filename = os.path.join(FLAGS.save_folder,
                                          file_name + '_analysis_sample_resps')

    retina_ids = [resp['piece'] for resp in responses]

    # 4) Find `retina_embed` for a test retina
    # Find mean of embedding of training retinas

    latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2])
    batch_sz = 500
    ret_params_list = []
    for idataset in training_datasets:
        if idataset >= len(responses):
            continue
        print(idataset)
        rng = numpy.random.RandomState(23)
        feed_dict = sample_fcn.batch(stimuli,
                                     responses,
                                     idataset,
                                     sr_graph,
                                     batch_pos_sz=batch_sz,
                                     batch_neg_sz=batch_sz,
                                     batch_type='test',
                                     if_continuous=False,
                                     rng=rng)
        op = sess.run(sr_graph.retina_params, feed_dict=feed_dict)
        ret_params_list += [op]
    ret_params_list = np.array(ret_params_list)
    ret_params_init = np.mean(ret_params_list, 0)  # Use for initialization

    # Now optimize ret_params for each retina.
    loss_arbit_ret_params = sr_graph.loss_arbit_ret_params
    ret_params_grad = tf.gradients(loss_arbit_ret_params,
                                   sr_graph.retina_params_arbitrary)
    ret_params_dict = {}
    for idataset in testing_datasets:
        rng = numpy.random.RandomState(23)
        ret_params_new = np.copy(ret_params_init)
        lr = 0.001
        ret_log = []
        for iiter in range(100):
            feed_dict = sample_fcn.batch(stimuli,
                                         responses,
                                         idataset,
                                         sr_graph,
                                         batch_pos_sz=500,
                                         batch_neg_sz=0,
                                         batch_type='train',
                                         if_continuous=True,
                                         rng=rng)
            feed_dict.update(
                {sr_graph.retina_params_arbitrary: ret_params_new})
            delta_ret_param, l_np = sess.run(
                [ret_params_grad, loss_arbit_ret_params], feed_dict=feed_dict)
            print('Retina: %d, step: %d, loss: %.3f, Ret_params: %s' %
                  (idataset, iiter, l_np, ret_params_new))
            ret_log += [np.copy(ret_params_new)]
            ret_params_new -= lr * delta_ret_param[0]

        dataset_log = {
            'final_embedding': np.copy(ret_params_new),
            'path': np.copy(ret_log)
        }
        ret_params_dict.update({idataset: dataset_log})

    pickle.dump(ret_params_dict,
                gfile.Open((save_analysis_filename + '_test6.pkl'), 'w'))

    # Latent embedding of different retinas.
    # 3) Interpolate between latent representation and see how responses change.
    interpolation_retinas_log = [[63, 5], [73, 65], [41, 83], [38,
                                                               57], [72, 76],
                                 [48, 58], [2,
                                            5]]  # [[72, 76], [48, 58], [2, 5]]
    batch_sz = 500
    save_dict_log = []
    for interpolation_retinas in interpolation_retinas_log:
        retina_params_end_pts = []
        cell_info = []
        valid_cell_log = []

        # 3a) Find latent representation of each retina.
        for idataset in interpolation_retinas:
            rng = numpy.random.RandomState(23)
            feed_dict = sample_fcn.batch(stimuli,
                                         responses,
                                         idataset,
                                         sr_graph,
                                         batch_pos_sz=batch_sz,
                                         batch_neg_sz=batch_sz,
                                         batch_type='test',
                                         if_continuous=False,
                                         rng=rng)
            op = sess.run([
                sr_graph.retina_params, sr_graph.stim_tf,
                sr_graph.anchor_model.embed_locations_original,
                sr_graph.anchor_model.map_cell_grid_tf
            ],
                          feed_dict=feed_dict)
            ret_params_np, stim_np, cell_locs, map_cell_grid = op
            retina_params_end_pts += [ret_params_np]
            rcct_log = []
            for icell in range(map_cell_grid.shape[2]):
                r, c = np.where(map_cell_grid[:, :, icell] > 0)
                ct = np.where(np.squeeze(cell_locs[r, c, :]) > 0)[0]
                rcct_log += [[r[0], c[0], ct[0]]]
            cell_info += [np.array(rcct_log)]
            valid_cell_log += [responses[idataset]['valid_cells']]

        # 3b) Now, interpolate and check the responses.
        fr_interpolate_log = []
        alpha_log = np.arange(0, 1.1, 0.1)
        for alpha in alpha_log:
            retina_params_interpolate = (
                alpha * retina_params_end_pts[0] +
                (1 - alpha) * retina_params_end_pts[1])
            feed_dict = {
                sr_graph.stim_tf: stim_np,
                sr_graph.retina_params_arbitrary: retina_params_interpolate
            }
            fr_interpolate = sess.run(
                sr_graph.response_pred_from_arbit_ret_params,
                feed_dict=feed_dict)
            fr_interpolate_log += [fr_interpolate]

        fr_interpolate_log = np.array(fr_interpolate_log)
        save_dict = {
            'interpolation_retinas': interpolation_retinas,
            'alpha_log': alpha_log,
            'fr_interpolate_log': fr_interpolate_log,
            'cell_info': cell_info,
            'valid_cell_log': valid_cell_log,
            'retina_params_end_pts': retina_params_end_pts
        }
        save_dict_log += [save_dict]

    pickle.dump(save_dict_log,
                gfile.Open((save_analysis_filename + '_test5.pkl'), 'w'))

    # 1) Predict responses of different retinas - training AND testing retinas.
    #from IPython import embed; embed()

    tag_list = ['training_datasets', 'testing_datasets']
    results_tr_tst = {}
    for itr_tst, tr_test_datasets in enumerate(
        [training_datasets, testing_datasets]):
        results_log = {}
        for idataset in tr_test_datasets:

            if idataset >= len(responses):
                continue
            print(idataset)
            rng = numpy.random.RandomState(23)
            feed_dict = sample_fcn.batch(stimuli,
                                         responses,
                                         idataset,
                                         sr_graph,
                                         batch_pos_sz=200,
                                         batch_neg_sz=200,
                                         batch_type='test',
                                         if_continuous=True,
                                         rng=rng)

            op = sess.run([
                sr_graph.fr_predicted,
                sr_graph.anchor_model.embed_locations_original,
                sr_graph.anchor_model.map_cell_grid_tf,
                sr_graph.anchor_model.responses_tf, sr_graph.retina_params
            ],
                          feed_dict=feed_dict)
            fr_pred_np, cell_locs, map_cell_grid, responses_np, ret_params_np = op

            # r, c, z = np.where(cell_locs > 0)
            fr_pred_cell = np.squeeze(np.zeros_like(responses_np))
            for icell in range(responses_np.shape[1]):
                r, c = np.where(map_cell_grid[:, :, icell] > 0)
                ct = np.where(np.squeeze(cell_locs[r, c, :]) > 0)[0]
                r = r[0]
                c = c[0]
                ct = ct[0]
                fr_pred_cell[:, icell] = fr_pred_np[:, r, c, ct]
                '''
        tms = np.arange(responses_np.shape[0])
        plt.stem(tms, responses_np[:, icell, 0])
        plt.plot(3 * fr_pred_np[:, r, c, 0])
        plt.plot(3 * fr_pred_np[:, r, c, 1])
        '''

            # Find the stimulus
            stim_np = feed_dict[sr_graph.stim_tf]
            t_len, dx, dy, t_depth = stim_np.shape

            stim_np_compressed = np.zeros(
                (t_len + t_depth - 1, dx, dy))  # 500, 80, 40, 30
            stim_np_compressed[:t_depth] = np.transpose(
                stim_np[0, :, :, :], [2, 0, 1])
            for itm in np.arange(1, t_len):
                stim_np_compressed[itm + t_depth -
                                   1, :, :] = stim_np[itm, :, :, 0]

            save_dict = {
                'fr_pred_cell': fr_pred_cell,
                'responses_recorded': np.squeeze(responses_np),
                'valid_cells': responses[idataset]['valid_cells'],
                'ctype_1hot': responses[idataset]['ctype_1hot'],
                'cell_locs': cell_locs,
                'map_cell_grid': map_cell_grid,
                'ret_params_np': ret_params_np,
                'fr_pred_np': fr_pred_np,
                'stimulus_key': responses[idataset]['stimulus_key']
            }
            results_tr_tst.update(
                {responses[idataset]['stimulus_key']: stim_np_compressed})

            results_log.update({idataset: save_dict})

        results_tr_tst.update({tag_list[itr_tst]: results_log})
        results_tr_tst.update({'retina_ids': retina_ids})
    pickle.dump(results_tr_tst,
                gfile.Open((save_analysis_filename + '_test3.pkl'), 'w'))

    # 2) Is the latent representation consistent for each retina, across responses?
    latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2])
    batch_sz_list = [500]
    n_repeats = 1
    tag_list = ['training_datasets', 'testing_datasets']
    results_tr_tst = {}
    for itr_tst, tr_test_datasets in enumerate(
        [training_datasets, testing_datasets]):
        results_log = {}
        for idataset in tr_test_datasets:
            if idataset >= len(responses):
                continue
            print(idataset)
            rng = numpy.random.RandomState(23)
            ret_params_np = np.zeros(
                (len(batch_sz_list), n_repeats, latent_dimensionality))
            for ibatch_sz, batch_sz in enumerate(batch_sz_list):
                for iresp in range(n_repeats):
                    print(idataset, ibatch_sz, iresp)
                    feed_dict = sample_fcn.batch(stimuli,
                                                 responses,
                                                 idataset,
                                                 sr_graph,
                                                 batch_pos_sz=batch_sz,
                                                 batch_neg_sz=batch_sz,
                                                 batch_type='test',
                                                 if_continuous=False,
                                                 rng=rng)
                    op = sess.run(sr_graph.retina_params, feed_dict=feed_dict)
                    print(op)
                    ret_params_np[ibatch_sz, iresp, :] = op

            save_dict = {
                'ret_params_np': ret_params_np,
                'batch_sz_list': batch_sz_list,
                'valid_cells': responses[idataset]['valid_cells']
            }
            results_log.update({idataset: save_dict})

        results_tr_tst.update({tag_list[itr_tst]: results_log})
        results_tr_tst.update({'retina_ids': retina_ids})
    pickle.dump(results_tr_tst,
                gfile.Open((save_analysis_filename + '_test4.pkl'), 'w'))

    # 3) Embedding of EIs
    if hasattr(sr_graph, 'retina_params_from_ei'):
        latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2])
        batch_sz = 500
        tag_list = ['training_datasets', 'testing_datasets']
        results_tr_tst = {}
        for itr_tst, tr_test_datasets in enumerate(
            [training_datasets, testing_datasets]):
            results_log = {}
            for idataset in tr_test_datasets:
                if idataset >= len(responses):
                    continue
                print(idataset)
                rng = numpy.random.RandomState(23)
                feed_dict = sample_fcn.batch(stimuli,
                                             responses,
                                             idataset,
                                             sr_graph,
                                             batch_pos_sz=batch_sz,
                                             batch_neg_sz=batch_sz,
                                             batch_type='test',
                                             if_continuous=False,
                                             rng=rng)
                op = sess.run(
                    [sr_graph.retina_params_from_ei, sr_graph.retina_params],
                    feed_dict=feed_dict)
                ret_params_from_ei_np, ret_params_np = op
                print(op)

                save_dict = {
                    'ret_params_np': ret_params_np,
                    'ret_params_from_ei_np': ret_params_from_ei_np,
                    'valid_cells': responses[idataset]['valid_cells']
                }
                results_log.update({idataset: save_dict})

            results_tr_tst.update({tag_list[itr_tst]: results_log})
            results_tr_tst.update({'retina_ids': retina_ids})
        pickle.dump(
            results_tr_tst,
            gfile.Open((save_analysis_filename + '_test4_ei.pkl'), 'w'))
def get_next_training_batch(iteration):
    # stimulus.astype('float32')[tms[icnt: icnt+FLAGS.batchsz], :],
    # response.astype('float32')[tms[icnt: icnt+FLAGS.batchsz], :]
    # FLAGS.batchsz

    # we will use global stimulus and response variables
    global stim_train_part
    global resp_train_part
    global chunk_order

    togo = True
    while togo:
        if (iteration % FLAGS.n_b_in_c == 0):
            # load new chunk of data
            ichunk = (iteration / FLAGS.n_b_in_c) % (
                FLAGS.train_len - 1)  # last one chunks used for testing
            if (ichunk == 0
                ):  # shuffle training chunks at start of training data
                chunk_order = np.random.permutation(np.arange(
                    FLAGS.train_len))  # remove first chunk - weired?
            #  if logfile != None :
            #    logfile.write('\nTraining chunks shuffled')

            if chunk_order[ichunk] + 1 != 1:
                filename = FLAGS.data_location + 'Off_par_data_' + str(
                    chunk_order[ichunk] + 1) + '.mat'
                file_r = gfile.Open(filename, 'r')
                data = sio.loadmat(file_r)
                stim_train_part = data['maskedMovdd_part']
                resp_train_part = data['Y_part']

                ichunk = chunk_order[ichunk] + 1
                while stim_train_part.shape[1] < FLAGS.batchsz:
                    #print('Need to add extra chunk')
                    if (ichunk > FLAGS.n_chunks):
                        ichunk = 2
                    filename = FLAGS.data_location + 'Off_par_data_' + str(
                        ichunk) + '.mat'
                    file_r = gfile.Open(filename, 'r')
                    data = sio.loadmat(file_r)
                    stim_train_part = np.append(stim_train_part,
                                                data['maskedMovdd_part'],
                                                axis=1)
                    resp_train_part = np.append(resp_train_part,
                                                data['Y_part'],
                                                axis=1)
                    #print(np.shape(stim_train_part), np.shape(resp_train_part))
                    ichunk = ichunk + 1

            #  if logfile != None:
            #    logfile.write('\nNew training data chunk loaded at: '+ str(iteration) + ' chunk #: ' + str(chunk_order[ichunk]))

        ibatch = iteration % FLAGS.n_b_in_c
        try:
            stim_train = np.array(stim_train_part[:, ibatch:ibatch +
                                                  FLAGS.batchsz],
                                  dtype='float32').T
            resp_train = np.array(resp_train_part[:, ibatch:ibatch +
                                                  FLAGS.batchsz],
                                  dtype='float32').T
            togo = False
        except:
            iteration = np.random.randint(1, 100000)
            print('Load exception iteration: ' + str(iteration) + 'chunk: ' +
                  str(chunk_order[ichunk]) + 'batch: ' + str(ibatch))
            togo = True

    stim_train = stim_train[:, chosen_mask]
    resp_train = resp_train[:, cells_choose]
    return stim_train, resp_train, FLAGS.batchsz
def main(argv):
    print('\nCode started')

    global cells_choose
    global chosen_mask

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    global chunk_order
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge' or FLAGS.model_id == 'poisson_relu':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print(FLAGS.model_id)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')
    else:
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))

            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'poisson_relu':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(
                tf.exp(tf.nn.relu(stim, w)), a
            ) + b  #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(f) / 120. -
                          tf.reduce_sum(resp * tf.log(f))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'hinge':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.relu(1 - 2 *
                                                  (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        # summaries
        l_summary = tf.scalar_summary('loss', loss)
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.merge_all_summaries()
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        print('\nStarting new code')

        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('\nPrevious results loaded')
        else:
            print('\nVariables initialized')
        # Do the fitting

        icnt = 0
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        #logfile.close()

        for istep in np.arange(start_iter, 400000):

            print(istep)
            # get training data
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }
            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)
                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
Beispiel #21
0
def main(argv):
    logfile = gfile.Open(
        FLAGS.save_location + 'log_bias=' + str(FLAGS.bias_ratio) + '_lam_W=' +
        str(FLAGS.lam_W) + '_lam_a=' + str(FLAGS.lam_a) + '.txt', "w")
    logfile.write('Starting new thread\n')
    logfile.flush()
    print('\nlog file written once')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)

    #plt.ion()
    ## Load data
    file = h5py.File(FLAGS.data_location + 'Off_parasol.mat', 'r')
    logfile.write('\ndataset loaded')
    # Load Masked movie
    data = file.get('maskedMovdd')
    maskedMov = np.array(data)
    cells = file.get('cells')
    nCells = cells.shape[0]
    ttf_log = file.get('ttf_log')
    ttf_avg = file.get('ttf_avg')
    stimulus = maskedMov
    total_mask_log = file.get('totalMaskAccept_log')

    # Load spike Response of cells
    data = file.get('Y')
    biSpkResp_coll = np.array(data, dtype='float32')
    mask = np.array(np.ones(3200), dtype=bool)

    ##
    Nsub = FLAGS.ratio_SU * nCells
    StimDim = maskedMov.shape[1]

    # initialize subunits
    W_init = initialize_SU(nSU=Nsub)
    a_init = np.random.rand(Nsub, nCells)

    su_act = stimulus[10000:12000, :].dot(W_init)
    su_std = np.sqrt(np.diag(su_act.T.dot(su_act)) / stimulus.shape[0])

    bias_init = FLAGS.bias_ratio * su_std
    print(bias_init)
    logfile.write('bias = ' + str(bias_init))
    logfile.write('\nSU initialized')

    with tf.Session() as sess:
        stim = tf.placeholder(tf.float32, shape=[None, StimDim], name="stim")
        resp = tf.placeholder(tf.float32, name="resp")
        data_len = tf.placeholder(tf.float32, name="data_len")

        #W = tf.Variable(tf.random_uniform([StimDim,Nsub]))
        W = tf.Variable(np.array(W_init, dtype='float32'))
        a = tf.Variable(np.array(a_init, dtype='float32'))
        bias = tf.Variable(np.array(bias_init, dtype='float32'))
        #a = tf.Variable(np.identity(Nsub,dtype='float32')*0.01)
        lam = tf.matmul(tf.nn.relu(tf.matmul(stim, W) + bias),
                        tf.nn.relu(a)) + 0.0001  # collapse a dimension
        loss = (tf.reduce_sum(lam) / 120. - tf.reduce_sum(
            resp * tf.log(lam))) / data_len + FLAGS.lam_W * tf.reduce_sum(
                tf.abs(W)) + FLAGS.lam_a * tf.reduce_sum(tf.abs(a))
        train_step = tf.train.AdamOptimizer(1e-4).minimize(
            loss, var_list=[W, a, bias])

        sess.run(tf.initialize_all_variables())

        # Do the fitting
        batchSz = 100
        icnt = 0

        fd_test = {
            stim: stimulus.astype('float32')[216000 - 10000:216000 - 1, :],
            resp:
            biSpkResp_coll.astype('float32')[216000 - 10000:216000 - 1, :],
            data_len: 10000
        }
        ls_train_log = np.array([])
        ls_test_log = np.array([])
        tms = np.random.permutation(np.arange(216000 - 1000))
        for istep in range(100000):
            time_start = timeit.timeit()
            fd_train = {
                stim: stimulus.astype('float32')[tms[icnt:icnt + batchSz], :],
                resp:
                biSpkResp_coll.astype('float32')[tms[icnt:icnt + batchSz], :],
                data_len: batchSz
            }
            sess.run(train_step, feed_dict=fd_train)
            if istep % 10 == 0:
                ls_train = sess.run(loss, feed_dict=fd_train)
                ls_test = sess.run(loss, feed_dict=fd_test)
                ls_train_log = np.append(ls_train_log, ls_train)
                ls_test_log = np.append(ls_test_log, ls_test)
                logfile.write('\nIterations: ' + str(istep) +
                              ' Training error: ' + str(ls_train) +
                              ' Testing error: ' + str(ls_test))
                logfile.flush()
                sio.savemat(
                    FLAGS.save_location + 'data_bias=' +
                    str(FLAGS.bias_ratio) + '_lam_W=' + str(FLAGS.lam_W) +
                    '_lam_a' + str(FLAGS.lam_a) + '_ratioSU' +
                    str(FLAGS.ratio_SU) + '_grid_spacing_' +
                    str(FLAGS.su_grid_spacing) + '.mat', {
                        'bias_ratio': FLAGS.bias_ratio,
                        'bias_init': bias_init,
                        'bias': bias.eval(),
                        'W': W.eval(),
                        'a': a.eval(),
                        'W_init': W_init,
                        'a_init': a_init,
                        'ls_train_log': ls_train_log,
                        'ls_test_log': ls_test_log
                    })

            icnt = icnt + batchSz
            if icnt > 216000 - 10000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 10000))

    logfile.close()
def learn_lin_embedding(stimulus, responses, filename,
                        lam_l1=0.01, beta=10, time_window=30, lr=0.01):

  num_cell_types = 2
  dimx = 80
  dimy = 40

  leng = np.minimum(stimulus.shape[0], responses.shape[0])
  resp_short = responses[:leng, :]
  stim_short = np.reshape(stimulus[:leng, :, :], [leng, -1])
  init_A = stim_short[:-4, :].T.dot(resp_short[4:, :])
  init_A_3d = np.reshape(init_A.T, [-1, stimulus.shape[1], stimulus.shape[2]])

  n_cells = responses.shape[1]
  with tf.Session() as sess:

    stim_tf = tf.placeholder(tf.float32,
                             shape=[None, dimx,
                             dimy, time_window]) # batch x X x Y x time_window

    # Linear filter in time
    ttf = tf.Variable(np.ones((time_window, 1)).astype(np.float32))
    stim_tf_2d = tf.reshape(stim_tf, [-1, time_window])
    stim_filtered_2d = tf.matmul(stim_tf_2d, ttf)
    stim_filtered = tf.reshape(stim_filtered_2d, [-1, dimx, dimy, 1])
    stim_filtered = tf.gather(tf.transpose(stim_filtered, [3, 0, 1, 2]), 0)

    # filter in space
    # A = tf.Variable(np.ones((n_cells, dimx, dimy)).astype(np.float32))
    A = tf.Variable(init_A_3d.astype(np.float32))
    A_2d = tf.reshape(A, [n_cells, dimx*dimy])
    
    stim_filtered_perm_2d = tf.reshape(stim_filtered, [-1, dimx*dimy])
    stim_space_filtered_2d = tf.matmul(stim_filtered_perm_2d, tf.transpose(A_2d))
    stim_out = tf.expand_dims(stim_space_filtered_2d, 2)


    responses_anchor_tf = tf.placeholder(dtype=tf.float32,
                                 shape=[None, n_cells, 1],
                                 name='anchor') # batch x n_cells, 1

    responses_neg_tf = tf.placeholder(dtype=tf.float32,
                           shape=[None, n_cells, 1],
                           name='anchor') # batch x n_cells, 1

    from IPython import embed; embed()
    d_s_r_pos = - tf.reduce_sum((stim_out*responses_anchor_tf)**2, [1, 2]) # batch
    d_pairwise_s_rneg = - tf.reduce_sum((tf.expand_dims(stim_out, 1) *
                               tf.expand_dims(responses_neg_tf, 0))**2, [2, 3]) # batch x batch_neg


    difference = (tf.expand_dims(d_s_r_pos/beta, 1) -  d_pairwise_s_rneg/beta) # postives x negatives

    # # log(1 + \sum_j(exp(d+ - dj-)))
    difference_padded = tf.pad(difference, [[0, 0], [0, 1]])
    loss = tf.reduce_sum(beta * tf.reduce_logsumexp(difference_padded, 1), 0)

    accuracy_tf =  tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg))

    lr = 0.01
    train_op = tf.train.AdagradOptimizer(lr).minimize(loss)

    with tf.control_dependencies([train_op]):
      prox_op = tf.assign(A, tf.nn.relu(A - lam_l1) - tf.nn.relu(- A - lam_l1))

    update = tf.group(train_op, prox_op)


    # Now train
    sess.run(tf.global_variables_initializer())
    a_log = []
    for iiter in range(200000):
      stim_batch, resp_batch, resp_batch_neg = sample_datasets.get_batch(stimulus, responses,
                                                         batch_size=100, batch_neg_resp=100,
                                                         stim_history=time_window,
                                                         min_window=10,
                                                         batch_type='train')


      feed_dict = {stim_tf: stim_batch.astype(np.float32),
                   responses_anchor_tf: np.expand_dims(resp_batch, 2).astype(np.float32),
                   responses_neg_tf: np.expand_dims(resp_batch_neg, 2).astype(np.float32)}


      _, l, a = sess.run([update, loss, accuracy_tf], feed_dict = feed_dict)
      a_log += [a]

      if iiter % 10 == 0:
        print(a)
      if iiter % 1000 == 0:
        save_dict = {'A': sess.run(A), 'ttf': sess.run(ttf)}
        pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder, filename), 'w'))

    return [sess.run(A), sess.run(ttf)]
Beispiel #23
0
def main(unused_argv=()):

  # set random seed
  np.random.seed(121)
  print('random seed reset')

  # Get details of stored model.
  model_savepath, model_filename = config.get_filepaths()

  # Load responses to two trials of long white noise.
  data_wn = du.DataUtilsMetric(os.path.join(FLAGS.data_path, FLAGS.data_test))

  # Quadratic score function.
  with tf.Session() as sess:

    # Define and restore/initialize the model.
    tf.logging.info('Model : %s ' % FLAGS.model)
    met = config.get_model(sess, model_savepath, model_filename,
                           data_wn, True)
    print('IS_TRAINING = TRUE!!! ')
    tf.logging.info('IS_TRAINING = TRUE!!! ')

    PrintModelAnalysis(tf.get_default_graph())

    # get triplets
    # triplet A
    outputs = data_wn.get_triplets(batch_size=FLAGS.batch_size_test,
                                   time_window=FLAGS.time_window)
    anchor_test, pos_test, neg_test, _, _, _ = outputs
    triplet_a = (anchor_test, pos_test, neg_test)

    # triplet B
    outputs = data_wn.get_tripletsB(batch_size=FLAGS.batch_size_test,
                                    time_window=FLAGS.time_window)
    anchor_test, pos_test, neg_test, _, _, _ = outputs
    triplet_b = (anchor_test, pos_test, neg_test)

    triplets = [triplet_a, triplet_b]
    triplet_labels = ['triplet A', 'triplet B']

    analysis_results = {}  # collect analysis results in a dictionary

    # 1. Plot distances between positive and negative pairs.
    # analyse.plot_pos_neg_distances(met, anchor_test, pos_test, neg_test)
    # tf.logging.info('Distances plotted')

    # 2. Accuracy of triplet orderings - fraction of triplets where
    # distance with positive is smaller than distance with negative.
    triplet_dict = {}
    for iitriplet, itriplet in enumerate(triplets):
      dist_pos, dist_neg, accuracy = analyse.compute_distances(met, *itriplet)
      dist_analysis = {'pos': dist_pos,
                       'neg': dist_neg,
                       'accuracy': accuracy}
      triplet_dict.update({triplet_labels[iitriplet]: dist_analysis})
    analysis_results.update({'distances': triplet_dict})
    tf.logging.info('Accuracy computed')

    # 3. Precision-Recall analysis : declare positive if s(x,y)<t and
    # negative otherwise. Vary threshold t, and plot precision-recall and
    # ROC curves.
    triplet_dict = {}
    for iitriplet, itriplet in enumerate(triplets):
      output = analyse.precision_recall(met, *itriplet, toplot=False)
      precision_log, recall_log, f1_log, fpr_log, tpr_log, pr_data = output
      pr = {'precision': precision_log, 'recall': recall_log,
            'pr_data': pr_data}
      roc = {'TPR': tpr_log, 'FPR': fpr_log}
      pr_results = {'PR': pr, 'F1': f1_log, 'ROC': roc}
      triplet_dict.update({triplet_labels[iitriplet]: pr_results})
    analysis_results.update({'PR_analysis': triplet_dict})
    tf.logging.info('Precision Recall, F1 score and ROC curves computed')

    # 4. Clustering analysis: How well clustered are responses for a stimulus?
    # Get all trials for a few (1000) stimuli and compute
    # distances between all pairs of points.
    # See how many of responses generated by same stimulus are actually
    # near to each other.
    n_tests = 10
    p_log = []
    r_log = []
    s_log = []
    resp_log = []
    dist_log = []
    embedding_log = []
    for itest in range(n_tests):

      n_stims = 10  # previously 100
      tf.logging.info('Number of random samples is : %d' % n_stims)
      resp_fcn = data_wn.get_response_all_trials
      resp_all_trials, stim_id = resp_fcn(n_stims, FLAGS.time_window,
                                          random_seed=itest)

      #  TODO(bhaishahster) : Remove duplicates from resp_all_trials
      distance_pairs = analyse.get_pairwise_distances(met, resp_all_trials)
      k_log = [1, 2, 3, 4, 5, 10, 15, 20, 50, 75, 100, 200, 300, 400, 500]
      precision_log = []
      recall_log = []
      for k in k_log:
        precision, recall = analyse.topK_retrieval(distance_pairs, k, stim_id)
        precision_log += [precision]
        recall_log += [recall]

      p_log += [precision_log]
      r_log += [recall_log]
      s_log += [stim_id]
      resp_log += [resp_all_trials]
      dist_log += [distance_pairs]


      #tf.logging.info('Getting 2D t-SNE embedding')
      #model = manifold.TSNE(n_components=2)
      #tSNE_embedding = model.fit_transform(distance_pairs)
      #embedding_log += [tSNE_embedding]

    all_trials = {'distances': dist_log, 'K': k_log,
                  'precision': p_log,
                  'recall': r_log,
                  'probe_stim_idx': s_log, 'probes': resp_log,
                  'embedding': embedding_log}
    analysis_results_clustering = {'all_trials': all_trials}
    pickle_file_clustering = (os.path.join(model_savepath, model_filename)
                              + '_' + FLAGS.data_test +
                              '_analysis_clustering.pkl')
    pickle.dump(analysis_results_clustering, gfile.Open(pickle_file_clustering, 'w'))

    tf.logging.info('Clustering analysis done.')

    '''
    # sample few/all repeats of stimuli which are continous.
    repeats = data_wn.get_repeats()
    n_samples_max = 10
    samples = np.random.randint(0, repeats.shape[0],
                                np.minimum(n_samples_max, repeats.shape[0]))

    n_start_times = 5
    time_window = 15
    resps_cont = np.zeros((n_start_times, n_samples_max,
                           time_window, repeats.shape[-1]))
    from IPython import embed; embed()

    for istart in range(n_start_times):
      start_tm = np.random.randint(repeats.shape[1] - time_window)
      resps_cont[istart, :, :, :] = repeats[samples, start_tm:
                                            start_tm+time_window, :]
    resps_cont_2d = np.reshape(resps_cont, [-1, resps_cont.shape[-1]])
    resps_cont_3d = np.expand_dims(resps_cont_2d, 2)
    distances_cont_resp = analyse.get_pairwise_distances(met, resps_cont_3d)

    n_components = 2
    model = manifold.TSNE(n_components=n_components)
    ts = model.fit_transform(distances_cont_resp)
    tts = np.reshape(ts, [n_start_times, n_samples_max,
                          time_window, n_components])

    from IPython import embed; embed()

    plt.figure()
    for istart in [1]:  # range(n_start_times):
      for isample in range(n_samples_max):
        pts = tts[istart, isample, :, :]
        plt.plot(pts[:, 0], pts[:, 1])

    plt.show()
    '''

    # 5. Store the parameters of the score function.
    score_params = met.get_parameters()
    analysis_results.update({'score_params': score_params})
    tf.logging.info('Got interesting parameters of score')

    # 6. Retreival analysis on training data.
    # Retrieve the nearest responses in training data for a probe test response.

    # Load training data.
#     data_wn_train = du.DataUtilsMetric(os.path.join(FLAGS.data_path,
#                                                     FLAGS.data_train))
#
#     out_data = data_wn_train.get_all_responses(FLAGS.time_window)
#     train_all_resp, train_stim_time = out_data
#
#     # Get a few test stimuli. Here we use all repreats of a few stimuli.
#     n_stims = 100
#     resp_all_trials, stim_id = data_wn.get_response_all_trials(n_stims,
#                                                                FLAGS.time_window)
#     k = 1000
#     retrieved, retrieved_stim = analyse.topK_retrieval_probes(train_all_resp,
#                                                               train_stim_time,
#                                                               resp_all_trials,
#                                                               k, met)
#     retrieval_dict = {'probe': resp_all_trials, 'probe_stim_idx': stim_id,
#                       'retrieved': retrieved,
#                       'retrieved_stim_idx': retrieved_stim}
#     analysis_results.update({'retrieval': retrieval_dict})
#     tf.logging.info('Retrieved nearest points in training data'
#                     ' for some probes in test data')

    # TODO(bhaishahster) : Decode stimulus using retrieved responses.


    # 7. Learn encoding model.
    # Learn mapping from stimulus to response.

    # from IPython import embed; embed()
    '''
    data_wn_train = du.DataUtilsMetric(os.path.join(FLAGS.data_path,
                                                    'example_long_wn_2rep_'
                                                    'ON_OFF_with_stim.mat'))

    data_wn_test = du.DataUtilsMetric(os.path.join(FLAGS.data_path,
                                                    'example_wn_30reps_ON_'
                                                    'OFF_with_stimulus.mat'))
    stimulus_test = data_wn_test.get_stimulus()
    response_test = data_wn_test.get_repeats()

    stimulus = data_wn_train.get_stimulus()
    response = data_wn_train.get_repeats()
    ttf = data_wn_train.ttf[::-1]
    encoding_fcn = encoding_model.learn_encoding_model_ln

     # Initialize ttf, RF using ttf and scale ttf to match firing rate
    RF_np, ttf_np, model = encoding_fcn(sess, met, stimulus, response, ttf_in=ttf,
                                 lr=0.1)
    firing_rate_pred = sess.run(model.firing_rate,
                                feed_dict={model.stimulus: stimulus_test})

    initialize_all = {'RF': RF_np, 'ttf': ttf,
                      'firing_rate_test': firing_rate_pred}

    # Initialize ttf and do no other initializations
    RF_np_noinit, ttf_np_noinit, model = encoding_fcn(sess,met, stimulus, response,
                                               ttf_in=ttf,
                                               initialize_RF_using_ttf=False,
                                               scale_ttf=False, lr=0.1)
    firing_rate_pred = sess.run(model.firing_rate,
                                feed_dict={model.stimulus: stimulus_test})
    initialize_only_ttf = {'RF': RF_np_noinit, 'ttf': ttf_np_noinit,
                           'firing_rate_test': firing_rate_pred}

    # Initialize ttf and do no other initializations
    RF_np_noinit2, ttf_np_noinit2, model = encoding_fcn(sess, met, stimulus, response,
                                                 ttf_in=None,
                                                 initialize_RF_using_ttf=False,
                                                 scale_ttf=False, lr=0.1)
    firing_rate_pred = sess.run(model.firing_rate,
                                feed_dict={model.stimulus: stimulus_test})
    initialize_none = {'RF': RF_np_noinit2, 'ttf': ttf_np_noinit2,
                       'firing_rate_test': firing_rate_pred}

    encoding_models = {'Init_all': initialize_all,
                       'Init_ttf': initialize_only_ttf,
                       'Init_none': initialize_none,
                       'responses_test': response_test}

    analysis_results.update({'Encoding_models': encoding_models})
    '''


    # 8. Is similarity in images implicitly learnt in the metric ?
    # Reconstruction done in colab notebook

    '''
    class StimulusMetric(object):
      """Compute MSE between stimuli."""

      def get_distance(self, in1, in2):
        return np.sqrt(np.sum(np.sum((in1 - in2)**2, 2), 1))

    # TODO(bhaishahster) : Filtering by time is remaining!
    stimuli_met = StimulusMetric()

    stim_distance, resp_distance, times, responses = analyse.compare_stimulus_score_similarity(data_wn, stimuli_met,
                                                                            met)
    compare_stim_mse_resp_met = {'stimulus_mse': stim_distance,
                                 'response_metric': resp_distance,
                                 'times': times,
                                 'response_pairs': responses}
    analysis_results.update({'perception': compare_stim_mse_resp_met})
    '''

    # 9. Retrieve nearest responses from ALL possible response patterns
    # Retrieve the nearest responses in training data for a probe test response.

    '''
    import itertools
    lst = list(map(list, itertools.product([0, 1], repeat=data_wn.n_cells)))
    all_resp = np.array(lst)
    all_resp = np.expand_dims(all_resp, 2)

    # Get a few test stimuli. Here we use all repreats of a few stimuli.
    n_stims = 100
    probe_responses, stim_id = data_wn.get_response_all_trials(n_stims,
                                                               FLAGS.time_window)
    distances_corpus = analyse.compute_all_distances(all_resp, probe_responses,
                                                              met)
    retrieval_dict = {'probe': probe_responses, 'probe_stim_idx': stim_id,
                      'corpus': all_resp,
                      'distance_corpus': distances_corpus}

    analysis_results.update({'retrieval_ALL_responses': retrieval_dict})
    tf.logging.info('Distance of probe to ALL possible response patterns')
    '''
    # 10. Get embedding for all possible responses,
    #       only if there are less than 15 cells

    if data_wn.n_cells < 15:
      import itertools
      lst = list(map(list, itertools.product([0, 1], repeat=data_wn.n_cells)))
      all_resp = np.expand_dims(np.array(lst), 2)  # use time_window of 1.
      all_resp_embedding = met.get_embedding(all_resp)
      analysis_results.update({'all_resp_embedding': all_resp_embedding})

    
    # save analysis in a pickle file
    # from IPython import embed; embed()
    pickle_file = (os.path.join(model_savepath, model_filename) + '_' +
                   FLAGS.data_test +
                   '_analysis.pkl')
    pickle.dump(analysis_results, gfile.Open(pickle_file, 'w'))
    # pickle.dump(analysis_results, file_io.FileIO(pickle_file, 'w'))
    tf.logging.info('File: ' + pickle_file)
    tf.logging.info('Analysis results saved')
    print('File: ' + pickle_file)
def main(argv):
    #plt.ion() # interactive plotting
    window = FLAGS.window
    n_pix = (2 * window + 1)**2
    dimx = np.floor(1 + ((40 - (2 * window + 1)) / FLAGS.stride)).astype('int')
    dimy = np.floor(1 + ((80 - (2 * window + 1)) / FLAGS.stride)).astype('int')
    nCells = 107
    # load model
    # load filename
    print(FLAGS.model_id)
    with tf.Session() as sess:
        if FLAGS.model_id == 'relu':
            # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0
            short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                              '_ratioSU=' + str(FLAGS.ratio_SU) +
                              '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                              '_normalized_bg')
            w = tf.Variable(
                np.array(np.random.randn(3200, 749), dtype='float32'))
            a = tf.Variable(
                np.array(np.random.randn(749, 107), dtype='float32'))

        if FLAGS.model_id == 'relu_window':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))  # exp 5
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')

            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_sfm':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_sfm_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.01 + 0.005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.02 + np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_a_support':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.002 * np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        if FLAGS.model_id == 'exp_window_a_support':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.002 * np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
        FLAGS.save_location = parent_folder + short_filename + '/'

        # get relevant files
        file_list = gfile.ListDirectory(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename
        print('\nLoading: ', save_filename)
        bin_files = []
        meta_files = []
        for file_n in file_list:
            if re.search(short_filename + '.', file_n):
                if re.search('.meta', file_n):
                    meta_files += [file_n]
                else:
                    bin_files += [file_n]
        #print(bin_files)
        print(len(meta_files), len(bin_files), len(file_list))

        # get iteration numbers
        iterations = np.array([])
        for file_name in bin_files:
            try:
                iterations = np.append(
                    iterations, int(file_name.split('/')[-1].split('-')[-1]))
            except:
                print('Could not load filename: ' + file_name)
        iterations.sort()
        print(iterations)

        iter_plot = iterations[-1]
        print(int(iter_plot))

        # load tensorflow variables
        saver_var = tf.train.Saver(tf.all_variables())

        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        # plot subunit - cell connections
        plt.figure()
        plt.cla()
        plt.imshow(a.eval(), cmap='gray', interpolation='nearest')
        print(np.shape(a.eval()))
        plt.title('Iteration: ' + str(int(iter_plot)))
        plt.show()
        plt.draw()

        # plot all subunits on 40x80 grid
        try:
            wts = w.eval()
            for isu in range(100):
                fig = plt.subplot(10, 10, isu + 1)
                plt.imshow(np.reshape(wts[:, isu], [40, 80]),
                           interpolation='nearest',
                           cmap='gray')
            plt.title('Iteration: ' + str(int(iter_plot)))
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
        except:
            print('w full does not exist? ')

        # plot a few subunits - wmother + wdel
        try:
            wts = w.eval()
            print('wts shape:', np.shape(wts))
            icnt = 1
            for idimx in np.arange(dimx):
                for idimy in np.arange(dimy):
                    fig = plt.subplot(dimx, dimy, icnt)
                    plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]),
                                          (2 * window + 1, 2 * window + 1)),
                               interpolation='nearest',
                               cmap='gray')
                    icnt = icnt + 1
                    fig.axes.get_xaxis().set_visible(False)
                    fig.axes.get_yaxis().set_visible(False)
            plt.show()
            plt.draw()
        except:
            print('w does not exist?')

        # plot wmother
        try:
            w_mot = np.squeeze(w_mother.eval())
            print(w_mot)
            plt.imshow(w_mot, interpolation='nearest', cmap='gray')
            plt.title('Mother subunit')
            plt.show()
            plt.draw()
        except:
            print('w mother does not exist')

        # plot wmother + wdel
        try:
            w_mot = np.squeeze(w_mother.eval())
            w_del = np.squeeze(w_del.eval())
            wts = np.array(np.random.randn(dimx, dimy, (2 * window + 1)**2))
            for idimx in np.arange(dimx):
                print(idimx)
                for idimy in np.arange(dimy):
                    wts[idimx,
                        idimy, :] = np.ndarray.flatten(w_mot) + w_del[idimx,
                                                                      idimy, :]
        except:
            print('w mother + w delta do not exist? ')
        '''
    try:
      
      icnt=1
      for idimx in np.arange(dimx):
        for idimy in np.arange(dimy):
          fig = plt.subplot(dimx, dimy, icnt)
          plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2*window+1,2*window+1)), interpolation='nearest', cmap='gray')
          fig.axes.get_xaxis().set_visible(False)
          fig.axes.get_yaxis().set_visible(False)
    except:
      print('w mother + w delta plotting error? ')
    
    # plot wdel
    try:
      w_del = np.squeeze(w_del.eval())
      icnt=1
      for idimx in np.arange(dimx):
        for idimy in np.arange(dimy):
          fig = plt.subplot(dimx, dimy, icnt)
          plt.imshow( np.reshape(w_del[idimx, idimy, :], (2*window+1,2*window+1)), interpolation='nearest', cmap='gray')
          icnt = icnt+1
          fig.axes.get_xaxis().set_visible(False)
          fig.axes.get_yaxis().set_visible(False)
    except:
      print('w delta do not exist? ')
    plt.suptitle('Iteration: ' + str(int(iter_plot)))
    plt.show()
    plt.draw()
    '''
        # select a cell, and show its subunits.
        #try:

        ## Load data summary, get mask
        filename = FLAGS.data_location + 'data_details.mat'
        summary_file = gfile.Open(filename, 'r')
        data_summary = sio.loadmat(summary_file)
        total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
        stas = data_summary['stas']
        print(np.shape(total_mask))

        # a is 2D

        a_eval = a.eval()
        print(np.shape(a_eval))
        # get softmax numpy
        if FLAGS.model_id == 'relu_window_mother_sfm' or FLAGS.model_id == 'relu_window_mother_sfm_exp':
            b = np.exp(a_eval) / np.sum(np.exp(a_eval), 0)
        else:
            b = a_eval

        plt.figure()
        plt.imshow(b, interpolation='nearest', cmap='gray')
        plt.show()
        plt.draw()

        # plot subunits for multiple cells.
        n_cells = 10
        n_plots_max = 20
        plt.figure()
        for icell_cnt, icell in enumerate(np.arange(n_cells)):
            mask2D = np.reshape(total_mask[icell, :], [40, 80])
            nz_idx = np.nonzero(mask2D)
            np.shape(nz_idx)
            print(nz_idx)
            ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
            xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])

            icnt = -1
            a_thr = np.percentile(np.abs(b[:, icell]), 99.5)
            n_plots = np.sum(np.abs(b[:, icell]) > a_thr)
            nx = np.ceil(np.sqrt(n_plots)).astype('int')
            ny = np.ceil(np.sqrt(n_plots)).astype('int')
            ifig = 0
            ww_sum = np.zeros((40, 80))

            for idimx in np.arange(dimx):
                for idimy in np.arange(dimy):
                    icnt = icnt + 1
                    if (np.abs(b[icnt, icell]) > a_thr):
                        ifig = ifig + 1
                        fig = plt.subplot(n_cells, n_plots_max,
                                          icell_cnt * n_plots_max + ifig + 2)
                        ww = np.zeros((40, 80))
                        ww[idimx * FLAGS.stride:idimx * FLAGS.stride +
                           (2 * window + 1),
                           idimy * FLAGS.stride:idimy * FLAGS.stride +
                           (2 * window + 1)] = b[icnt, icell] * (np.reshape(
                               wts[idimx, idimy, :],
                               (2 * window + 1, 2 * window + 1)))
                        plt.imshow(ww, interpolation='nearest', cmap='gray')
                        plt.ylim(ylim)
                        plt.xlim(xlim)
                        plt.title(b[icnt, icell])
                        fig.axes.get_xaxis().set_visible(False)
                        fig.axes.get_yaxis().set_visible(False)

                        ww_sum = ww_sum + ww

            fig = plt.subplot(n_cells, n_plots_max,
                              icell_cnt * n_plots_max + 2)
            plt.imshow(ww_sum, interpolation='nearest', cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('STA from model')

            fig = plt.subplot(n_cells, n_plots_max,
                              icell_cnt * n_plots_max + 1)
            plt.imshow(np.reshape(stas[:, icell], [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('True STA')

        plt.show()
        plt.draw()

        #except:
        #  print('a not 2D?')

        # using xlim and ylim, and plot the 'windows' which are relevant with their weights
        sq_flat = np.zeros((dimx, dimy))
        icnt = 0
        for idimx in np.arange(dimx):
            for idimy in np.arange(dimy):
                sq_flat[idimx, idimy] = icnt
                icnt = icnt + 1

        n_cells = 1
        n_plots_max = 10
        plt.figure()
        for icell_cnt, icell in enumerate(np.array(
            [1, 2, 3, 4, 5])):  #enumerate(np.arange(n_cells)):
            a_thr = np.percentile(np.abs(b[:, icell]), 99.5)
            mask2D = np.reshape(total_mask[icell, :], [40, 80])
            nz_idx = np.nonzero(mask2D)
            np.shape(nz_idx)
            print(nz_idx)
            ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
            xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])
            print(xlim, ylim)

            win_startx = np.ceil((xlim[0] - (2 * window + 1)) / FLAGS.stride)
            win_endx = np.floor((xlim[1] - 1) / FLAGS.stride)
            win_starty = np.ceil((ylim[0] - (2 * window + 1)) / FLAGS.stride)
            win_endy = np.floor((ylim[1] - 1) / FLAGS.stride)
            dimx_plot = win_endx - win_startx + 1
            dimy_plot = win_endy - win_starty + 1
            ww_sum = np.zeros((40, 80))
            for irow, idimy in enumerate(np.arange(win_startx, win_endx + 1)):
                for icol, idimx in enumerate(
                        np.arange(win_starty, win_endy + 1)):
                    fig = plt.subplot(dimx_plot + 1, dimy_plot,
                                      (irow + 1) * dimy_plot + icol + 1)
                    ww = np.zeros((40, 80))
                    ww[idimx * FLAGS.stride:idimx * FLAGS.stride +
                       (2 * window + 1),
                       idimy * FLAGS.stride:idimy * FLAGS.stride +
                       (2 * window + 1)] = (np.reshape(
                           wts[idimx,
                               idimy, :], (2 * window + 1, 2 * window + 1)))
                    plt.imshow(ww, interpolation='nearest', cmap='gray')
                    plt.ylim(ylim)
                    plt.xlim(xlim)
                    if b[sq_flat[idimx, idimy], icell] > a_thr:
                        plt.title(b[sq_flat[idimx, idimy], icell],
                                  fontsize=10,
                                  color='g')
                    else:
                        plt.title(b[sq_flat[idimx, idimy], icell],
                                  fontsize=10,
                                  color='r')
                    fig.axes.get_xaxis().set_visible(False)
                    fig.axes.get_yaxis().set_visible(False)

                    ww_sum = ww_sum + ww * b[sq_flat[idimx, idimy], icell]

            fig = plt.subplot(dimx_plot + 1, dimy_plot, 2)
            plt.imshow(ww_sum, interpolation='nearest', cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('STA from model')

            fig = plt.subplot(dimx_plot + 1, dimy_plot, 1)
            plt.imshow(np.reshape(stas[:, icell], [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('True STA')

            plt.show()
            plt.draw()
Beispiel #25
0
    def __init__(self,
                 data_location,
                 batch_sz_in=1000,
                 masked_stimulus=False,
                 chosen_cells=None,
                 total_samples=216000,
                 storage_chunk_sz=1000,
                 data_chunk_prefix='Off_par_data_',
                 stimulus_dimension=3200,
                 test_length=20000,
                 sr_key={
                     'stimulus': 'maskedMovdd_part',
                     'response': 'Y_part'
                 }):

        # loads dataset, sets up variables.
        # data_location: where data is stored, after splitting into chunks,
        # with each chunk having prefix 'data_chunk_prefix'
        # having 'storage_chunk_sz' number of samples out of 'total_samples'
        # 'masked_stimulus': if we want to clip the stimulus to relevant pixels for selected cells
        # chosen_cells : which cells to include in 'response'. If None, include all cells
        # sr_key: what variables in stored data correspond to stimulus and response matrices
        # the stored matrices should have shape:
        # stimulus (# stimulus_dimension x time) and response (# chosen_cells x time)
        # batch_sz_in : number of examples in each batch of training data
        # test_length : number of examples in test data

        if chosen_cells is None:
            all_cells = True

        self.batch_sz = batch_sz_in

        print('Initializing datasets')
        # Load summary of datasets
        data_filename = data_location + 'data_details.mat'
        tf.logging.info('Loading summary file: ' + data_filename)
        summary_file = gfile.Open(data_filename, 'r')

        data_summary = sio.loadmat(summary_file)
        #data_summary = sio.loadmat(data_filename)

        cells = np.squeeze(data_summary['cells'])
        print('Dataset details loaded')

        # choose cells to build model for
        if all_cells:
            self.cells_choose = np.array(np.ones(np.shape(cells)),
                                         dtype='bool')
        else:
            cells_choose = np.zeros(len(cells))
            for icell in chosen_cells:
                cells_choose += np.array(cells == icell).astype(np.int)
            self.cells_choose = cells_choose > 0
        n_cells = np.sum(self.cells_choose)  # number of cells
        self.stas = np.array(data_summary['stas'])
        self.stas = self.stas[:, self.cells_choose]
        print('Cells selected: %d' % n_cells)

        # count total number of spikes for chosen cells
        tot_spks = np.squeeze(data_summary['tot_spks'])
        tot_spks_chosen_cells = np.array(tot_spks[self.cells_choose],
                                         dtype='float32')
        self.total_spikes_chosen_cells = tot_spks_chosen_cells
        print('Total number of spikes loaded')

        # choose what stimulus mask to use
        # self.chosen_mask = which pixels to learn subunits over
        if masked_stimulus:
            total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
            self.cell_mask = total_mask
            self.chosen_mask = np.array(
                np.sum(total_mask[self.cells_choose, :], 0) > 0, dtype='bool')
        else:
            total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
            self.cell_mask = total_mask
            self.chosen_mask = np.array(
                np.ones(stimulus_dimension).astype('bool'))
        stim_dim = np.sum(self.chosen_mask)  # stimulus dimensions
        print('Stimulus mask size: %d' % np.sum(stim_dim))
        print('Stimulus mask loaded')

        # load stimulus and response from .mat files
        # python cant read too big .mat files,
        # so have broken it down into smaller pieces and stitch the data later
        self.stimulus = np.zeros((total_samples, np.sum(self.chosen_mask)))
        self.response = np.zeros((total_samples, n_cells))
        n_chunks_load = np.int(total_samples / storage_chunk_sz)
        for ichunk in range(216):
            print('Loading %d' % ichunk)
            filename = data_location + data_chunk_prefix + str(ichunk +
                                                               1) + '.mat'
            tf.logging.info('Trying to load: ' + filename)
            file_r = gfile.Open(filename, 'r')

            data = sio.loadmat(file_r)
            #data = sio.loadmat(filename)

            X = data[sr_key['stimulus']].T
            Y = data[sr_key['response']].T
            self.stimulus[ichunk * storage_chunk_sz:(ichunk + 1) *
                          storage_chunk_sz, :] = X[:, self.chosen_mask]
            self.response[ichunk * storage_chunk_sz:(ichunk + 1) *
                          storage_chunk_sz, :] = Y[:, self.cells_choose]

        # set up training and testing chunk IDs
        n_chunks = np.int(total_samples / self.batch_sz)
        test_num_chunks = test_length / self.batch_sz
        self.test_chunks = np.arange(test_num_chunks)
        self.train_chunks = np.random.permutation(
            np.arange(test_num_chunks + 1, n_chunks))
        self.ichunk_train = 0
def main(argv):

    print('\nCode started')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)

    ## Load data
    file = h5py.File(FLAGS.data_location + 'Off_parasol.mat', 'r')
    print('\ndataset loaded')

    # Load Masked movie
    data = file.get('maskedMovdd')
    stimulus = np.array(data)
    cells = file.get('cells')
    nCells = cells.shape[0]
    total_mask_log = file.get('totalMaskAccept_log')
    Nsub = FLAGS.ratio_SU * nCells
    stim_dim = stimulus.shape[1]

    # Load spike Response of cells
    data = file.get('Y')
    response = np.array(data, dtype='float32')
    tot_spks = np.squeeze(np.sum(response, axis=0))

    print(sys.getsizeof(file))
    print(sys.getsizeof(stimulus))
    print(sys.getsizeof(response))

    with tf.Session() as sess:
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.model_id == 0:
            # MODEL: lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0
            w_init = initialize_su(n_su=Nsub)
            a_init = np.random.rand(Nsub, nCells)
            w = tf.Variable(np.array(w_init, dtype='float32'))
            a = tf.Variable(np.array(a_init, dtype='float32'))
            lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + 0.0001
            loss = (tf.reduce_sum(lam / tot_spks) / 120. -
                    tf.reduce_sum(resp * tf.log(lam) / tot_spks)) / data_len
            loss_with_reg = loss + FLAGS.lam_w * tf.reduce_sum(
                tf.abs(w)) + FLAGS.lam_a * tf.reduce_sum(tf.abs(a))
            # training steps for a.
            train_step_a = tf.train.AdagradOptimizer(FLAGS.eta_a).minimize(
                loss, var_list=[a])
            # as 'a' is positive, this is op soft-thresholding for L1 and projecting to feasible set
            soft_th_a = tf.assign(a, tf.nn.relu(a - FLAGS.eta_a * FLAGS.lam_a))

            # training steps for w
            train_step_w = tf.train.AdagradOptimizer(FLAGS.eta_w).minimize(
                loss, var_list=[w])
            # do soft thresholding for 'w'
            soft_th_w = tf.assign(
                w,
                tf.nn.relu(w - FLAGS.eta_w * FLAGS.lam_w) -
                tf.nn.relu(-w - FLAGS.eta_w * FLAGS.lam_w))

            save_filename = (FLAGS.save_location + 'data_model=' +
                             str(FLAGS.model_id) + '_lam_w=' +
                             str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                             '_eta_w=' + str(FLAGS.eta_w) + '_eta_a=' +
                             str(FLAGS.eta_a) + '_ratioSU=' +
                             str(FLAGS.ratio_SU) + '_grid_spacing=' +
                             str(FLAGS.su_grid_spacing) + '_proximal')

        # initialize the model
        logfile = gfile.Open(save_filename + '.txt', "w")
        logfile.write('Starting new code\n')
        logfile.flush()
        sess.run(tf.initialize_all_variables())

        # Do the fitting
        batchsz = 100
        icnt = 0
        fd_test = {
            stim: stimulus.astype('float32')[216000 - 1000:216000 - 1, :],
            resp: response.astype('float32')[216000 - 1000:216000 - 1, :],
            data_len: 1000
        }

        ls_train_log = np.array([])
        ls_train_reg_log = np.array([])
        ls_test_log = np.array([])
        ls_test_reg_log = np.array([])

        tms = np.random.permutation(np.arange(216000 - 1000))
        for istep in range(100000):
            fd_train = {
                stim: stimulus.astype('float32')[tms[icnt:icnt + batchsz], :],
                resp: response.astype('float32')[tms[icnt:icnt + batchsz], :],
                data_len: batchsz
            }

            # gradient step for 'w'
            sess.run(train_step_w, feed_dict=fd_train)
            # soft thresholding for w
            sess.run(soft_th_w, feed_dict=fd_train)
            # gradient step for 'a'
            sess.run(train_step_a, feed_dict=fd_train)
            # soft thresholding for a, and project in constraint set
            sess.run(soft_th_a, feed_dict=fd_train)

            if istep % 10 == 0:
                # compute training and testing losses0.0
                ls_train = sess.run(loss, feed_dict=fd_train)
                ls_train_log = np.append(ls_train_log, ls_train)
                ls_train_reg = sess.run(loss_with_reg, feed_dict=fd_train)
                ls_train_reg_log = np.append(ls_train_reg_log, ls_train_reg)

                ls_test = sess.run(loss, feed_dict=fd_test)
                ls_test_log = np.append(ls_test_log, ls_test)
                ls_test_reg = sess.run(loss_with_reg, feed_dict=fd_test)
                ls_test_reg_log = np.append(ls_test_reg_log, ls_test_reg)

                # log results
                logfile.write('\nIterations: ' + str(istep) +
                              ' Training loss: ' + str(ls_train) +
                              ' with reg: ' + str(ls_train_reg) +
                              ' Testing loss: ' + str(ls_test) +
                              ' with reg: ' + str(ls_test_reg) +
                              '  w_l1_norm: ' + str(np.sum(np.abs(w.eval()))) +
                              ' a_l1_norm: ' + str(np.sum(np.abs(a.eval()))))
                logfile.flush()

                sio.savemat(
                    save_filename + '.mat', {
                        'w': w.eval(),
                        'a': a.eval(),
                        'w_init': w_init,
                        'a_init': a_init,
                        'ls_train_log': ls_train_log,
                        'ls_train_reg_log': ls_train_reg_log,
                        'ls_test_log': ls_test_log,
                        'ls_test_reg_log': ls_test_reg_log
                    })

            icnt += batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))

    logfile.close()
Beispiel #27
0
def main(unused_argv=()):

  np.random.seed(23)
  tf.set_random_seed(1234)
  random.seed(50)

  # Load stimulus-response data.
  # Collect population response across retinas in the list 'responses'.
  # Stimulus for each retina is indicated by 'stim_id',
  # which is found in 'stimuli' dictionary.
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  stimuli = {}
  responses = []
  for icnt, idataset in enumerate(datasets):

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key,
                                           boundary=FLAGS.valid_cells_boundary)
      stimulus, resp, dimx, dimy, _ = op

      stimuli.update({key: stimulus})
      responses += resp

  taskid = FLAGS.taskid
  dat = responses[taskid]
  stimulus = stimuli[dat['stimulus_key']]

  # parameters
  window = 5

  # Compute time course and non-linearity as two parameters which might be should be explored in embedded space.
  n_cells = dat['responses'].shape[1]
  T = np.minimum(stimulus.shape[0], dat['responses'].shape[0])

  stim_short = stimulus[:T, :, :]
  resp_short = dat['responses'][:T, :].astype(np.float32)

  save_dict = {}

  # Find time course, non-linearity and RF parameters

  ########################################################################
  # Separation between cell types
  ########################################################################
  save_dict.update({'cell_type': dat['cell_type']})
  save_dict.update({'dist_nn_cell_type': dat['dist_nn_cell_type']})

  ########################################################################
  # Find mean firing rate
  ########################################################################
  mean_fr = dat['responses'].mean(0)
  mean_fr_1 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==1])
  mean_fr_2 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==2])

  mean_fr_dict = {'mean_fr': mean_fr,
                  'mean_fr_1': mean_fr_1, 'mean_fr_2': mean_fr_2}
  save_dict.update({'mean_fr_dict': mean_fr_dict})

  ########################################################################
  # compute STAs
  ########################################################################
  stas = np.zeros((n_cells, 80, 40, 30))
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    for idelay in range(30):
      stas[icell, windx[0]: windx[1], windy[0]: windy[1], idelay] = np.reshape(resp_short[idelay:, icell].dot(stim_cell[:T-idelay, :]),
                                                                           [windx[1] - windx[0], windy[1] - windy[0]]) / np.sum(resp_short[idelay:, icell])

  stas_dict = {'stas': stas}
  # save_dict.update({'stas_dict': stas_dict})


  ########################################################################
  # Find time courses for each cell
  ########################################################################
  ttf_log = []
  for icell in range(n_cells):
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    ll = stas[icell, windx[0]: windx[1], windy[0]: windy[1], :]

    ll_2d = np.reshape(ll, [-1, ll.shape[-1]])

    u, s, v = np.linalg.svd(ll_2d)

    ttf_log += [v[0, :]]

  ttf_log = np.array(ttf_log)
  signs = [np.sign(ttf_log[icell, np.argmax(np.abs(ttf_log[icell, :]))]) for icell in range(ttf_log.shape[0])]
  ttf_corrected = np.expand_dims(np.array(signs), 1) * ttf_log
  ttf_corrected[np.squeeze(dat['cell_type'])==1, :] = ttf_corrected[np.squeeze(dat['cell_type'])==1, :] * -1

  ttf_mean_1 = ttf_corrected[np.squeeze(dat['cell_type'])==1, :].mean(0)
  ttf_mean_2 = ttf_corrected[np.squeeze(dat['cell_type'])==2, :].mean(0)

  ttf_params_1 = get_times(ttf_mean_1)
  ttf_params_2 = get_times(ttf_mean_2)

  ttf_dict = {'ttf_log': ttf_log,
              'ttf_mean_1': ttf_mean_1, 'ttf_mean_2': ttf_mean_2,
              'ttf_params_1': ttf_params_1, 'ttf_params_2': ttf_params_2}

  save_dict.update({'ttf_dict': ttf_dict})
  '''
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==1, :].T, 'r', alpha=0.3)
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==2, :].T, 'k', alpha=0.3)

  plt.plot(ttf_mean_1, 'r--')
  plt.plot(ttf_mean_2, 'k--')
  '''

  ########################################################################
  ## Find non-linearity
  ########################################################################
  f_nl = lambda x, p0, p1, p2, p3: p0 + p1*x + p2* np.power(x, 2) + p3* np.power(x, 3)

  nl_params_log = []
  stim_resp_log = []
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]

    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    sta_cell = np.reshape(stas[icell, windx[0]: windx[1], windy[0]: windy[1], :], [-1, stas.shape[-1]])

    stim_filter = np.zeros(stim_short.shape[0])
    for idelay in range(30):
      stim_filter[idelay: ] += stim_cell[:T-idelay, :].dot(sta_cell[:, idelay])

    # Normalize stim_filter
    stim_filter -= np.mean(stim_filter)
    stim_filter /= np.sqrt(np.var(stim_filter))

    resp_cell = resp_short[:, icell]

    stim_nl = []
    resp_nl = []
    for ipercentile in range(3, 97, 1):
      lb = np.percentile(stim_filter, ipercentile-3)
      ub = np.percentile(stim_filter, ipercentile+3)
      tms = np.logical_and(stim_filter >= lb, stim_filter < ub)
      stim_nl += [np.mean(stim_filter[tms])]
      resp_nl += [np.mean(resp_cell[tms])]

    stim_nl = np.array(stim_nl)
    resp_nl = np.array(resp_nl)

    popt, pcov = scipy.optimize.curve_fit(f_nl, stim_nl, resp_nl, p0=[1, 0, 0, 0])
    nl_params_log += [popt]
    stim_resp_log += [[stim_nl, resp_nl]]

  nl_params_log = np.array(nl_params_log)

  np_params_mean_1 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==1, :], 0)
  np_params_mean_2 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==2, :], 0)

  nl_params_dict = {'nl_params_log': nl_params_log,
                    'np_params_mean_1': np_params_mean_1,
                    'np_params_mean_2': np_params_mean_2,
                    'stim_resp_log': stim_resp_log}

  save_dict.update({'nl_params_dict': nl_params_dict})

  '''
  # Visualize Non-linearities
  for icell in range(n_cells):

    stim_in = np.arange(-3, 3, 0.1)
    fr = f_nl(stim_in, *nl_params_log[icell, :])
    if np.squeeze(dat['cell_type'])[icell] == 1:
      c = 'r'
    else:
      c = 'k'
    plt.plot(stim_in, fr, c, alpha=0.2)

  fr = f_nl(stim_in, *np_params_mean_1)
  plt.plot(stim_in, fr, 'r--')

  fr = f_nl(stim_in, *np_params_mean_2)
  plt.plot(stim_in, fr, 'k--')
  '''

  pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder , dat['piece']), 'w'))
  pickle.dump(stas_dict, gfile.Open(os.path.join(FLAGS.save_folder , 'stas' + dat['piece']), 'w'))
Beispiel #28
0
def main(argv):
    print('\nCode started')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    global chosen_mask
    global cells_choose

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
        cells == 3066)
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = tot_spks[cells_choose]
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)
    # get test data
    stim_test, resp_test, test_length = get_test_data()

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    #short_filename = 'data_model=ASM_pop_bg'
    #  short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) +
    #    '_step_sz'+ str(FLAGS.step_sz)+'_bg')

    # saving details
    batchsz = np.array([1000, 1000, 1000], dtype='int')
    n_b_in_c = np.array([1, 1, 1], dtype='int')
    step_sz = np.array([1, 1, 1], dtype='float32')
    folder_names = ['experiment25', 'experiment22', 'experiment23']
    roc_data = [[]] * n_cells
    for icnt, FLAGS.model_id in enumerate(['hinge', 'poisson', 'logistic']):
        # restore file
        FLAGS.batchsz = batchsz[icnt]
        FLAGS.n_b_in_c = n_b_in_c[icnt]
        FLAGS.step_sz = step_sz[icnt]
        folder_name = folder_names[icnt]
        if FLAGS.model_id == 'poisson':
            short_filename = ('data_model=ASM_pop_batch_sz=' +
                              str(FLAGS.batchsz) + '_n_b_in_c' +
                              str(FLAGS.n_b_in_c) + '_step_sz' +
                              str(FLAGS.step_sz) + '_bg')

        if FLAGS.model_id == 'poisson_full':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' +
                              str(FLAGS.n_b_in_c) + '_step_sz' +
                              str(FLAGS.step_sz) + '_bg')

        if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' +
                              str(FLAGS.n_b_in_c) + '_step_sz' +
                              str(FLAGS.step_sz) + '_bg')

        print(FLAGS.model_id)
        parent_folder = FLAGS.save_location + folder_name + '/'
        save_location = parent_folder + short_filename + '/'
        restore_file = get_latest_file(save_location, short_filename)

        tf.reset_default_graph()
        with tf.Session() as sess:
            # Learn population model!
            stim = tf.placeholder(tf.float32,
                                  shape=[None, stim_dim],
                                  name='stim')
            resp = tf.placeholder(tf.float32, name='resp')
            data_len = tf.placeholder(tf.float32, name='data_len')

            # define models
            if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
                w = tf.Variable(
                    np.array(0.01 * np.random.randn(stim_dim, n_su),
                             dtype='float32'))
                a = tf.Variable(
                    np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                             dtype='float32'))
                z = tf.transpose(
                    tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a), 2))

            if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge':
                w = tf.Variable(
                    np.array(0.01 * np.random.randn(stim_dim, n_su),
                             dtype='float32'))
                a = tf.Variable(
                    np.array(0.01 * np.random.rand(n_su, n_cells),
                             dtype='float32'))
                b_init = np.random.randn(
                    n_cells
                )  #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0)))
                b = tf.Variable(b_init, dtype='float32')
                z = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b

            # restore variables
            # load tensorflow variables
            print(tf.all_variables())
            saver_var = tf.train.Saver(tf.all_variables())
            saver_var.restore(sess, restore_file)
            fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

            z_eval = sess.run(z, feed_dict=fd_test)
            print(z_eval[0:20, :])
            print(resp_test[0:20, :])

            for roc_cell in np.arange(n_cells):
                roc = get_roc(z_eval[:, roc_cell], resp_test[:, roc_cell])
                roc_data[roc_cell] = roc_data[roc_cell] + [roc]
                print(roc_cell)

    plt.figure()
    for icell in range(n_cells):
        plt.subplot(1, n_cells, icell + 1)
        for icnt in np.arange(3):
            plt.plot(roc_data[icell][icnt][0], roc_data[icell][icnt][1])
            plt.hold(True)
            plt.xlabel('recall')
            plt.ylabel('precision')
        plt.legend(['hinge', 'poisson', 'logistic'])
        cells_ch = cells[cells_choose]
        plt.title(cells_ch[icell])
    plt.show()
    plt.draw()
Beispiel #29
0
def test_metric(training_datasets, testing_datasets, responses, stimuli,
                sr_graph, sess, file_name):
    print('Testing for metric learning')
    tf.logging.info('Testing for metric learning')

    # saving filename.
    save_analysis_filename = os.path.join(FLAGS.save_folder,
                                          file_name + '_analysis_sample_resps')

    save_dict = {}
    #if FLAGS.taskid in [2, 8, 14, 20]:
    #  return

    retina_ids = [resp['piece'] for resp in responses]
    '''
  # NULL + SUBUNITS: Test subunits on null stimulus
  # Responses with different number of subunits - which is closer to stimulus?
  # Take embedding of population responses to same stimulus but
  # different number of subunits and see which is closer to the stimulus.
  rand_number = np.random.rand(2)
  dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets}
  su_analysis = experiments.subunit_discriminability_null(dataset_dict,
                                                          stimuli, responses,
                                                          sr_graph,
                                                          num_examples=10000)
  su_analysis.update({'rand_key': rand_number})
  su_analysis.update({'retina_ids': retina_ids})
  pickle.dump(su_analysis, gfile.Open(save_analysis_filename + '_test14_su_analysis_null.pkl', 'w'))

  

  # SUBUNITS: Responses with different number of subunits - which is closer to stimulus?
  # Take embedding of population responses to same stimulus but
  # different number of subunits and see which is closer to the stimulus.
  rand_number = np.random.rand(2)
  dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets}
  su_analysis = experiments.subunit_discriminability(dataset_dict,
                                          stimuli, responses, sr_graph,
                                          num_examples=10000)
  su_analysis.update({'rand_key': rand_number})
  su_analysis.update({'retina_ids': retina_ids})
  pickle.dump(su_analysis, gfile.Open(save_analysis_filename + '_test12_su_analysis.pkl', 'w'))

  '''
    # Embed stimuli-response + decode
    # WN
    for stim_key in stimuli.keys():
        expt = experiments.stimulus_response_embedding_expt
        stim_resp_embedding = expt(stimuli,
                                   responses,
                                   sr_graph,
                                   stim_key=stim_key,
                                   n_samples=1000)
        save_dict.update({'stim_resp_embedding_wn': stim_resp_embedding})
        save_dict.update({'retina_ids': retina_ids})
        pickle.dump(
            stim_resp_embedding,
            gfile.Open((save_analysis_filename + '_test3_%s.pkl') % stim_key,
                       'w'))
    '''
  # stimulus + response embed
  for stim_key in stimuli.keys():
    expt = experiments.stimulus_response_embedding_expt
    stim_resp_embedding = expt(stimuli, responses, sr_graph,
                               stim_key=stim_key, n_samples=1000, if_continuous=True)
    save_dict.update({'stim_resp_embedding_wn': stim_resp_embedding})
    save_dict.update({'retina_ids': retina_ids})
    pickle.dump(stim_resp_embedding, gfile.Open((save_analysis_filename + '_test3_continuous_%s.pkl') %  stim_key, 'w'))
  '''

    # ROC analysis on training and testing datasets
    rand_number = np.random.rand(2)
    dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets}
    roc_analysis = experiments.roc_analysis(dataset_dict,
                                            stimuli,
                                            responses,
                                            sr_graph,
                                            num_examples=10000)
    roc_analysis.update({'rand_key': rand_number})
    roc_analysis.update({'retina_ids': retina_ids})
    pickle.dump(roc_analysis,
                gfile.Open(save_analysis_filename + '_test4.pkl', 'w'))

    # ROC analysis with subsampling of cells
    # dataset_dict = {'train_sr': training_datasets, 'test_sr': testing_datasets}
    dataset_dict = {'test_sr': testing_datasets}
    if len(training_datasets) == 1:
        dataset_dict.update({'train_sr': training_datasets})

    roc_frac_cells_dict = {}
    for frac_cells in [0.05, 0.1, 0.15, 0.2, 0.5, 0.8, 1.0]:
        print('frac_cells : %.3f' % frac_cells)
        roc_analysis = experiments.roc_analysis(dataset_dict,
                                                stimuli,
                                                responses,
                                                sr_graph,
                                                num_examples=10000,
                                                frac_cells=frac_cells)
        roc_frac_cells_dict.update({frac_cells: roc_analysis})
    roc_frac_cells_dict.update({'retina_ids': retina_ids})
    pickle.dump(
        roc_frac_cells_dict,
        gfile.Open(save_analysis_filename + '_test4_frac_cells_new.pkl', 'w'))
    print(save_analysis_filename + '_test4_frac_cells_new.pkl')

    # ROC analysis - RSS triplets - on training and testing datasets
    rand_number = np.random.rand(2)
    dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets}
    roc_analysis = experiments.roc_analysis(dataset_dict,
                                            stimuli,
                                            responses,
                                            sr_graph,
                                            num_examples=10000,
                                            negative_stim=True)
    roc_analysis.update({'rand_key': rand_number})
    roc_analysis.update({'retina_ids': retina_ids})
    pickle.dump(roc_analysis,
                gfile.Open(save_analysis_filename + '_test4_rss.pkl', 'w'))

    ## Response transformations
    expt = experiments.response_transformation_increase_nl
    nl_expt = expt(stimuli,
                   responses,
                   sr_graph,
                   time_start_list=[100, 4000, 10000],
                   time_len=100,
                   alpha_list=[1.5, 1.25, 0.8, 0.6])
    nl_expt.update({'retina_ids': retina_ids})
    pickle.dump(
        nl_expt,
        gfile.Open(save_analysis_filename + '_test7_resp_transform_wn.pkl',
                   'w'))

    ## Embed all stimuli.
    '''
  stim_embedding_dict = experiments.stimulus_embedding_expt(stimuli, sr_graph,
                                                n_stims_per_type=1000)

  save_dict.update({'stimulus_embedding': stim_embedding_dict})

  pickle.dump(stim_embedding_dict, gfile.Open(save_analysis_filename + '_test1.pkl', 'w'))
  print('Saved after test 1')
  '''

    ## Explore invariances of the stimulus embedding.
    # Change luminance, contrast, geometrical changes (translate, rotate) and see
    # how they are reflected in embedded space.
    #
    # Get stimuli which are transformed
    stim_transformations_dict = experiments.stimulus_transformations_expt(
        stimuli,
        sr_graph,
        n_stims_per_type_transform=50,
        n_stims_per_type_bkg=500)  # 2000
    save_dict.update({'stimulus_transformations': stim_transformations_dict})

    pickle.dump(stim_transformations_dict,
                gfile.Open(save_analysis_filename + '_test2.pkl', 'w'))
    print('Saved after test 2')

    # NSEM
    '''
  stim_resp_embedding = expt(stimuli, responses, sr_graph,
                             stim_key='stim_2', n_samples=100)
  save_dict.update({'stim_resp_embedding_nsem': stim_resp_embedding})
  save_dict.update({'retina_ids': retina_ids})
  pickle.dump(stim_resp_embedding, gfile.Open(save_analysis_filename + '_test3_nsem.pkl', 'w'))
  print('Saved after test 3')
  '''

    # Drop different cell types
    expt = experiments.resp_drop_cells_expt
    for stim_key in stimuli.keys():
        resp_drop_cells = expt(stimuli,
                               responses,
                               sr_graph,
                               stim_key=stim_key,
                               n_samples=100)
        save_dict.update({'resp_drop_cells': resp_drop_cells})
        save_dict.update({'retina_ids': retina_ids})
        pickle.dump(
            resp_drop_cells,
            gfile.Open(save_analysis_filename + '_test8_%s.pkl' % stim_key,
                       'w'))

    ## TODO(bhaishahster): joint-auto-embed model response prediction for some stimuli
    # Rasters across different retinas embedded in same space.
    expt = experiments.resp_wn_repeats_multiple_retina
    resp_repeats = expt(sr_graph, n_samples=100)
    pickle.dump(resp_repeats,
                gfile.Open(save_analysis_filename + '_test9_repeats.pkl', 'w'))

    # Predict responses by embedding from rasters of different retina.
    expt = experiments.resp_wn_repeats_multiple_retina_encoding
    resp_repeats, response_embeddings, cell_geometry_log = expt(
        sr_graph, n_repeats=10)  # n_repeats=None for using all repeats
    pickle.dump(
        resp_repeats,
        gfile.Open(save_analysis_filename + '_test10_repeats_prediction.pkl',
                   'w'))

    # Interpolate between retinas
    expt = experiments.resp_wn_repeats_interpolate_embeddings
    interpolate_dict = expt(
        sr_graph, [ri[:7, 600:, :, :, :] for ri in response_embeddings])
    pickle.dump([interpolate_dict, cell_geometry_log],
                gfile.Open(
                    save_analysis_filename +
                    '_test11_repeats_pred_interpolation.pkl', 'w'))

    # Interpolate between retinas - mean of embedding across repeats
    expt = experiments.resp_wn_repeats_interpolate_embeddings
    interpolate_dict = expt(sr_graph, [
        np.expand_dims(ri[:, 600:, :, :, :].mean(0), 0)
        for ri in response_embeddings
    ])
    pickle.dump([interpolate_dict, cell_geometry_log],
                gfile.Open(
                    save_analysis_filename +
                    '_test11_repeats_pred_interpolation_mean.pkl', 'w'))

    return
    '''
  # Load multiple check points and analyse accuracy
  # /retina/response_model/python/metric_learning/end_to_end/stimulus_response_embedding --logtostderr --mode=1 --taskid=2 --save_suffix='_stim-resp_wn_nsem' --stim_layers='1, 5, 1, 3, 64, 1, 3, 64, 1, 3, 64, 1, 3, 64, 2, 3, 64, 2, 3, 1, 1' --resp_layers='3, 64, 1, 3, 64, 1, 3, 64, 1, 3, 64, 2, 3, 64, 2, 3, 1, 1' --batch_norm=True --save_folder='//home/bhaishahster/end_to_end_feb_5_6PM' --learning_rate=0.001 --batch_train_sz=100 --batch_neg_train_sz=100 --sr_model='convolutional_embedding'
  for frac_cells in [0.2, 1.0, 0.5]:
    dataset_dict = {'train_sr': training_datasets}
    saver = tf.train.Saver()

    filename = bookkeeping.get_filename(training_datasets, testing_datasets,
                                        FLAGS.beta, FLAGS.sr_model)
    long_filename = os.path.join(FLAGS.save_folder, filename)
    checkpoints = gfile.Glob(long_filename + '*.meta')
    checkpoints = [cpts[:-5] for cpts in checkpoints]

    roc_dict = {}
    for cpts in checkpoints:
      try:
        print(cpts)
        saver.restore(sess, cpts)
        iteration = int(cpts.split('/')[-1].split('-')[-1])

        roc_analysis = experiments.roc_analysis(dataset_dict,
                                                stimuli, responses, sr_graph,
                                                num_examples=10000,
                                                frac_cells=frac_cells)
        roc_dict.update({iteration: roc_analysis})
      except:
        pass
    pickle.dump(roc_dict, gfile.Open((save_analysis_filename +
                                      '_test6_frac_cells_%.2f.pkl') % frac_cells, 'w'))




  # ROC analysis with negatives generated at small difference to positives.
  dataset_dict = {'train_sr': training_datasets}
  roc_delta_t_dict = {}
  for delta_t in [1, 2, 3, 4, 5]:
    print('Delta t : %d' % delta_t)
    roc_analysis = experiments.roc_analysis(dataset_dict,
                                            stimuli, responses, sr_graph,
                                            num_examples=10000,
                                            delta_t=delta_t)
    roc_delta_t_dict.update({delta_t: roc_analysis})
  pickle.dump(roc_delta_t_dict, gfile.Open(save_analysis_filename + '_test4_delta_t.pkl', 'w'))

  '''

    ## ROC curves of responses from repeats - dataset 1
    '''
  repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat'
  repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
  repeats_data['cell_type'] = repeats_data['cell_type'].T
  # process repeats data
  num_cell_types = 2
  dimx = 80
  dimy = 40
  data_util.process_dataset(repeats_data, dimx, dimy, num_cell_types)
  # analyse and store the result
  test_reps = analyse_response_repeats(repeats_data,
                                       sr_graph.anchor_model,
                                       sr_graph.neg_model, sr_graph.sess)
  save_dict.update({'test_reps_2015-09-23-7': test_reps})
  pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
  '''

    ## ROC curves of responses from repeats - dataset 2
    '''
  repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat'
  repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
  data_util.process_dataset(repeats_data, dimx, dimy, num_cell_types)

  # analyse and store the result
  test_clustering = analyse_response_repeats_all_trials(repeats_data,
                                                        sr_graph.anchor_model,
                                                        sr_graph.neg_model,
                                                        sr_graph.sess)
  save_dict.update({'test_reps_2005_08_03_0': test_clustering})
  pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
  '''
    #
    # get model params

    model_pars_dict = {
        'model_pars': sr_graph.sess.run(tf.trainable_variables())
    }
    pickle.dump(model_pars_dict,
                gfile.Open(save_analysis_filename + '_test5.pkl', 'w'))

    print(save_analysis_filename)
def main(argv):
    np.random.seed(23)

    # Figure out dictionary path.
    dict_list = gfile.ListDirectory(FLAGS.src_dict)
    dict_path = os.path.join(FLAGS.src_dict, dict_list[FLAGS.taskid])

    # Load the dictionary
    if dict_path[-3:] == 'pkl':
        data = pickle.load(gfile.Open(dict_path, 'r'))
    if dict_path[-3:] == 'mat':
        data = sio.loadmat(gfile.Open(dict_path, 'r'))

    #FLAGS.save_dir = '/home/bhaishahster/stimulation_algos/dictionaries/' + dict_list[FLAGS.taskid][:-4]
    FLAGS.save_dir = FLAGS.save_dir + dict_list[FLAGS.taskid][:-4]
    if not gfile.Exists(FLAGS.save_dir):
        gfile.MkDir(FLAGS.save_dir)

    # S_collection = data['S']  # Target
    A = data['A']  # Decoder
    D = data['D'].T  # Dictionary

    # clean dictionary
    thr_list = np.arange(0, 1, 0.01)
    dict_val = []
    for thr in thr_list:
        dict_val += [np.sum(np.sum(D.T > thr, 1) != 0)]
    plt.ion()
    plt.figure()
    plt.plot(thr_list, dict_val)
    plt.xlabel('Threshold')
    plt.ylabel(
        'Number of dictionary elements with \n atleast one element above threshold'
    )
    plt.title('Please choose threshold')
    thr_use = float(input('What threshold to use?'))
    plt.axvline(thr_use)
    plt.title('Using threshold: %.5f' % thr_use)

    dict_valid = np.sum(D.T > thr_use, 1) > 0
    D = D[:, dict_valid]
    D = np.append(D, np.zeros((D.shape[0], 1)), 1)
    print(
        'Appending a "dummy" dictionary element that does not activate any cell'
    )

    # Vary stimulus resolution
    for itarget in range(20):
        n_targets = 1
        for stix_resolution in [32, 64, 16, 8]:

            # Get the target
            x_dim = int(640 / stix_resolution)
            y_dim = int(320 / stix_resolution)
            targets = (np.random.rand(y_dim, x_dim, n_targets) < 0.5) - 0.5
            upscale = stix_resolution / 8
            targets = np.repeat(np.repeat(targets, upscale, axis=0),
                                upscale,
                                axis=1)
            targets = np.reshape(targets, [-1, targets.shape[-1]])
            S_actual = targets[:, 0]

            # Remove null component of A from S
            S = A.dot(np.linalg.pinv(A).dot(S_actual))

            # Run Greedy first to initialize
            x_greedy = greedy_stimulation(S,
                                          A,
                                          D,
                                          max_stims=FLAGS.t_max * FLAGS.delta,
                                          file_suffix='%d_%d' %
                                          (stix_resolution, itarget),
                                          save=True,
                                          save_dir=FLAGS.save_dir)

            # Load greedy output from previous run
            #data_greedy = pickle.load(gfile.Open('/home/bhaishahster/greedy_2000_32_0.pkl', 'r'))
            #x_greedy = data_greedy['x_chosen']

            # Plan for multiple time points
            x_init = np.zeros((x_greedy.shape[0], FLAGS.t_max))
            for it in range(FLAGS.t_max):
                print((it + 1) * FLAGS.delta - 1)
                x_init[:, it] = x_greedy[:, (it + 1) * FLAGS.delta - 1]

            #simultaneous_planning(S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate,
            #                     delta=FLAGS.delta, normalization=FLAGS.normalization,
            #                    file_suffix='%d_%d_normal' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir)

            from IPython import embed
            embed()
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Interleaved discretization.
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)

            # Exponential weighing.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Exponential weighing with interleaved discretization.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)
            '''
      # Plot results
      data_fractional = pickle.load(gfile.Open('/home/bhaishahster/2012-09-24-0_SAD_fr1/pgd_20_5.000000_C_100_32_0_normal.pkl', 'r'))
      plt.ion()
      start = 0
      end = 20

      error_target_frac = np.linalg.norm(S - data_fractional['S'])
      #error_target_greedy = np.linalg.norm(S - data_greedy['S'])
      print('Did target change? \n Err from fractional: %.3f ' % error_target_frac)
      #'\n Err from greedy %.3f' % (error_target_frac,
      #                             error_target_greedy))
      # from IPython import embed; embed()
      normalize = np.sum(data_fractional['S'] ** 2)
      plt.ion()
      plt.plot(np.arange(120, len(data_fractional['f_log'])), data_fractional['f_log'][120:] / normalize, 'b')
      plt.axhline(data_fractional['errors'][5:].mean() / normalize, color='m')
      plt.axhline(data_fractional['errors_ht_discrete'][5:].mean() / normalize, color='y')
      plt.axhline(data_fractional['errors_rr_discrete'][5:].mean() / normalize, color='r')
      plt.axhline(data_greedy['error_curve'][120:].mean() / normalize, color='k')
      plt.legend(['fraction_curve', 'fractional error', 'HT', 'RR', 'Greedy'])
      plt.pause(1.0)
      '''

    from IPython import embed
    embed()