def initialize_folder(self, save_location, folder_name):

    parent_folder = os.path.join(save_location, folder_name)
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    self.parent_folder = parent_folder

    save_location = os.path.join(parent_folder, self.short_filename)
    if not gfile.IsDirectory(save_location):
      gfile.MkDir(save_location)
    self.save_location = save_location

    self.save_filename = os.path.join(self.save_location, self.short_filename)
  def initialize_folder(self, save_location, folder_name):
    """Intialize saving location of the model."""

    parent_folder = os.path.join(save_location, folder_name)
    # Make folder if it does not exist.
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    self.parent_folder = parent_folder

    save_location = os.path.join(parent_folder, self.short_filename)
    if not gfile.IsDirectory(save_location):
      gfile.MkDir(save_location)
    self.save_location = save_location

    self.save_filename = os.path.join(self.save_location, self.short_filename)
Exemple #3
0
    def convert_to_TFRecords_chunks(self,
                                    prefix,
                                    save_location='~/tmp',
                                    examples_per_file=1000):
        # converts the stimulus-response data into a TFRecords file.

        tf.logging.info('making TFRecords chunks')
        stimulus = self.stimulus.astype(np.float32)
        response = self.response.astype(np.float32)
        num_examples = stimulus.shape[0]
        num_files = np.ceil(num_examples / examples_per_file).astype(np.int)
        tf.logging.info('Number of files: %d, examples: %d' %
                        (num_files, num_examples))

        def _bytes_feature(value):
            # value: bytes to convert to tf.train.Feature
            return tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[value]))

        # make folder for storing .tfrecords files
        folder_tfrec = os.path.join(save_location, prefix)
        # Make folder if it does not exist.
        if not gfile.IsDirectory(folder_tfrec):
            tf.logging.info('making folder to store tfrecords')
            gfile.MkDir(folder_tfrec)
        else:
            tf.logging.info('folder exists, will overwrite results')

        index = -1
        for ifile in range(num_files):
            filename = os.path.join(folder_tfrec,
                                    'chunk_' + str(ifile) + '.tfrecords')
            tf.logging.info('Writing %s , starting index %d' %
                            (filename, index + 1))
            writer = tf.python_io.TFRecordWriter(filename)
            for iexample in range(examples_per_file):
                index += 1
                stimulus_raw = stimulus[index, :].tostring()
                response_raw = response[index, :].tostring()
                #print(index, stimulus[index,:].shape, response[index,:].shape)
                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'stimulus': _bytes_feature(stimulus_raw),
                        'response': _bytes_feature(response_raw)
                    }))
                writer.write(example.SerializeToString())
            writer.close()
def initialize_model(save_folder, file_name, sess):
  """Setup model variables and saving information.

  Args:
    save_folder (string) : Folder to store model.
                           Makes one if it does not exist.
    filename (string) : Prefix of model/checkpoint files.
    sess : Tensorflow session.
  """

  # Make folder.
  if not gfile.IsDirectory(save_folder):
    gfile.MkDir(save_folder)

  # Initialize variables.
  saver_var, start_iter = initialize_variables(sess, save_folder, file_name)
  return saver_var, start_iter
Exemple #5
0
def RunComputation():

    # filename for saving file
    if FLAGS.architecture == '2 layer_stimulus':
        architecture_string = ('_architecture=' + str(FLAGS.architecture) +
                               '_stim_downsample_window=' +
                               str(FLAGS.stim_downsample_window) +
                               '_stim_downsample_stride=' +
                               str(FLAGS.stim_downsample_stride))
    else:
        architecture_string = ('_architecture=' + str(FLAGS.architecture))

    short_filename = ('model=' + str(FLAGS.model_id) + '_loss=' +
                      str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) +
                      '_lam_w=' + str(FLAGS.lam_w) + '_step_sz' +
                      str(FLAGS.step_sz) + '_tlen=' + str(FLAGS.train_len) +
                      '_window=' + str(FLAGS.window) + '_stride=' +
                      str(FLAGS.stride) + str(architecture_string) + '_jitter')

    # make a folder with name derived from parameters of the algorithm - it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print('Does the file exist?', gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)

    save_filename = FLAGS.save_location + short_filename
    """Main function which runs all TensorFlow computations."""
    with tf.Graph().as_default() as gra:
        with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)):
            print(FLAGS.config_params)
            tf.logging.info(FLAGS.config_params)
            # set up training dataset
            tc_mean = get_data_mat.init_chunks(FLAGS.n_chunks)
            '''
      # plot histogram of a training dataset
      stim_train, resp_train, train_len = get_data_mat.get_stim_resp('train',
                                                                     num_chunks=FLAGS.num_chunks_to_load)
      plt.hist(np.ndarray.flatten(stim_train[:,:,0:]))
      plt.show()
      plt.draw()
      '''
            # Create computation graph.
            #
            # Graph should be fully constructed before you create supervisor.
            # Attempt to modify graph after supervisor is created will cause an error.

            with tf.name_scope('model'):
                if FLAGS.architecture == '1 layer':
                    # single GPU model
                    if False:
                        global_step = tf.contrib.framework.create_global_step()
                        model, stim, resp = jitter_model.approximate_conv_jitter(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels)

                    # multiGPU model
                    if True:
                        model, stim, resp, global_step = jitter_model.approximate_conv_jitter_multigpu(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params)

                if FLAGS.architecture == '2 layer_stimulus':
                    # stimulus is first smoothened to lower dimensions, then same model is applied
                    print(' put stimulus to lower dimenstions!')
                    model, stim, resp, global_step, stim_tuple = jitter_model.approximate_conv_jitter_multigpu_stim_lr(
                        FLAGS.n_cells, FLAGS.lam_w, FLAGS.window, FLAGS.stride,
                        FLAGS.step_sz, tc_mean, FLAGS.su_channels,
                        FLAGS.config_params, FLAGS.stim_downsample_window,
                        FLAGS.stim_downsample_stride)

            # Print the number of variables in graph
            print('Calculating model size')  # Hope we do not exceed memory
            PrintModelAnalysis(gra, max_depth=10)
            #import pdb; pdb.set_trace()

            # Builds our summary op.
            summary_op = model.merged_summary

            # Create a Supervisor.  It will take care of initialization, summaries,
            # checkpoints, and recovery.
            #
            # When multiple replicas of this program are running, the first one,
            # identified by --task=0 is the 'chief' supervisor.  It is the only one
            # that takes case of initialization, etc.
            is_chief = (FLAGS.task == 0)  # & (FLAGS.learn==1)
            print(save_filename)
            if FLAGS.learn == 1:
                # use supervisor only for learning,
                # otherwise it messes up data as it tries to store variables while you are doing analysis

                sv = tf.train.Supervisor(logdir=save_filename,
                                         is_chief=is_chief,
                                         saver=tf.train.Saver(),
                                         summary_op=None,
                                         save_model_secs=100,
                                         global_step=global_step,
                                         recovery_wait_secs=5)

                if (is_chief and FLAGS.learn == 1):
                    tf.train.write_graph(tf.get_default_graph().as_graph_def(),
                                         save_filename, 'graph.pbtxt')

                # Get an initialized, and possibly recovered session.  Launch the
                # services: Checkpointing, Summaries, step counting.
                #
                # When multiple replicas of this program are running the services are
                # only launched by the 'chief' replica.
                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                #import pdb; pdb.set_trace()
                sess = sv.PrepareSession(FLAGS.master, config=session_config)

                FitComputation(sv, sess, model, stim, resp, global_step,
                               summary_op, stim_tuple)
                sv.Stop()

            else:
                # if not learn, then analyse

                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                with tf.Session(config=session_config) as sess:
                    saver_var = tf.train.Saver(
                        tf.all_variables(),
                        keep_checkpoint_every_n_hours=float('inf'))
                    restore_file = tf.train.latest_checkpoint(save_filename)
                    print(restore_file)
                    start_iter = int(
                        restore_file.split('/')[-1].split('-')[-1])
                    saver_var.restore(sess, restore_file)

                    if FLAGS.architecture == '2 layer_stimulus':
                        AnalyseModel_lr(sess, model)
                    else:
                        AnalyseModel(sv, sess, model)
