Пример #1
0
  def initialize_folder(self, save_location, folder_name):

    parent_folder = os.path.join(save_location, folder_name)
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    self.parent_folder = parent_folder

    save_location = os.path.join(parent_folder, self.short_filename)
    if not gfile.IsDirectory(save_location):
      gfile.MkDir(save_location)
    self.save_location = save_location

    self.save_filename = os.path.join(self.save_location, self.short_filename)
Пример #2
0
def inputs(name, data_location, batch_size, num_epochs, stim_dim, resp_dim):
    # gives a batch of stimulus and responses from a .tfrecords file
    # works for .tfrecords file made using CoarseDataUtils.convert_to_TFRecords

    # Get filename queue.
    # Actual name is either 'name', 'name.tfrecords' or
    # folder 'name' with list of .tfrecords files.
    with tf.name_scope('input'):
        filename = os.path.join(data_location, name)
        filename_extension = os.path.join(data_location, name + '.tfrecords')
        if gfile.Exists(filename) and not gfile.IsDirectory(filename):
            tf.logging.info('%s Exists' % filename)
            filenames = [filename]
        elif gfile.Exists(filename_extension
                          ) and not gfile.IsDirectory(filename_extension):
            tf.logging.info('%s Exists' % filename_extension)
            filenames = [filename_extension]
        elif gfile.IsDirectory(filename):
            tf.logging.info('%s Exists and is a directory' % filename)
            filenames_short = gfile.ListDirectory(filename)
            filenames = [
                os.path.join(filename, ifilenames_short)
                for ifilenames_short in filenames_short
            ]
        tf.logging.info(filenames)
        filename_queue = tf.train.string_input_producer(filenames,
                                                        num_epochs=num_epochs,
                                                        capacity=10000)

    # Even when reading in multiple threads, share the filename
    # queue.
    stimulus, response = read_and_decode(filename_queue, stim_dim, resp_dim)

    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.)
    # We run this in two threads to avoid being a bottleneck.

    stimulus_batch, response_batch = tf.train.shuffle_batch(
        [stimulus, response],
        batch_size=batch_size,
        num_threads=30,
        capacity=5000 + 3 * batch_size,
        # Ensures a minimum amount of shuffling of examples.
        min_after_dequeue=2000)
    '''
  stimulus_batch, response_batch = tf.train.batch(
      [stimulus, response], batch_size=batch_size, num_threads=30,
      capacity = 50000 + 3 * batch_size)
  '''
    return stimulus_batch, response_batch
  def initialize_folder(self, save_location, folder_name):
    """Intialize saving location of the model."""

    parent_folder = os.path.join(save_location, folder_name)
    # Make folder if it does not exist.
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    self.parent_folder = parent_folder

    save_location = os.path.join(parent_folder, self.short_filename)
    if not gfile.IsDirectory(save_location):
      gfile.MkDir(save_location)
    self.save_location = save_location

    self.save_filename = os.path.join(self.save_location, self.short_filename)
def main(unused_argv=()):

  # Load stimulus-response data
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  stimuli = {}
  responses = []
  print(datasets)
  for icnt, idataset in enumerate(datasets):
    #for icnt, idataset in enumerate([datasets[2]]):
    #  print('HACK - only one dataset used!!')

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key)
      stimulus, resp, dimx, dimy, num_cell_types = op

      stimuli.update({key: stimulus})
      responses += resp

  print('# Responses %d' % len(responses))
  stimulus = stimuli[responses[FLAGS.taskid]['stimulus_key']]
  save_filename = ('linear_taskid_%d_piece_%s.pkl' % (FLAGS.taskid, responses[FLAGS.taskid]['piece']))
  print(save_filename)
  learn_lin_embedding(stimulus, np.double(responses[FLAGS.taskid]['responses']),
                      filename=save_filename,
                      lam_l1=0.00001, beta=10, time_window=30,
                      lr=0.01)


  print('DONE!')
Пример #5
0
    def convert_to_TFRecords_chunks(self,
                                    prefix,
                                    save_location='~/tmp',
                                    examples_per_file=1000):
        # converts the stimulus-response data into a TFRecords file.

        tf.logging.info('making TFRecords chunks')
        stimulus = self.stimulus.astype(np.float32)
        response = self.response.astype(np.float32)
        num_examples = stimulus.shape[0]
        num_files = np.ceil(num_examples / examples_per_file).astype(np.int)
        tf.logging.info('Number of files: %d, examples: %d' %
                        (num_files, num_examples))

        def _bytes_feature(value):
            # value: bytes to convert to tf.train.Feature
            return tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[value]))

        # make folder for storing .tfrecords files
        folder_tfrec = os.path.join(save_location, prefix)
        # Make folder if it does not exist.
        if not gfile.IsDirectory(folder_tfrec):
            tf.logging.info('making folder to store tfrecords')
            gfile.MkDir(folder_tfrec)
        else:
            tf.logging.info('folder exists, will overwrite results')

        index = -1
        for ifile in range(num_files):
            filename = os.path.join(folder_tfrec,
                                    'chunk_' + str(ifile) + '.tfrecords')
            tf.logging.info('Writing %s , starting index %d' %
                            (filename, index + 1))
            writer = tf.python_io.TFRecordWriter(filename)
            for iexample in range(examples_per_file):
                index += 1
                stimulus_raw = stimulus[index, :].tostring()
                response_raw = response[index, :].tostring()
                #print(index, stimulus[index,:].shape, response[index,:].shape)
                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'stimulus': _bytes_feature(stimulus_raw),
                        'response': _bytes_feature(response_raw)
                    }))
                writer.write(example.SerializeToString())
            writer.close()
def initialize_model(save_folder, file_name, sess):
  """Setup model variables and saving information.

  Args:
    save_folder (string) : Folder to store model.
                           Makes one if it does not exist.
    filename (string) : Prefix of model/checkpoint files.
    sess : Tensorflow session.
  """

  # Make folder.
  if not gfile.IsDirectory(save_folder):
    gfile.MkDir(save_folder)

  # Initialize variables.
  saver_var, start_iter = initialize_variables(sess, save_folder, file_name)
  return saver_var, start_iter
