def main(unused_argv=()):

  # Load stimulus-response data
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  stimuli = {}
  responses = []
  print(datasets)
  for icnt, idataset in enumerate(datasets):
    #for icnt, idataset in enumerate([datasets[2]]):
    #  print('HACK - only one dataset used!!')

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key)
      stimulus, resp, dimx, dimy, num_cell_types = op

      stimuli.update({key: stimulus})
      responses += resp

  print('# Responses %d' % len(responses))
  stimulus = stimuli[responses[FLAGS.taskid]['stimulus_key']]
  save_filename = ('linear_taskid_%d_piece_%s.pkl' % (FLAGS.taskid, responses[FLAGS.taskid]['piece']))
  print(save_filename)
  learn_lin_embedding(stimulus, np.double(responses[FLAGS.taskid]['responses']),
                      filename=save_filename,
                      lam_l1=0.00001, beta=10, time_window=30,
                      lr=0.01)


  print('DONE!')
Example #2
0
def get_latest_file(save_location, short_filename):  # get relevant files
    file_list = gfile.ListDirectory(save_location)
    print(save_location, short_filename)
    save_filename = save_location + short_filename
    print('\nLoading: ', save_filename)
    bin_files = []
    meta_files = []
    for file_n in file_list:
        if re.search(short_filename + '.', file_n):
            if re.search('.meta', file_n):
                meta_files += [file_n]
            else:
                bin_files += [file_n]
        # print(bin_files)
    print(len(meta_files), len(bin_files), len(file_list))

    # get iteration numbers
    iterations = np.array([])
    for file_name in bin_files:
        try:
            iterations = np.append(
                iterations, int(file_name.split('/')[-1].split('-')[-1]))
        except:
            print('Could not load filename: ' + file_name)
    iterations.sort()
    print(iterations)

    iter_plot = iterations[-1]
    print(int(iter_plot))
    restore_file = save_filename + '-' + str(int(iter_plot))
    return restore_file
Example #3
0
def inputs(name, data_location, batch_size, num_epochs, stim_dim, resp_dim):
    # gives a batch of stimulus and responses from a .tfrecords file
    # works for .tfrecords file made using CoarseDataUtils.convert_to_TFRecords

    # Get filename queue.
    # Actual name is either 'name', 'name.tfrecords' or
    # folder 'name' with list of .tfrecords files.
    with tf.name_scope('input'):
        filename = os.path.join(data_location, name)
        filename_extension = os.path.join(data_location, name + '.tfrecords')
        if gfile.Exists(filename) and not gfile.IsDirectory(filename):
            tf.logging.info('%s Exists' % filename)
            filenames = [filename]
        elif gfile.Exists(filename_extension
                          ) and not gfile.IsDirectory(filename_extension):
            tf.logging.info('%s Exists' % filename_extension)
            filenames = [filename_extension]
        elif gfile.IsDirectory(filename):
            tf.logging.info('%s Exists and is a directory' % filename)
            filenames_short = gfile.ListDirectory(filename)
            filenames = [
                os.path.join(filename, ifilenames_short)
                for ifilenames_short in filenames_short
            ]
        tf.logging.info(filenames)
        filename_queue = tf.train.string_input_producer(filenames,
                                                        num_epochs=num_epochs,
                                                        capacity=10000)

    # Even when reading in multiple threads, share the filename
    # queue.
    stimulus, response = read_and_decode(filename_queue, stim_dim, resp_dim)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.

    stimulus_batch, response_batch = tf.train.shuffle_batch(
        [stimulus, response],
        batch_size=batch_size,
        num_threads=30,
        capacity=5000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=2000)
    '''
  stimulus_batch, response_batch = tf.train.batch(
      [stimulus, response], batch_size=batch_size, num_threads=30,
      capacity = 50000 + 3 * batch_size)
  '''
    return stimulus_batch, response_batch
Example #4
0
def main(unused_argv=()):

  # Load stimulus-response data
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  responses = []
  print(datasets)
  for icnt, idataset in enumerate([datasets]): #TODO(bhaishahster): HACK.

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key)
      stimulus, resp, dimx, dimy, num_cell_types = op

      responses += resp

    for idataset in range(len(responses)):
      k, b, ttf = fit_ln_population(responses[idataset]['responses'], stimulus)  # Use FLAGS.taskID
      save_dict = {'k': k, 'b': b, 'ttf': ttf}

      save_analysis_filename = os.path.join(FLAGS.save_folder,
                                            responses[idataset]['piece']
                                            + '_ln_model.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
Example #5
0
def main(unused_argv=()):

  np.random.seed(23)
  tf.set_random_seed(1234)
  random.seed(50)

  # Load stimulus-response data.
  # Collect population response across retinas in the list 'responses'.
  # Stimulus for each retina is indicated by 'stim_id',
  # which is found in 'stimuli' dictionary.
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  stimuli = {}
  responses = []
  for icnt, idataset in enumerate(datasets):

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key,
                                           boundary=FLAGS.valid_cells_boundary)
      stimulus, resp, dimx, dimy, _ = op

      stimuli.update({key: stimulus})
      responses += resp

  taskid = FLAGS.taskid
  dat = responses[taskid]
  stimulus = stimuli[dat['stimulus_key']]

  # parameters
  window = 5

  # Compute time course and non-linearity as two parameters which might be should be explored in embedded space.
  n_cells = dat['responses'].shape[1]
  T = np.minimum(stimulus.shape[0], dat['responses'].shape[0])

  stim_short = stimulus[:T, :, :]
  resp_short = dat['responses'][:T, :].astype(np.float32)

  save_dict = {}

  # Find time course, non-linearity and RF parameters

  ########################################################################
  # Separation between cell types
  ########################################################################
  save_dict.update({'cell_type': dat['cell_type']})
  save_dict.update({'dist_nn_cell_type': dat['dist_nn_cell_type']})

  ########################################################################
  # Find mean firing rate
  ########################################################################
  mean_fr = dat['responses'].mean(0)
  mean_fr_1 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==1])
  mean_fr_2 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==2])

  mean_fr_dict = {'mean_fr': mean_fr,
                  'mean_fr_1': mean_fr_1, 'mean_fr_2': mean_fr_2}
  save_dict.update({'mean_fr_dict': mean_fr_dict})

  ########################################################################
  # compute STAs
  ########################################################################
  stas = np.zeros((n_cells, 80, 40, 30))
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    for idelay in range(30):
      stas[icell, windx[0]: windx[1], windy[0]: windy[1], idelay] = np.reshape(resp_short[idelay:, icell].dot(stim_cell[:T-idelay, :]),
                                                                           [windx[1] - windx[0], windy[1] - windy[0]]) / np.sum(resp_short[idelay:, icell])

  stas_dict = {'stas': stas}
  # save_dict.update({'stas_dict': stas_dict})


  ########################################################################
  # Find time courses for each cell
  ########################################################################
  ttf_log = []
  for icell in range(n_cells):
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    ll = stas[icell, windx[0]: windx[1], windy[0]: windy[1], :]

    ll_2d = np.reshape(ll, [-1, ll.shape[-1]])

    u, s, v = np.linalg.svd(ll_2d)

    ttf_log += [v[0, :]]

  ttf_log = np.array(ttf_log)
  signs = [np.sign(ttf_log[icell, np.argmax(np.abs(ttf_log[icell, :]))]) for icell in range(ttf_log.shape[0])]
  ttf_corrected = np.expand_dims(np.array(signs), 1) * ttf_log
  ttf_corrected[np.squeeze(dat['cell_type'])==1, :] = ttf_corrected[np.squeeze(dat['cell_type'])==1, :] * -1

  ttf_mean_1 = ttf_corrected[np.squeeze(dat['cell_type'])==1, :].mean(0)
  ttf_mean_2 = ttf_corrected[np.squeeze(dat['cell_type'])==2, :].mean(0)

  ttf_params_1 = get_times(ttf_mean_1)
  ttf_params_2 = get_times(ttf_mean_2)

  ttf_dict = {'ttf_log': ttf_log,
              'ttf_mean_1': ttf_mean_1, 'ttf_mean_2': ttf_mean_2,
              'ttf_params_1': ttf_params_1, 'ttf_params_2': ttf_params_2}

  save_dict.update({'ttf_dict': ttf_dict})
  '''
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==1, :].T, 'r', alpha=0.3)
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==2, :].T, 'k', alpha=0.3)

  plt.plot(ttf_mean_1, 'r--')
  plt.plot(ttf_mean_2, 'k--')
  '''

  ########################################################################
  ## Find non-linearity
  ########################################################################
  f_nl = lambda x, p0, p1, p2, p3: p0 + p1*x + p2* np.power(x, 2) + p3* np.power(x, 3)

  nl_params_log = []
  stim_resp_log = []
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]

    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    sta_cell = np.reshape(stas[icell, windx[0]: windx[1], windy[0]: windy[1], :], [-1, stas.shape[-1]])

    stim_filter = np.zeros(stim_short.shape[0])
    for idelay in range(30):
      stim_filter[idelay: ] += stim_cell[:T-idelay, :].dot(sta_cell[:, idelay])

    # Normalize stim_filter
    stim_filter -= np.mean(stim_filter)
    stim_filter /= np.sqrt(np.var(stim_filter))

    resp_cell = resp_short[:, icell]

    stim_nl = []
    resp_nl = []
    for ipercentile in range(3, 97, 1):
      lb = np.percentile(stim_filter, ipercentile-3)
      ub = np.percentile(stim_filter, ipercentile+3)
      tms = np.logical_and(stim_filter >= lb, stim_filter < ub)
      stim_nl += [np.mean(stim_filter[tms])]
      resp_nl += [np.mean(resp_cell[tms])]

    stim_nl = np.array(stim_nl)
    resp_nl = np.array(resp_nl)

    popt, pcov = scipy.optimize.curve_fit(f_nl, stim_nl, resp_nl, p0=[1, 0, 0, 0])
    nl_params_log += [popt]
    stim_resp_log += [[stim_nl, resp_nl]]

  nl_params_log = np.array(nl_params_log)

  np_params_mean_1 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==1, :], 0)
  np_params_mean_2 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==2, :], 0)

  nl_params_dict = {'nl_params_log': nl_params_log,
                    'np_params_mean_1': np_params_mean_1,
                    'np_params_mean_2': np_params_mean_2,
                    'stim_resp_log': stim_resp_log}

  save_dict.update({'nl_params_dict': nl_params_dict})

  '''
  # Visualize Non-linearities
  for icell in range(n_cells):

    stim_in = np.arange(-3, 3, 0.1)
    fr = f_nl(stim_in, *nl_params_log[icell, :])
    if np.squeeze(dat['cell_type'])[icell] == 1:
      c = 'r'
    else:
      c = 'k'
    plt.plot(stim_in, fr, c, alpha=0.2)

  fr = f_nl(stim_in, *np_params_mean_1)
  plt.plot(stim_in, fr, 'r--')

  fr = f_nl(stim_in, *np_params_mean_2)
  plt.plot(stim_in, fr, 'k--')
  '''

  pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder , dat['piece']), 'w'))
  pickle.dump(stas_dict, gfile.Open(os.path.join(FLAGS.save_folder , 'stas' + dat['piece']), 'w'))