def main(argv):

    # global variables will be used for getting training data
    global cells_choose
    global chosen_mask
    global chunk_order

    # set random seeds: when same algorithm run with different FLAGS,
    # the sequence of random data is same.
    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    # initial chunk order (will be re-shuffled everytime we go over a chunk)
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    # Load data summary
    data_filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(data_filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])

    # which cells to train subunits for
    if FLAGS.all_cells == 'True':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    else:
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    n_cells = np.sum(cells_choose)  # number of cells

    # load spikes and relevant stimulus pixels for chosen cells
    tot_spks = np.squeeze(data_summary['tot_spks'])
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    # chosen_mask = which pixels to learn subunits over
    if FLAGS.masked_stimulus == 'True':
        chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                               dtype='bool')
    else:
        chosen_mask = np.array(np.ones(3200).astype('bool'))
    stim_dim = np.sum(chosen_mask)  # stimulus dimensions
    print('\ndataset summary loaded')

    # print parameters
    print('Save folder name: ' + str(FLAGS.folder_name) + '\nmodel:' +
          str(FLAGS.model_id) + '\nLoss:' + str(FLAGS.loss) +
          '\nmasked stimulus:' + str(FLAGS.masked_stimulus) + '\nall_cells?' +
          str(FLAGS.all_cells) + '\nbatch size' + str(FLAGS.batchsz) +
          '\nstep size' + str(FLAGS.step_sz) + '\ntraining length: ' +
          str(FLAGS.train_len) + '\nn_cells: ' + str(n_cells))

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # filename for saving file
    short_filename = ('_masked_stim=' + str(FLAGS.masked_stimulus) +
                      '_all_cells=' + str(FLAGS.all_cells) + '_loss=' +
                      str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) +
                      '_step_sz' + str(FLAGS.step_sz) + '_tlen=' +
                      str(FLAGS.train_len) + '_bg')

    with tf.Session() as sess:
        # set up stimulus and response palceholders
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.loss == 'poisson':
            b_init = np.array(
                0.000001 * np.ones(n_cells)
            )  # a very small positive bias needed to avoid log(0) in poisson loss
        else:
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)
            )  # log-odds, a good initialization for some losses (like logistic)

        # different firing rate models
        if FLAGS.model_id == 'exp_additive':
            # This model was implemented for earlier work.
            # firing rate for cell c: lam_c = sum_s exp(w_s.x + a_sc)

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                                     dtype='float32'),
                            name='a')
            # firing rate model
            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            regularization = 0
            vars_fit = [w, a]

            def proj(
            ):  # called after every training step - to project to parameter constraints
                pass

        if FLAGS.model_id == 'relu':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_lam_w=' +
                              str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) +
                              '_nsu=' + str(n_su) + short_filename)
            # variables
            w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(0.01 * np.random.rand(n_su, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            # firing rate model
            lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # regularization of parameters
            regularization = (FLAGS.lam_w * tf.reduce_sum(tf.abs(w)) +
                              FLAGS.lam_a * tf.reduce_sum(tf.abs(a)))
            # projection to satisfy constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows(
            )  # get convolutional windows

            # variables
            w = tf.Variable(np.array(0.1 +
                                     0.05 * np.random.rand(dimx, dimy, n_pix),
                                     dtype='float32'),
                            name='w')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w, a]  # which variables are learnt
            if not FLAGS.loss == 'poisson':  # don't learn b for poisson loss
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_wts = tf.nn.relu(tf.reduce_sum(tf.mul(stim_masked, w), 3))
            # get firing rate
            lam = tf.matmul(tf.reshape(stim_wts, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother':
            # firing rate for cell c: lam_c = a_c'.relu(w.x) + b,
            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to learn
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            #  stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                sess.run(a_pos)
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm':
            # firing rate for cell c: lam_c = a_sfm_c'.relu(w.x) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            mask_tf, dimx, dimy, n_pix = get_windows()

            # variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]  # which variables to fit
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # stimulus filtered with convolutional windows
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # activation of differnet subunits
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get firing rate
            lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        if FLAGS.model_id == 'relu_window_mother_sfm_exp':
            # firing rate for cell c: lam_c = exp(a_sfm_c'.relu(w.x)) + b,
            # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well.
            # exponential output NL would cancel the log() in poisson and might get better estimation properties.

            # where w_i are over a small window which are convolutionally related with each other.
            # w_i = w_mother + w_del_i,
            # where w_mother is common accross all 'windows' and w_del is different for different windows.

            # we know a>0 and for poisson loss, b>0
            # for poisson loss: small b added to prevent lam_c going to 0

            # filename
            short_filename = ('model=' + str(FLAGS.model_id) + '_window=' +
                              str(FLAGS.window) + '_stride=' +
                              str(FLAGS.stride) + '_lam_w=' +
                              str(FLAGS.lam_w) + short_filename)
            # get windows
            mask_tf, dimx, dimy, n_pix = get_windows()

            # declare variables
            w_del = tf.Variable(np.array(0.05 *
                                         np.random.randn(dimx, dimy, n_pix),
                                         dtype='float32'),
                                name='w_del')
            w_mother = tf.Variable(np.array(np.ones(
                (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)),
                                            dtype='float32'),
                                   name='w_mother')
            a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells),
                                     dtype='float32'),
                            name='a')
            a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
            b = tf.Variable(np.array(b_init, dtype='float32'), name='b')
            vars_fit = [w_mother, w_del, a]
            if not FLAGS.loss == 'poisson':
                vars_fit = vars_fit + [b]

            # filter stimulus
            stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3)
            stim_convolved = tf.reduce_sum(
                tf.nn.conv2d(stim4D,
                             w_mother,
                             strides=[1, FLAGS.stride, FLAGS.stride, 1],
                             padding="VALID"), 3)
            stim_masked = tf.nn.conv2d(
                stim4D,
                mask_tf,
                strides=[1, FLAGS.stride, FLAGS.stride, 1],
                padding="VALID")
            stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

            # get subunit activation
            su_act = tf.nn.relu(stim_del + stim_convolved)

            # get cell firing rates
            lam = tf.exp(
                tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm)) + b

            # regularization
            regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

            # projection to satisfy hard variable constraints
            b_pos = tf.assign(b, (b + tf.abs(b)) / 2)

            def proj():
                if FLAGS.loss == 'poisson':
                    sess.run(b_pos)

        # different loss functions
        if FLAGS.loss == 'poisson':
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len

        if FLAGS.loss == 'logistic':
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * lam)) / data_len

        if FLAGS.loss == 'hinge':
            loss_inter = tf.reduce_sum(
                tf.nn.relu(1 - 2 * (resp - 0.5) * lam)) / data_len

        loss = loss_inter + regularization  # add regularization to get final loss function

        # training consists of calling training()
        # which performs a train step and
        # project parameters to model specific constraints using proj()
        train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
            loss, var_list=vars_fit)

        def training(inp_dict):
            sess.run(train_step,
                     feed_dict=inp_dict)  # one step of gradient descent
            proj()  # model specific projection operations

        # evaluate loss on given data.
        def get_loss(inp_dict):
            ls = sess.run(loss, feed_dict=inp_dict)
            return ls

        # saving details
        # make a folder with name derived from parameters of the algorithm
        # - it saves checkpoint files and summaries used in tensorboard
        parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
        # make folder if it does not exist
        if not gfile.IsDirectory(parent_folder):
            gfile.MkDir(parent_folder)
        FLAGS.save_location = parent_folder + short_filename + '/'
        if not gfile.IsDirectory(FLAGS.save_location):
            gfile.MkDir(FLAGS.save_location)
        save_filename = FLAGS.save_location + short_filename

        # create summary writers
        # create histogram summary for all parameters which are learnt
        for ivar in vars_fit:
            tf.histogram_summary(ivar.name, ivar)
        # loss summary
        l_summary = tf.scalar_summary('loss', loss)
        # loss without regularization summary
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)
        # Merge all the summary writer ops into one op (this way,
        # calling one op stores all summaries)
        merged = tf.merge_all_summaries()
        # training and testing has separate summary writers
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        ## Fitting procedure
        print('Start fitting')
        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            # restore previous fits if they are available
            # - useful when programs are preempted frequently on .
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            # restore previous iteration count and start from there.
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)  # restore variables
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('Previous results loaded')
        else:
            print('Variables initialized')

        # Finally, do fitting
        icnt = 0
        # get test data and make test dictionary
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        for istep in np.arange(start_iter, 400000):
            print(istep)
            # get training data and make test dictionary
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }

            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)

                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