Пример #7
0
def main(unused_argv=()):

  # Load stimulus-response data
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  responses = []
  print(datasets)
  for icnt, idataset in enumerate([datasets]): #TODO(bhaishahster): HACK.

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key)
      stimulus, resp, dimx, dimy, num_cell_types = op

      responses += resp

    for idataset in range(len(responses)):
      k, b, ttf = fit_ln_population(responses[idataset]['responses'], stimulus)  # Use FLAGS.taskID
      save_dict = {'k': k, 'b': b, 'ttf': ttf}

      save_analysis_filename = os.path.join(FLAGS.save_folder,
                                            responses[idataset]['piece']
                                            + '_ln_model.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
Пример #8
0
def RunComputation():

    # filename for saving file
    if FLAGS.architecture == '2 layer_stimulus':
        architecture_string = ('_architecture=' + str(FLAGS.architecture) +
                               '_stim_downsample_window=' +
                               str(FLAGS.stim_downsample_window) +
                               '_stim_downsample_stride=' +
                               str(FLAGS.stim_downsample_stride))
    else:
        architecture_string = ('_architecture=' + str(FLAGS.architecture))

    short_filename = ('model=' + str(FLAGS.model_id) + '_loss=' +
                      str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) +
                      '_lam_w=' + str(FLAGS.lam_w) + '_step_sz' +
                      str(FLAGS.step_sz) + '_tlen=' + str(FLAGS.train_len) +
                      '_window=' + str(FLAGS.window) + '_stride=' +
                      str(FLAGS.stride) + str(architecture_string) + '_jitter')

    # make a folder with name derived from parameters of the algorithm - it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print('Does the file exist?', gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)

    save_filename = FLAGS.save_location + short_filename
    """Main function which runs all TensorFlow computations."""
    with tf.Graph().as_default() as gra:
        with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)):
            print(FLAGS.config_params)
            tf.logging.info(FLAGS.config_params)
            # set up training dataset
            tc_mean = get_data_mat.init_chunks(FLAGS.n_chunks)
            '''
      # plot histogram of a training dataset
      stim_train, resp_train, train_len = get_data_mat.get_stim_resp('train',
                                                                     num_chunks=FLAGS.num_chunks_to_load)
      plt.hist(np.ndarray.flatten(stim_train[:,:,0:]))
      plt.show()
      plt.draw()
      '''
            # Create computation graph.
            #
            # Graph should be fully constructed before you create supervisor.
            # Attempt to modify graph after supervisor is created will cause an error.

            with tf.name_scope('model'):
                if FLAGS.architecture == '1 layer':
                    # single GPU model
                    if False:
                        global_step = tf.contrib.framework.create_global_step()
                        model, stim, resp = jitter_model.approximate_conv_jitter(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels)

                    # multiGPU model
                    if True:
                        model, stim, resp, global_step = jitter_model.approximate_conv_jitter_multigpu(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params)

                if FLAGS.architecture == '2 layer_stimulus':
                    # stimulus is first smoothened to lower dimensions, then same model is applied
                    print(' put stimulus to lower dimenstions!')
                    model, stim, resp, global_step, stim_tuple = jitter_model.approximate_conv_jitter_multigpu_stim_lr(
                        FLAGS.n_cells, FLAGS.lam_w, FLAGS.window, FLAGS.stride,
                        FLAGS.step_sz, tc_mean, FLAGS.su_channels,
                        FLAGS.config_params, FLAGS.stim_downsample_window,
                        FLAGS.stim_downsample_stride)

            # Print the number of variables in graph
            print('Calculating model size')  # Hope we do not exceed memory
            PrintModelAnalysis(gra, max_depth=10)
            #import pdb; pdb.set_trace()

            # Builds our summary op.
            summary_op = model.merged_summary

            # Create a Supervisor.  It will take care of initialization, summaries,
            # checkpoints, and recovery.
            #
            # When multiple replicas of this program are running, the first one,
            # identified by --task=0 is the 'chief' supervisor.  It is the only one
            # that takes case of initialization, etc.
            is_chief = (FLAGS.task == 0)  # & (FLAGS.learn==1)
            print(save_filename)
            if FLAGS.learn == 1:
                # use supervisor only for learning,
                # otherwise it messes up data as it tries to store variables while you are doing analysis

                sv = tf.train.Supervisor(logdir=save_filename,
                                         is_chief=is_chief,
                                         saver=tf.train.Saver(),
                                         summary_op=None,
                                         save_model_secs=100,
                                         global_step=global_step,
                                         recovery_wait_secs=5)

                if (is_chief and FLAGS.learn == 1):
                    tf.train.write_graph(tf.get_default_graph().as_graph_def(),
                                         save_filename, 'graph.pbtxt')

                # Get an initialized, and possibly recovered session.  Launch the
                # services: Checkpointing, Summaries, step counting.
                #
                # When multiple replicas of this program are running the services are
                # only launched by the 'chief' replica.
                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                #import pdb; pdb.set_trace()
                sess = sv.PrepareSession(FLAGS.master, config=session_config)

                FitComputation(sv, sess, model, stim, resp, global_step,
                               summary_op, stim_tuple)
                sv.Stop()

            else:
                # if not learn, then analyse

                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                with tf.Session(config=session_config) as sess:
                    saver_var = tf.train.Saver(
                        tf.all_variables(),
                        keep_checkpoint_every_n_hours=float('inf'))
                    restore_file = tf.train.latest_checkpoint(save_filename)
                    print(restore_file)
                    start_iter = int(
                        restore_file.split('/')[-1].split('-')[-1])
                    saver_var.restore(sess, restore_file)

                    if FLAGS.architecture == '2 layer_stimulus':
                        AnalyseModel_lr(sess, model)
                    else:
                        AnalyseModel(sv, sess, model)
Пример #9
0
def main(argv):
    print('\nCode started')

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = tot_spks[cells_choose]
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    #short_filename = 'data_model=ASM_pop_bg'
    #  short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) +
    #    '_step_sz'+ str(FLAGS.step_sz)+'_bg')

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'logistic':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    if FLAGS.model_id == 'poisson_full':
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        # variables
        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.1 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))
        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.random.randn(
                n_cells
            )  #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0)))
            b = tf.Variable(b_init, dtype='float32')

    # get relevant files
        file_list = gfile.ListDirectory(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename
        print('\nLoading: ', save_filename)
        bin_files = []
        meta_files = []
        for file_n in file_list:
            if re.search(short_filename + '.', file_n):
                if re.search('.meta', file_n):
                    meta_files += [file_n]
                else:
                    bin_files += [file_n]
        #print(bin_files)
        print(len(meta_files), len(bin_files), len(file_list))

        # get iteration numbers
        iterations = np.array([])
        for file_name in bin_files:
            try:
                iterations = np.append(
                    iterations, int(file_name.split('/')[-1].split('-')[-1]))
            except:
                print('Could not load filename: ' + file_name)
        iterations.sort()
        print(iterations)

        iter_plot = iterations[-1]
        print(int(iter_plot))

        # load tensorflow variables
        saver_var = tf.train.Saver(tf.all_variables())

        restore_file = save_filename + '-' + str(int(iter_plot))
        saver_var.restore(sess, restore_file)

        a_eval = a.eval()
        print(np.exp(np.squeeze(a_eval)))
        #print(np.shape(a_eval))

        # get 2D region to plot
        mask2D = np.reshape(chosen_mask, [40, 80])
        nz_idx = np.nonzero(mask2D)
        np.shape(nz_idx)
        print(nz_idx)
        ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1])
        xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1])
        w_eval = w.eval()

        plt.figure()
        n_su = w_eval.shape[1]
        for isu in np.arange(n_su):
            xx = np.zeros((3200))
            xx[chosen_mask] = w_eval[:, isu]
            fig = plt.subplot(np.ceil(np.sqrt(n_su)), np.ceil(np.sqrt(n_su)),
                              isu + 1)
            plt.imshow(np.reshape(xx, [40, 80]),
                       interpolation='nearest',
                       cmap='gray')
            plt.ylim(ylim)
            plt.xlim(xlim)
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            #if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge':
            #  plt.title(str(a_eval[isu, :]))
            #else:
            #  plt.title(str(np.squeeze(np.exp(a_eval[:, 0, isu]))), fontsize=12)

        plt.suptitle('Iteration:' + str(int(iter_plot)) + ' batchSz:' +
                     str(FLAGS.batchsz) + ' step size:' + str(FLAGS.step_sz),
                     fontsize=18)
        plt.show()
        plt.draw()