Example #6
0
def main(argv):
    print('\nCode started')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = tot_spks[cells_choose]
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    #short_filename = 'data_model=ASM_pop_bg'
    #  short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) +
    #    '_step_sz'+ str(FLAGS.step_sz)+'_bg')

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'logistic':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'poisson_full':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        # variables
        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.1 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))
        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.random.randn(
                n_cells
            )  #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0)))
            b = tf.Variable(b_init, dtype='float32')

    # get relevant files
        file_list = gfile.ListDirectory(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename
        print('\nLoading: ', save_filename)
        bin_files = []
        meta_files = []
        for file_n in file_list:
            if re.search(short_filename + '.', file_n):
                if re.search('.meta', file_n):
                    meta_files += [file_n]
                else:
                    bin_files += [file_n]
        #print(bin_files)
        print(len(meta_files), len(bin_files), len(file_list))

        # get iteration numbers
        iterations = np.array([])
        for file_name in bin_files:
            try:
                iterations = np.append(
                    iterations, int(file_name.split('/')[-1].split('-')[-1]))
            except:
                print('Could not load filename: ' + file_name)
        iterations.sort()
        print(iterations)

        iter_plot = iterations[-1]
        print(int(iter_plot))

        # load tensorflow variables
        saver_var = tf.train.Saver(tf.all_variables())

        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        a_eval = a.eval()
        print(np.exp(np.squeeze(a_eval)))
        #print(np.shape(a_eval))

        # get 2D region to plot
        mask2D = np.reshape(chosen_mask, [40, 80])
        nz_idx = np.nonzero(mask2D)
        np.shape(nz_idx)
        print(nz_idx)
        ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
        xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])
        w_eval = w.eval()

        plt.figure()
        n_su = w_eval.shape[1]
        for isu in np.arange(n_su):
            xx = np.zeros((3200))
            xx[chosen_mask] = w_eval[:, isu]
            fig = plt.subplot(np.ceil(np.sqrt(n_su)), np.ceil(np.sqrt(n_su)),
                              isu + 1)
            plt.imshow(np.reshape(xx, [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            #if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge':
            #  plt.title(str(a_eval[isu, :]))
            #else:
            #  plt.title(str(np.squeeze(np.exp(a_eval[:, 0, isu]))), fontsize=12)

        plt.suptitle('Iteration:' + str(int(iter_plot)) + ' batchSz:' +
                     str(FLAGS.batchsz) + ' step size:' + str(FLAGS.step_sz),
                     fontsize=18)
        plt.show()
        plt.draw()
def main(argv):
    #plt.ion() # interactive plotting
    window = FLAGS.window
    n_pix = (2 * window + 1)**2
    dimx = np.floor(1 + ((40 - (2 * window + 1)) / FLAGS.stride)).astype('int')
    dimy = np.floor(1 + ((80 - (2 * window + 1)) / FLAGS.stride)).astype('int')
    nCells = 107
    # load model
    # load filename
    print(FLAGS.model_id)
    with tf.Session() as sess:
        if FLAGS.model_id == 'relu':
            # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0
            short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                              '_ratioSU=' + str(FLAGS.ratio_SU) +
                              '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                              '_normalized_bg')
            w = tf.Variable(
                np.array(np.random.randn(3200, 749), dtype='float32'))
            a = tf.Variable(
                np.array(np.random.randn(749, 107), dtype='float32'))

        if FLAGS.model_id == 'relu_window':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))  # exp 5
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')

            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_sfm':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_sfm_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.01 + 0.005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.02 + np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        if FLAGS.model_id == 'relu_window_mother_exp':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w_del = tf.Variable(
                np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            w_mother = tf.Variable(
                np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)),
                         dtype='float32'))
            a = tf.Variable(
                np.array(np.random.rand(dimx * dimy, nCells), dtype='float32'))

        if FLAGS.model_id == 'relu_window_a_support':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.002 * np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        if FLAGS.model_id == 'exp_window_a_support':
            short_filename = ('data_model=' + str(FLAGS.model_id) +
                              '_window=' + str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_bg')
            w = tf.Variable(
                np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.002 * np.random.rand(dimx * dimy, nCells),
                         dtype='float32'))

        parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
        FLAGS.save_location = parent_folder + short_filename + '/'

        # get relevant files
        file_list = gfile.ListDirectory(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename
        print('\nLoading: ', save_filename)
        bin_files = []
        meta_files = []
        for file_n in file_list:
            if re.search(short_filename + '.', file_n):
                if re.search('.meta', file_n):
                    meta_files += [file_n]
                else:
                    bin_files += [file_n]
        #print(bin_files)
        print(len(meta_files), len(bin_files), len(file_list))

        # get iteration numbers
        iterations = np.array([])
        for file_name in bin_files:
            try:
                iterations = np.append(
                    iterations, int(file_name.split('/')[-1].split('-')[-1]))
            except:
                print('Could not load filename: ' + file_name)
        iterations.sort()
        print(iterations)

        iter_plot = iterations[-1]
        print(int(iter_plot))

        # load tensorflow variables
        saver_var = tf.train.Saver(tf.all_variables())

        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        # plot subunit - cell connections
        plt.figure()
        plt.cla()
        plt.imshow(a.eval(), cmap='gray', interpolation='nearest')
        print(np.shape(a.eval()))
        plt.title('Iteration: ' + str(int(iter_plot)))
        plt.show()
        plt.draw()

        # plot all subunits on 40x80 grid
        try:
            wts = w.eval()
            for isu in range(100):
                fig = plt.subplot(10, 10, isu + 1)
                plt.imshow(np.reshape(wts[:, isu], [40, 80]),
                           interpolation='nearest',
                           cmap='gray')
            plt.title('Iteration: ' + str(int(iter_plot)))
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
        except:
            print('w full does not exist? ')

        # plot a few subunits - wmother + wdel
        try:
            wts = w.eval()
            print('wts shape:', np.shape(wts))
            icnt = 1
            for idimx in np.arange(dimx):
                for idimy in np.arange(dimy):
                    fig = plt.subplot(dimx, dimy, icnt)
                    plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]),
                                          (2 * window + 1, 2 * window + 1)),
                               interpolation='nearest',
                               cmap='gray')
                    icnt = icnt + 1
                    fig.axes.get_xaxis().set_visible(False)
                    fig.axes.get_yaxis().set_visible(False)
            plt.show()
            plt.draw()
        except:
            print('w does not exist?')

        # plot wmother
        try:
            w_mot = np.squeeze(w_mother.eval())
            print(w_mot)
            plt.imshow(w_mot, interpolation='nearest', cmap='gray')
            plt.title('Mother subunit')
            plt.show()
            plt.draw()
        except:
            print('w mother does not exist')

        # plot wmother + wdel
        try:
            w_mot = np.squeeze(w_mother.eval())
            w_del = np.squeeze(w_del.eval())
            wts = np.array(np.random.randn(dimx, dimy, (2 * window + 1)**2))
            for idimx in np.arange(dimx):
                print(idimx)
                for idimy in np.arange(dimy):
                    wts[idimx,
                        idimy, :] = np.ndarray.flatten(w_mot) + w_del[idimx,
                                                                      idimy, :]
        except:
            print('w mother + w delta do not exist? ')
        '''
    try:
      
      icnt=1
      for idimx in np.arange(dimx):
        for idimy in np.arange(dimy):
          fig = plt.subplot(dimx, dimy, icnt)
          plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2*window+1,2*window+1)), interpolation='nearest', cmap='gray')
          fig.axes.get_xaxis().set_visible(False)
          fig.axes.get_yaxis().set_visible(False)
    except:
      print('w mother + w delta plotting error? ')
    
    # plot wdel
    try:
      w_del = np.squeeze(w_del.eval())
      icnt=1
      for idimx in np.arange(dimx):
        for idimy in np.arange(dimy):
          fig = plt.subplot(dimx, dimy, icnt)
          plt.imshow( np.reshape(w_del[idimx, idimy, :], (2*window+1,2*window+1)), interpolation='nearest', cmap='gray')
          icnt = icnt+1
          fig.axes.get_xaxis().set_visible(False)
          fig.axes.get_yaxis().set_visible(False)
    except:
      print('w delta do not exist? ')
    plt.suptitle('Iteration: ' + str(int(iter_plot)))
    plt.show()
    plt.draw()
    '''
        # select a cell, and show its subunits.
        #try:

        ## Load data summary, get mask
        filename = FLAGS.data_location + 'data_details.mat'
        summary_file = gfile.Open(filename, 'r')
        data_summary = sio.loadmat(summary_file)
        total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
        stas = data_summary['stas']
        print(np.shape(total_mask))

        # a is 2D

        a_eval = a.eval()
        print(np.shape(a_eval))
        # get softmax numpy
        if FLAGS.model_id == 'relu_window_mother_sfm' or FLAGS.model_id == 'relu_window_mother_sfm_exp':
            b = np.exp(a_eval) / np.sum(np.exp(a_eval), 0)
        else:
            b = a_eval

        plt.figure()
        plt.imshow(b, interpolation='nearest', cmap='gray')
        plt.show()
        plt.draw()

        # plot subunits for multiple cells.
        n_cells = 10
        n_plots_max = 20
        plt.figure()
        for icell_cnt, icell in enumerate(np.arange(n_cells)):
            mask2D = np.reshape(total_mask[icell, :], [40, 80])
            nz_idx = np.nonzero(mask2D)
            np.shape(nz_idx)
            print(nz_idx)
            ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
            xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])

            icnt = -1
            a_thr = np.percentile(np.abs(b[:, icell]), 99.5)
            n_plots = np.sum(np.abs(b[:, icell]) > a_thr)
            nx = np.ceil(np.sqrt(n_plots)).astype('int')
            ny = np.ceil(np.sqrt(n_plots)).astype('int')
            ifig = 0
            ww_sum = np.zeros((40, 80))

            for idimx in np.arange(dimx):
                for idimy in np.arange(dimy):
                    icnt = icnt + 1
                    if (np.abs(b[icnt, icell]) > a_thr):
                        ifig = ifig + 1
                        fig = plt.subplot(n_cells, n_plots_max,
                                          icell_cnt * n_plots_max + ifig + 2)
                        ww = np.zeros((40, 80))
                        ww[idimx * FLAGS.stride:idimx * FLAGS.stride +
                           (2 * window + 1),
                           idimy * FLAGS.stride:idimy * FLAGS.stride +
                           (2 * window + 1)] = b[icnt, icell] * (np.reshape(
                               wts[idimx, idimy, :],
                               (2 * window + 1, 2 * window + 1)))
                        plt.imshow(ww, interpolation='nearest', cmap='gray')
                        plt.ylim(ylim)
                        plt.xlim(xlim)
                        plt.title(b[icnt, icell])
                        fig.axes.get_xaxis().set_visible(False)
                        fig.axes.get_yaxis().set_visible(False)

                        ww_sum = ww_sum + ww

            fig = plt.subplot(n_cells, n_plots_max,
                              icell_cnt * n_plots_max + 2)
            plt.imshow(ww_sum, interpolation='nearest', cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('STA from model')

            fig = plt.subplot(n_cells, n_plots_max,
                              icell_cnt * n_plots_max + 1)
            plt.imshow(np.reshape(stas[:, icell], [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('True STA')

        plt.show()
        plt.draw()

        #except:
        #  print('a not 2D?')

        # using xlim and ylim, and plot the 'windows' which are relevant with their weights
        sq_flat = np.zeros((dimx, dimy))
        icnt = 0
        for idimx in np.arange(dimx):
            for idimy in np.arange(dimy):
                sq_flat[idimx, idimy] = icnt
                icnt = icnt + 1

        n_cells = 1
        n_plots_max = 10
        plt.figure()
        for icell_cnt, icell in enumerate(np.array(
            [1, 2, 3, 4, 5])):  #enumerate(np.arange(n_cells)):
            a_thr = np.percentile(np.abs(b[:, icell]), 99.5)
            mask2D = np.reshape(total_mask[icell, :], [40, 80])
            nz_idx = np.nonzero(mask2D)
            np.shape(nz_idx)
            print(nz_idx)
            ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
            xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])
            print(xlim, ylim)

            win_startx = np.ceil((xlim[0] - (2 * window + 1)) / FLAGS.stride)
            win_endx = np.floor((xlim[1] - 1) / FLAGS.stride)
            win_starty = np.ceil((ylim[0] - (2 * window + 1)) / FLAGS.stride)
            win_endy = np.floor((ylim[1] - 1) / FLAGS.stride)
            dimx_plot = win_endx - win_startx + 1
            dimy_plot = win_endy - win_starty + 1
            ww_sum = np.zeros((40, 80))
            for irow, idimy in enumerate(np.arange(win_startx, win_endx + 1)):
                for icol, idimx in enumerate(
                        np.arange(win_starty, win_endy + 1)):
                    fig = plt.subplot(dimx_plot + 1, dimy_plot,
                                      (irow + 1) * dimy_plot + icol + 1)
                    ww = np.zeros((40, 80))
                    ww[idimx * FLAGS.stride:idimx * FLAGS.stride +
                       (2 * window + 1),
                       idimy * FLAGS.stride:idimy * FLAGS.stride +
                       (2 * window + 1)] = (np.reshape(
                           wts[idimx,
                               idimy, :], (2 * window + 1, 2 * window + 1)))
                    plt.imshow(ww, interpolation='nearest', cmap='gray')
                    plt.ylim(ylim)
                    plt.xlim(xlim)
                    if b[sq_flat[idimx, idimy], icell] > a_thr:
                        plt.title(b[sq_flat[idimx, idimy], icell],
                                  fontsize=10,
                                  color='g')
                    else:
                        plt.title(b[sq_flat[idimx, idimy], icell],
                                  fontsize=10,
                                  color='r')
                    fig.axes.get_xaxis().set_visible(False)
                    fig.axes.get_yaxis().set_visible(False)

                    ww_sum = ww_sum + ww * b[sq_flat[idimx, idimy], icell]

            fig = plt.subplot(dimx_plot + 1, dimy_plot, 2)
            plt.imshow(ww_sum, interpolation='nearest', cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('STA from model')

            fig = plt.subplot(dimx_plot + 1, dimy_plot, 1)
            plt.imshow(np.reshape(stas[:, icell], [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.title('True STA')

            plt.show()
            plt.draw()
Example #8
0
def main(unused_argv=()):

    np.random.seed(23)
    tf.set_random_seed(1234)
    random.seed(50)

    # Load stimulus-response data.
    # Collect population response across retinas in the list 'responses'.
    # Stimulus for each retina is indicated by 'stim_id',
    # which is found in 'stimuli' dictionary.
    datasets = gfile.ListDirectory(FLAGS.src_dir)
    stimuli = {}
    responses = []
    for icnt, idataset in enumerate(datasets):
        fullpath = os.path.join(FLAGS.src_dir, idataset)
        if gfile.IsDirectory(fullpath):
            key = 'stim_%d' % icnt
            op = data_util.get_stimulus_response(
                FLAGS.src_dir,
                idataset,
                key,
                boundary=FLAGS.valid_cells_boundary,
                if_get_stim=True)
            stimulus, resp, dimx, dimy, _ = op
            stimuli.update({key: stimulus})
            responses += resp

    # Get training and testing partitions.
    # Generate partitions
    # The partitions for the taskid should be listed in partition_file.

    op = partitions.get_partitions(FLAGS.partition_file, FLAGS.taskid)
    training_datasets, testing_datasets = op

    with tf.Session() as sess:

        # Get stimulus-response embedding.
        if FLAGS.mode == 0:
            is_training = True
        if FLAGS.mode == 1:
            is_training = True
        if FLAGS.mode == 2:
            is_training = True
            print('NOTE: is_training = True in test')
        if FLAGS.mode == 3:
            is_training = True
            print('NOTE: is_training = True in test')

        sample_fcn = sample_datasets
        if (FLAGS.sr_model == 'convolutional_embedding'):
            embedding = sr_models.convolutional_embedding(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if (FLAGS.sr_model == 'convolutional_embedding_expt'
                or FLAGS.sr_model == 'convolutional_embedding_margin_expt' or
                FLAGS.sr_model == 'convolutional_embedding_inner_product_expt'
                or FLAGS.sr_model == 'convolutional_embedding_gauss_expt'
                or FLAGS.sr_model == 'convolutional_embedding_kernel_expt'):
            embedding = sr_models_expt.convolutional_embedding_experimental(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_autoembedder':
            embedding = sr_models_expt.convolutional_autoembedder(
                sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_autoembedder_l2':
            embedding = sr_models_expt.convolutional_autoembedder(
                sess, is_training, dimx, dimy, loss='log_sum_exp')

        if FLAGS.sr_model == 'convolutional_encoder' or FLAGS.sr_model == 'convolutional_encoder_2':
            embedding = encoding_models_expt.convolutional_encoder(
                sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_encoder_using_retina_id':
            model = encoding_models_expt.convolutional_encoder_using_retina_id
            embedding = model(sess, is_training, dimx, dimy, len(responses))
            sample_fcn = sample_datasets_2

        if (FLAGS.sr_model == 'residual') or (FLAGS.sr_model
                                              == 'residual_inner_product'):
            embedding = sr_models_expt.residual_experimental(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'lin_rank1' or FLAGS.sr_model == 'lin_rank1_blind':
            if ((len(training_datasets) != 1)
                    and (training_datasets != testing_datasets)):
                raise ValueError('Identical training/testing data'
                                 ' (exactly 1) supported')

            n_cells = responses[training_datasets[0]]['responses'].shape[1]
            cell_locations = responses[training_datasets[0]]['map_cell_grid']
            cell_masks = responses[training_datasets[0]]['mask_cells']
            firing_rates = responses[training_datasets[0]]['mean_firing_rate']
            cell_type = responses[training_datasets[0]]['cell_type'].squeeze()

            model_fn = sr_baseline_models.linear_rank1_models
            embedding = model_fn(FLAGS.sr_model,
                                 sess,
                                 dimx,
                                 dimy,
                                 n_cells,
                                 center_locations=cell_locations,
                                 cell_masks=cell_masks,
                                 firing_rates=firing_rates,
                                 cell_type=cell_type,
                                 time_window=30)

        # print model graph
        PrintModelAnalysis(tf.get_default_graph())

        # Get filename, initialize model
        file_name = bookkeeping.get_filename(training_datasets,
                                             testing_datasets, FLAGS.beta,
                                             FLAGS.sr_model)
        tf.logging.info('Filename: %s' % file_name)
        saver_var, start_iter = bookkeeping.initialize_model(
            FLAGS.save_folder, file_name, sess)

        # Setup summary ops.
        # Save separate summary for each retina (both training/testing).
        summary_ops = []
        for iret in np.arange(len(responses)):
            r_list = []
            r1 = tf.summary.scalar('loss_%d' % iret, embedding.loss)
            r_list += [r1]

            if hasattr(embedding, 'accuracy_tf'):
                r2 = tf.summary.scalar('accuracy_%d' % iret,
                                       embedding.accuracy_tf)
                r_list += [r2]

            if FLAGS.sr_model == 'convolutional_autoembedder' or FLAGS.sr_model == 'convolutional_autoembedder_l2':
                r3 = tf.summary.scalar('loss_triplet_%d' % iret,
                                       embedding.loss_triplet)
                r4 = tf.summary.scalar('loss_stim_decode_from_resp_%d' % iret,
                                       embedding.loss_stim_decode_from_resp)
                r5 = tf.summary.scalar('loss_stim_decode_from_stim_%d' % iret,
                                       embedding.loss_stim_decode_from_stim)
                r6 = tf.summary.scalar('loss_resp_decode_from_resp_%d' % iret,
                                       embedding.loss_resp_decode_from_resp)
                r7 = tf.summary.scalar('loss_resp_decode_from_stim_%d' % iret,
                                       embedding.loss_resp_decode_from_stim)
                r_list += [r3, r4, r5, r6, r7]
                '''
        chosen_stim = 2
        bound = FLAGS.valid_cells_boundary
        
        r8 = tf.summary.image('stim_decode_from_stim_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3))

        r9 = tf.summary.image('stim_decode_from_resp_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3))

        r10 = tf.summary.image('resp_decode_from_stim_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r11 = tf.summary.image('resp_decode_from_resp_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r12 = tf.summary.image('resp_decode_from_stim_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r13 = tf.summary.image('resp_decode_from_resp_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r14 = tf.summary.image('resp_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r15 = tf.summary.image('resp_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r_list += [r8, r9, r10, r11, r12, r13, r14, r15]
        '''

            summary_ops += [tf.summary.merge(r_list)]

        # Setup summary writers.
        summary_writers = []
        for loc in ['train', 'test']:
            summary_location = os.path.join(FLAGS.save_folder, file_name,
                                            'summary_' + loc)
            summary_writer = tf.summary.FileWriter(summary_location,
                                                   sess.graph)
            summary_writers += [summary_writer]

        # Separate tests for encoding or metric learning,
        #  prosthesis usage or just neuroscience usage.
        if FLAGS.mode == 3:
            testing.test_encoding(training_datasets, testing_datasets,
                                  responses, stimuli, embedding, sess,
                                  file_name, sample_fcn)

        elif FLAGS.mode == 2:
            prosthesis.stimulate(embedding, sess, file_name, dimx, dimy)

        elif FLAGS.mode == 1:
            testing.test_metric(training_datasets, testing_datasets, responses,
                                stimuli, embedding, sess, file_name)

        else:
            training.training(start_iter,
                              sess,
                              embedding,
                              summary_writers,
                              summary_ops,
                              saver_var,
                              training_datasets,
                              testing_datasets,
                              responses,
                              stimuli,
                              file_name,
                              sample_fcn,
                              summary_freq=500,
                              save_freq=500)
def main(argv):
    #plt.ion() # interactive plotting

    # load model
    # load filename
    print(FLAGS.model_id)
    print(FLAGS.folder_name)
    if FLAGS.model_id == 'relu':
        # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' +
                          str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                          '_ratioSU=' + str(FLAGS.ratio_SU) +
                          '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                          '_normalized_bg')

    if FLAGS.model_id == 'exp':
        short_filename = ('data_model3=' + str(FLAGS.model_id) +
                          '_bias_init=' + str(FLAGS.bias_init_scale) +
                          '_ratioSU=' + str(FLAGS.ratio_SU) +
                          '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                          '_normalized_bg')

    if FLAGS.model_id == 'mel_re_pow2':
        short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' +
                          str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                          '_ratioSU=' + str(FLAGS.ratio_SU) +
                          '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                          '_normalized_bg')

    if FLAGS.model_id == 'relu_logistic':
        short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' +
                          str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                          '_ratioSU=' + str(FLAGS.ratio_SU) +
                          '_grid_spacing=' + str(FLAGS.su_grid_spacing) +
                          '_normalized_bg')

    if FLAGS.model_id == 'relu_proximal':
        short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' +
                          str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                          '_eta_w=' + str(FLAGS.eta_w) + '_eta_a=' +
                          str(FLAGS.eta_a) + '_ratioSU=' +
                          str(FLAGS.ratio_SU) + '_grid_spacing=' +
                          str(FLAGS.su_grid_spacing) + '_proximal_bg')

    if FLAGS.model_id == 'relu_eg':
        short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' +
                          str(FLAGS.lam_w) + '_eta_w=' + str(FLAGS.eta_w) +
                          '_eta_a=' + str(FLAGS.eta_a) + '_ratioSU=' +
                          str(FLAGS.ratio_SU) + '_grid_spacing=' +
                          str(FLAGS.su_grid_spacing) + '_eg_bg')

    # get relevant files
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    FLAGS.save_location = parent_folder + short_filename + '/'
    file_list = gfile.ListDirectory(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    bin_files = []
    meta_files = []
    for file_n in file_list:
        if re.search(short_filename + '.', file_n):
            if re.search('.meta', file_n):
                meta_files += [file_n]
            else:
                bin_files += [file_n]
    #print(bin_files)
    print(len(meta_files), len(bin_files), len(file_list))

    # get iteration numbers
    iterations = np.array([])
    for file_name in bin_files:
        try:
            iterations = np.append(
                iterations, int(file_name.split('/')[-1].split('-')[-1]))
        except:
            print('Bad filename' + file_name)
    iterations.sort()
    print(iterations)

    iter_plot = iterations[-1]
    print(int(iter_plot))
    with tf.Session() as sess:
        # load tensorflow variables
        w = tf.Variable(np.array(np.random.randn(3200, 749), dtype='float32'))
        a = tf.Variable(np.array(np.random.randn(749, 107), dtype='float32'))
        saver_var = tf.train.Saver(tf.all_variables())
        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        # plot subunit - cell connections
        plt.figure()
        plt.cla()
        plt.imshow(a.eval(), cmap='gray', interpolation='nearest')
        plt.title('Iteration: ' + str(int(iter_plot)))
        plt.show()
        plt.draw()

        # plot a few subunits
        wts = w.eval()
        for isu in range(100):
            fig = plt.subplot(10, 10, isu + 1)
            plt.imshow(np.reshape(wts[:, isu], [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
        plt.title('Iteration: ' + str(int(iter_plot)))
        fig.axes.get_xaxis().set_visible(False)
        fig.axes.get_yaxis().set_visible(False)

        plt.show()
        plt.draw()
def main(argv):
    np.random.seed(23)

    # Figure out dictionary path.
    dict_list = gfile.ListDirectory(FLAGS.src_dict)
    dict_path = os.path.join(FLAGS.src_dict, dict_list[FLAGS.taskid])

    # Load the dictionary
    if dict_path[-3:] == 'pkl':
        data = pickle.load(gfile.Open(dict_path, 'r'))
    if dict_path[-3:] == 'mat':
        data = sio.loadmat(gfile.Open(dict_path, 'r'))

    #FLAGS.save_dir = '/home/bhaishahster/stimulation_algos/dictionaries/' + dict_list[FLAGS.taskid][:-4]
    FLAGS.save_dir = FLAGS.save_dir + dict_list[FLAGS.taskid][:-4]
    if not gfile.Exists(FLAGS.save_dir):
        gfile.MkDir(FLAGS.save_dir)

    # S_collection = data['S']  # Target
    A = data['A']  # Decoder
    D = data['D'].T  # Dictionary

    # clean dictionary
    thr_list = np.arange(0, 1, 0.01)
    dict_val = []
    for thr in thr_list:
        dict_val += [np.sum(np.sum(D.T > thr, 1) != 0)]
    plt.ion()
    plt.figure()
    plt.plot(thr_list, dict_val)
    plt.xlabel('Threshold')
    plt.ylabel(
        'Number of dictionary elements with \n atleast one element above threshold'
    )
    plt.title('Please choose threshold')
    thr_use = float(input('What threshold to use?'))
    plt.axvline(thr_use)
    plt.title('Using threshold: %.5f' % thr_use)

    dict_valid = np.sum(D.T > thr_use, 1) > 0
    D = D[:, dict_valid]
    D = np.append(D, np.zeros((D.shape[0], 1)), 1)
    print(
        'Appending a "dummy" dictionary element that does not activate any cell'
    )

    # Vary stimulus resolution
    for itarget in range(20):
        n_targets = 1
        for stix_resolution in [32, 64, 16, 8]:

            # Get the target
            x_dim = int(640 / stix_resolution)
            y_dim = int(320 / stix_resolution)
            targets = (np.random.rand(y_dim, x_dim, n_targets) < 0.5) - 0.5
            upscale = stix_resolution / 8
            targets = np.repeat(np.repeat(targets, upscale, axis=0),
                                upscale,
                                axis=1)
            targets = np.reshape(targets, [-1, targets.shape[-1]])
            S_actual = targets[:, 0]

            # Remove null component of A from S
            S = A.dot(np.linalg.pinv(A).dot(S_actual))

            # Run Greedy first to initialize
            x_greedy = greedy_stimulation(S,
                                          A,
                                          D,
                                          max_stims=FLAGS.t_max * FLAGS.delta,
                                          file_suffix='%d_%d' %
                                          (stix_resolution, itarget),
                                          save=True,
                                          save_dir=FLAGS.save_dir)

            # Load greedy output from previous run
            #data_greedy = pickle.load(gfile.Open('/home/bhaishahster/greedy_2000_32_0.pkl', 'r'))
            #x_greedy = data_greedy['x_chosen']

            # Plan for multiple time points
            x_init = np.zeros((x_greedy.shape[0], FLAGS.t_max))
            for it in range(FLAGS.t_max):
                print((it + 1) * FLAGS.delta - 1)
                x_init[:, it] = x_greedy[:, (it + 1) * FLAGS.delta - 1]

            #simultaneous_planning(S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate,
            #                     delta=FLAGS.delta, normalization=FLAGS.normalization,
            #                    file_suffix='%d_%d_normal' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir)

            from IPython import embed
            embed()
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Interleaved discretization.
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)

            # Exponential weighing.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Exponential weighing with interleaved discretization.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)
            '''
      # Plot results
      data_fractional = pickle.load(gfile.Open('/home/bhaishahster/2012-09-24-0_SAD_fr1/pgd_20_5.000000_C_100_32_0_normal.pkl', 'r'))
      plt.ion()
      start = 0
      end = 20

      error_target_frac = np.linalg.norm(S - data_fractional['S'])
      #error_target_greedy = np.linalg.norm(S - data_greedy['S'])
      print('Did target change? \n Err from fractional: %.3f ' % error_target_frac)
      #'\n Err from greedy %.3f' % (error_target_frac,
      #                             error_target_greedy))
      # from IPython import embed; embed()
      normalize = np.sum(data_fractional['S'] ** 2)
      plt.ion()
      plt.plot(np.arange(120, len(data_fractional['f_log'])), data_fractional['f_log'][120:] / normalize, 'b')
      plt.axhline(data_fractional['errors'][5:].mean() / normalize, color='m')
      plt.axhline(data_fractional['errors_ht_discrete'][5:].mean() / normalize, color='y')
      plt.axhline(data_fractional['errors_rr_discrete'][5:].mean() / normalize, color='r')
      plt.axhline(data_greedy['error_curve'][120:].mean() / normalize, color='k')
      plt.legend(['fraction_curve', 'fractional error', 'HT', 'RR', 'Greedy'])
      plt.pause(1.0)
      '''

    from IPython import embed
    embed()
def subunit_discriminability(dataset_dict,
                             stimuli,
                             responses,
                             sr_graph,
                             num_examples=1000):
    ## compute distances between s-r pairs - pos and neg.
    ## negative_stim - if the negative is a stimulus or a response
    if num_examples % 100 != 0:
        raise ValueError('Only supports examples which are multiples of 100.')

    subunit_fit_loc = '/home/bhaishahster/stim-resp_collection_big_wn_retina_subunit_properties_train'
    subunits_datasets = gfile.ListDirectory(subunit_fit_loc)

    save_dict = {}

    datasets_log = {}
    for dat_key, datasets in dataset_dict.items():

        distances_log = {}
        distances_retina_sr_log = []
        distances_retina_rr_log = []
        for iretina in range(len(datasets)):
            # Find the relevant subunit fit
            piece = responses[iretina]['piece']
            matched_dataset = [
                ifit for ifit in subunits_datasets if piece[:12] == ifit[:12]
            ]
            if matched_dataset == []:
                raise ValueError('Could not find subunit fit')

            subunit_fit_path = os.path.join(subunit_fit_loc,
                                            matched_dataset[0])

            # Get predicted spikes.
            dat_resp_su = pickle.load(
                gfile.Open(
                    os.path.join(subunit_fit_path, 'response_prediction.pkl'),
                    'r'))
            resp_su = dat_resp_su[
                'resp_su']  # it has non-rejected cells as well.

            # Remove some cells.
            select_cells = [
                icell for icell in range(resp_su.shape[2])
                if dat_resp_su['cell_ids'][icell] in responses[iretina]
                ['cellID_list'].squeeze()
            ]

            select_cells = np.array(select_cells)
            resp_su = resp_su[:, :, select_cells].astype(np.float32)

            # Get stimulus
            stimulus = stimuli[responses[iretina]['stimulus_key']]
            stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :]
            responses_recorded_test = responses[iretina]['responses'][
                FLAGS.test_min:FLAGS.test_max, :]

            # Sample stimuli and responses.
            random_times = np.random.randint(40, stimulus_test.shape[0],
                                             num_examples)
            batch_size = 100

            # Recorded response - predicted response distances.
            distances_retina = np.zeros((num_examples, 10)) + np.nan
            for Nsub in range(1, 11):
                for ibatch in range(
                        np.floor(num_examples / batch_size).astype(np.int)):

                    # construct stimulus tensor.
                    stim_history = 30
                    resp_pred_batch = np.zeros((batch_size, resp_su.shape[2]))
                    resp_rec_batch = np.zeros((batch_size, resp_su.shape[2]))

                    for isample in range(batch_size):
                        itime = random_times[batch_size * ibatch + isample]
                        resp_pred_batch[isample, :] = resp_su[Nsub - 1,
                                                              itime, :]
                        resp_rec_batch[isample, :] = responses_recorded_test[
                            itime, :]

                    # Embed predicted responses
                    feed_dict = make_feed_dict(sr_graph,
                                               responses[iretina],
                                               responses=resp_pred_batch)
                    embed_predicted = sr_graph.sess.run(
                        sr_graph.anchor_model.responses_embed,
                        feed_dict=feed_dict)

                    # Embed recorded responses
                    feed_dict = make_feed_dict(sr_graph,
                                               responses[iretina],
                                               responses=resp_rec_batch)
                    embed_recorded = sr_graph.sess.run(
                        sr_graph.anchor_model.responses_embed,
                        feed_dict=feed_dict)

                    dd = sr_graph.sess.run(sr_graph.distances_arbitrary,
                                           feed_dict={
                                               sr_graph.arbitrary_embedding_1:
                                               embed_predicted,
                                               sr_graph.arbitrary_embedding_2:
                                               embed_recorded
                                           })

                    distances_retina[batch_size * ibatch:batch_size *
                                     (ibatch + 1), Nsub - 1] = dd
                    print(iretina, Nsub, ibatch)

            distances_retina_rr_log += [distances_retina]

            # Stimulus - predicted response distances.
            distances_retina = np.zeros((num_examples, 10)) + np.nan
            for Nsub in range(1, 11):
                for ibatch in range(
                        np.floor(num_examples / batch_size).astype(np.int)):

                    # construct stimulus tensor.
                    stim_history = 30
                    stim_batch = np.zeros(
                        (batch_size, stimulus_test.shape[1],
                         stimulus_test.shape[2], stim_history))
                    resp_batch = np.zeros((batch_size, resp_su.shape[2]))

                    for isample in range(batch_size):
                        itime = random_times[batch_size * ibatch + isample]
                        stim_batch[isample, :, :, :] = np.transpose(
                            stimulus_test[itime:itime - stim_history:-1, :, :],
                            [1, 2, 0])
                        resp_batch[isample, :] = resp_su[Nsub - 1, itime, :]

                    feed_dict = make_feed_dict(sr_graph, responses[iretina],
                                               resp_batch, stim_batch)

                    # Get distances
                    d_pos = sr_graph.sess.run(sr_graph.d_s_r_pos,
                                              feed_dict=feed_dict)

                    distances_retina[batch_size * ibatch:batch_size *
                                     (ibatch + 1), Nsub - 1] = d_pos
                    print(iretina, Nsub, ibatch)

            distances_retina_sr_log += [distances_retina]

        distances_log.update({'rr': distances_retina_rr_log})
        distances_log.update({'sr': distances_retina_sr_log})
        datasets_log.update({dat_key: distances_log})
    save_dict.update({
        'datasets_log': datasets_log,
        'dataset_dict': dataset_dict
    })

    return save_dict
def response_transformation_increase_nl(stimuli,
                                        responses,
                                        sr_graph,
                                        time_start_list,
                                        time_len=100,
                                        alpha_list=[1.5, 1.25, 0.8, 0.6]):
    # 1. Take an LN model and increase non-linearity.
    # How do the points move in response space?

    # Load LN models
    ln_save_folder = '/home/bhaishahster/stim-resp_collection_ln_model_exp'
    files = gfile.ListDirectory(ln_save_folder)

    ln_models = []
    for ifile in files:
        print(ifile)
        ln_models += [
            pickle.load(gfile.Open(os.path.join(ln_save_folder, ifile), 'r'))
        ]

    t_start_dict = {}
    t_min = FLAGS.test_min
    t_max = FLAGS.test_max

    for time_start in time_start_list:

        print('Start time %d' % time_start)
        retina_log = []
        for iretina_test in range(3, len(responses)):

            print('Retina: %d' % iretina_test)
            piece_id = responses[iretina_test]['piece']

            # find piece in ln_models
            matched_ln_model = [
                ifile for ifile in range(len(files))
                if files[ifile][:12] == piece_id[:12]
            ]
            if len(matched_ln_model) == 0:
                print('LN model not found')
                continue
            if len(matched_ln_model) > 1:
                print('More than 1 LN model found')

            # Sample a sequence of stimuli and predict spikes
            iresp = responses[iretina_test]
            iln_model = ln_models[matched_ln_model[0]]
            stimulus_test = stimuli[iresp['stimulus_key']]

            stim_sample = stimulus_test[time_start:time_start + time_len, :, :]
            spikes, lam_np = analysis_utils.predict_responses_ln(
                stim_sample,
                iln_model['k'],
                iln_model['b'],
                iln_model['ttf'],
                n_trials=1)
            spikes_log = np.copy(spikes[0, :, :])
            alpha_log = np.ones(time_len)

            # Increase nonlinearity, normalize firing rate and embed.

            for alpha in alpha_list:
                _, lam_np_alpha = analysis_utils.predict_responses_ln(
                    stim_sample,
                    alpha * iln_model['k'],
                    alpha * iln_model['b'],
                    iln_model['ttf'],
                    n_trials=1)
                correction_firing_rate = np.mean(lam_np) / np.mean(
                    lam_np_alpha)
                correction_b = np.log(correction_firing_rate)
                spikes_corrected, lam_np_corrected = analysis_utils.predict_responses_ln(
                    stim_sample,
                    alpha * iln_model['k'],
                    alpha * iln_model['b'] + correction_b,
                    iln_model['ttf'],
                    n_trials=1)
                print(alpha, np.mean(lam_np), np.mean(lam_np_alpha),
                      np.mean(lam_np_corrected))
                spikes_log = np.append(spikes_log,
                                       spikes_corrected[0, :, :],
                                       axis=0)
                alpha_log = np.append(alpha_log,
                                      alpha * np.ones(time_len),
                                      axis=0)

                # plt.figure()
                # analysis_utils.plot_raster(spikes_corrected[:, :, 23])
                # plt.title(alpha)

            # Embed responses
            try:
                resp_trans = np.expand_dims(
                    spikes_log[:, iresp['valid_cells']], 2)
                feed_dict = {
                    sr_graph.anchor_model.map_cell_grid_tf:
                    iresp['map_cell_grid'],
                    sr_graph.anchor_model.cell_types_tf: iresp['ctype_1hot'],
                    sr_graph.anchor_model.mean_fr_tf:
                    iresp['mean_firing_rate'],
                    sr_graph.anchor_model.responses_tf: resp_trans
                }

                if hasattr(sr_graph.anchor_model, 'dist_nn'):
                    dist_nn = np.array([
                        iresp['dist_nn_cell_type'][1],
                        iresp['dist_nn_cell_type'][2]
                    ]).astype(np.float32)
                    feed_dict.update({
                        sr_graph.anchor_model.dist_nn: dist_nn,
                        sr_graph.neg_model.dist_nn: dist_nn
                    })

                rr = sr_graph.sess.run(sr_graph.anchor_model.responses_embed,
                                       feed_dict=feed_dict)

                retina_log += [{
                    'spikes_log': spikes_log,
                    'alpha_log': alpha_log,
                    'resp_embed': rr,
                    'piece': piece_id
                }]

            except:
                print('Error! ')
                retina_log += [np.nan]
                pass

        t_start_dict.update({time_start: retina_log})

    return t_start_dict
def main(unused_argv=()):

  ## copy data locally
  dst = FLAGS.tmp_dir
  print('Starting Copy')
  if not gfile.IsDirectory(dst):
    gfile.MkDir(dst)

  files = gfile.ListDirectory(FLAGS.src_dir)
  for ifile in files:
    ffile = os.path.join(dst, ifile)
    if not gfile.Exists(ffile):
      gfile.Copy(os.path.join(FLAGS.src_dir, ifile), ffile)
      print('Copied %s' % os.path.join(FLAGS.src_dir, ifile))
    else:
      print('File exists %s' % ffile)

  print('File copied to destination')


  ## load data
  # load stimulus
  data = h5py.File(os.path.join(dst, 'stimulus.mat'))
  stimulus = np.array(data.get('stimulus')) - 0.5

  # load responses from multiple retina
  datasets_list = os.path.join(dst, 'datasets.txt')
  datasets = open(datasets_list, "r").read()
  training_datasets = [line for line in datasets.splitlines()]

  responses = []
  for idata in training_datasets:
    print(idata)
    data_file = os.path.join(dst, idata)
    data = sio.loadmat(data_file)
    responses += [data]
    print(np.max(data['centers'], 0))

  # generate additional features for responses
  num_cell_types = 2
  dimx = 80
  dimy = 40
  for iresp in responses:
    # remove cells which are outside 80x40 window.
    process_dataset(iresp, dimx, dimy, num_cell_types)

  ## generate graph -
  if FLAGS.is_test == 0:
    is_training = True
  if FLAGS.is_test == 1:
    is_training = True # False

  with tf.Session() as sess:

    ## Make graph
    # embed stimulus.
    time_window = 30
    stimx = stimulus.shape[1]
    stimy = stimulus.shape[2]
    stim_tf = tf.placeholder(tf.float32,
                             shape=[None, stimx,
                             stimy, time_window]) # batch x X x Y x time_window
    batch_norm = FLAGS.batch_norm
    stim_embed = embed_stimulus(FLAGS.stim_layers.split(','),
                                batch_norm, stim_tf, is_training,
                                reuse_variables=False)

    '''
    ttf_tf = tf.Variable(np.ones(time_window).astype(np.float32)/10, name='stim_ttf')
    filt = tf.expand_dims(tf.expand_dims(tf.expand_dims(ttf_tf, 0), 0), 3)
    stim_time_filt = tf.nn.conv2d(stim_tf, filt,
                                    strides=[1, 1, 1, 1], padding='SAME') # batch x X x Y x 1


    ilayer = 0
    stim_time_filt = slim.conv2d(stim_time_filt, 1, [3, 3],
                        stride=1,
                        scope='stim_layer_wt_%d' % ilayer,
                        reuse=False,
                        normalizer_fn=slim.batch_norm,
                        activation_fn=tf.nn.softplus,
                        normalizer_params={'is_training': is_training},
                        padding='SAME')
    '''


    # embed responses.
    num_cell_types = 2
    layers = FLAGS.resp_layers  # format: window x filters x stride .. NOTE: final filters=1, stride =1 throughout
    batch_norm = FLAGS.batch_norm
    time_window = 1
    anchor_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                     layers=layers,
                                                     batch_norm=batch_norm,
                                                     is_training=is_training,
                                                     reuse_variables=False,
                                                     num_cell_types=2,
                                                     dimx=dimx, dimy=dimy)

    neg_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                  layers=layers,
                                                  batch_norm=batch_norm,
                                                  is_training=is_training,
                                                  reuse_variables=True,
                                                  num_cell_types=2,
                                                  dimx=dimx, dimy=dimy)

    d_s_r_pos = tf.reduce_sum((stim_embed - anchor_model.responses_embed)**2, [1, 2, 3]) # batch
    d_pairwise_s_rneg = tf.reduce_sum((tf.expand_dims(stim_embed, 1) -
                                 tf.expand_dims(neg_model.responses_embed, 0))**2, [2, 3, 4]) # batch x batch_neg
    beta = 10
    # if FLAGS.save_suffix == 'lr=0.001':
    loss = tf.reduce_sum(beta * tf.reduce_logsumexp(tf.expand_dims(d_s_r_pos / beta, 1) -
                                                    d_pairwise_s_rneg / beta, 1), 0)
    # else :

    # loss = tf.reduce_sum(tf.nn.softplus(1 +  tf.expand_dims(d_s_r_pos, 1) - d_pairwise_s_rneg))
    accuracy_tf =  tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg))

    lr = 0.001
    train_op = tf.train.AdagradOptimizer(lr).minimize(loss)

    # set up training and testing data
    training_datasets_all = [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]
    testing_datasets = [0, 4, 8, 12]
    print('Testing datasets', testing_datasets)

    n_training_datasets_log = [1, 3, 5, 7, 9, 11, 12]

    if (np.floor(FLAGS.taskid / 4)).astype(np.int) < len(n_training_datasets_log):
      # do randomly sampled training data for 0<= FLAGS.taskid < 28
      prng = RandomState(23)
      n_training_datasets = n_training_datasets_log[(np.floor(FLAGS.taskid / 4)).astype(np.int)]
      for _ in range(10*FLAGS.taskid):
        print(prng.choice(training_datasets_all,
                          n_training_datasets, replace=False))
      training_datasets = prng.choice(training_datasets_all,
                                      n_training_datasets, replace=False)

      # training_datasets = [i for i in range(7) if i< 7-FLAGS.taskid] #[0, 1, 2, 3, 4, 5]


    else:
      # do 1 training data, chosen in order for FLAGS.taskid >= 28
      datasets_all = np.arange(16)
      training_datasets = [datasets_all[FLAGS.taskid % (4 * len(n_training_datasets_log))]]

    print('Task ID %d' % FLAGS.taskid)
    print('Training datasets', training_datasets)

    # Initialize stuff.
    file_name = ('end_to_end_stim_%s_resp_%s_beta_%d_taskid_%d'
                 '_training_%s_testing_%s_%s' % (FLAGS.stim_layers,
                                                 FLAGS.resp_layers, beta, FLAGS.taskid,
                                                 str(training_datasets)[1: -1],
                                                 str(testing_datasets)[1: -1],
                                                 FLAGS.save_suffix))
    saver_var, start_iter = initialize_model(FLAGS.save_folder, file_name, sess)

    # print model graph
    PrintModelAnalysis(tf.get_default_graph())

    # Add summary ops
    retina_number = tf.placeholder(tf.int16, name='input_retina');

    summary_ops = []
    for iret in np.arange(len(responses)):
      print(iret)
      r1 = tf.summary.scalar('loss_%d' % iret , loss)
      r2 = tf.summary.scalar('accuracy_%d' % iret , accuracy_tf)
      summary_ops += [tf.summary.merge([r1, r2])]

    # tf.summary.scalar('loss', loss)
    # tf.summary.scalar('accuracy', accuracy_tf)
    # summary_op = tf.summary.merge_all()

    # Setup summary writers
    summary_writers = []
    for loc in ['train', 'test']:
      summary_location = os.path.join(FLAGS.save_folder, file_name,
                                      'summary_' + loc )
      summary_writer = tf.summary.FileWriter(summary_location, sess.graph)
      summary_writers += [summary_writer]


    # start training
    batch_neg_train_sz = 100
    batch_train_sz = 100


    def batch(dataset_id):
      batch_train = get_batch(stimulus, responses[dataset_id]['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   anchor_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   neg_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   neg_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],
                   retina_number : dataset_id}

      return feed_dict

    def batch_few_cells(responses):
      batch_train = get_batch(stimulus, responses['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses['map_cell_grid'],
                   anchor_model.cell_types_tf: responses['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses['map_cell_grid'],
                   neg_model.cell_types_tf: responses['ctype_1hot'],
                   neg_model.mean_fr_tf: responses['mean_firing_rate'],
                   }

      return feed_dict

    if FLAGS.is_test == 1:
      print('Testing')
      save_dict = {}

      from IPython import embed; embed()
      ## Estimate one, fix others
      '''
      grad_resp = tf.gradients(d_s_r_pos, anchor_model.responses_tf)

      t_start = 1000
      t_len = 100
      stim_history = 30
      stim_batch = np.zeros((t_len, stimulus.shape[1],
                         stimulus.shape[2], stim_history))
      for isample, itime in enumerate(np.arange(t_start, t_start + t_len)):
        stim_batch[isample, :, :, :] = np.transpose(stimulus[itime: itime-stim_history:-1, :, :], [1, 2, 0])

      iretina = testing_datasets[0]
      resp_batch = np.expand_dims(np.random.rand(t_len, responses[iretina]['responses'].shape[1]), 2)

      step_sz = 0.01
      eps = 1e-2
      dist_prev = np.inf
      for iiter in range(10000):
        feed_dict = {stim_tf: stim_batch,
                     anchor_model.map_cell_grid_tf: responses[iretina]['map_cell_grid'],
                     anchor_model.cell_types_tf: responses[iretina]['ctype_1hot'],
                     anchor_model.mean_fr_tf: responses[iretina]['mean_firing_rate'],
                     anchor_model.responses_tf: resp_batch}
        dist_np, resp_grad_np = sess.run([d_s_r_pos, grad_resp], feed_dict=feed_dict)
        if np.sum(np.abs(dist_prev - dist_np)) < eps:
          break
        print(np.sum(dist_np), np.sum(np.abs(dist_prev - dist_np)))
        dist_prev = dist_np
        resp_batch = resp_batch - step_sz * resp_grad_np[0]

      resp_batch = resp_batch.squeeze()
      '''


      # from IPython import embed; embed()
      ## compute distances between s-r pairs for small number of cells

      test_retina = []
      for iretina in range(len(testing_datasets)):
        dataset_id = testing_datasets[iretina]

        num_cells_total = responses[dataset_id]['responses'].shape[1]
        dataset_center = responses[dataset_id]['centers'].mean(0)
        dataset_cell_distances = np.sqrt(np.sum((responses[dataset_id]['centers'] -
                                         dataset_center), 1))
        order_cells = np.argsort(dataset_cell_distances)

        test_sr_few_cells = {}
        for num_cells_prc in [5, 10, 20, 30, 50, 100]:
          num_cells = np.percentile(np.arange(num_cells_total),
                                    num_cells_prc).astype(np.int)

          choose_cells = order_cells[:num_cells]

          resposnes_few_cells = {'responses': responses[dataset_id]['responses'][:, choose_cells],
                                 'map_cell_grid': responses[dataset_id]['map_cell_grid'][:, :, choose_cells],
                                'ctype_1hot': responses[dataset_id]['ctype_1hot'][choose_cells, :],
                                'mean_firing_rate': responses[dataset_id]['mean_firing_rate'][choose_cells]}
          # get a batch
          d_pos_log = np.array([])
          d_neg_log = np.array([])
          for test_iter in range(1000):
            print(iretina, num_cells_prc, test_iter)
            feed_dict = batch_few_cells(resposnes_few_cells)
            d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
            d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
            d_pos_log = np.append(d_pos_log, d_pos)
            d_neg_log = np.append(d_neg_log, d_neg)

          precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

          print(np.sum(d_pos_log > d_neg_log))
          print(np.sum(d_pos_log < d_neg_log))
          test_sr= {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                     'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log,
                    'num_cells': num_cells}

          test_sr_few_cells.update({'num_cells_prc_%d' % num_cells_prc : test_sr})
        test_retina += [test_sr_few_cells]
      save_dict.update({'few_cell_analysis': test_retina})

      ## compute distances between s-r pairs - pos and neg.

      test_retina = []
      for iretina in range(len(testing_datasets)):
        # stim-resp log
        d_pos_log = np.array([])
        d_neg_log = np.array([])
        for test_iter in range(1000):
          print(test_iter)
          feed_dict = batch(testing_datasets[iretina])
          d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
          d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
          d_pos_log = np.append(d_pos_log, d_pos)
          d_neg_log = np.append(d_neg_log, d_neg)

        precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

        print(np.sum(d_pos_log > d_neg_log))
        print(np.sum(d_pos_log < d_neg_log))
        test_sr = {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                   'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log}

        test_retina += [test_sr]

      save_dict.update({'test_sr': test_retina})


      ## ROC curves of responses from repeats - dataset 1
      repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      repeats_data['cell_type'] = repeats_data['cell_type'].T

      # process repeats data
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      test_reps = analyse_response_repeats(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2015-09-23-7': test_reps})

      ## ROC curves of responses from repeats - dataset 2
      repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      '''
      test_clustering = analyse_response_repeats_all_trials(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2005_08_03_0': test_clustering})
      '''
      #
      # get model params
      save_dict.update({'model_pars': sess.run(tf.trainable_variables())})


      save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
      print(save_analysis_filename)
      return

    test_iiter = 0
    for iiter in range(start_iter, FLAGS.max_iter): # TODO(bhaishahster) :add FLAGS.max_iter

      # get a new batch
      # stim_tf, anchor_model.responses_tf, neg_model.responses_tf

      # training step
      train_dataset = training_datasets[iiter % len(training_datasets)]
      feed_dict_train = batch(train_dataset)
      _, loss_np_train = sess.run([train_op, loss], feed_dict=feed_dict_train)
      print(train_dataset, loss_np_train)

      # write summary
      if iiter % 10 == 0:
        # write train summary
        test_iiter = test_iiter + 1

        train_dataset = training_datasets[test_iiter % len(training_datasets)]
        feed_dict_train = batch(train_dataset)
        summary_train = sess.run(summary_ops[train_dataset], feed_dict=feed_dict_train)
        summary_writers[0].add_summary(summary_train, iiter)

        # write test summary
        test_dataset = testing_datasets[test_iiter % len(testing_datasets)]
        feed_dict_test = batch(test_dataset)
        l_test, summary_test = sess.run([loss, summary_ops[test_dataset]], feed_dict=feed_dict_test)
        summary_writers[1].add_summary(summary_test, iiter)
        print('Test retina: %d, loss: %.3f' % (test_dataset, l_test))

      # save model
      if iiter % 10 == 0:
        save_model(saver_var, FLAGS.save_folder, file_name, sess, iiter)
def main(unused_argv=()):

    #np.random.seed(23)
    #tf.set_random_seed(1234)
    #random.seed(50)

    # 1. Load stimulus-response data.
    # Collect population response across retinas in the list 'responses'.
    # Stimulus for each retina is indicated by 'stim_id',
    # which is found in 'stimuli' dictionary.
    datasets = gfile.ListDirectory(FLAGS.src_dir)
    stimuli = {}
    responses = []
    for icnt, idataset in enumerate(datasets):

        fullpath = os.path.join(FLAGS.src_dir, idataset)
        if gfile.IsDirectory(fullpath):
            key = 'stim_%d' % icnt
            op = data_util.get_stimulus_response(
                FLAGS.src_dir,
                idataset,
                key,
                boundary=FLAGS.valid_cells_boundary)
            stimulus, resp, dimx, dimy, _ = op

            stimuli.update({key: stimulus})
            responses += resp

    # 2. Do response prediction for a retina
    iretina = FLAGS.taskid
    subunit_fit_loc = FLAGS.save_folder
    subunits_datasets = gfile.ListDirectory(subunit_fit_loc)

    piece = responses[iretina]['piece']
    matched_dataset = [
        ifit for ifit in subunits_datasets if piece[:12] == ifit[:12]
    ]
    if matched_dataset == []:
        raise ValueError('Could not find subunit fit')

    subunit_fit_path = os.path.join(subunit_fit_loc, matched_dataset[0])
    stimulus = stimuli[responses[iretina]['stimulus_key']]

    # sample test data
    stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :]

    # Optionally, create a null stimulus for all the cells.

    resp_ret = responses[iretina]

    if FLAGS.is_null:
        # Make null stimulus
        stimulus_test = get_null_stimulus(resp_ret, subunit_fit_path,
                                          stimulus_test)

    resp_su = get_su_spks(subunit_fit_path, stimulus_test, responses[iretina])

    save_dict = {
        'resp_su': resp_su.astype(np.int8),
        'cell_ids': responses[iretina]['cellID_list'].squeeze()
    }

    if FLAGS.is_null:
        save_dict.update({'stimulus_null': stimulus_test})
        save_suff = '_null'
    else:
        save_suff = ''

    pickle.dump(
        save_dict,
        gfile.Open(
            os.path.join(subunit_fit_path,
                         'response_prediction%s.pkl' % save_suff), 'w'))