def main(argv):
    print('\nCode started')

    global cells_choose
    global chosen_mask

    np.random.seed(FLAGS.np_randseed)
    random.seed(FLAGS.randseed)
    global chunk_order
    chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1))

    ## Load data summary

    filename = FLAGS.data_location + 'data_details.mat'
    summary_file = gfile.Open(filename, 'r')
    data_summary = sio.loadmat(summary_file)
    cells = np.squeeze(data_summary['cells'])
    if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge' or FLAGS.model_id == 'poisson_relu':
        cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | (
            cells == 3066)
    if FLAGS.model_id == 'poisson_full':
        cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool')
    n_cells = np.sum(cells_choose)

    tot_spks = np.squeeze(data_summary['tot_spks'])
    total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T
    tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32')
    chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0,
                           dtype='bool')
    print(np.shape(chosen_mask))
    print(np.sum(chosen_mask))

    stim_dim = np.sum(chosen_mask)

    print(FLAGS.model_id)

    print('\ndataset summary loaded')
    # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells

    # decide the number of subunits to fit
    n_su = FLAGS.ratio_SU * n_cells

    # saving details
    if FLAGS.model_id == 'poisson':
        short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) +
                          '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')
    else:
        short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' +
                          str(FLAGS.batchsz) + '_n_b_in_c' +
                          str(FLAGS.n_b_in_c) + '_step_sz' +
                          str(FLAGS.step_sz) + '_tlen=' +
                          str(FLAGS.train_len) + '_bg')

    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print(gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)
    print(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    with tf.Session() as sess:
        # Learn population model!
        stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim')
        resp = tf.placeholder(tf.float32, name='resp')
        data_len = tf.placeholder(tf.float32, name='data_len')

        if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_cells, 1, n_su),
                         dtype='float32'))

            lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a),
                                             2))
            #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(lam) / 120. -
                          tf.reduce_sum(resp * tf.log(lam))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'poisson_relu':
            # variables
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(
                tf.exp(tf.nn.relu(stim, w)), a
            ) + b  #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len
            loss_inter = (tf.reduce_sum(f) / 120. -
                          tf.reduce_sum(resp * tf.log(f))) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a])

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'logistic':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.softplus(
                -2 * (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        if FLAGS.model_id == 'hinge':
            w = tf.Variable(
                np.array(0.01 * np.random.randn(stim_dim, n_su),
                         dtype='float32'))
            a = tf.Variable(
                np.array(0.01 * np.random.rand(n_su, n_cells),
                         dtype='float32'))
            b_init = np.log(
                (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells))
            b = tf.Variable(b_init, dtype='float32')
            f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b
            loss_inter = tf.reduce_sum(tf.nn.relu(1 - 2 *
                                                  (resp - 0.5) * f)) / data_len
            loss = loss_inter
            train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(
                loss, var_list=[w, a, b])
            a_pos = tf.assign(a, (a + tf.abs(a)) / 2)

            def training(inp_dict):
                sess.run(train_step, feed_dict=inp_dict)
                sess.run(a_pos)

            def get_loss(inp_dict):
                ls = sess.run(loss, feed_dict=inp_dict)
                return ls

        # summaries
        l_summary = tf.scalar_summary('loss', loss)
        l_inter_summary = tf.scalar_summary('loss_inter', loss_inter)

        # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
        merged = tf.merge_all_summaries()
        train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                              sess.graph)
        test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')

        print('\nStarting new code')

        sess.run(tf.initialize_all_variables())
        saver_var = tf.train.Saver(tf.all_variables(),
                                   keep_checkpoint_every_n_hours=0.05)
        load_prev = False
        start_iter = 0
        try:
            latest_filename = short_filename + '_latest_fn'
            restore_file = tf.train.latest_checkpoint(FLAGS.save_location,
                                                      latest_filename)
            start_iter = int(restore_file.split('/')[-1].split('-')[-1])
            saver_var.restore(sess, restore_file)
            load_prev = True
        except:
            print('No previous dataset')

        if load_prev:
            print('\nPrevious results loaded')
        else:
            print('\nVariables initialized')
        # Do the fitting

        icnt = 0
        stim_test, resp_test, test_length = get_test_data()
        fd_test = {stim: stim_test, resp: resp_test, data_len: test_length}

        #logfile.close()

        for istep in np.arange(start_iter, 400000):

            print(istep)
            # get training data
            stim_train, resp_train, train_len = get_next_training_batch(istep)
            fd_train = {
                stim: stim_train,
                resp: resp_train,
                data_len: train_len
            }
            # take training step
            training(fd_train)

            if istep % 10 == 0:
                # compute training and testing losses
                ls_train = get_loss(fd_train)
                ls_test = get_loss(fd_test)
                latest_filename = short_filename + '_latest_fn'
                saver_var.save(sess,
                               save_filename,
                               global_step=istep,
                               latest_filename=latest_filename)

                # add training summary
                summary = sess.run(merged, feed_dict=fd_train)
                train_writer.add_summary(summary, istep)
                # add testing summary
                summary = sess.run(merged, feed_dict=fd_test)
                test_writer.add_summary(summary, istep)
                print(istep, ls_train, ls_test)

            icnt += FLAGS.batchsz
            if icnt > 216000 - 1000:
                icnt = 0
                tms = np.random.permutation(np.arange(216000 - 1000))
Exemple #8
0
def main(argv):

  # initialize training and testing chunks
  init_chunks()

   # setup dataset
  _,cids_select, n_cells, tc_select, tc_mean = setup_dataset()

  # print parameters
  print('Save folder name: ' + str(FLAGS.folder_name) +
        '\nmodel:' + str(FLAGS.model_id) +
        '\nLoss:' + str(FLAGS.loss) +
        '\nbatch size' + str(FLAGS.batchsz) +
        '\nstep size' + str(FLAGS.step_sz) +
        '\ntraining length: ' + str(FLAGS.train_len) +
        '\nn_cells: '+str(n_cells))

  # filename for saving file
  short_filename = ('_loss='+
                    str(FLAGS.loss) + '_batch_sz='+ str(FLAGS.batchsz) +
                    '_step_sz'+ str(FLAGS.step_sz) +
                    '_tlen=' + str(FLAGS.train_len) + '_jitter')

  # setup model
  with tf.Session() as sess:
    # initialize stuff
    if FLAGS.loss == 'poisson':
      b_init = np.array(0.000001*np.ones(n_cells)) # a very small positive bias needed to avoid log(0) in poisson loss
    else:
      b_init =  np.log((tot_spks_chosen_cells)/(216000. - tot_spks_chosen_cells)) # log-odds, a good initialization for some

    # RGB time filter
    tm4D = np.zeros((30,1,3,3))
    for ichannel in range(3):
      tm4D[:,0,ichannel,ichannel] = tc_mean[:,ichannel]
    tc = tf.Variable((tm4D).astype('float32'),name = 'tc')

    d1=640
    d2=320
    colors=3

    # make data placeholders
    stim = tf.placeholder(tf.float32,shape=[None,d1,d2,colors],name='stim')
    resp = tf.placeholder(tf.float32,shape=[None,n_cells],name='resp')
    data_len = tf.placeholder(tf.float32,name='data_len')

    # time convolution
    # time course should be time,d1,color,color
    # original stimulus is (time, d1,d2,color). Permute it to (d2,time,d1,color) so that 1D convolution could be mimicked using conv_2d.
    stim_time_filtered = tf.transpose(tf.nn.conv2d(tf.transpose(stim,(2,0,1,3)),tc, strides=[1,1,1,1], padding='VALID'), (1,2,0,3))

    # learn almost convolutional model
    short_filename = ('model=' + str(FLAGS.model_id) + short_filename)
    mask_tf, dimx, dimy, n_pix = get_windows()
    w_del = tf.Variable(np.array( 0.05*np.random.randn(dimx, dimy, n_pix),dtype='float32'), name='w_del')
    w_mother = tf.Variable(np.array( np.ones((2 * FLAGS.window + 1, 2 * FLAGS.window + 1, FLAGS.su_channels, 1)),dtype='float32'), name='w_mother')
    a = tf.Variable(np.array(np.random.randn(dimx*dimy, n_cells),dtype='float32'), name='a')
    a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a)))
    b = tf.Variable(np.array(b_init,dtype='float32'), name='b')
    vars_fit = [w_mother, w_del, a] # which variables to fit
    if not FLAGS.loss == 'poisson':
      vars_fit = vars_fit + [b]

    # stimulus filtered with convolutional windows
    stim4D = stim_time_filtered#tf.expand_dims(tf.reshape(stim, (-1,40,80)), 3)
    stim_convolved = tf.reduce_sum(tf.nn.conv2d(stim4D, w_mother, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID"),3)
    stim_masked = tf.nn.conv2d(stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID" )
    stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3)

    # activation of different subunits
    su_act = tf.nn.relu(stim_del + stim_convolved)

    # get firing rate
    lam = tf.matmul(tf.reshape(su_act, [-1, dimx*dimy]), a_sfm) + b

    # regularization
    regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del))

    # projection to satisfy hard variable constraints
    b_pos = tf.assign(b, (b + tf.abs(b))/2)
    def proj():
      if FLAGS.loss == 'poisson':
        sess.run(b_pos)


    if FLAGS.loss == 'poisson':
      loss_inter = (tf.reduce_sum(lam)/120. - tf.reduce_sum(resp*tf.log(lam))) / data_len

    loss = loss_inter + regularization # add regularization to get final loss function

    # training consists of calling training()
    # which performs a train step and project parameters to model specific constraints using proj()
    train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize(loss, var_list=vars_fit)
    def training(inp_dict):
      sess.run(train_step, feed_dict=inp_dict) # one step of gradient descent
      proj() # model specific projection operations

    # evaluate loss on given data.
    def get_loss(inp_dict):
      ls = sess.run(loss,feed_dict = inp_dict)
      return ls


    # saving details
    # make a folder with name derived from parameters of the algorithm - it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
      gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    if not gfile.IsDirectory(FLAGS.save_location):
      gfile.MkDir(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename


    # create summary writers
    # create histogram summary for all parameters which are learnt
    for ivar in vars_fit:
      tf.histogram_summary(ivar.name, ivar)
    # loss summary
    l_summary = tf.scalar_summary('loss',loss)
    # loss without regularization summary
    l_inter_summary = tf.scalar_summary('loss_inter',loss_inter)
    # Merge all the summary writer ops into one op (this way, calling one op stores all summaries)
    merged = tf.merge_all_summaries()
    # training and testing has separate summary writers
    train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train',
                                      sess.graph)
    test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test')


    ## load previous results
    sess.run(tf.initialize_all_variables())
    saver_var = tf.train.Saver(tf.all_variables(), keep_checkpoint_every_n_hours=0.05)
    load_prev = False
    start_iter=0
    try:
      # restore previous fits if they are available - useful when programs are preempted frequently on .
      latest_filename = short_filename + '_latest_fn'
      restore_file = tf.train.latest_checkpoint(FLAGS.save_location, latest_filename)
      start_iter = int(restore_file.split('/')[-1].split('-')[-1]) # restore previous iteration count and start from there.
      saver_var.restore(sess, restore_file) # restore variables
      load_prev = True
    except:
      print('No previous dataset')

    if load_prev:
      print('\nPrevious results loaded')
    else:
      print('\nVariables initialized')

      
    # Finally, do fitting
    # get test data and make test dictionary
    stim_test,resp_test,test_length = get_stim_resp('test')
    fd_test = {stim: stim_test,
               resp: resp_test,
               data_len: test_length}

    for istep in np.arange(start_iter,400000):
      print(istep)
      # get training data and make test dictionary
      stim_train, resp_train, train_len = get_stim_resp('train')
      fd_train = {stim: stim_train,
                  resp: resp_train,
                  data_len: train_len}
      print('loaded a training set')

      # take training step
      training(fd_train)
      print('performed a training step')

      if istep%10 == 0:
        # save variables
        latest_filename = short_filename + '_latest_fn'
        saver_var.save(sess, save_filename, global_step=istep, latest_filename = latest_filename)
        # add training summary
        summary = sess.run(merged, feed_dict=fd_train)
        train_writer.add_summary(summary,istep)
        print('training loss calculated')

        # add testing summary
        summary = sess.run(merged, feed_dict=fd_test)
        test_writer.add_summary(summary,istep)
        print('testing loss calculated')