def main(argv):

    # global variables will be used for getting training data
    global cells_choose
    global chosen_mask
    global chunk_order

    # set random seeds: when same algorithm run with different FLAGS,
    # the sequence of random data is same.
    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    # initial chunk order (will be re-shuffled everytime we go over a chunk)
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    # Load data summary
    data_filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(data_filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])

    # which cells to train subunits for
    if FLAGS.all_cells == 'True':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    else:
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    n_cells = np.sum(cells_choose)  # number of cells

    # load spikes and relevant stimulus pixels for chosen cells
    tot_spks = np.squeeze(data_summary['tot_spks'])
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    # chosen_mask = which pixels to learn subunits over
    if FLAGS.masked_stimulus == 'True':
        chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                               dtype='bool')
    else:
        chosen_mask = np.array(np.ones(3200).astype('bool'))
    stim_dim = np.sum(chosen_mask)  # stimulus dimensions
    print('\ndataset summary loaded')

    # print parameters
    print('Save folder name: ' + str(FLAGS.folder_name) + '\nmodel:' +
          str(FLAGS.model_id) + '\nLoss:' + str(FLAGS.loss) +
          '\nmasked stimulus:' + str(FLAGS.masked_stimulus) + '\nall_cells?' +
          str(FLAGS.all_cells) + '\nbatch size' + str(FLAGS.batchsz) +
          '\nstep size' + str(FLAGS.step_sz) + '\ntraining length: ' +
          str(FLAGS.train_len) + '\nn_cells: ' + str(n_cells))

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # filename for saving file
    short_filename = ('_masked_stim=' + str(FLAGS.masked_stimulus) +
                      '_all_cells=' + str(FLAGS.all_cells) + '_loss=' +
                      str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) +
                      '_step_sz' + str(FLAGS.step_sz) + '_tlen=' +
                      str(FLAGS.train_len) + '_bg')

    with tf.Session() as sess:
        # set up stimulus and response palceholders
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.loss == 'poisson':
            b_init = np.array(
                0.000001 * np.ones(n_cells)
            )  # a very small positive bias needed to avoid log(0) in poisson loss
        else:
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)
            )  # log-odds, a good initialization for some losses (like logistic)

        # different firing rate models
        if FLAGS.model_id == 'exp_additive':
            # This model was implemented for earlier work.
            # firing rate for cell c: lam_c = sum_s exp(w_s.x + a_sc)

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                                     dtype='float32'),
                            name='a')
            # firing rate model
            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            regularization = 0
            vars_fit = [w, a]

            def proj(
            ):  # called after every training step - to project to parameter constraints
                pass

        if FLAGS.model_id == 'relu':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                              '_nsu=' + str(n_su) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_su, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            # firing rate model
            lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # regularization of parameters
            regularization = (FLAGS.lam_w * tf.reduce_sum(tf.abs(w)) +
                              FLAGS.lam_a * tf.reduce_sum(tf.abs(a)))
            # projection to satisfy constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows(
            )  # get convolutional windows

            # variables
            w = tf.Variable(np.array(0.1 +
                                     0.05 * np.random.rand(dimx, dimy, n_pix),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_wts = tf.nn.relu(tf.reduce_sum(tf.mul(stim_masked, w), 3))
            # get firing rate
            lam = tf.matmul(tf.reshape(stim_wts, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to learn
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            #  stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm':
            # firing rate for cell c: lam_c = a_sfm_c'.relu(w.x) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to fit
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm_exp':
            # firing rate for cell c: lam_c = exp(a_sfm_c'.relu(w.x)) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.
            # exponential output NL would cancel the log() in poisson and might get better estimation properties.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            # get windows
            mask_tf, dimx, dimy, n_pix = get_windows()

            # declare variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # filter stimulus
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # get subunit activation
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get cell firing rates
            lam = tf.exp(
                tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm)) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        # different loss functions
        if FLAGS.loss == 'poisson':
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len

        if FLAGS.loss == 'logistic':
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * lam)) / data_len

        if FLAGS.loss == 'hinge':
            loss_inter = tf.reduce_sum(
                tf.nn.relu(1 - 2 * (resp - 0.5) * lam)) / data_len

        loss = loss_inter + regularization  # add regularization to get final loss function

        # training consists of calling training()
        # which performs a train step and
        # project parameters to model specific constraints using proj()
        train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
            loss, var_list=vars_fit)

        def training(inp_dict):
            sess.run(train_step,
                     feed_dict=inp_dict)  # one step of gradient descent
            proj()  # model specific projection operations

        # evaluate loss on given data.
        def get_loss(inp_dict):
            ls = sess.run(loss, feed_dict=inp_dict)
            return ls

        # saving details
        # make a folder with name derived from parameters of the algorithm
        # - it saves checkpoint files and summaries used in tensorboard
        parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
        # make folder if it does not exist
        if not gfile.IsDirectory(parent_folder):
            gfile.MkDir(parent_folder)
        FLAGS.save_location = parent_folder + short_filename + '/'
        if not gfile.IsDirectory(FLAGS.save_location):
            gfile.MkDir(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename

        # create summary writers
        # create histogram summary for all parameters which are learnt
        for ivar in vars_fit:
            tf.histogram_summary(ivar.name, ivar)
        # loss summary
        l_summary = tf.scalar_summary('loss', loss)
        # loss without regularization summary
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)
        # Merge all the summary writer ops into one op (this way,
        # calling one op stores all summaries)
        merged = tf.merge_all_summaries()
        # training and testing has separate summary writers
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        ## Fitting procedure
        print('Start fitting')
        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            # restore previous fits if they are available
            # - useful when programs are preempted frequently on .
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            # restore previous iteration count and start from there.
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)  # restore variables
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('Previous results loaded')
        else:
            print('Variables initialized')

        # Finally, do fitting
        icnt = 0
        # get test data and make test dictionary
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        for istep in np.arange(start_iter, 400000):
            print(istep)
            # get training data and make test dictionary
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }

            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)

                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
