def main(unused_argv=()): # Load stimulus-response data datasets = gfile.ListDirectory(FLAGS.src_dir) stimuli = {} responses = [] print(datasets) for icnt, idataset in enumerate(datasets): #for icnt, idataset in enumerate([datasets[2]]): # print('HACK - only one dataset used!!') fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key) stimulus, resp, dimx, dimy, num_cell_types = op stimuli.update({key: stimulus}) responses += resp print('# Responses %d' % len(responses)) stimulus = stimuli[responses[FLAGS.taskid]['stimulus_key']] save_filename = ('linear_taskid_%d_piece_%s.pkl' % (FLAGS.taskid, responses[FLAGS.taskid]['piece'])) print(save_filename) learn_lin_embedding(stimulus, np.double(responses[FLAGS.taskid]['responses']), filename=save_filename, lam_l1=0.00001, beta=10, time_window=30, lr=0.01) print('DONE!')
def get_latest_file(save_location, short_filename): # get relevant files file_list = gfile.ListDirectory(save_location) print(save_location, short_filename) save_filename = save_location + short_filename print('\nLoading: ', save_filename) bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] # print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Could not load filename: ' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) restore_file = save_filename + '-' + str(int(iter_plot)) return restore_file
def inputs(name, data_location, batch_size, num_epochs, stim_dim, resp_dim): # gives a batch of stimulus and responses from a .tfrecords file # works for .tfrecords file made using CoarseDataUtils.convert_to_TFRecords # Get filename queue. # Actual name is either 'name', 'name.tfrecords' or # folder 'name' with list of .tfrecords files. with tf.name_scope('input'): filename = os.path.join(data_location, name) filename_extension = os.path.join(data_location, name + '.tfrecords') if gfile.Exists(filename) and not gfile.IsDirectory(filename): tf.logging.info('%s Exists' % filename) filenames = [filename] elif gfile.Exists(filename_extension ) and not gfile.IsDirectory(filename_extension): tf.logging.info('%s Exists' % filename_extension) filenames = [filename_extension] elif gfile.IsDirectory(filename): tf.logging.info('%s Exists and is a directory' % filename) filenames_short = gfile.ListDirectory(filename) filenames = [ os.path.join(filename, ifilenames_short) for ifilenames_short in filenames_short ] tf.logging.info(filenames) filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, capacity=10000) # Even when reading in multiple threads, share the filename # queue. stimulus, response = read_and_decode(filename_queue, stim_dim, resp_dim) # Shuffle the examples and collect them into batch_size batches. # (Internally uses a RandomShuffleQueue.) # We run this in two threads to avoid being a bottleneck. stimulus_batch, response_batch = tf.train.shuffle_batch( [stimulus, response], batch_size=batch_size, num_threads=30, capacity=5000 + 3 * batch_size, # Ensures a minimum amount of shuffling of examples. min_after_dequeue=2000) ''' stimulus_batch, response_batch = tf.train.batch( [stimulus, response], batch_size=batch_size, num_threads=30, capacity = 50000 + 3 * batch_size) ''' return stimulus_batch, response_batch
def main(unused_argv=()): # Load stimulus-response data datasets = gfile.ListDirectory(FLAGS.src_dir) responses = [] print(datasets) for icnt, idataset in enumerate([datasets]): #TODO(bhaishahster): HACK. fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key) stimulus, resp, dimx, dimy, num_cell_types = op responses += resp for idataset in range(len(responses)): k, b, ttf = fit_ln_population(responses[idataset]['responses'], stimulus) # Use FLAGS.taskID save_dict = {'k': k, 'b': b, 'ttf': ttf} save_analysis_filename = os.path.join(FLAGS.save_folder, responses[idataset]['piece'] + '_ln_model.pkl') pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
def main(unused_argv=()): np.random.seed(23) tf.set_random_seed(1234) random.seed(50) # Load stimulus-response data. # Collect population response across retinas in the list 'responses'. # Stimulus for each retina is indicated by 'stim_id', # which is found in 'stimuli' dictionary. datasets = gfile.ListDirectory(FLAGS.src_dir) stimuli = {} responses = [] for icnt, idataset in enumerate(datasets): fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key, boundary=FLAGS.valid_cells_boundary) stimulus, resp, dimx, dimy, _ = op stimuli.update({key: stimulus}) responses += resp taskid = FLAGS.taskid dat = responses[taskid] stimulus = stimuli[dat['stimulus_key']] # parameters window = 5 # Compute time course and non-linearity as two parameters which might be should be explored in embedded space. n_cells = dat['responses'].shape[1] T = np.minimum(stimulus.shape[0], dat['responses'].shape[0]) stim_short = stimulus[:T, :, :] resp_short = dat['responses'][:T, :].astype(np.float32) save_dict = {} # Find time course, non-linearity and RF parameters ######################################################################## # Separation between cell types ######################################################################## save_dict.update({'cell_type': dat['cell_type']}) save_dict.update({'dist_nn_cell_type': dat['dist_nn_cell_type']}) ######################################################################## # Find mean firing rate ######################################################################## mean_fr = dat['responses'].mean(0) mean_fr_1 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==1]) mean_fr_2 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==2]) mean_fr_dict = {'mean_fr': mean_fr, 'mean_fr_1': mean_fr_1, 'mean_fr_2': mean_fr_2} save_dict.update({'mean_fr_dict': mean_fr_dict}) ######################################################################## # compute STAs ######################################################################## stas = np.zeros((n_cells, 80, 40, 30)) for icell in range(n_cells): print(icell) center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1]) for idelay in range(30): stas[icell, windx[0]: windx[1], windy[0]: windy[1], idelay] = np.reshape(resp_short[idelay:, icell].dot(stim_cell[:T-idelay, :]), [windx[1] - windx[0], windy[1] - windy[0]]) / np.sum(resp_short[idelay:, icell]) stas_dict = {'stas': stas} # save_dict.update({'stas_dict': stas_dict}) ######################################################################## # Find time courses for each cell ######################################################################## ttf_log = [] for icell in range(n_cells): center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] ll = stas[icell, windx[0]: windx[1], windy[0]: windy[1], :] ll_2d = np.reshape(ll, [-1, ll.shape[-1]]) u, s, v = np.linalg.svd(ll_2d) ttf_log += [v[0, :]] ttf_log = np.array(ttf_log) signs = [np.sign(ttf_log[icell, np.argmax(np.abs(ttf_log[icell, :]))]) for icell in range(ttf_log.shape[0])] ttf_corrected = np.expand_dims(np.array(signs), 1) * ttf_log ttf_corrected[np.squeeze(dat['cell_type'])==1, :] = ttf_corrected[np.squeeze(dat['cell_type'])==1, :] * -1 ttf_mean_1 = ttf_corrected[np.squeeze(dat['cell_type'])==1, :].mean(0) ttf_mean_2 = ttf_corrected[np.squeeze(dat['cell_type'])==2, :].mean(0) ttf_params_1 = get_times(ttf_mean_1) ttf_params_2 = get_times(ttf_mean_2) ttf_dict = {'ttf_log': ttf_log, 'ttf_mean_1': ttf_mean_1, 'ttf_mean_2': ttf_mean_2, 'ttf_params_1': ttf_params_1, 'ttf_params_2': ttf_params_2} save_dict.update({'ttf_dict': ttf_dict}) ''' plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==1, :].T, 'r', alpha=0.3) plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==2, :].T, 'k', alpha=0.3) plt.plot(ttf_mean_1, 'r--') plt.plot(ttf_mean_2, 'k--') ''' ######################################################################## ## Find non-linearity ######################################################################## f_nl = lambda x, p0, p1, p2, p3: p0 + p1*x + p2* np.power(x, 2) + p3* np.power(x, 3) nl_params_log = [] stim_resp_log = [] for icell in range(n_cells): print(icell) center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1]) sta_cell = np.reshape(stas[icell, windx[0]: windx[1], windy[0]: windy[1], :], [-1, stas.shape[-1]]) stim_filter = np.zeros(stim_short.shape[0]) for idelay in range(30): stim_filter[idelay: ] += stim_cell[:T-idelay, :].dot(sta_cell[:, idelay]) # Normalize stim_filter stim_filter -= np.mean(stim_filter) stim_filter /= np.sqrt(np.var(stim_filter)) resp_cell = resp_short[:, icell] stim_nl = [] resp_nl = [] for ipercentile in range(3, 97, 1): lb = np.percentile(stim_filter, ipercentile-3) ub = np.percentile(stim_filter, ipercentile+3) tms = np.logical_and(stim_filter >= lb, stim_filter < ub) stim_nl += [np.mean(stim_filter[tms])] resp_nl += [np.mean(resp_cell[tms])] stim_nl = np.array(stim_nl) resp_nl = np.array(resp_nl) popt, pcov = scipy.optimize.curve_fit(f_nl, stim_nl, resp_nl, p0=[1, 0, 0, 0]) nl_params_log += [popt] stim_resp_log += [[stim_nl, resp_nl]] nl_params_log = np.array(nl_params_log) np_params_mean_1 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==1, :], 0) np_params_mean_2 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==2, :], 0) nl_params_dict = {'nl_params_log': nl_params_log, 'np_params_mean_1': np_params_mean_1, 'np_params_mean_2': np_params_mean_2, 'stim_resp_log': stim_resp_log} save_dict.update({'nl_params_dict': nl_params_dict}) ''' # Visualize Non-linearities for icell in range(n_cells): stim_in = np.arange(-3, 3, 0.1) fr = f_nl(stim_in, *nl_params_log[icell, :]) if np.squeeze(dat['cell_type'])[icell] == 1: c = 'r' else: c = 'k' plt.plot(stim_in, fr, c, alpha=0.2) fr = f_nl(stim_in, *np_params_mean_1) plt.plot(stim_in, fr, 'r--') fr = f_nl(stim_in, *np_params_mean_2) plt.plot(stim_in, fr, 'k--') ''' pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder , dat['piece']), 'w')) pickle.dump(stas_dict, gfile.Open(os.path.join(FLAGS.save_folder , 'stas' + dat['piece']), 'w'))
def main(argv): print('\nCode started') np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) ## Load data summary filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) cells = np.squeeze(data_summary['cells']) if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic': cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | ( cells == 3066) if FLAGS.model_id == 'poisson_full': cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool') n_cells = np.sum(cells_choose) tot_spks = np.squeeze(data_summary['tot_spks']) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T tot_spks_chosen_cells = tot_spks[cells_choose] chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0, dtype='bool') print(np.shape(chosen_mask)) print(np.sum(chosen_mask)) stim_dim = np.sum(chosen_mask) print('\ndataset summary loaded') # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells # decide the number of subunits to fit n_su = FLAGS.ratio_SU * n_cells # saving details #short_filename = 'data_model=ASM_pop_bg' # short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + # '_step_sz'+ str(FLAGS.step_sz)+'_bg') # saving details if FLAGS.model_id == 'poisson': short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'logistic': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'poisson_full': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' FLAGS.save_location = parent_folder + short_filename + '/' print(gfile.IsDirectory(FLAGS.save_location)) print(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename with tf.Session() as sess: # Learn population model! stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') # variables if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.1 * np.random.rand(n_cells, 1, n_su), dtype='float32')) if FLAGS.model_id == 'logistic': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.random.randn( n_cells ) #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0))) b = tf.Variable(b_init, dtype='float32') # get relevant files file_list = gfile.ListDirectory(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename print('\nLoading: ', save_filename) bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] #print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Could not load filename: ' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) # load tensorflow variables saver_var = tf.train.Saver(tf.all_variables()) restore_file = save_filename + '-' + str(int(iter_plot)) saver_var.restore(sess, restore_file) a_eval = a.eval() print(np.exp(np.squeeze(a_eval))) #print(np.shape(a_eval)) # get 2D region to plot mask2D = np.reshape(chosen_mask, [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) w_eval = w.eval() plt.figure() n_su = w_eval.shape[1] for isu in np.arange(n_su): xx = np.zeros((3200)) xx[chosen_mask] = w_eval[:, isu] fig = plt.subplot(np.ceil(np.sqrt(n_su)), np.ceil(np.sqrt(n_su)), isu + 1) plt.imshow(np.reshape(xx, [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) #if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge': # plt.title(str(a_eval[isu, :])) #else: # plt.title(str(np.squeeze(np.exp(a_eval[:, 0, isu]))), fontsize=12) plt.suptitle('Iteration:' + str(int(iter_plot)) + ' batchSz:' + str(FLAGS.batchsz) + ' step size:' + str(FLAGS.step_sz), fontsize=18) plt.show() plt.draw()
def main(argv): #plt.ion() # interactive plotting window = FLAGS.window n_pix = (2 * window + 1)**2 dimx = np.floor(1 + ((40 - (2 * window + 1)) / FLAGS.stride)).astype('int') dimy = np.floor(1 + ((80 - (2 * window + 1)) / FLAGS.stride)).astype('int') nCells = 107 # load model # load filename print(FLAGS.model_id) with tf.Session() as sess: if FLAGS.model_id == 'relu': # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0 short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') w = tf.Variable( np.array(np.random.randn(3200, 749), dtype='float32')) a = tf.Variable( np.array(np.random.randn(749, 107), dtype='float32')) if FLAGS.model_id == 'relu_window': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) # exp 5 a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_sfm': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_sfm_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.01 + 0.005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.02 + np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_a_support': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.002 * np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'exp_window_a_support': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.002 * np.random.rand(dimx * dimy, nCells), dtype='float32')) parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' FLAGS.save_location = parent_folder + short_filename + '/' # get relevant files file_list = gfile.ListDirectory(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename print('\nLoading: ', save_filename) bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] #print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Could not load filename: ' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) # load tensorflow variables saver_var = tf.train.Saver(tf.all_variables()) restore_file = save_filename + '-' + str(int(iter_plot)) saver_var.restore(sess, restore_file) # plot subunit - cell connections plt.figure() plt.cla() plt.imshow(a.eval(), cmap='gray', interpolation='nearest') print(np.shape(a.eval())) plt.title('Iteration: ' + str(int(iter_plot))) plt.show() plt.draw() # plot all subunits on 40x80 grid try: wts = w.eval() for isu in range(100): fig = plt.subplot(10, 10, isu + 1) plt.imshow(np.reshape(wts[:, isu], [40, 80]), interpolation='nearest', cmap='gray') plt.title('Iteration: ' + str(int(iter_plot))) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w full does not exist? ') # plot a few subunits - wmother + wdel try: wts = w.eval() print('wts shape:', np.shape(wts)) icnt = 1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2 * window + 1, 2 * window + 1)), interpolation='nearest', cmap='gray') icnt = icnt + 1 fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.show() plt.draw() except: print('w does not exist?') # plot wmother try: w_mot = np.squeeze(w_mother.eval()) print(w_mot) plt.imshow(w_mot, interpolation='nearest', cmap='gray') plt.title('Mother subunit') plt.show() plt.draw() except: print('w mother does not exist') # plot wmother + wdel try: w_mot = np.squeeze(w_mother.eval()) w_del = np.squeeze(w_del.eval()) wts = np.array(np.random.randn(dimx, dimy, (2 * window + 1)**2)) for idimx in np.arange(dimx): print(idimx) for idimy in np.arange(dimy): wts[idimx, idimy, :] = np.ndarray.flatten(w_mot) + w_del[idimx, idimy, :] except: print('w mother + w delta do not exist? ') ''' try: icnt=1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2*window+1,2*window+1)), interpolation='nearest', cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w mother + w delta plotting error? ') # plot wdel try: w_del = np.squeeze(w_del.eval()) icnt=1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow( np.reshape(w_del[idimx, idimy, :], (2*window+1,2*window+1)), interpolation='nearest', cmap='gray') icnt = icnt+1 fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w delta do not exist? ') plt.suptitle('Iteration: ' + str(int(iter_plot))) plt.show() plt.draw() ''' # select a cell, and show its subunits. #try: ## Load data summary, get mask filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T stas = data_summary['stas'] print(np.shape(total_mask)) # a is 2D a_eval = a.eval() print(np.shape(a_eval)) # get softmax numpy if FLAGS.model_id == 'relu_window_mother_sfm' or FLAGS.model_id == 'relu_window_mother_sfm_exp': b = np.exp(a_eval) / np.sum(np.exp(a_eval), 0) else: b = a_eval plt.figure() plt.imshow(b, interpolation='nearest', cmap='gray') plt.show() plt.draw() # plot subunits for multiple cells. n_cells = 10 n_plots_max = 20 plt.figure() for icell_cnt, icell in enumerate(np.arange(n_cells)): mask2D = np.reshape(total_mask[icell, :], [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) icnt = -1 a_thr = np.percentile(np.abs(b[:, icell]), 99.5) n_plots = np.sum(np.abs(b[:, icell]) > a_thr) nx = np.ceil(np.sqrt(n_plots)).astype('int') ny = np.ceil(np.sqrt(n_plots)).astype('int') ifig = 0 ww_sum = np.zeros((40, 80)) for idimx in np.arange(dimx): for idimy in np.arange(dimy): icnt = icnt + 1 if (np.abs(b[icnt, icell]) > a_thr): ifig = ifig + 1 fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + ifig + 2) ww = np.zeros((40, 80)) ww[idimx * FLAGS.stride:idimx * FLAGS.stride + (2 * window + 1), idimy * FLAGS.stride:idimy * FLAGS.stride + (2 * window + 1)] = b[icnt, icell] * (np.reshape( wts[idimx, idimy, :], (2 * window + 1, 2 * window + 1))) plt.imshow(ww, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) plt.title(b[icnt, icell]) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) ww_sum = ww_sum + ww fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + 2) plt.imshow(ww_sum, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('STA from model') fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + 1) plt.imshow(np.reshape(stas[:, icell], [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('True STA') plt.show() plt.draw() #except: # print('a not 2D?') # using xlim and ylim, and plot the 'windows' which are relevant with their weights sq_flat = np.zeros((dimx, dimy)) icnt = 0 for idimx in np.arange(dimx): for idimy in np.arange(dimy): sq_flat[idimx, idimy] = icnt icnt = icnt + 1 n_cells = 1 n_plots_max = 10 plt.figure() for icell_cnt, icell in enumerate(np.array( [1, 2, 3, 4, 5])): #enumerate(np.arange(n_cells)): a_thr = np.percentile(np.abs(b[:, icell]), 99.5) mask2D = np.reshape(total_mask[icell, :], [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) print(xlim, ylim) win_startx = np.ceil((xlim[0] - (2 * window + 1)) / FLAGS.stride) win_endx = np.floor((xlim[1] - 1) / FLAGS.stride) win_starty = np.ceil((ylim[0] - (2 * window + 1)) / FLAGS.stride) win_endy = np.floor((ylim[1] - 1) / FLAGS.stride) dimx_plot = win_endx - win_startx + 1 dimy_plot = win_endy - win_starty + 1 ww_sum = np.zeros((40, 80)) for irow, idimy in enumerate(np.arange(win_startx, win_endx + 1)): for icol, idimx in enumerate( np.arange(win_starty, win_endy + 1)): fig = plt.subplot(dimx_plot + 1, dimy_plot, (irow + 1) * dimy_plot + icol + 1) ww = np.zeros((40, 80)) ww[idimx * FLAGS.stride:idimx * FLAGS.stride + (2 * window + 1), idimy * FLAGS.stride:idimy * FLAGS.stride + (2 * window + 1)] = (np.reshape( wts[idimx, idimy, :], (2 * window + 1, 2 * window + 1))) plt.imshow(ww, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) if b[sq_flat[idimx, idimy], icell] > a_thr: plt.title(b[sq_flat[idimx, idimy], icell], fontsize=10, color='g') else: plt.title(b[sq_flat[idimx, idimy], icell], fontsize=10, color='r') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) ww_sum = ww_sum + ww * b[sq_flat[idimx, idimy], icell] fig = plt.subplot(dimx_plot + 1, dimy_plot, 2) plt.imshow(ww_sum, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('STA from model') fig = plt.subplot(dimx_plot + 1, dimy_plot, 1) plt.imshow(np.reshape(stas[:, icell], [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('True STA') plt.show() plt.draw()
def main(unused_argv=()): np.random.seed(23) tf.set_random_seed(1234) random.seed(50) # Load stimulus-response data. # Collect population response across retinas in the list 'responses'. # Stimulus for each retina is indicated by 'stim_id', # which is found in 'stimuli' dictionary. datasets = gfile.ListDirectory(FLAGS.src_dir) stimuli = {} responses = [] for icnt, idataset in enumerate(datasets): fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response( FLAGS.src_dir, idataset, key, boundary=FLAGS.valid_cells_boundary, if_get_stim=True) stimulus, resp, dimx, dimy, _ = op stimuli.update({key: stimulus}) responses += resp # Get training and testing partitions. # Generate partitions # The partitions for the taskid should be listed in partition_file. op = partitions.get_partitions(FLAGS.partition_file, FLAGS.taskid) training_datasets, testing_datasets = op with tf.Session() as sess: # Get stimulus-response embedding. if FLAGS.mode == 0: is_training = True if FLAGS.mode == 1: is_training = True if FLAGS.mode == 2: is_training = True print('NOTE: is_training = True in test') if FLAGS.mode == 3: is_training = True print('NOTE: is_training = True in test') sample_fcn = sample_datasets if (FLAGS.sr_model == 'convolutional_embedding'): embedding = sr_models.convolutional_embedding( FLAGS.sr_model, sess, is_training, dimx, dimy) if (FLAGS.sr_model == 'convolutional_embedding_expt' or FLAGS.sr_model == 'convolutional_embedding_margin_expt' or FLAGS.sr_model == 'convolutional_embedding_inner_product_expt' or FLAGS.sr_model == 'convolutional_embedding_gauss_expt' or FLAGS.sr_model == 'convolutional_embedding_kernel_expt'): embedding = sr_models_expt.convolutional_embedding_experimental( FLAGS.sr_model, sess, is_training, dimx, dimy) if FLAGS.sr_model == 'convolutional_autoembedder': embedding = sr_models_expt.convolutional_autoembedder( sess, is_training, dimx, dimy) if FLAGS.sr_model == 'convolutional_autoembedder_l2': embedding = sr_models_expt.convolutional_autoembedder( sess, is_training, dimx, dimy, loss='log_sum_exp') if FLAGS.sr_model == 'convolutional_encoder' or FLAGS.sr_model == 'convolutional_encoder_2': embedding = encoding_models_expt.convolutional_encoder( sess, is_training, dimx, dimy) if FLAGS.sr_model == 'convolutional_encoder_using_retina_id': model = encoding_models_expt.convolutional_encoder_using_retina_id embedding = model(sess, is_training, dimx, dimy, len(responses)) sample_fcn = sample_datasets_2 if (FLAGS.sr_model == 'residual') or (FLAGS.sr_model == 'residual_inner_product'): embedding = sr_models_expt.residual_experimental( FLAGS.sr_model, sess, is_training, dimx, dimy) if FLAGS.sr_model == 'lin_rank1' or FLAGS.sr_model == 'lin_rank1_blind': if ((len(training_datasets) != 1) and (training_datasets != testing_datasets)): raise ValueError('Identical training/testing data' ' (exactly 1) supported') n_cells = responses[training_datasets[0]]['responses'].shape[1] cell_locations = responses[training_datasets[0]]['map_cell_grid'] cell_masks = responses[training_datasets[0]]['mask_cells'] firing_rates = responses[training_datasets[0]]['mean_firing_rate'] cell_type = responses[training_datasets[0]]['cell_type'].squeeze() model_fn = sr_baseline_models.linear_rank1_models embedding = model_fn(FLAGS.sr_model, sess, dimx, dimy, n_cells, center_locations=cell_locations, cell_masks=cell_masks, firing_rates=firing_rates, cell_type=cell_type, time_window=30) # print model graph PrintModelAnalysis(tf.get_default_graph()) # Get filename, initialize model file_name = bookkeeping.get_filename(training_datasets, testing_datasets, FLAGS.beta, FLAGS.sr_model) tf.logging.info('Filename: %s' % file_name) saver_var, start_iter = bookkeeping.initialize_model( FLAGS.save_folder, file_name, sess) # Setup summary ops. # Save separate summary for each retina (both training/testing). summary_ops = [] for iret in np.arange(len(responses)): r_list = [] r1 = tf.summary.scalar('loss_%d' % iret, embedding.loss) r_list += [r1] if hasattr(embedding, 'accuracy_tf'): r2 = tf.summary.scalar('accuracy_%d' % iret, embedding.accuracy_tf) r_list += [r2] if FLAGS.sr_model == 'convolutional_autoembedder' or FLAGS.sr_model == 'convolutional_autoembedder_l2': r3 = tf.summary.scalar('loss_triplet_%d' % iret, embedding.loss_triplet) r4 = tf.summary.scalar('loss_stim_decode_from_resp_%d' % iret, embedding.loss_stim_decode_from_resp) r5 = tf.summary.scalar('loss_stim_decode_from_stim_%d' % iret, embedding.loss_stim_decode_from_stim) r6 = tf.summary.scalar('loss_resp_decode_from_resp_%d' % iret, embedding.loss_resp_decode_from_resp) r7 = tf.summary.scalar('loss_resp_decode_from_stim_%d' % iret, embedding.loss_resp_decode_from_stim) r_list += [r3, r4, r5, r6, r7] ''' chosen_stim = 2 bound = FLAGS.valid_cells_boundary r8 = tf.summary.image('stim_decode_from_stim_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3)) r9 = tf.summary.image('stim_decode_from_resp_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.stim_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 3], 0), 3)) r10 = tf.summary.image('resp_decode_from_stim_chann0_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3)) r11 = tf.summary.image('resp_decode_from_resp_chann0_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3)) r12 = tf.summary.image('resp_decode_from_stim_chann1_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_stim[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3)) r13 = tf.summary.image('resp_decode_from_resp_chann1_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.resp_decode_from_resp[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3)) r14 = tf.summary.image('resp_chann0_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 0], 0), 3)) r15 = tf.summary.image('resp_chann1_%d' % iret, tf.expand_dims(tf.expand_dims(embedding.anchor_model.responses_embed_1[chosen_stim, bound:80-bound, bound:40-bound, 1], 0), 3)) r_list += [r8, r9, r10, r11, r12, r13, r14, r15] ''' summary_ops += [tf.summary.merge(r_list)] # Setup summary writers. summary_writers = [] for loc in ['train', 'test']: summary_location = os.path.join(FLAGS.save_folder, file_name, 'summary_' + loc) summary_writer = tf.summary.FileWriter(summary_location, sess.graph) summary_writers += [summary_writer] # Separate tests for encoding or metric learning, # prosthesis usage or just neuroscience usage. if FLAGS.mode == 3: testing.test_encoding(training_datasets, testing_datasets, responses, stimuli, embedding, sess, file_name, sample_fcn) elif FLAGS.mode == 2: prosthesis.stimulate(embedding, sess, file_name, dimx, dimy) elif FLAGS.mode == 1: testing.test_metric(training_datasets, testing_datasets, responses, stimuli, embedding, sess, file_name) else: training.training(start_iter, sess, embedding, summary_writers, summary_ops, saver_var, training_datasets, testing_datasets, responses, stimuli, file_name, sample_fcn, summary_freq=500, save_freq=500)
def main(argv): #plt.ion() # interactive plotting # load model # load filename print(FLAGS.model_id) print(FLAGS.folder_name) if FLAGS.model_id == 'relu': # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0 short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') if FLAGS.model_id == 'exp': short_filename = ('data_model3=' + str(FLAGS.model_id) + '_bias_init=' + str(FLAGS.bias_init_scale) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') if FLAGS.model_id == 'mel_re_pow2': short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') if FLAGS.model_id == 'relu_logistic': short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') if FLAGS.model_id == 'relu_proximal': short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_eta_w=' + str(FLAGS.eta_w) + '_eta_a=' + str(FLAGS.eta_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_proximal_bg') if FLAGS.model_id == 'relu_eg': short_filename = ('data_model3=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_eta_w=' + str(FLAGS.eta_w) + '_eta_a=' + str(FLAGS.eta_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_eg_bg') # get relevant files parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' FLAGS.save_location = parent_folder + short_filename + '/' file_list = gfile.ListDirectory(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] #print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Bad filename' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) with tf.Session() as sess: # load tensorflow variables w = tf.Variable(np.array(np.random.randn(3200, 749), dtype='float32')) a = tf.Variable(np.array(np.random.randn(749, 107), dtype='float32')) saver_var = tf.train.Saver(tf.all_variables()) restore_file = save_filename + '-' + str(int(iter_plot)) saver_var.restore(sess, restore_file) # plot subunit - cell connections plt.figure() plt.cla() plt.imshow(a.eval(), cmap='gray', interpolation='nearest') plt.title('Iteration: ' + str(int(iter_plot))) plt.show() plt.draw() # plot a few subunits wts = w.eval() for isu in range(100): fig = plt.subplot(10, 10, isu + 1) plt.imshow(np.reshape(wts[:, isu], [40, 80]), interpolation='nearest', cmap='gray') plt.title('Iteration: ' + str(int(iter_plot))) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.show() plt.draw()
def main(argv): np.random.seed(23) # Figure out dictionary path. dict_list = gfile.ListDirectory(FLAGS.src_dict) dict_path = os.path.join(FLAGS.src_dict, dict_list[FLAGS.taskid]) # Load the dictionary if dict_path[-3:] == 'pkl': data = pickle.load(gfile.Open(dict_path, 'r')) if dict_path[-3:] == 'mat': data = sio.loadmat(gfile.Open(dict_path, 'r')) #FLAGS.save_dir = '/home/bhaishahster/stimulation_algos/dictionaries/' + dict_list[FLAGS.taskid][:-4] FLAGS.save_dir = FLAGS.save_dir + dict_list[FLAGS.taskid][:-4] if not gfile.Exists(FLAGS.save_dir): gfile.MkDir(FLAGS.save_dir) # S_collection = data['S'] # Target A = data['A'] # Decoder D = data['D'].T # Dictionary # clean dictionary thr_list = np.arange(0, 1, 0.01) dict_val = [] for thr in thr_list: dict_val += [np.sum(np.sum(D.T > thr, 1) != 0)] plt.ion() plt.figure() plt.plot(thr_list, dict_val) plt.xlabel('Threshold') plt.ylabel( 'Number of dictionary elements with \n atleast one element above threshold' ) plt.title('Please choose threshold') thr_use = float(input('What threshold to use?')) plt.axvline(thr_use) plt.title('Using threshold: %.5f' % thr_use) dict_valid = np.sum(D.T > thr_use, 1) > 0 D = D[:, dict_valid] D = np.append(D, np.zeros((D.shape[0], 1)), 1) print( 'Appending a "dummy" dictionary element that does not activate any cell' ) # Vary stimulus resolution for itarget in range(20): n_targets = 1 for stix_resolution in [32, 64, 16, 8]: # Get the target x_dim = int(640 / stix_resolution) y_dim = int(320 / stix_resolution) targets = (np.random.rand(y_dim, x_dim, n_targets) < 0.5) - 0.5 upscale = stix_resolution / 8 targets = np.repeat(np.repeat(targets, upscale, axis=0), upscale, axis=1) targets = np.reshape(targets, [-1, targets.shape[-1]]) S_actual = targets[:, 0] # Remove null component of A from S S = A.dot(np.linalg.pinv(A).dot(S_actual)) # Run Greedy first to initialize x_greedy = greedy_stimulation(S, A, D, max_stims=FLAGS.t_max * FLAGS.delta, file_suffix='%d_%d' % (stix_resolution, itarget), save=True, save_dir=FLAGS.save_dir) # Load greedy output from previous run #data_greedy = pickle.load(gfile.Open('/home/bhaishahster/greedy_2000_32_0.pkl', 'r')) #x_greedy = data_greedy['x_chosen'] # Plan for multiple time points x_init = np.zeros((x_greedy.shape[0], FLAGS.t_max)) for it in range(FLAGS.t_max): print((it + 1) * FLAGS.delta - 1) x_init[:, it] = x_greedy[:, (it + 1) * FLAGS.delta - 1] #simultaneous_planning(S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, # delta=FLAGS.delta, normalization=FLAGS.normalization, # file_suffix='%d_%d_normal' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir) from IPython import embed embed() simultaneous_planning_interleaved_discretization( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_pgd' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=np.inf, steps_max=500 * 20 - 1) # Interleaved discretization. simultaneous_planning_interleaved_discretization( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_pgd_od' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=500, steps_max=500 * 20 - 1) # Exponential weighing. simultaneous_planning_interleaved_discretization_exp_gradient( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_ew' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=np.inf, steps_max=500 * 20 - 1) # Exponential weighing with interleaved discretization. simultaneous_planning_interleaved_discretization_exp_gradient( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_ew_od' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=500, steps_max=500 * 20 - 1) ''' # Plot results data_fractional = pickle.load(gfile.Open('/home/bhaishahster/2012-09-24-0_SAD_fr1/pgd_20_5.000000_C_100_32_0_normal.pkl', 'r')) plt.ion() start = 0 end = 20 error_target_frac = np.linalg.norm(S - data_fractional['S']) #error_target_greedy = np.linalg.norm(S - data_greedy['S']) print('Did target change? \n Err from fractional: %.3f ' % error_target_frac) #'\n Err from greedy %.3f' % (error_target_frac, # error_target_greedy)) # from IPython import embed; embed() normalize = np.sum(data_fractional['S'] ** 2) plt.ion() plt.plot(np.arange(120, len(data_fractional['f_log'])), data_fractional['f_log'][120:] / normalize, 'b') plt.axhline(data_fractional['errors'][5:].mean() / normalize, color='m') plt.axhline(data_fractional['errors_ht_discrete'][5:].mean() / normalize, color='y') plt.axhline(data_fractional['errors_rr_discrete'][5:].mean() / normalize, color='r') plt.axhline(data_greedy['error_curve'][120:].mean() / normalize, color='k') plt.legend(['fraction_curve', 'fractional error', 'HT', 'RR', 'Greedy']) plt.pause(1.0) ''' from IPython import embed embed()
def subunit_discriminability(dataset_dict, stimuli, responses, sr_graph, num_examples=1000): ## compute distances between s-r pairs - pos and neg. ## negative_stim - if the negative is a stimulus or a response if num_examples % 100 != 0: raise ValueError('Only supports examples which are multiples of 100.') subunit_fit_loc = '/home/bhaishahster/stim-resp_collection_big_wn_retina_subunit_properties_train' subunits_datasets = gfile.ListDirectory(subunit_fit_loc) save_dict = {} datasets_log = {} for dat_key, datasets in dataset_dict.items(): distances_log = {} distances_retina_sr_log = [] distances_retina_rr_log = [] for iretina in range(len(datasets)): # Find the relevant subunit fit piece = responses[iretina]['piece'] matched_dataset = [ ifit for ifit in subunits_datasets if piece[:12] == ifit[:12] ] if matched_dataset == []: raise ValueError('Could not find subunit fit') subunit_fit_path = os.path.join(subunit_fit_loc, matched_dataset[0]) # Get predicted spikes. dat_resp_su = pickle.load( gfile.Open( os.path.join(subunit_fit_path, 'response_prediction.pkl'), 'r')) resp_su = dat_resp_su[ 'resp_su'] # it has non-rejected cells as well. # Remove some cells. select_cells = [ icell for icell in range(resp_su.shape[2]) if dat_resp_su['cell_ids'][icell] in responses[iretina] ['cellID_list'].squeeze() ] select_cells = np.array(select_cells) resp_su = resp_su[:, :, select_cells].astype(np.float32) # Get stimulus stimulus = stimuli[responses[iretina]['stimulus_key']] stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :] responses_recorded_test = responses[iretina]['responses'][ FLAGS.test_min:FLAGS.test_max, :] # Sample stimuli and responses. random_times = np.random.randint(40, stimulus_test.shape[0], num_examples) batch_size = 100 # Recorded response - predicted response distances. distances_retina = np.zeros((num_examples, 10)) + np.nan for Nsub in range(1, 11): for ibatch in range( np.floor(num_examples / batch_size).astype(np.int)): # construct stimulus tensor. stim_history = 30 resp_pred_batch = np.zeros((batch_size, resp_su.shape[2])) resp_rec_batch = np.zeros((batch_size, resp_su.shape[2])) for isample in range(batch_size): itime = random_times[batch_size * ibatch + isample] resp_pred_batch[isample, :] = resp_su[Nsub - 1, itime, :] resp_rec_batch[isample, :] = responses_recorded_test[ itime, :] # Embed predicted responses feed_dict = make_feed_dict(sr_graph, responses[iretina], responses=resp_pred_batch) embed_predicted = sr_graph.sess.run( sr_graph.anchor_model.responses_embed, feed_dict=feed_dict) # Embed recorded responses feed_dict = make_feed_dict(sr_graph, responses[iretina], responses=resp_rec_batch) embed_recorded = sr_graph.sess.run( sr_graph.anchor_model.responses_embed, feed_dict=feed_dict) dd = sr_graph.sess.run(sr_graph.distances_arbitrary, feed_dict={ sr_graph.arbitrary_embedding_1: embed_predicted, sr_graph.arbitrary_embedding_2: embed_recorded }) distances_retina[batch_size * ibatch:batch_size * (ibatch + 1), Nsub - 1] = dd print(iretina, Nsub, ibatch) distances_retina_rr_log += [distances_retina] # Stimulus - predicted response distances. distances_retina = np.zeros((num_examples, 10)) + np.nan for Nsub in range(1, 11): for ibatch in range( np.floor(num_examples / batch_size).astype(np.int)): # construct stimulus tensor. stim_history = 30 stim_batch = np.zeros( (batch_size, stimulus_test.shape[1], stimulus_test.shape[2], stim_history)) resp_batch = np.zeros((batch_size, resp_su.shape[2])) for isample in range(batch_size): itime = random_times[batch_size * ibatch + isample] stim_batch[isample, :, :, :] = np.transpose( stimulus_test[itime:itime - stim_history:-1, :, :], [1, 2, 0]) resp_batch[isample, :] = resp_su[Nsub - 1, itime, :] feed_dict = make_feed_dict(sr_graph, responses[iretina], resp_batch, stim_batch) # Get distances d_pos = sr_graph.sess.run(sr_graph.d_s_r_pos, feed_dict=feed_dict) distances_retina[batch_size * ibatch:batch_size * (ibatch + 1), Nsub - 1] = d_pos print(iretina, Nsub, ibatch) distances_retina_sr_log += [distances_retina] distances_log.update({'rr': distances_retina_rr_log}) distances_log.update({'sr': distances_retina_sr_log}) datasets_log.update({dat_key: distances_log}) save_dict.update({ 'datasets_log': datasets_log, 'dataset_dict': dataset_dict }) return save_dict
def response_transformation_increase_nl(stimuli, responses, sr_graph, time_start_list, time_len=100, alpha_list=[1.5, 1.25, 0.8, 0.6]): # 1. Take an LN model and increase non-linearity. # How do the points move in response space? # Load LN models ln_save_folder = '/home/bhaishahster/stim-resp_collection_ln_model_exp' files = gfile.ListDirectory(ln_save_folder) ln_models = [] for ifile in files: print(ifile) ln_models += [ pickle.load(gfile.Open(os.path.join(ln_save_folder, ifile), 'r')) ] t_start_dict = {} t_min = FLAGS.test_min t_max = FLAGS.test_max for time_start in time_start_list: print('Start time %d' % time_start) retina_log = [] for iretina_test in range(3, len(responses)): print('Retina: %d' % iretina_test) piece_id = responses[iretina_test]['piece'] # find piece in ln_models matched_ln_model = [ ifile for ifile in range(len(files)) if files[ifile][:12] == piece_id[:12] ] if len(matched_ln_model) == 0: print('LN model not found') continue if len(matched_ln_model) > 1: print('More than 1 LN model found') # Sample a sequence of stimuli and predict spikes iresp = responses[iretina_test] iln_model = ln_models[matched_ln_model[0]] stimulus_test = stimuli[iresp['stimulus_key']] stim_sample = stimulus_test[time_start:time_start + time_len, :, :] spikes, lam_np = analysis_utils.predict_responses_ln( stim_sample, iln_model['k'], iln_model['b'], iln_model['ttf'], n_trials=1) spikes_log = np.copy(spikes[0, :, :]) alpha_log = np.ones(time_len) # Increase nonlinearity, normalize firing rate and embed. for alpha in alpha_list: _, lam_np_alpha = analysis_utils.predict_responses_ln( stim_sample, alpha * iln_model['k'], alpha * iln_model['b'], iln_model['ttf'], n_trials=1) correction_firing_rate = np.mean(lam_np) / np.mean( lam_np_alpha) correction_b = np.log(correction_firing_rate) spikes_corrected, lam_np_corrected = analysis_utils.predict_responses_ln( stim_sample, alpha * iln_model['k'], alpha * iln_model['b'] + correction_b, iln_model['ttf'], n_trials=1) print(alpha, np.mean(lam_np), np.mean(lam_np_alpha), np.mean(lam_np_corrected)) spikes_log = np.append(spikes_log, spikes_corrected[0, :, :], axis=0) alpha_log = np.append(alpha_log, alpha * np.ones(time_len), axis=0) # plt.figure() # analysis_utils.plot_raster(spikes_corrected[:, :, 23]) # plt.title(alpha) # Embed responses try: resp_trans = np.expand_dims( spikes_log[:, iresp['valid_cells']], 2) feed_dict = { sr_graph.anchor_model.map_cell_grid_tf: iresp['map_cell_grid'], sr_graph.anchor_model.cell_types_tf: iresp['ctype_1hot'], sr_graph.anchor_model.mean_fr_tf: iresp['mean_firing_rate'], sr_graph.anchor_model.responses_tf: resp_trans } if hasattr(sr_graph.anchor_model, 'dist_nn'): dist_nn = np.array([ iresp['dist_nn_cell_type'][1], iresp['dist_nn_cell_type'][2] ]).astype(np.float32) feed_dict.update({ sr_graph.anchor_model.dist_nn: dist_nn, sr_graph.neg_model.dist_nn: dist_nn }) rr = sr_graph.sess.run(sr_graph.anchor_model.responses_embed, feed_dict=feed_dict) retina_log += [{ 'spikes_log': spikes_log, 'alpha_log': alpha_log, 'resp_embed': rr, 'piece': piece_id }] except: print('Error! ') retina_log += [np.nan] pass t_start_dict.update({time_start: retina_log}) return t_start_dict
def main(unused_argv=()): ## copy data locally dst = FLAGS.tmp_dir print('Starting Copy') if not gfile.IsDirectory(dst): gfile.MkDir(dst) files = gfile.ListDirectory(FLAGS.src_dir) for ifile in files: ffile = os.path.join(dst, ifile) if not gfile.Exists(ffile): gfile.Copy(os.path.join(FLAGS.src_dir, ifile), ffile) print('Copied %s' % os.path.join(FLAGS.src_dir, ifile)) else: print('File exists %s' % ffile) print('File copied to destination') ## load data # load stimulus data = h5py.File(os.path.join(dst, 'stimulus.mat')) stimulus = np.array(data.get('stimulus')) - 0.5 # load responses from multiple retina datasets_list = os.path.join(dst, 'datasets.txt') datasets = open(datasets_list, "r").read() training_datasets = [line for line in datasets.splitlines()] responses = [] for idata in training_datasets: print(idata) data_file = os.path.join(dst, idata) data = sio.loadmat(data_file) responses += [data] print(np.max(data['centers'], 0)) # generate additional features for responses num_cell_types = 2 dimx = 80 dimy = 40 for iresp in responses: # remove cells which are outside 80x40 window. process_dataset(iresp, dimx, dimy, num_cell_types) ## generate graph - if FLAGS.is_test == 0: is_training = True if FLAGS.is_test == 1: is_training = True # False with tf.Session() as sess: ## Make graph # embed stimulus. time_window = 30 stimx = stimulus.shape[1] stimy = stimulus.shape[2] stim_tf = tf.placeholder(tf.float32, shape=[None, stimx, stimy, time_window]) # batch x X x Y x time_window batch_norm = FLAGS.batch_norm stim_embed = embed_stimulus(FLAGS.stim_layers.split(','), batch_norm, stim_tf, is_training, reuse_variables=False) ''' ttf_tf = tf.Variable(np.ones(time_window).astype(np.float32)/10, name='stim_ttf') filt = tf.expand_dims(tf.expand_dims(tf.expand_dims(ttf_tf, 0), 0), 3) stim_time_filt = tf.nn.conv2d(stim_tf, filt, strides=[1, 1, 1, 1], padding='SAME') # batch x X x Y x 1 ilayer = 0 stim_time_filt = slim.conv2d(stim_time_filt, 1, [3, 3], stride=1, scope='stim_layer_wt_%d' % ilayer, reuse=False, normalizer_fn=slim.batch_norm, activation_fn=tf.nn.softplus, normalizer_params={'is_training': is_training}, padding='SAME') ''' # embed responses. num_cell_types = 2 layers = FLAGS.resp_layers # format: window x filters x stride .. NOTE: final filters=1, stride =1 throughout batch_norm = FLAGS.batch_norm time_window = 1 anchor_model = conv.ConvolutionalProsthesisScore(sess, time_window=1, layers=layers, batch_norm=batch_norm, is_training=is_training, reuse_variables=False, num_cell_types=2, dimx=dimx, dimy=dimy) neg_model = conv.ConvolutionalProsthesisScore(sess, time_window=1, layers=layers, batch_norm=batch_norm, is_training=is_training, reuse_variables=True, num_cell_types=2, dimx=dimx, dimy=dimy) d_s_r_pos = tf.reduce_sum((stim_embed - anchor_model.responses_embed)**2, [1, 2, 3]) # batch d_pairwise_s_rneg = tf.reduce_sum((tf.expand_dims(stim_embed, 1) - tf.expand_dims(neg_model.responses_embed, 0))**2, [2, 3, 4]) # batch x batch_neg beta = 10 # if FLAGS.save_suffix == 'lr=0.001': loss = tf.reduce_sum(beta * tf.reduce_logsumexp(tf.expand_dims(d_s_r_pos / beta, 1) - d_pairwise_s_rneg / beta, 1), 0) # else : # loss = tf.reduce_sum(tf.nn.softplus(1 + tf.expand_dims(d_s_r_pos, 1) - d_pairwise_s_rneg)) accuracy_tf = tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg)) lr = 0.001 train_op = tf.train.AdagradOptimizer(lr).minimize(loss) # set up training and testing data training_datasets_all = [1, 2, 3, 5, 6, 7, 9, 10, 11, 13, 14, 15] testing_datasets = [0, 4, 8, 12] print('Testing datasets', testing_datasets) n_training_datasets_log = [1, 3, 5, 7, 9, 11, 12] if (np.floor(FLAGS.taskid / 4)).astype(np.int) < len(n_training_datasets_log): # do randomly sampled training data for 0<= FLAGS.taskid < 28 prng = RandomState(23) n_training_datasets = n_training_datasets_log[(np.floor(FLAGS.taskid / 4)).astype(np.int)] for _ in range(10*FLAGS.taskid): print(prng.choice(training_datasets_all, n_training_datasets, replace=False)) training_datasets = prng.choice(training_datasets_all, n_training_datasets, replace=False) # training_datasets = [i for i in range(7) if i< 7-FLAGS.taskid] #[0, 1, 2, 3, 4, 5] else: # do 1 training data, chosen in order for FLAGS.taskid >= 28 datasets_all = np.arange(16) training_datasets = [datasets_all[FLAGS.taskid % (4 * len(n_training_datasets_log))]] print('Task ID %d' % FLAGS.taskid) print('Training datasets', training_datasets) # Initialize stuff. file_name = ('end_to_end_stim_%s_resp_%s_beta_%d_taskid_%d' '_training_%s_testing_%s_%s' % (FLAGS.stim_layers, FLAGS.resp_layers, beta, FLAGS.taskid, str(training_datasets)[1: -1], str(testing_datasets)[1: -1], FLAGS.save_suffix)) saver_var, start_iter = initialize_model(FLAGS.save_folder, file_name, sess) # print model graph PrintModelAnalysis(tf.get_default_graph()) # Add summary ops retina_number = tf.placeholder(tf.int16, name='input_retina'); summary_ops = [] for iret in np.arange(len(responses)): print(iret) r1 = tf.summary.scalar('loss_%d' % iret , loss) r2 = tf.summary.scalar('accuracy_%d' % iret , accuracy_tf) summary_ops += [tf.summary.merge([r1, r2])] # tf.summary.scalar('loss', loss) # tf.summary.scalar('accuracy', accuracy_tf) # summary_op = tf.summary.merge_all() # Setup summary writers summary_writers = [] for loc in ['train', 'test']: summary_location = os.path.join(FLAGS.save_folder, file_name, 'summary_' + loc ) summary_writer = tf.summary.FileWriter(summary_location, sess.graph) summary_writers += [summary_writer] # start training batch_neg_train_sz = 100 batch_train_sz = 100 def batch(dataset_id): batch_train = get_batch(stimulus, responses[dataset_id]['responses'], batch_size=batch_train_sz, batch_neg_resp=batch_neg_train_sz, stim_history=30, min_window=10) stim_batch, resp_batch, resp_batch_neg = batch_train feed_dict = {stim_tf: stim_batch, anchor_model.responses_tf: np.expand_dims(resp_batch, 2), neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2), anchor_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'], anchor_model.cell_types_tf: responses[dataset_id]['ctype_1hot'], anchor_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'], neg_model.map_cell_grid_tf: responses[dataset_id]['map_cell_grid'], neg_model.cell_types_tf: responses[dataset_id]['ctype_1hot'], neg_model.mean_fr_tf: responses[dataset_id]['mean_firing_rate'], retina_number : dataset_id} return feed_dict def batch_few_cells(responses): batch_train = get_batch(stimulus, responses['responses'], batch_size=batch_train_sz, batch_neg_resp=batch_neg_train_sz, stim_history=30, min_window=10) stim_batch, resp_batch, resp_batch_neg = batch_train feed_dict = {stim_tf: stim_batch, anchor_model.responses_tf: np.expand_dims(resp_batch, 2), neg_model.responses_tf: np.expand_dims(resp_batch_neg, 2), anchor_model.map_cell_grid_tf: responses['map_cell_grid'], anchor_model.cell_types_tf: responses['ctype_1hot'], anchor_model.mean_fr_tf: responses['mean_firing_rate'], neg_model.map_cell_grid_tf: responses['map_cell_grid'], neg_model.cell_types_tf: responses['ctype_1hot'], neg_model.mean_fr_tf: responses['mean_firing_rate'], } return feed_dict if FLAGS.is_test == 1: print('Testing') save_dict = {} from IPython import embed; embed() ## Estimate one, fix others ''' grad_resp = tf.gradients(d_s_r_pos, anchor_model.responses_tf) t_start = 1000 t_len = 100 stim_history = 30 stim_batch = np.zeros((t_len, stimulus.shape[1], stimulus.shape[2], stim_history)) for isample, itime in enumerate(np.arange(t_start, t_start + t_len)): stim_batch[isample, :, :, :] = np.transpose(stimulus[itime: itime-stim_history:-1, :, :], [1, 2, 0]) iretina = testing_datasets[0] resp_batch = np.expand_dims(np.random.rand(t_len, responses[iretina]['responses'].shape[1]), 2) step_sz = 0.01 eps = 1e-2 dist_prev = np.inf for iiter in range(10000): feed_dict = {stim_tf: stim_batch, anchor_model.map_cell_grid_tf: responses[iretina]['map_cell_grid'], anchor_model.cell_types_tf: responses[iretina]['ctype_1hot'], anchor_model.mean_fr_tf: responses[iretina]['mean_firing_rate'], anchor_model.responses_tf: resp_batch} dist_np, resp_grad_np = sess.run([d_s_r_pos, grad_resp], feed_dict=feed_dict) if np.sum(np.abs(dist_prev - dist_np)) < eps: break print(np.sum(dist_np), np.sum(np.abs(dist_prev - dist_np))) dist_prev = dist_np resp_batch = resp_batch - step_sz * resp_grad_np[0] resp_batch = resp_batch.squeeze() ''' # from IPython import embed; embed() ## compute distances between s-r pairs for small number of cells test_retina = [] for iretina in range(len(testing_datasets)): dataset_id = testing_datasets[iretina] num_cells_total = responses[dataset_id]['responses'].shape[1] dataset_center = responses[dataset_id]['centers'].mean(0) dataset_cell_distances = np.sqrt(np.sum((responses[dataset_id]['centers'] - dataset_center), 1)) order_cells = np.argsort(dataset_cell_distances) test_sr_few_cells = {} for num_cells_prc in [5, 10, 20, 30, 50, 100]: num_cells = np.percentile(np.arange(num_cells_total), num_cells_prc).astype(np.int) choose_cells = order_cells[:num_cells] resposnes_few_cells = {'responses': responses[dataset_id]['responses'][:, choose_cells], 'map_cell_grid': responses[dataset_id]['map_cell_grid'][:, :, choose_cells], 'ctype_1hot': responses[dataset_id]['ctype_1hot'][choose_cells, :], 'mean_firing_rate': responses[dataset_id]['mean_firing_rate'][choose_cells]} # get a batch d_pos_log = np.array([]) d_neg_log = np.array([]) for test_iter in range(1000): print(iretina, num_cells_prc, test_iter) feed_dict = batch_few_cells(resposnes_few_cells) d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict) d_neg = np.diag(d_neg) # np.mean(d_neg, 1) # d_pos_log = np.append(d_pos_log, d_pos) d_neg_log = np.append(d_neg_log, d_neg) precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log) print(np.sum(d_pos_log > d_neg_log)) print(np.sum(d_pos_log < d_neg_log)) test_sr= {'precision': precision_log, 'recall': recall_log, 'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log, 'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log, 'num_cells': num_cells} test_sr_few_cells.update({'num_cells_prc_%d' % num_cells_prc : test_sr}) test_retina += [test_sr_few_cells] save_dict.update({'few_cell_analysis': test_retina}) ## compute distances between s-r pairs - pos and neg. test_retina = [] for iretina in range(len(testing_datasets)): # stim-resp log d_pos_log = np.array([]) d_neg_log = np.array([]) for test_iter in range(1000): print(test_iter) feed_dict = batch(testing_datasets[iretina]) d_pos, d_neg = sess.run([d_s_r_pos, d_pairwise_s_rneg], feed_dict=feed_dict) d_neg = np.diag(d_neg) # np.mean(d_neg, 1) # d_pos_log = np.append(d_pos_log, d_pos) d_neg_log = np.append(d_neg_log, d_neg) precision_log, recall_log, F1_log, FPR_log, TPR_log = ROC(d_pos_log, d_neg_log) print(np.sum(d_pos_log > d_neg_log)) print(np.sum(d_pos_log < d_neg_log)) test_sr = {'precision': precision_log, 'recall': recall_log, 'F1': F1_log, 'FPR': FPR_log, 'TPR': TPR_log, 'd_pos_log': d_pos_log, 'd_neg_log': d_neg_log} test_retina += [test_sr] save_dict.update({'test_sr': test_retina}) ## ROC curves of responses from repeats - dataset 1 repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat' repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r')); repeats_data['cell_type'] = repeats_data['cell_type'].T # process repeats data process_dataset(repeats_data, dimx, dimy, num_cell_types) # analyse and store the result test_reps = analyse_response_repeats(repeats_data, anchor_model, neg_model, sess) save_dict.update({'test_reps_2015-09-23-7': test_reps}) ## ROC curves of responses from repeats - dataset 2 repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat' repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r')); process_dataset(repeats_data, dimx, dimy, num_cell_types) # analyse and store the result ''' test_clustering = analyse_response_repeats_all_trials(repeats_data, anchor_model, neg_model, sess) save_dict.update({'test_reps_2005_08_03_0': test_clustering}) ''' # # get model params save_dict.update({'model_pars': sess.run(tf.trainable_variables())}) save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis.pkl') pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w')) print(save_analysis_filename) return test_iiter = 0 for iiter in range(start_iter, FLAGS.max_iter): # TODO(bhaishahster) :add FLAGS.max_iter # get a new batch # stim_tf, anchor_model.responses_tf, neg_model.responses_tf # training step train_dataset = training_datasets[iiter % len(training_datasets)] feed_dict_train = batch(train_dataset) _, loss_np_train = sess.run([train_op, loss], feed_dict=feed_dict_train) print(train_dataset, loss_np_train) # write summary if iiter % 10 == 0: # write train summary test_iiter = test_iiter + 1 train_dataset = training_datasets[test_iiter % len(training_datasets)] feed_dict_train = batch(train_dataset) summary_train = sess.run(summary_ops[train_dataset], feed_dict=feed_dict_train) summary_writers[0].add_summary(summary_train, iiter) # write test summary test_dataset = testing_datasets[test_iiter % len(testing_datasets)] feed_dict_test = batch(test_dataset) l_test, summary_test = sess.run([loss, summary_ops[test_dataset]], feed_dict=feed_dict_test) summary_writers[1].add_summary(summary_test, iiter) print('Test retina: %d, loss: %.3f' % (test_dataset, l_test)) # save model if iiter % 10 == 0: save_model(saver_var, FLAGS.save_folder, file_name, sess, iiter)
def main(unused_argv=()): #np.random.seed(23) #tf.set_random_seed(1234) #random.seed(50) # 1. Load stimulus-response data. # Collect population response across retinas in the list 'responses'. # Stimulus for each retina is indicated by 'stim_id', # which is found in 'stimuli' dictionary. datasets = gfile.ListDirectory(FLAGS.src_dir) stimuli = {} responses = [] for icnt, idataset in enumerate(datasets): fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response( FLAGS.src_dir, idataset, key, boundary=FLAGS.valid_cells_boundary) stimulus, resp, dimx, dimy, _ = op stimuli.update({key: stimulus}) responses += resp # 2. Do response prediction for a retina iretina = FLAGS.taskid subunit_fit_loc = FLAGS.save_folder subunits_datasets = gfile.ListDirectory(subunit_fit_loc) piece = responses[iretina]['piece'] matched_dataset = [ ifit for ifit in subunits_datasets if piece[:12] == ifit[:12] ] if matched_dataset == []: raise ValueError('Could not find subunit fit') subunit_fit_path = os.path.join(subunit_fit_loc, matched_dataset[0]) stimulus = stimuli[responses[iretina]['stimulus_key']] # sample test data stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :] # Optionally, create a null stimulus for all the cells. resp_ret = responses[iretina] if FLAGS.is_null: # Make null stimulus stimulus_test = get_null_stimulus(resp_ret, subunit_fit_path, stimulus_test) resp_su = get_su_spks(subunit_fit_path, stimulus_test, responses[iretina]) save_dict = { 'resp_su': resp_su.astype(np.int8), 'cell_ids': responses[iretina]['cellID_list'].squeeze() } if FLAGS.is_null: save_dict.update({'stimulus_null': stimulus_test}) save_suff = '_null' else: save_suff = '' pickle.dump( save_dict, gfile.Open( os.path.join(subunit_fit_path, 'response_prediction%s.pkl' % save_suff), 'w'))