def main(argv):
    np.random.seed(23)

    # Figure out dictionary path.
    dict_list = gfile.ListDirectory(FLAGS.src_dict)
    dict_path = os.path.join(FLAGS.src_dict, dict_list[FLAGS.taskid])

    # Load the dictionary
    if dict_path[-3:] == 'pkl':
        data = pickle.load(gfile.Open(dict_path, 'r'))
    if dict_path[-3:] == 'mat':
        data = sio.loadmat(gfile.Open(dict_path, 'r'))

    #FLAGS.save_dir = '/home/bhaishahster/stimulation_algos/dictionaries/' + dict_list[FLAGS.taskid][:-4]
    FLAGS.save_dir = FLAGS.save_dir + dict_list[FLAGS.taskid][:-4]
    if not gfile.Exists(FLAGS.save_dir):
        gfile.MkDir(FLAGS.save_dir)

    # S_collection = data['S']  # Target
    A = data['A']  # Decoder
    D = data['D'].T  # Dictionary

    # clean dictionary
    thr_list = np.arange(0, 1, 0.01)
    dict_val = []
    for thr in thr_list:
        dict_val += [np.sum(np.sum(D.T > thr, 1) != 0)]
    plt.ion()
    plt.figure()
    plt.plot(thr_list, dict_val)
    plt.xlabel('Threshold')
    plt.ylabel(
        'Number of dictionary elements with \n atleast one element above threshold'
    )
    plt.title('Please choose threshold')
    thr_use = float(input('What threshold to use?'))
    plt.axvline(thr_use)
    plt.title('Using threshold: %.5f' % thr_use)

    dict_valid = np.sum(D.T > thr_use, 1) > 0
    D = D[:, dict_valid]
    D = np.append(D, np.zeros((D.shape[0], 1)), 1)
    print(
        'Appending a "dummy" dictionary element that does not activate any cell'
    )

    # Vary stimulus resolution
    for itarget in range(20):
        n_targets = 1
        for stix_resolution in [32, 64, 16, 8]:

            # Get the target
            x_dim = int(640 / stix_resolution)
            y_dim = int(320 / stix_resolution)
            targets = (np.random.rand(y_dim, x_dim, n_targets) < 0.5) - 0.5
            upscale = stix_resolution / 8
            targets = np.repeat(np.repeat(targets, upscale, axis=0),
                                upscale,
                                axis=1)
            targets = np.reshape(targets, [-1, targets.shape[-1]])
            S_actual = targets[:, 0]

            # Remove null component of A from S
            S = A.dot(np.linalg.pinv(A).dot(S_actual))

            # Run Greedy first to initialize
            x_greedy = greedy_stimulation(S,
                                          A,
                                          D,
                                          max_stims=FLAGS.t_max * FLAGS.delta,
                                          file_suffix='%d_%d' %
                                          (stix_resolution, itarget),
                                          save=True,
                                          save_dir=FLAGS.save_dir)

            # Load greedy output from previous run
            #data_greedy = pickle.load(gfile.Open('/home/bhaishahster/greedy_2000_32_0.pkl', 'r'))
            #x_greedy = data_greedy['x_chosen']

            # Plan for multiple time points
            x_init = np.zeros((x_greedy.shape[0], FLAGS.t_max))
            for it in range(FLAGS.t_max):
                print((it + 1) * FLAGS.delta - 1)
                x_init[:, it] = x_greedy[:, (it + 1) * FLAGS.delta - 1]

            #simultaneous_planning(S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate,
            #                     delta=FLAGS.delta, normalization=FLAGS.normalization,
            #                    file_suffix='%d_%d_normal' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir)

            from IPython import embed
            embed()
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Interleaved discretization.
            simultaneous_planning_interleaved_discretization(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_pgd_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)

            # Exponential weighing.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=np.inf,
                steps_max=500 * 20 - 1)

            # Exponential weighing with interleaved discretization.
            simultaneous_planning_interleaved_discretization_exp_gradient(
                S,
                A,
                D,
                t_max=FLAGS.t_max,
                lr=FLAGS.learning_rate,
                delta=FLAGS.delta,
                normalization=FLAGS.normalization,
                file_suffix='%d_%d_ew_od' % (stix_resolution, itarget),
                x_init=x_init,
                save_dir=FLAGS.save_dir,
                freeze_freq=500,
                steps_max=500 * 20 - 1)
            '''
      # Plot results
      data_fractional = pickle.load(gfile.Open('/home/bhaishahster/2012-09-24-0_SAD_fr1/pgd_20_5.000000_C_100_32_0_normal.pkl', 'r'))
      plt.ion()
      start = 0
      end = 20

      error_target_frac = np.linalg.norm(S - data_fractional['S'])
      #error_target_greedy = np.linalg.norm(S - data_greedy['S'])
      print('Did target change? \n Err from fractional: %.3f ' % error_target_frac)
      #'\n Err from greedy %.3f' % (error_target_frac,
      #                             error_target_greedy))
      # from IPython import embed; embed()
      normalize = np.sum(data_fractional['S'] ** 2)
      plt.ion()
      plt.plot(np.arange(120, len(data_fractional['f_log'])), data_fractional['f_log'][120:] / normalize, 'b')
      plt.axhline(data_fractional['errors'][5:].mean() / normalize, color='m')
      plt.axhline(data_fractional['errors_ht_discrete'][5:].mean() / normalize, color='y')
      plt.axhline(data_fractional['errors_rr_discrete'][5:].mean() / normalize, color='r')
      plt.axhline(data_greedy['error_curve'][120:].mean() / normalize, color='k')
      plt.legend(['fraction_curve', 'fractional error', 'HT', 'RR', 'Greedy'])
      plt.pause(1.0)
      '''

    from IPython import embed
    embed()