Пример #11
0
def main(unused_argv=()):

    np.random.seed(23)
    tf.set_random_seed(1234)
    random.seed(50)

    # Load stimulus-response data.
    # Collect population response across retinas in the list 'responses'.
    # Stimulus for each retina is indicated by 'stim_id',
    # which is found in 'stimuli' dictionary.
    datasets = gfile.ListDirectory(FLAGS.src_dir)
    stimuli = {}
    responses = []
    for icnt, idataset in enumerate(datasets):
        fullpath = os.path.join(FLAGS.src_dir, idataset)
        if gfile.IsDirectory(fullpath):
            key = 'stim_%d' % icnt
            op = data_util.get_stimulus_response(
                FLAGS.src_dir,
                idataset,
                key,
                boundary=FLAGS.valid_cells_boundary,
                if_get_stim=True)
            stimulus, resp, dimx, dimy, _ = op
            stimuli.update({key: stimulus})
            responses += resp

    # Get training and testing partitions.
    # Generate partitions
    # The partitions for the taskid should be listed in partition_file.

    op = partitions.get_partitions(FLAGS.partition_file, FLAGS.taskid)
    training_datasets, testing_datasets = op

    with tf.Session() as sess:

        # Get stimulus-response embedding.
        if FLAGS.mode == 0:
            is_training = True
        if FLAGS.mode == 1:
            is_training = True
        if FLAGS.mode == 2:
            is_training = True
            print('NOTE: is_training = True in test')
        if FLAGS.mode == 3:
            is_training = True
            print('NOTE: is_training = True in test')

        sample_fcn = sample_datasets
        if (FLAGS.sr_model == 'convolutional_embedding'):
            embedding = sr_models.convolutional_embedding(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if (FLAGS.sr_model == 'convolutional_embedding_expt'
                or FLAGS.sr_model == 'convolutional_embedding_margin_expt' or
                FLAGS.sr_model == 'convolutional_embedding_inner_product_expt'
                or FLAGS.sr_model == 'convolutional_embedding_gauss_expt'
                or FLAGS.sr_model == 'convolutional_embedding_kernel_expt'):
            embedding = sr_models_expt.convolutional_embedding_experimental(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_autoembedder':
            embedding = sr_models_expt.convolutional_autoembedder(
                sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_autoembedder_l2':
            embedding = sr_models_expt.convolutional_autoembedder(
                sess, is_training, dimx, dimy, loss='log_sum_exp')

        if FLAGS.sr_model == 'convolutional_encoder' or FLAGS.sr_model == 'convolutional_encoder_2':
            embedding = encoding_models_expt.convolutional_encoder(
                sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'convolutional_encoder_using_retina_id':
            model = encoding_models_expt.convolutional_encoder_using_retina_id
            embedding = model(sess, is_training, dimx, dimy, len(responses))
            sample_fcn = sample_datasets_2

        if (FLAGS.sr_model == 'residual') or (FLAGS.sr_model
                                              == 'residual_inner_product'):
            embedding = sr_models_expt.residual_experimental(
                FLAGS.sr_model, sess, is_training, dimx, dimy)

        if FLAGS.sr_model == 'lin_rank1' or FLAGS.sr_model == 'lin_rank1_blind':
            if ((len(training_datasets) != 1)
                    and (training_datasets != testing_datasets)):
                raise ValueError('Identical training/testing data'
                                 ' (exactly 1) supported')

            n_cells = responses[training_datasets[0]]['responses'].shape[1]
            cell_locations = responses[training_datasets[0]]['map_cell_grid']
            cell_masks = responses[training_datasets[0]]['mask_cells']
            firing_rates = responses[training_datasets[0]]['mean_firing_rate']
            cell_type = responses[training_datasets[0]]['cell_type'].squeeze()

            model_fn = sr_baseline_models.linear_rank1_models
            embedding = model_fn(FLAGS.sr_model,
                                 sess,
                                 dimx,
                                 dimy,
                                 n_cells,
                                 center_locations=cell_locations,
                                 cell_masks=cell_masks,
                                 firing_rates=firing_rates,
                                 cell_type=cell_type,
                                 time_window=30)

        # print model graph
        PrintModelAnalysis(tf.get_default_graph())

        # Get filename, initialize model
        file_name = bookkeeping.get_filename(training_datasets,
                                             testing_datasets, FLAGS.beta,
                                             FLAGS.sr_model)
        tf.logging.info('Filename: %s' % file_name)
        saver_var, start_iter = bookkeeping.initialize_model(
            FLAGS.save_folder, file_name, sess)

        # Setup summary ops.
        # Save separate summary for each retina (both training/testing).
        summary_ops = []
        for iret in np.arange(len(responses)):
            r_list = []
            r1 = tf.summary.scalar('loss_%d' % iret, embedding.loss)
            r_list += [r1]

            if hasattr(embedding, 'accuracy_tf'):
                r2 = tf.summary.scalar('accuracy_%d' % iret,
                                       embedding.accuracy_tf)
                r_list += [r2]

            if FLAGS.sr_model == 'convolutional_autoembedder' or FLAGS.sr_model == 'convolutional_autoembedder_l2':
                r3 = tf.summary.scalar('loss_triplet_%d' % iret,
                                       embedding.loss_triplet)
                r4 = tf.summary.scalar('loss_stim_decode_from_resp_%d' % iret,
                                       embedding.loss_stim_decode_from_resp)
                r5 = tf.summary.scalar('loss_stim_decode_from_stim_%d' % iret,
                                       embedding.loss_stim_decode_from_stim)
                r6 = tf.summary.scalar('loss_resp_decode_from_resp_%d' % iret,
                                       embedding.loss_resp_decode_from_resp)
                r7 = tf.summary.scalar('loss_resp_decode_from_stim_%d' % iret,
                                       embedding.loss_resp_decode_from_stim)
                r_list += [r3, r4, r5, r6, r7]
                '''
        chosen_stim = 2
        bound = FLAGS.valid_cells_boundary
        
        r8 = tf.summary.image('stim_decode_from_stim_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3))

        r9 = tf.summary.image('stim_decode_from_resp_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3))

        r10 = tf.summary.image('resp_decode_from_stim_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r11 = tf.summary.image('resp_decode_from_resp_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r12 = tf.summary.image('resp_decode_from_stim_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r13 = tf.summary.image('resp_decode_from_resp_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r14 = tf.summary.image('resp_chann0_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3))

        r15 = tf.summary.image('resp_chann1_%d' % iret,
                              tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3))

        r_list += [r8, r9, r10, r11, r12, r13, r14, r15]
        '''

            summary_ops += [tf.summary.merge(r_list)]

        # Setup summary writers.
        summary_writers = []
        for loc in ['train', 'test']:
            summary_location = os.path.join(FLAGS.save_folder, file_name,
                                            'summary_' + loc)
            summary_writer = tf.summary.FileWriter(summary_location,
                                                   sess.graph)
            summary_writers += [summary_writer]

        # Separate tests for encoding or metric learning,
        #  prosthesis usage or just neuroscience usage.
        if FLAGS.mode == 3:
            testing.test_encoding(training_datasets, testing_datasets,
                                  responses, stimuli, embedding, sess,
                                  file_name, sample_fcn)

        elif FLAGS.mode == 2:
            prosthesis.stimulate(embedding, sess, file_name, dimx, dimy)

        elif FLAGS.mode == 1:
            testing.test_metric(training_datasets, testing_datasets, responses,
                                stimuli, embedding, sess, file_name)

        else:
            training.training(start_iter,
                              sess,
                              embedding,
                              summary_writers,
                              summary_ops,
                              saver_var,
                              training_datasets,
                              testing_datasets,
                              responses,
                              stimuli,
                              file_name,
                              sample_fcn,
                              summary_freq=500,
                              save_freq=500)
Пример #12
0
def RunComputation():

    # filename for saving files, derived from FLAGS.
    short_filename = get_filename()

    # make a folder with name derived from parameters of the algorithm
    # it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print('Does the file exist?', gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    if FLAGS.learn == 0:
        # for analysis, use smaller batch sizes, so that we can work with single GPU.
        FLAGS.batchsz = 600

    #Set up tensorflow
    with tf.Graph().as_default() as gra:
        with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)):
            print(FLAGS.config_params)
            tf.logging.info(FLAGS.config_params)

            # set up training dataset
            # tc_mean = get_data_mat.init_chunks(FLAGS.n_chunks) <- use this with old get_data_mat
            tc_mean = get_data_mat.init_chunks(FLAGS.batchsz)
            #plt.plot(tc_mean)
            #plt.show()
            #plt.draw()

            # Create computation graph.
            #
            # Graph should be fully constructed before you create supervisor.
            # Attempt to modify graph after supervisor is created will cause an error.
            with tf.name_scope('model'):
                if FLAGS.architecture == '1 layer':
                    # single GPU model
                    if False:
                        global_step = tf.contrib.framework.create_global_step()
                        model, stim, resp = jitter_model.approximate_conv_jitter(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels)

                    # multiGPU model
                    if True:
                        model, stim, resp, global_step = jitter_model.approximate_conv_jitter_multigpu(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params)

                if FLAGS.architecture == '2 layer_stimulus':
                    # stimulus is first smoothened to lower dimensions, then same model is applied
                    print('First take a low resolution version of stimulus')
                    model, stim, resp, global_step, stim_tuple = (
                        jitter_model.approximate_conv_jitter_multigpu_stim_lr(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params,
                            FLAGS.stim_downsample_window,
                            FLAGS.stim_downsample_stride))

                if FLAGS.architecture == 'complex':
                    print(' Multiple modifications over 2 layered model above')
                    model, stim, resp, global_step = (
                        jitter_model_2.
                        approximate_conv_jitter_multigpu_complex(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params,
                            FLAGS.stim_downsample_window,
                            FLAGS.stim_downsample_stride))

            # Print the number of variables in graph
            print('Calculating model size')  # Hope we do not exceed memory
            PrintModelAnalysis(gra, max_depth=10)

            # Builds our summary op.
            summary_op = model.merged_summary

            # Create a Supervisor.  It will take care of initialization, summaries,
            # checkpoints, and recovery.
            #
            # When multiple replicas of this program are running, the first one,
            # identified by --task=0 is the 'chief' supervisor.  It is the only one
            # that takes case of initialization, etc.
            is_chief = (FLAGS.task == 0)  # & (FLAGS.learn==1)
            print(save_filename)

            if FLAGS.learn == 1:
                # use supervisor only for learning,
                # otherwise it messes up data as it tries to store variables while you are doing analysis

                sv = tf.train.Supervisor(logdir=save_filename,
                                         is_chief=is_chief,
                                         saver=tf.train.Saver(),
                                         summary_op=None,
                                         save_model_secs=100,
                                         global_step=global_step,
                                         recovery_wait_secs=5)

                if (is_chief and FLAGS.learn == 1):
                    # save graph only if task id =0 (is_chief) and learning the model
                    tf.train.write_graph(tf.get_default_graph().as_graph_def(),
                                         save_filename, 'graph.pbtxt')

                # Get an initialized, and possibly recovered session.  Launch the
                # services: Checkpointing, Summaries, step counting.
                #
                # When multiple replicas of this program are running the services are
                # only launched by the 'chief' replica.
                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                sess = sv.PrepareSession(FLAGS.master, config=session_config)

                # Finally, learn the parameters of the model
                FitComputation(sv, sess, model, stim, resp, global_step,
                               summary_op)
                sv.Stop()

            else:
                # Analyse the model

                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                with tf.Session(config=session_config) as sess:

                    # First, recover the model
                    saver_var = tf.train.Saver(
                        tf.all_variables(),
                        keep_checkpoint_every_n_hours=float('inf'))
                    restore_file = tf.train.latest_checkpoint(save_filename)
                    print(restore_file)
                    start_iter = int(
                        restore_file.split('/')[-1].split('-')[-1])
                    saver_var.restore(sess, restore_file)

                    # model specific analysis
                    if FLAGS.architecture == '2 layer_stimulus':
                        AnalyseModel_lr(sess, model)
                    elif FLAGS.architecture == 'complex':
                        AnalyseModel_complex(sess, model, stim, resp,
                                             save_filename)
                    else:
                        AnalyseModel(sv, sess, model)
Пример #13
0
def main(argv):

  # initialize training and testing chunks
  init_chunks()

   # setup dataset
  _,cids_select, n_cells, tc_select, tc_mean = setup_dataset()

  # print parameters
  print('Save folder name: ' + str(FLAGS.folder_name) +
        '\nmodel:' + str(FLAGS.model_id) +
        '\nLoss:' + str(FLAGS.loss) +
        '\nbatch size' + str(FLAGS.batchsz) +
        '\nstep size' + str(FLAGS.step_sz) +
        '\ntraining length: ' + str(FLAGS.train_len) +
        '\nn_cells: '+str(n_cells))

  # filename for saving file
  short_filename = ('_loss='+
                    str(FLAGS.loss) + '_batch_sz='+ str(FLAGS.batchsz) +
                    '_step_sz'+ str(FLAGS.step_sz) +
                    '_tlen=' + str(FLAGS.train_len) + '_jitter')

  # setup model
  with tf.Session() as sess:
    # initialize stuff
    if FLAGS.loss == 'poisson':
      b_init = np.array(0.000001*np.ones(n_cells)) # a very small positive bias needed to avoid log(0) in poisson loss
    else:
      b_init =  np.log((tot_spks_chosen_cells)/(216000. - tot_spks_chosen_cells)) # log-odds, a good initialization for some

    # RGB time filter
    tm4D = np.zeros((30,1,3,3))
    for ichannel in range(3):
      tm4D[:,0,ichannel,ichannel] = tc_mean[:,ichannel]
    tc = tf.Variable((tm4D).astype('float32'),name = 'tc')

    d1=640
    d2=320
    colors=3

    # make data placeholders
    stim = tf.placeholder(tf.float32,shape=[None,d1,d2,colors],name='stim')
    resp = tf.placeholder(tf.float32,shape=[None,n_cells],name='resp')
    data_len = tf.placeholder(tf.float32,name='data_len')

    # time convolution
    # time course should be time,d1,color,color
    # original stimulus is (time, d1,d2,color). Permute it to (d2,time,d1,color) so that 1D convolution could be mimicked using conv_2d.
    stim_time_filtered = tf.transpose(tf.nn.conv2d(tf.transpose(stim,(2,0,1,3)),tc, strides=[1,1,1,1], padding='VALID'), (1,2,0,3))

    # learn almost convolutional model
    short_filename = ('model=' + str(FLAGS.model_id) + short_filename)
    mask_tf, dimx, dimy, n_pix = get_windows()
    w_del = tf.Variable(np.array( 0.05*np.random.randn(dimx, dimy, n_pix),dtype='float32'), name='w_del')
    w_mother = tf.Variable(np.array( np.ones((2 * FLAGS.window + 1, 2 * FLAGS.window + 1, FLAGS.su_channels, 1)),dtype='float32'), name='w_mother')
    a = tf.Variable(np.array(np.random.randn(dimx*dimy, n_cells),dtype='float32'), name='a')
    a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
    b = tf.Variable(np.array(b_init,dtype='float32'), name='b')
    vars_fit = [w_mother, w_del, a] # which variables to fit
    if not FLAGS.loss == 'poisson':
      vars_fit = vars_fit + [b]

    # stimulus filtered with convolutional windows
    stim4D = stim_time_filtered#tf.expand_dims(tf.reshape(stim, (-1,40,80)), 3)
    stim_convolved = tf.reduce_sum(tf.nn.conv2d(stim4D, w_mother, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID"),3)
    stim_masked = tf.nn.conv2d(stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID" )
    stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

    # activation of different subunits
    su_act = tf.nn.relu(stim_del + stim_convolved)

    # get firing rate
    lam = tf.matmul(tf.reshape(su_act, [-1, dimx*dimy]), a_sfm) + b

    # regularization
    regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

    # projection to satisfy hard variable constraints
    b_pos = tf.assign(b, (b + tf.abs(b))/2)
    def proj():
      if FLAGS.loss == 'poisson':
        sess.run(b_pos)


    if FLAGS.loss == 'poisson':
      loss_inter = (tf.reduce_sum(lam)/120. - tf.reduce_sum(resp*tf.log(lam))) / data_len

    loss = loss_inter + regularization # add regularization to get final loss function

    # training consists of calling training()
    # which performs a train step and project parameters to model specific constraints using proj()
    train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(loss, var_list=vars_fit)
    def training(inp_dict):
      sess.run(train_step, feed_dict=inp_dict) # one step of gradient descent
      proj() # model specific projection operations

    # evaluate loss on given data.
    def get_loss(inp_dict):
      ls = sess.run(loss,feed_dict = inp_dict)
      return ls


    # saving details
    # make a folder with name derived from parameters of the algorithm - it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    if not gfile.IsDirectory(FLAGS.save_location):
      gfile.MkDir(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename


    # create summary writers
    # create histogram summary for all parameters which are learnt
    for ivar in vars_fit:
      tf.histogram_summary(ivar.name, ivar)
    # loss summary
    l_summary = tf.scalar_summary('loss',loss)
    # loss without regularization summary
    l_inter_summary = tf.scalar_summary('loss_inter',loss_inter)
    # Merge all the summary writer ops into one op (this way, calling one op stores all summaries)
    merged = tf.merge_all_summaries()
    # training and testing has separate summary writers
    train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                      sess.graph)
    test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')


    ## load previous results
    sess.run(tf.initialize_all_variables())
    saver_var = tf.train.Saver(tf.all_variables(), keep_checkpoint_every_n_hours=0.05)
    load_prev = False
    start_iter=0
    try:
      # restore previous fits if they are available - useful when programs are preempted frequently on .
      latest_filename = short_filename + '_latest_fn'
      restore_file = tf.train.latest_checkpoint(FLAGS.save_location, latest_filename)
      start_iter = int(restore_file.split('/')[-1].split('-')[-1]) # restore previous iteration count and start from there.
      saver_var.restore(sess, restore_file) # restore variables
      load_prev = True
    except:
      print('No previous dataset')

    if load_prev:
      print('\nPrevious results loaded')
    else:
      print('\nVariables initialized')

      
    # Finally, do fitting
    # get test data and make test dictionary
    stim_test,resp_test,test_length = get_stim_resp('test')
    fd_test = {stim: stim_test,
               resp: resp_test,
               data_len: test_length}

    for istep in np.arange(start_iter,400000):
      print(istep)
      # get training data and make test dictionary
      stim_train, resp_train, train_len = get_stim_resp('train')
      fd_train = {stim: stim_train,
                  resp: resp_train,
                  data_len: train_len}
      print('loaded a training set')

      # take training step
      training(fd_train)
      print('performed a training step')

      if istep%10 == 0:
        # save variables
        latest_filename = short_filename + '_latest_fn'
        saver_var.save(sess, save_filename, global_step=istep, latest_filename = latest_filename)
        # add training summary
        summary = sess.run(merged, feed_dict=fd_train)
        train_writer.add_summary(summary,istep)
        print('training loss calculated')

        # add testing summary
        summary = sess.run(merged, feed_dict=fd_test)
        test_writer.add_summary(summary,istep)
        print('testing loss calculated')
def main(unused_argv=()):

  ## copy data locally
  dst = FLAGS.tmp_dir
  print('Starting Copy')
  if not gfile.IsDirectory(dst):
    gfile.MkDir(dst)

  files = gfile.ListDirectory(FLAGS.src_dir)
  for ifile in files:
    ffile = os.path.join(dst, ifile)
    if not gfile.Exists(ffile):
      gfile.Copy(os.path.join(FLAGS.src_dir, ifile), ffile)
      print('Copied %s' % os.path.join(FLAGS.src_dir, ifile))
    else:
      print('File exists %s' % ffile)

  print('File copied to destination')


  ## load data
  # load stimulus
  data = h5py.File(os.path.join(dst, 'stimulus.mat'))
  stimulus = np.array(data.get('stimulus')) - 0.5

  # load responses from multiple retina
  datasets_list = os.path.join(dst, 'datasets.txt')
  datasets = open(datasets_list, "r").read()
  training_datasets = [line for line in datasets.splitlines()]

  responses = []
  for idata in training_datasets:
    print(idata)
    data_file = os.path.join(dst, idata)
    data = sio.loadmat(data_file)
    responses += [data]
    print(np.max(data['centers'], 0))

  # generate additional features for responses
  num_cell_types = 2
  dimx = 80
  dimy = 40
  for iresp in responses:
    # remove cells which are outside 80x40 window.
    process_dataset(iresp, dimx, dimy, num_cell_types)

  ## generate graph -
  if FLAGS.is_test == 0:
    is_training = True
  if FLAGS.is_test == 1:
    is_training = True # False

  with tf.Session() as sess:

    ## Make graph
    # embed stimulus.
    time_window = 30
    stimx = stimulus.shape[1]
    stimy = stimulus.shape[2]
    stim_tf = tf.placeholder(tf.float32,
                             shape=[None, stimx,
                             stimy, time_window]) # batch x X x Y x time_window
    batch_norm = FLAGS.batch_norm
    stim_embed = embed_stimulus(FLAGS.stim_layers.split(','),
                                batch_norm, stim_tf, is_training,
                                reuse_variables=False)

    '''
    ttf_tf = tf.Variable(np.ones(time_window).astype(np.float32)/10, name='stim_ttf')
    filt = tf.expand_dims(tf.expand_dims(tf.expand_dims(ttf_tf, 0), 0), 3)
    stim_time_filt = tf.nn.conv2d(stim_tf, filt,
                                    strides=[1, 1, 1, 1], padding='SAME') # batch x X x Y x 1


    ilayer = 0
    stim_time_filt = slim.conv2d(stim_time_filt, 1, [3, 3],
                        stride=1,
                        scope='stim_layer_wt_%d' % ilayer,
                        reuse=False,
                        normalizer_fn=slim.batch_norm,
                        activation_fn=tf.nn.softplus,
                        normalizer_params={'is_training': is_training},
                        padding='SAME')
    '''


    # embed responses.
    num_cell_types = 2
    layers = FLAGS.resp_layers  # format: window x filters x stride .. NOTE: final filters=1, stride =1 throughout
    batch_norm = FLAGS.batch_norm
    time_window = 1
    anchor_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                     layers=layers,
                                                     batch_norm=batch_norm,
                                                     is_training=is_training,
                                                     reuse_variables=False,
                                                     num_cell_types=2,
                                                     dimx=dimx, dimy=dimy)

    neg_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                  layers=layers,
                                                  batch_norm=batch_norm,
                                                  is_training=is_training,
                                                  reuse_variables=True,
                                                  num_cell_types=2,
                                                  dimx=dimx, dimy=dimy)

    d_s_r_pos = tf.reduce_sum((stim_embed - anchor_model.responses_embed)**2, [1, 2, 3]) # batch
    d_pairwise_s_rneg = tf.reduce_sum((tf.expand_dims(stim_embed, 1) -
                                 tf.expand_dims(neg_model.responses_embed, 0))**2, [2, 3, 4]) # batch x batch_neg
    beta = 10
    # if FLAGS.save_suffix == 'lr=0.001':
    loss = tf.reduce_sum(beta * tf.reduce_logsumexp(tf.expand_dims(d_s_r_pos / beta, 1) -
                                                    d_pairwise_s_rneg / beta, 1), 0)
    # else :

    # loss = tf.reduce_sum(tf.nn.softplus(1 +  tf.expand_dims(d_s_r_pos, 1) - d_pairwise_s_rneg))
    accuracy_tf =  tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg))

    lr = 0.001
    train_op = tf.train.AdagradOptimizer(lr).minimize(loss)

    # set up training and testing data
    training_datasets_all = [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]
    testing_datasets = [0, 4, 8, 12]
    print('Testing datasets', testing_datasets)

    n_training_datasets_log = [1, 3, 5, 7, 9, 11, 12]

    if (np.floor(FLAGS.taskid / 4)).astype(np.int) < len(n_training_datasets_log):
      # do randomly sampled training data for 0<= FLAGS.taskid < 28
      prng = RandomState(23)
      n_training_datasets = n_training_datasets_log[(np.floor(FLAGS.taskid / 4)).astype(np.int)]
      for _ in range(10*FLAGS.taskid):
        print(prng.choice(training_datasets_all,
                          n_training_datasets, replace=False))
      training_datasets = prng.choice(training_datasets_all,
                                      n_training_datasets, replace=False)

      # training_datasets = [i for i in range(7) if i< 7-FLAGS.taskid] #[0, 1, 2, 3, 4, 5]


    else:
      # do 1 training data, chosen in order for FLAGS.taskid >= 28
      datasets_all = np.arange(16)
      training_datasets = [datasets_all[FLAGS.taskid % (4 * len(n_training_datasets_log))]]

    print('Task ID %d' % FLAGS.taskid)
    print('Training datasets', training_datasets)

    # Initialize stuff.
    file_name = ('end_to_end_stim_%s_resp_%s_beta_%d_taskid_%d'
                 '_training_%s_testing_%s_%s' % (FLAGS.stim_layers,
                                                 FLAGS.resp_layers, beta, FLAGS.taskid,
                                                 str(training_datasets)[1: -1],
                                                 str(testing_datasets)[1: -1],
                                                 FLAGS.save_suffix))
    saver_var, start_iter = initialize_model(FLAGS.save_folder, file_name, sess)

    # print model graph
    PrintModelAnalysis(tf.get_default_graph())

    # Add summary ops
    retina_number = tf.placeholder(tf.int16, name='input_retina');

    summary_ops = []
    for iret in np.arange(len(responses)):
      print(iret)
      r1 = tf.summary.scalar('loss_%d' % iret , loss)
      r2 = tf.summary.scalar('accuracy_%d' % iret , accuracy_tf)
      summary_ops += [tf.summary.merge([r1, r2])]

    # tf.summary.scalar('loss', loss)
    # tf.summary.scalar('accuracy', accuracy_tf)
    # summary_op = tf.summary.merge_all()

    # Setup summary writers
    summary_writers = []
    for loc in ['train', 'test']:
      summary_location = os.path.join(FLAGS.save_folder, file_name,
                                      'summary_' + loc )
      summary_writer = tf.summary.FileWriter(summary_location, sess.graph)
      summary_writers += [summary_writer]


    # start training
    batch_neg_train_sz = 100
    batch_train_sz = 100


    def batch(dataset_id):
      batch_train = get_batch(stimulus, responses[dataset_id]['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   anchor_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   neg_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   neg_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],
                   retina_number : dataset_id}

      return feed_dict

    def batch_few_cells(responses):
      batch_train = get_batch(stimulus, responses['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses['map_cell_grid'],
                   anchor_model.cell_types_tf: responses['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses['map_cell_grid'],
                   neg_model.cell_types_tf: responses['ctype_1hot'],
                   neg_model.mean_fr_tf: responses['mean_firing_rate'],
                   }

      return feed_dict

    if FLAGS.is_test == 1:
      print('Testing')
      save_dict = {}

      from IPython import embed; embed()
      ## Estimate one, fix others
      '''
      grad_resp = tf.gradients(d_s_r_pos, anchor_model.responses_tf)

      t_start = 1000
      t_len = 100
      stim_history = 30
      stim_batch = np.zeros((t_len, stimulus.shape[1],
                         stimulus.shape[2], stim_history))
      for isample, itime in enumerate(np.arange(t_start, t_start + t_len)):
        stim_batch[isample, :, :, :] = np.transpose(stimulus[itime: itime-stim_history:-1, :, :], [1, 2, 0])

      iretina = testing_datasets[0]
      resp_batch = np.expand_dims(np.random.rand(t_len, responses[iretina]['responses'].shape[1]), 2)

      step_sz = 0.01
      eps = 1e-2
      dist_prev = np.inf
      for iiter in range(10000):
        feed_dict = {stim_tf: stim_batch,
                     anchor_model.map_cell_grid_tf: responses[iretina]['map_cell_grid'],
                     anchor_model.cell_types_tf: responses[iretina]['ctype_1hot'],
                     anchor_model.mean_fr_tf: responses[iretina]['mean_firing_rate'],
                     anchor_model.responses_tf: resp_batch}
        dist_np, resp_grad_np = sess.run([d_s_r_pos, grad_resp], feed_dict=feed_dict)
        if np.sum(np.abs(dist_prev - dist_np)) < eps:
          break
        print(np.sum(dist_np), np.sum(np.abs(dist_prev - dist_np)))
        dist_prev = dist_np
        resp_batch = resp_batch - step_sz * resp_grad_np[0]

      resp_batch = resp_batch.squeeze()
      '''


      # from IPython import embed; embed()
      ## compute distances between s-r pairs for small number of cells

      test_retina = []
      for iretina in range(len(testing_datasets)):
        dataset_id = testing_datasets[iretina]

        num_cells_total = responses[dataset_id]['responses'].shape[1]
        dataset_center = responses[dataset_id]['centers'].mean(0)
        dataset_cell_distances = np.sqrt(np.sum((responses[dataset_id]['centers'] -
                                         dataset_center), 1))
        order_cells = np.argsort(dataset_cell_distances)

        test_sr_few_cells = {}
        for num_cells_prc in [5, 10, 20, 30, 50, 100]:
          num_cells = np.percentile(np.arange(num_cells_total),
                                    num_cells_prc).astype(np.int)

          choose_cells = order_cells[:num_cells]

          resposnes_few_cells = {'responses': responses[dataset_id]['responses'][:, choose_cells],
                                 'map_cell_grid': responses[dataset_id]['map_cell_grid'][:, :, choose_cells],
                                'ctype_1hot': responses[dataset_id]['ctype_1hot'][choose_cells, :],
                                'mean_firing_rate': responses[dataset_id]['mean_firing_rate'][choose_cells]}
          # get a batch
          d_pos_log = np.array([])
          d_neg_log = np.array([])
          for test_iter in range(1000):
            print(iretina, num_cells_prc, test_iter)
            feed_dict = batch_few_cells(resposnes_few_cells)
            d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
            d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
            d_pos_log = np.append(d_pos_log, d_pos)
            d_neg_log = np.append(d_neg_log, d_neg)

          precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

          print(np.sum(d_pos_log > d_neg_log))
          print(np.sum(d_pos_log < d_neg_log))
          test_sr= {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                     'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log,
                    'num_cells': num_cells}

          test_sr_few_cells.update({'num_cells_prc_%d' % num_cells_prc : test_sr})
        test_retina += [test_sr_few_cells]
      save_dict.update({'few_cell_analysis': test_retina})

      ## compute distances between s-r pairs - pos and neg.

      test_retina = []
      for iretina in range(len(testing_datasets)):
        # stim-resp log
        d_pos_log = np.array([])
        d_neg_log = np.array([])
        for test_iter in range(1000):
          print(test_iter)
          feed_dict = batch(testing_datasets[iretina])
          d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
          d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
          d_pos_log = np.append(d_pos_log, d_pos)
          d_neg_log = np.append(d_neg_log, d_neg)

        precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

        print(np.sum(d_pos_log > d_neg_log))
        print(np.sum(d_pos_log < d_neg_log))
        test_sr = {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                   'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log}

        test_retina += [test_sr]

      save_dict.update({'test_sr': test_retina})


      ## ROC curves of responses from repeats - dataset 1
      repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      repeats_data['cell_type'] = repeats_data['cell_type'].T

      # process repeats data
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      test_reps = analyse_response_repeats(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2015-09-23-7': test_reps})

      ## ROC curves of responses from repeats - dataset 2
      repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      '''
      test_clustering = analyse_response_repeats_all_trials(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2005_08_03_0': test_clustering})
      '''
      #
      # get model params
      save_dict.update({'model_pars': sess.run(tf.trainable_variables())})


      save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
      print(save_analysis_filename)
      return

    test_iiter = 0
    for iiter in range(start_iter, FLAGS.max_iter): # TODO(bhaishahster) :add FLAGS.max_iter

      # get a new batch
      # stim_tf, anchor_model.responses_tf, neg_model.responses_tf

      # training step
      train_dataset = training_datasets[iiter % len(training_datasets)]
      feed_dict_train = batch(train_dataset)
      _, loss_np_train = sess.run([train_op, loss], feed_dict=feed_dict_train)
      print(train_dataset, loss_np_train)

      # write summary
      if iiter % 10 == 0:
        # write train summary
        test_iiter = test_iiter + 1

        train_dataset = training_datasets[test_iiter % len(training_datasets)]
        feed_dict_train = batch(train_dataset)
        summary_train = sess.run(summary_ops[train_dataset], feed_dict=feed_dict_train)
        summary_writers[0].add_summary(summary_train, iiter)

        # write test summary
        test_dataset = testing_datasets[test_iiter % len(testing_datasets)]
        feed_dict_test = batch(test_dataset)
        l_test, summary_test = sess.run([loss, summary_ops[test_dataset]], feed_dict=feed_dict_test)
        summary_writers[1].add_summary(summary_test, iiter)
        print('Test retina: %d, loss: %.3f' % (test_dataset, l_test))

      # save model
      if iiter % 10 == 0:
        save_model(saver_var, FLAGS.save_folder, file_name, sess, iiter)
Пример #15
0
def main(argv):
    print('\nCode started')

    global cells_choose
    global chosen_mask

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    global chunk_order
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge' or FLAGS.model_id == 'poisson_relu':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print(FLAGS.model_id)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')
    else:
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))

            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'poisson_relu':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(
                tf.exp(tf.nn.relu(stim, w)), a
            ) + b  #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(f) / 120. -
                          tf.reduce_sum(resp * tf.log(f))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'hinge':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.relu(1 - 2 *
                                                  (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        # summaries
        l_summary = tf.scalar_summary('loss', loss)
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.merge_all_summaries()
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        print('\nStarting new code')

        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('\nPrevious results loaded')
        else:
            print('\nVariables initialized')
        # Do the fitting

        icnt = 0
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        #logfile.close()

        for istep in np.arange(start_iter, 400000):

            print(istep)
            # get training data
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }
            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)
                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
Пример #16
0
def main(unused_argv=()):

  np.random.seed(23)
  tf.set_random_seed(1234)
  random.seed(50)

  # Load stimulus-response data.
  # Collect population response across retinas in the list 'responses'.
  # Stimulus for each retina is indicated by 'stim_id',
  # which is found in 'stimuli' dictionary.
  datasets = gfile.ListDirectory(FLAGS.src_dir)
  stimuli = {}
  responses = []
  for icnt, idataset in enumerate(datasets):

    fullpath = os.path.join(FLAGS.src_dir, idataset)
    if gfile.IsDirectory(fullpath):
      key = 'stim_%d' % icnt
      op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key,
                                           boundary=FLAGS.valid_cells_boundary)
      stimulus, resp, dimx, dimy, _ = op

      stimuli.update({key: stimulus})
      responses += resp

  taskid = FLAGS.taskid
  dat = responses[taskid]
  stimulus = stimuli[dat['stimulus_key']]

  # parameters
  window = 5

  # Compute time course and non-linearity as two parameters which might be should be explored in embedded space.
  n_cells = dat['responses'].shape[1]
  T = np.minimum(stimulus.shape[0], dat['responses'].shape[0])

  stim_short = stimulus[:T, :, :]
  resp_short = dat['responses'][:T, :].astype(np.float32)

  save_dict = {}

  # Find time course, non-linearity and RF parameters

  ########################################################################
  # Separation between cell types
  ########################################################################
  save_dict.update({'cell_type': dat['cell_type']})
  save_dict.update({'dist_nn_cell_type': dat['dist_nn_cell_type']})

  ########################################################################
  # Find mean firing rate
  ########################################################################
  mean_fr = dat['responses'].mean(0)
  mean_fr_1 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==1])
  mean_fr_2 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==2])

  mean_fr_dict = {'mean_fr': mean_fr,
                  'mean_fr_1': mean_fr_1, 'mean_fr_2': mean_fr_2}
  save_dict.update({'mean_fr_dict': mean_fr_dict})

  ########################################################################
  # compute STAs
  ########################################################################
  stas = np.zeros((n_cells, 80, 40, 30))
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    for idelay in range(30):
      stas[icell, windx[0]: windx[1], windy[0]: windy[1], idelay] = np.reshape(resp_short[idelay:, icell].dot(stim_cell[:T-idelay, :]),
                                                                           [windx[1] - windx[0], windy[1] - windy[0]]) / np.sum(resp_short[idelay:, icell])

  stas_dict = {'stas': stas}
  # save_dict.update({'stas_dict': stas_dict})


  ########################################################################
  # Find time courses for each cell
  ########################################################################
  ttf_log = []
  for icell in range(n_cells):
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]
    ll = stas[icell, windx[0]: windx[1], windy[0]: windy[1], :]

    ll_2d = np.reshape(ll, [-1, ll.shape[-1]])

    u, s, v = np.linalg.svd(ll_2d)

    ttf_log += [v[0, :]]

  ttf_log = np.array(ttf_log)
  signs = [np.sign(ttf_log[icell, np.argmax(np.abs(ttf_log[icell, :]))]) for icell in range(ttf_log.shape[0])]
  ttf_corrected = np.expand_dims(np.array(signs), 1) * ttf_log
  ttf_corrected[np.squeeze(dat['cell_type'])==1, :] = ttf_corrected[np.squeeze(dat['cell_type'])==1, :] * -1

  ttf_mean_1 = ttf_corrected[np.squeeze(dat['cell_type'])==1, :].mean(0)
  ttf_mean_2 = ttf_corrected[np.squeeze(dat['cell_type'])==2, :].mean(0)

  ttf_params_1 = get_times(ttf_mean_1)
  ttf_params_2 = get_times(ttf_mean_2)

  ttf_dict = {'ttf_log': ttf_log,
              'ttf_mean_1': ttf_mean_1, 'ttf_mean_2': ttf_mean_2,
              'ttf_params_1': ttf_params_1, 'ttf_params_2': ttf_params_2}

  save_dict.update({'ttf_dict': ttf_dict})
  '''
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==1, :].T, 'r', alpha=0.3)
  plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==2, :].T, 'k', alpha=0.3)

  plt.plot(ttf_mean_1, 'r--')
  plt.plot(ttf_mean_2, 'k--')
  '''

  ########################################################################
  ## Find non-linearity
  ########################################################################
  f_nl = lambda x, p0, p1, p2, p3: p0 + p1*x + p2* np.power(x, 2) + p3* np.power(x, 3)

  nl_params_log = []
  stim_resp_log = []
  for icell in range(n_cells):
    print(icell)
    center = dat['centers'][icell, :].astype(np.int)
    windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)]
    windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)]

    stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1])
    sta_cell = np.reshape(stas[icell, windx[0]: windx[1], windy[0]: windy[1], :], [-1, stas.shape[-1]])

    stim_filter = np.zeros(stim_short.shape[0])
    for idelay in range(30):
      stim_filter[idelay: ] += stim_cell[:T-idelay, :].dot(sta_cell[:, idelay])

    # Normalize stim_filter
    stim_filter -= np.mean(stim_filter)
    stim_filter /= np.sqrt(np.var(stim_filter))

    resp_cell = resp_short[:, icell]

    stim_nl = []
    resp_nl = []
    for ipercentile in range(3, 97, 1):
      lb = np.percentile(stim_filter, ipercentile-3)
      ub = np.percentile(stim_filter, ipercentile+3)
      tms = np.logical_and(stim_filter >= lb, stim_filter < ub)
      stim_nl += [np.mean(stim_filter[tms])]
      resp_nl += [np.mean(resp_cell[tms])]

    stim_nl = np.array(stim_nl)
    resp_nl = np.array(resp_nl)

    popt, pcov = scipy.optimize.curve_fit(f_nl, stim_nl, resp_nl, p0=[1, 0, 0, 0])
    nl_params_log += [popt]
    stim_resp_log += [[stim_nl, resp_nl]]

  nl_params_log = np.array(nl_params_log)

  np_params_mean_1 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==1, :], 0)
  np_params_mean_2 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==2, :], 0)

  nl_params_dict = {'nl_params_log': nl_params_log,
                    'np_params_mean_1': np_params_mean_1,
                    'np_params_mean_2': np_params_mean_2,
                    'stim_resp_log': stim_resp_log}

  save_dict.update({'nl_params_dict': nl_params_dict})

  '''
  # Visualize Non-linearities
  for icell in range(n_cells):

    stim_in = np.arange(-3, 3, 0.1)
    fr = f_nl(stim_in, *nl_params_log[icell, :])
    if np.squeeze(dat['cell_type'])[icell] == 1:
      c = 'r'
    else:
      c = 'k'
    plt.plot(stim_in, fr, c, alpha=0.2)

  fr = f_nl(stim_in, *np_params_mean_1)
  plt.plot(stim_in, fr, 'r--')

  fr = f_nl(stim_in, *np_params_mean_2)
  plt.plot(stim_in, fr, 'k--')
  '''

  pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder , dat['piece']), 'w'))
  pickle.dump(stas_dict, gfile.Open(os.path.join(FLAGS.save_folder , 'stas' + dat['piece']), 'w'))