def main(unused_argv=()):

  ## copy data locally
  dst = FLAGS.tmp_dir
  print('Starting Copy')
  if not gfile.IsDirectory(dst):
    gfile.MkDir(dst)

  files = gfile.ListDirectory(FLAGS.src_dir)
  for ifile in files:
    ffile = os.path.join(dst, ifile)
    if not gfile.Exists(ffile):
      gfile.Copy(os.path.join(FLAGS.src_dir, ifile), ffile)
      print('Copied %s' % os.path.join(FLAGS.src_dir, ifile))
    else:
      print('File exists %s' % ffile)

  print('File copied to destination')


  ## load data
  # load stimulus
  data = h5py.File(os.path.join(dst, 'stimulus.mat'))
  stimulus = np.array(data.get('stimulus')) - 0.5

  # load responses from multiple retina
  datasets_list = os.path.join(dst, 'datasets.txt')
  datasets = open(datasets_list, "r").read()
  training_datasets = [line for line in datasets.splitlines()]

  responses = []
  for idata in training_datasets:
    print(idata)
    data_file = os.path.join(dst, idata)
    data = sio.loadmat(data_file)
    responses += [data]
    print(np.max(data['centers'], 0))

  # generate additional features for responses
  num_cell_types = 2
  dimx = 80
  dimy = 40
  for iresp in responses:
    # remove cells which are outside 80x40 window.
    process_dataset(iresp, dimx, dimy, num_cell_types)

  ## generate graph -
  if FLAGS.is_test == 0:
    is_training = True
  if FLAGS.is_test == 1:
    is_training = True # False

  with tf.Session() as sess:

    ## Make graph
    # embed stimulus.
    time_window = 30
    stimx = stimulus.shape[1]
    stimy = stimulus.shape[2]
    stim_tf = tf.placeholder(tf.float32,
                             shape=[None, stimx,
                             stimy, time_window]) # batch x X x Y x time_window
    batch_norm = FLAGS.batch_norm
    stim_embed = embed_stimulus(FLAGS.stim_layers.split(','),
                                batch_norm, stim_tf, is_training,
                                reuse_variables=False)

    '''
    ttf_tf = tf.Variable(np.ones(time_window).astype(np.float32)/10, name='stim_ttf')
    filt = tf.expand_dims(tf.expand_dims(tf.expand_dims(ttf_tf, 0), 0), 3)
    stim_time_filt = tf.nn.conv2d(stim_tf, filt,
                                    strides=[1, 1, 1, 1], padding='SAME') # batch x X x Y x 1


    ilayer = 0
    stim_time_filt = slim.conv2d(stim_time_filt, 1, [3, 3],
                        stride=1,
                        scope='stim_layer_wt_%d' % ilayer,
                        reuse=False,
                        normalizer_fn=slim.batch_norm,
                        activation_fn=tf.nn.softplus,
                        normalizer_params={'is_training': is_training},
                        padding='SAME')
    '''


    # embed responses.
    num_cell_types = 2
    layers = FLAGS.resp_layers  # format: window x filters x stride .. NOTE: final filters=1, stride =1 throughout
    batch_norm = FLAGS.batch_norm
    time_window = 1
    anchor_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                     layers=layers,
                                                     batch_norm=batch_norm,
                                                     is_training=is_training,
                                                     reuse_variables=False,
                                                     num_cell_types=2,
                                                     dimx=dimx, dimy=dimy)

    neg_model = conv.ConvolutionalProsthesisScore(sess, time_window=1,
                                                  layers=layers,
                                                  batch_norm=batch_norm,
                                                  is_training=is_training,
                                                  reuse_variables=True,
                                                  num_cell_types=2,
                                                  dimx=dimx, dimy=dimy)

    d_s_r_pos = tf.reduce_sum((stim_embed - anchor_model.responses_embed)**2, [1, 2, 3]) # batch
    d_pairwise_s_rneg = tf.reduce_sum((tf.expand_dims(stim_embed, 1) -
                                 tf.expand_dims(neg_model.responses_embed, 0))**2, [2, 3, 4]) # batch x batch_neg
    beta = 10
    # if FLAGS.save_suffix == 'lr=0.001':
    loss = tf.reduce_sum(beta * tf.reduce_logsumexp(tf.expand_dims(d_s_r_pos / beta, 1) -
                                                    d_pairwise_s_rneg / beta, 1), 0)
    # else :

    # loss = tf.reduce_sum(tf.nn.softplus(1 +  tf.expand_dims(d_s_r_pos, 1) - d_pairwise_s_rneg))
    accuracy_tf =  tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg))

    lr = 0.001
    train_op = tf.train.AdagradOptimizer(lr).minimize(loss)

    # set up training and testing data
    training_datasets_all = [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15]
    testing_datasets = [0, 4, 8, 12]
    print('Testing datasets', testing_datasets)

    n_training_datasets_log = [1, 3, 5, 7, 9, 11, 12]

    if (np.floor(FLAGS.taskid / 4)).astype(np.int) < len(n_training_datasets_log):
      # do randomly sampled training data for 0<= FLAGS.taskid < 28
      prng = RandomState(23)
      n_training_datasets = n_training_datasets_log[(np.floor(FLAGS.taskid / 4)).astype(np.int)]
      for _ in range(10*FLAGS.taskid):
        print(prng.choice(training_datasets_all,
                          n_training_datasets, replace=False))
      training_datasets = prng.choice(training_datasets_all,
                                      n_training_datasets, replace=False)

      # training_datasets = [i for i in range(7) if i< 7-FLAGS.taskid] #[0, 1, 2, 3, 4, 5]


    else:
      # do 1 training data, chosen in order for FLAGS.taskid >= 28
      datasets_all = np.arange(16)
      training_datasets = [datasets_all[FLAGS.taskid % (4 * len(n_training_datasets_log))]]

    print('Task ID %d' % FLAGS.taskid)
    print('Training datasets', training_datasets)

    # Initialize stuff.
    file_name = ('end_to_end_stim_%s_resp_%s_beta_%d_taskid_%d'
                 '_training_%s_testing_%s_%s' % (FLAGS.stim_layers,
                                                 FLAGS.resp_layers, beta, FLAGS.taskid,
                                                 str(training_datasets)[1: -1],
                                                 str(testing_datasets)[1: -1],
                                                 FLAGS.save_suffix))
    saver_var, start_iter = initialize_model(FLAGS.save_folder, file_name, sess)

    # print model graph
    PrintModelAnalysis(tf.get_default_graph())

    # Add summary ops
    retina_number = tf.placeholder(tf.int16, name='input_retina');

    summary_ops = []
    for iret in np.arange(len(responses)):
      print(iret)
      r1 = tf.summary.scalar('loss_%d' % iret , loss)
      r2 = tf.summary.scalar('accuracy_%d' % iret , accuracy_tf)
      summary_ops += [tf.summary.merge([r1, r2])]

    # tf.summary.scalar('loss', loss)
    # tf.summary.scalar('accuracy', accuracy_tf)
    # summary_op = tf.summary.merge_all()

    # Setup summary writers
    summary_writers = []
    for loc in ['train', 'test']:
      summary_location = os.path.join(FLAGS.save_folder, file_name,
                                      'summary_' + loc )
      summary_writer = tf.summary.FileWriter(summary_location, sess.graph)
      summary_writers += [summary_writer]


    # start training
    batch_neg_train_sz = 100
    batch_train_sz = 100


    def batch(dataset_id):
      batch_train = get_batch(stimulus, responses[dataset_id]['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   anchor_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'],
                   neg_model.cell_types_tf: responses[dataset_id]['ctype_1hot'],
                   neg_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'],
                   retina_number : dataset_id}

      return feed_dict

    def batch_few_cells(responses):
      batch_train = get_batch(stimulus, responses['responses'],
                              batch_size=batch_train_sz,
                              batch_neg_resp=batch_neg_train_sz,
                              stim_history=30, min_window=10)
      stim_batch, resp_batch, resp_batch_neg = batch_train
      feed_dict = {stim_tf: stim_batch,
                   anchor_model.responses_tf: np.expand_dims(resp_batch, 2),
                   neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2),

                   anchor_model.map_cell_grid_tf: responses['map_cell_grid'],
                   anchor_model.cell_types_tf: responses['ctype_1hot'],
                   anchor_model.mean_fr_tf: responses['mean_firing_rate'],

                   neg_model.map_cell_grid_tf: responses['map_cell_grid'],
                   neg_model.cell_types_tf: responses['ctype_1hot'],
                   neg_model.mean_fr_tf: responses['mean_firing_rate'],
                   }

      return feed_dict

    if FLAGS.is_test == 1:
      print('Testing')
      save_dict = {}

      from IPython import embed; embed()
      ## Estimate one, fix others
      '''
      grad_resp = tf.gradients(d_s_r_pos, anchor_model.responses_tf)

      t_start = 1000
      t_len = 100
      stim_history = 30
      stim_batch = np.zeros((t_len, stimulus.shape[1],
                         stimulus.shape[2], stim_history))
      for isample, itime in enumerate(np.arange(t_start, t_start + t_len)):
        stim_batch[isample, :, :, :] = np.transpose(stimulus[itime: itime-stim_history:-1, :, :], [1, 2, 0])

      iretina = testing_datasets[0]
      resp_batch = np.expand_dims(np.random.rand(t_len, responses[iretina]['responses'].shape[1]), 2)

      step_sz = 0.01
      eps = 1e-2
      dist_prev = np.inf
      for iiter in range(10000):
        feed_dict = {stim_tf: stim_batch,
                     anchor_model.map_cell_grid_tf: responses[iretina]['map_cell_grid'],
                     anchor_model.cell_types_tf: responses[iretina]['ctype_1hot'],
                     anchor_model.mean_fr_tf: responses[iretina]['mean_firing_rate'],
                     anchor_model.responses_tf: resp_batch}
        dist_np, resp_grad_np = sess.run([d_s_r_pos, grad_resp], feed_dict=feed_dict)
        if np.sum(np.abs(dist_prev - dist_np)) < eps:
          break
        print(np.sum(dist_np), np.sum(np.abs(dist_prev - dist_np)))
        dist_prev = dist_np
        resp_batch = resp_batch - step_sz * resp_grad_np[0]

      resp_batch = resp_batch.squeeze()
      '''


      # from IPython import embed; embed()
      ## compute distances between s-r pairs for small number of cells

      test_retina = []
      for iretina in range(len(testing_datasets)):
        dataset_id = testing_datasets[iretina]

        num_cells_total = responses[dataset_id]['responses'].shape[1]
        dataset_center = responses[dataset_id]['centers'].mean(0)
        dataset_cell_distances = np.sqrt(np.sum((responses[dataset_id]['centers'] -
                                         dataset_center), 1))
        order_cells = np.argsort(dataset_cell_distances)

        test_sr_few_cells = {}
        for num_cells_prc in [5, 10, 20, 30, 50, 100]:
          num_cells = np.percentile(np.arange(num_cells_total),
                                    num_cells_prc).astype(np.int)

          choose_cells = order_cells[:num_cells]

          resposnes_few_cells = {'responses': responses[dataset_id]['responses'][:, choose_cells],
                                 'map_cell_grid': responses[dataset_id]['map_cell_grid'][:, :, choose_cells],
                                'ctype_1hot': responses[dataset_id]['ctype_1hot'][choose_cells, :],
                                'mean_firing_rate': responses[dataset_id]['mean_firing_rate'][choose_cells]}
          # get a batch
          d_pos_log = np.array([])
          d_neg_log = np.array([])
          for test_iter in range(1000):
            print(iretina, num_cells_prc, test_iter)
            feed_dict = batch_few_cells(resposnes_few_cells)
            d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
            d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
            d_pos_log = np.append(d_pos_log, d_pos)
            d_neg_log = np.append(d_neg_log, d_neg)

          precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

          print(np.sum(d_pos_log > d_neg_log))
          print(np.sum(d_pos_log < d_neg_log))
          test_sr= {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                     'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log,
                    'num_cells': num_cells}

          test_sr_few_cells.update({'num_cells_prc_%d' % num_cells_prc : test_sr})
        test_retina += [test_sr_few_cells]
      save_dict.update({'few_cell_analysis': test_retina})

      ## compute distances between s-r pairs - pos and neg.

      test_retina = []
      for iretina in range(len(testing_datasets)):
        # stim-resp log
        d_pos_log = np.array([])
        d_neg_log = np.array([])
        for test_iter in range(1000):
          print(test_iter)
          feed_dict = batch(testing_datasets[iretina])
          d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict)
          d_neg = np.diag(d_neg) # np.mean(d_neg, 1) #
          d_pos_log = np.append(d_pos_log, d_pos)
          d_neg_log = np.append(d_neg_log, d_neg)

        precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log)

        print(np.sum(d_pos_log > d_neg_log))
        print(np.sum(d_pos_log < d_neg_log))
        test_sr = {'precision': precision_log, 'recall': recall_log,
                     'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log,
                   'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log}

        test_retina += [test_sr]

      save_dict.update({'test_sr': test_retina})


      ## ROC curves of responses from repeats - dataset 1
      repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      repeats_data['cell_type'] = repeats_data['cell_type'].T

      # process repeats data
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      test_reps = analyse_response_repeats(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2015-09-23-7': test_reps})

      ## ROC curves of responses from repeats - dataset 2
      repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat'
      repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r'));
      process_dataset(repeats_data, dimx, dimy, num_cell_types)

      # analyse and store the result
      '''
      test_clustering = analyse_response_repeats_all_trials(repeats_data, anchor_model, neg_model, sess)
      save_dict.update({'test_reps_2005_08_03_0': test_clustering})
      '''
      #
      # get model params
      save_dict.update({'model_pars': sess.run(tf.trainable_variables())})


      save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis.pkl')
      pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
      print(save_analysis_filename)
      return

    test_iiter = 0
    for iiter in range(start_iter, FLAGS.max_iter): # TODO(bhaishahster) :add FLAGS.max_iter

      # get a new batch
      # stim_tf, anchor_model.responses_tf, neg_model.responses_tf

      # training step
      train_dataset = training_datasets[iiter % len(training_datasets)]
      feed_dict_train = batch(train_dataset)
      _, loss_np_train = sess.run([train_op, loss], feed_dict=feed_dict_train)
      print(train_dataset, loss_np_train)

      # write summary
      if iiter % 10 == 0:
        # write train summary
        test_iiter = test_iiter + 1

        train_dataset = training_datasets[test_iiter % len(training_datasets)]
        feed_dict_train = batch(train_dataset)
        summary_train = sess.run(summary_ops[train_dataset], feed_dict=feed_dict_train)
        summary_writers[0].add_summary(summary_train, iiter)

        # write test summary
        test_dataset = testing_datasets[test_iiter % len(testing_datasets)]
        feed_dict_test = batch(test_dataset)
        l_test, summary_test = sess.run([loss, summary_ops[test_dataset]], feed_dict=feed_dict_test)
        summary_writers[1].add_summary(summary_test, iiter)
        print('Test retina: %d, loss: %.3f' % (test_dataset, l_test))

      # save model
      if iiter % 10 == 0:
        save_model(saver_var, FLAGS.save_folder, file_name, sess, iiter)
Exemple #11
0
def RunComputation():

    # filename for saving files, derived from FLAGS.
    short_filename = get_filename()

    # make a folder with name derived from parameters of the algorithm
    # it saves checkpoint files and summaries used in tensorboard
    parent_folder = FLAGS.save_location + FLAGS.folder_name + '/'
    # make folder if it does not exist
    if not gfile.IsDirectory(parent_folder):
        gfile.MkDir(parent_folder)
    FLAGS.save_location = parent_folder + short_filename + '/'
    print('Does the file exist?', gfile.IsDirectory(FLAGS.save_location))
    if not gfile.IsDirectory(FLAGS.save_location):
        gfile.MkDir(FLAGS.save_location)
    save_filename = FLAGS.save_location + short_filename

    if FLAGS.learn == 0:
        # for analysis, use smaller batch sizes, so that we can work with single GPU.
        FLAGS.batchsz = 600

    #Set up tensorflow
    with tf.Graph().as_default() as gra:
        with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks)):
            print(FLAGS.config_params)
            tf.logging.info(FLAGS.config_params)

            # set up training dataset
            # tc_mean = get_data_mat.init_chunks(FLAGS.n_chunks) <- use this with old get_data_mat
            tc_mean = get_data_mat.init_chunks(FLAGS.batchsz)
            #plt.plot(tc_mean)
            #plt.show()
            #plt.draw()

            # Create computation graph.
            #
            # Graph should be fully constructed before you create supervisor.
            # Attempt to modify graph after supervisor is created will cause an error.
            with tf.name_scope('model'):
                if FLAGS.architecture == '1 layer':
                    # single GPU model
                    if False:
                        global_step = tf.contrib.framework.create_global_step()
                        model, stim, resp = jitter_model.approximate_conv_jitter(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels)

                    # multiGPU model
                    if True:
                        model, stim, resp, global_step = jitter_model.approximate_conv_jitter_multigpu(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params)

                if FLAGS.architecture == '2 layer_stimulus':
                    # stimulus is first smoothened to lower dimensions, then same model is applied
                    print('First take a low resolution version of stimulus')
                    model, stim, resp, global_step, stim_tuple = (
                        jitter_model.approximate_conv_jitter_multigpu_stim_lr(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params,
                            FLAGS.stim_downsample_window,
                            FLAGS.stim_downsample_stride))

                if FLAGS.architecture == 'complex':
                    print(' Multiple modifications over 2 layered model above')
                    model, stim, resp, global_step = (
                        jitter_model_2.
                        approximate_conv_jitter_multigpu_complex(
                            FLAGS.n_cells, FLAGS.lam_w, FLAGS.window,
                            FLAGS.stride, FLAGS.step_sz, tc_mean,
                            FLAGS.su_channels, FLAGS.config_params,
                            FLAGS.stim_downsample_window,
                            FLAGS.stim_downsample_stride))

            # Print the number of variables in graph
            print('Calculating model size')  # Hope we do not exceed memory
            PrintModelAnalysis(gra, max_depth=10)

            # Builds our summary op.
            summary_op = model.merged_summary

            # Create a Supervisor.  It will take care of initialization, summaries,
            # checkpoints, and recovery.
            #
            # When multiple replicas of this program are running, the first one,
            # identified by --task=0 is the 'chief' supervisor.  It is the only one
            # that takes case of initialization, etc.
            is_chief = (FLAGS.task == 0)  # & (FLAGS.learn==1)
            print(save_filename)

            if FLAGS.learn == 1:
                # use supervisor only for learning,
                # otherwise it messes up data as it tries to store variables while you are doing analysis

                sv = tf.train.Supervisor(logdir=save_filename,
                                         is_chief=is_chief,
                                         saver=tf.train.Saver(),
                                         summary_op=None,
                                         save_model_secs=100,
                                         global_step=global_step,
                                         recovery_wait_secs=5)

                if (is_chief and FLAGS.learn == 1):
                    # save graph only if task id =0 (is_chief) and learning the model
                    tf.train.write_graph(tf.get_default_graph().as_graph_def(),
                                         save_filename, 'graph.pbtxt')

                # Get an initialized, and possibly recovered session.  Launch the
                # services: Checkpointing, Summaries, step counting.
                #
                # When multiple replicas of this program are running the services are
                # only launched by the 'chief' replica.
                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                sess = sv.PrepareSession(FLAGS.master, config=session_config)

                # Finally, learn the parameters of the model
                FitComputation(sv, sess, model, stim, resp, global_step,
                               summary_op)
                sv.Stop()

            else:
                # Analyse the model

                session_config = tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False)
                with tf.Session(config=session_config) as sess:

                    # First, recover the model
                    saver_var = tf.train.Saver(
                        tf.all_variables(),
                        keep_checkpoint_every_n_hours=float('inf'))
                    restore_file = tf.train.latest_checkpoint(save_filename)
                    print(restore_file)
                    start_iter = int(
                        restore_file.split('/')[-1].split('-')[-1])
                    saver_var.restore(sess, restore_file)

                    # model specific analysis
                    if FLAGS.architecture == '2 layer_stimulus':
                        AnalyseModel_lr(sess, model)
                    elif FLAGS.architecture == 'complex':
                        AnalyseModel_complex(sess, model, stim, resp,
                                             save_filename)
                    else:
                        AnalyseModel(sv, sess, model)