def main(unused_argv=()):

    #np.random.seed(23)
    #tf.set_random_seed(1234)
    #random.seed(50)

    # 1. Load stimulus-response data.
    # Collect population response across retinas in the list 'responses'.
    # Stimulus for each retina is indicated by 'stim_id',
    # which is found in 'stimuli' dictionary.
    datasets = gfile.ListDirectory(FLAGS.src_dir)
    stimuli = {}
    responses = []
    for icnt, idataset in enumerate(datasets):

        fullpath = os.path.join(FLAGS.src_dir, idataset)
        if gfile.IsDirectory(fullpath):
            key = 'stim_%d' % icnt
            op = data_util.get_stimulus_response(
                FLAGS.src_dir,
                idataset,
                key,
                boundary=FLAGS.valid_cells_boundary)
            stimulus, resp, dimx, dimy, _ = op

            stimuli.update({key: stimulus})
            responses += resp

    # 2. Do response prediction for a retina
    iretina = FLAGS.taskid
    subunit_fit_loc = FLAGS.save_folder
    subunits_datasets = gfile.ListDirectory(subunit_fit_loc)

    piece = responses[iretina]['piece']
    matched_dataset = [
        ifit for ifit in subunits_datasets if piece[:12] == ifit[:12]
    ]
    if matched_dataset == []:
        raise ValueError('Could not find subunit fit')

    subunit_fit_path = os.path.join(subunit_fit_loc, matched_dataset[0])
    stimulus = stimuli[responses[iretina]['stimulus_key']]

    # sample test data
    stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :]

    # Optionally, create a null stimulus for all the cells.

    resp_ret = responses[iretina]

    if FLAGS.is_null:
        # Make null stimulus
        stimulus_test = get_null_stimulus(resp_ret, subunit_fit_path,
                                          stimulus_test)

    resp_su = get_su_spks(subunit_fit_path, stimulus_test, responses[iretina])

    save_dict = {
        'resp_su': resp_su.astype(np.int8),
        'cell_ids': responses[iretina]['cellID_list'].squeeze()
    }

    if FLAGS.is_null:
        save_dict.update({'stimulus_null': stimulus_test})
        save_suff = '_null'
    else:
        save_suff = ''

    pickle.dump(
        save_dict,
        gfile.Open(
            os.path.join(subunit_fit_path,
                         'response_prediction%s.pkl' % save_suff), 'w'))