def stimulate(sr_graph, sess, file_name, dimx, dimy): piece = '2015-11-09-3.mat' # saving filename. save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_prosthesis.pkl') save_dict = {} # load dictionary. dict_dir = '/home/bhaishahster/Downloads/dictionary' dict_src = os.path.join(dict_dir, piece) # _, dictionary, cellID_list, EA, elec_loc = load_single_elec_stim_data(gfile.Open(dict_src, 'r')) _, dictionary, cellID_list, EA, elec_loc = load_single_elec_stim_data( dict_src) dictionary = dictionary.T # Load cell properties cell_data_dir = '/home/bhaishahster/Downloads/rgb-8-1-0.48-11111' cell_file = os.path.join(cell_data_dir, piece) data_cell = sio.loadmat(gfile.Open(cell_file, 'r')) data_util.process_dataset(data_cell, dimx=80, dimy=40, num_cell_types=2) # Load stimulus data = h5py.File(os.path.join(cell_data_dir, 'stimulus.mat')) stimulus = np.array(data.get('stimulus')) - 0.5 # Generate targets # random 100 samples t_len = 100 stim_history = 30 stim_batch = np.zeros( (t_len, stimulus.shape[1], stimulus.shape[2], stim_history)) for isample, itime in enumerate( np.random.randint(0, stimulus.shape[0], t_len)): stim_batch[isample, :, :, :] = np.transpose( stimulus[itime:itime - stim_history:-1, :, :], [1, 2, 0]) from IPython import embed embed() # Use regression to decide dictionary elements regress_dictionary(sr_graph, stim_batch, dictionary, 10, dimx, dimy, data_cell) # Select stimulation pattern dict_sel_np_logit, r_s, dictionary, d_log = get_optimal_stimulation( stim_batch, sr_graph, dictionary, data_cell, sess) save_dict.update({ 'dict_sel': dict_sel_np_logit, 'resp_sample': r_s, 'dictionary': dictionary, 'd_log': d_log }) pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
def compute_sta(cell_id, response, cid_idx, delay=4): # compute STA for one cell - load a chunk, compute STA, repeat sta_frame = np.zeros((640, 320, 3)) for ichunk in np.arange(1000) + 1: print('Loading chunk: %d' % ichunk) # get a stimulus chunk stim_path = FLAGS.data_location + 'Stimulus/' stim_file = sio.loadmat( gfile.Open(stim_path + 'stim_chunk_' + str(ichunk) + '.mat')) chunk_start = np.squeeze(stim_file['chunk_start']) chunk_end = np.squeeze(stim_file['chunk_end']) jump = stim_file['jump'] stim_chunk = stim_file['stimulus_chunk'] stim_chunk = np.transpose(stim_chunk, [3, 0, 1, 2]) print(np.shape(stim_chunk)) print(chunk_start, chunk_end) # find non-zero time points and for itime in np.arange(29, chunk_end - chunk_start + 1): if response[chunk_start + itime - 1, cid_idx] > 0: sta_frame = sta_frame + np.squeeze( stim_chunk[itime - delay, :, :, :]) # compute STA sta_frame = sta_frame / np.sum(response[:, cid_idx]) import pdb pdb.set_trace() # plot STA plt.imshow(sta_frame[:, :, 2]) plt.show() plt.draw()
def get_partitions(partition_file, taskid): """Reads partition file and return training/testing paritions for the taskid. The partition file has each task in differnet rows, with each row format - taskid: training dataset list: testing dataset list Args: partition_file : File containing all the partitions. taskid : Index of partition to load. Returns: training_datasets : dataset ids to train on. testing_datasets: dataset ids to test on. """ # Load partitions with gfile.Open(partition_file, 'r') as f: content = f.readlines() for iline in content: tokens = iline.split(':') tokens = [i.replace(' ', '') for i in tokens] if taskid == int(tokens[0]): training_datasets = tokens[1].split(',') testing_datasets = tokens[2].split(',') training_datasets = [int(i) for i in training_datasets] testing_datasets = [int(i) for i in testing_datasets] break return training_datasets, testing_datasets
def setup_dataset(): # initialize paths, get dataset properties, etc path = FLAGS.data_location # load cell response response_path = path + 'response.mat' response_file = sio.loadmat(gfile.Open(response_path)) resp_mat = response_file['binned_spikes'] resp_file_cids = np.squeeze(response_file['cell_ids']) # load off parasol cell IDs cids_path = path + 'cell_ids/cell_ids_OFF parasol.mat' cids_file = sio.loadmat(gfile.Open(cids_path)) cids_select = np.squeeze(cids_file['cids']) # find index of cells to choose from resp_mat resp_file_choose_idx = np.array([]) for icell in np.array(cids_select): idx = np.where(resp_file_cids == icell) resp_file_choose_idx = np.append(resp_file_choose_idx, idx[0]) # finally, get selected cells from resp_mat global response response = resp_mat[resp_file_choose_idx.astype('int'), :].T print(cids_select) print(resp_file_choose_idx.astype('int')) # load population time courses time_c_file_path = path + 'cell_ids/time_courses.mat' time_c_file = sio.loadmat(gfile.Open(time_c_file_path)) tc_mat = time_c_file['time_courses'] tm_cids = np.squeeze(time_c_file['cids']) # find average time course of cells of interest tc_file_choose_idx = np.array([]) for icell in np.array(cids_select): idx = np.where(tm_cids == icell) tc_file_choose_idx = np.append(tc_file_choose_idx, idx[0]) tc_select = tc_mat[tc_file_choose_idx.astype('int'), :, :] tc_mean = np.squeeze(np.mean(tc_select, axis=0)) n_cells = cids_select.shape[0] FLAGS.n_cells = n_cells # 'response', cell ids are 'cids_select' with 'n_cells' cells, 'tc_select' are timecourses, 'tc_mean' for mean time course return response, cids_select, n_cells, tc_select, tc_mean
def get_stimulus_batch(ichunk): stim_path = FLAGS.data_location + 'Stimulus/' stim_file = sio.loadmat(gfile.Open(stim_path+'stim_chunk_' + str(ichunk) + '.mat')) chunk_start = np.squeeze(stim_file['chunk_start']) chunk_end = np.squeeze(stim_file['chunk_end']) jump = stim_file['jump'] stim_chunk = stim_file['stimulus_chunk'] stim_chunk = np.transpose(stim_chunk, [3,0,1,2]) return stim_chunk, chunk_start, chunk_end
def get_test_data(): # the last chunk of data is test data test_data_chunks = [FLAGS.n_chunks] for ichunk in test_data_chunks: filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_part = data['maskedMovdd_part'].T resp_part = data['Y_part'].T test_len = stim_part.shape[0] stim_part = stim_part[:, chosen_mask] resp_part = resp_part[:, cells_choose] return stim_part, resp_part, test_len
def get_su_spks(subunit_fit_path, stimulus_test, resp_ret): '''Predict spikes for each cell and each number of subunits. ''' # Predict responses to 'stimulus_test' n_valid_cells = resp_ret['valid_cells'].sum() resp_su = np.zeros((10, stimulus_test.shape[0], n_valid_cells)) cell_ids = resp_ret['cellID_list'].squeeze() # time x cells x subunits for Nsub in range(1, 11): for icell in range(n_valid_cells): # Get subunits. fit_file = os.path.join( subunit_fit_path, 'Cell_%d_Nsub_%d.pkl' % (cell_ids[icell], Nsub)) try: su_fit = pickle.load(gfile.Open(fit_file, 'r')) except: print('Cell %d not loaded ' % cell_ids[icell]) continue print(Nsub, '%d' % cell_ids[icell]) # Get window to extract stimulus around RF. windx = su_fit['windx'] windy = su_fit['windy'] stim_cell = np.reshape( stimulus_test[:, windx[0]:windx[1], windy[0]:windy[1]], [stimulus_test.shape[0], -1]) # Filter in time. ttf = su_fit['ttf'] stim_filter = np.zeros_like(stim_cell) for idelay in range(30): length = stim_filter[idelay:, :].shape[0] stim_filter[idelay:, :] += stim_cell[:length, :] * ttf[idelay] stim_filter -= np.mean(stim_filter) stim_filter /= np.sqrt(np.var(stim_filter)) # Compute firing rate. K = su_fit['K'] b = su_fit['b'] firing_rate = np.exp(stim_filter.dot(K) + b[:, 0]).sum(-1) # Sample spikes. resp_su[Nsub - 1, :, icell] = np.random.poisson(firing_rate) return resp_su
def get_test_data(): # stimulus.astype('float32')[216000-1000: 216000-1, :] # response.astype('float32')[216000-1000: 216000-1, :] # length test_data_chunks = [FLAGS.n_chunks] for ichunk in test_data_chunks: filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_part = data['maskedMovdd_part'].T resp_part = data['Y_part'].T test_len = stim_part.shape[0] #logfile.write('\nReturning test data') stim_part = stim_part[:, chosen_mask] resp_part = resp_part[:, cells_choose] return stim_part, resp_part, test_len
def get_data_retina(piece='2005-04-06-4'): """Load data for 1 piece.""" # Load data file_d = h5py.File('/home/bhaishahster/' 'Downloads/%s.mat' % piece, 'r') stimulus = np.array(file_d.get('stimulus')) responses = np.array(file_d.get('response')) ei = np.array(file_d.get('ei')) cell_type = np.array(file_d.get('cell_type')) elec_loc_file = gfile.Open('/home/bhaishahster/' 'Downloads/Elec_loc512.mat', 'r') data = sio.loadmat(elec_loc_file) elec_loc = np.squeeze(np.array([data['elec_locx'], data['elec_locy']])).T ######################################################################### # Process stimulus = np.mean(stimulus, 1) - 0.5 stimulus = stimulus[:-1, :, :] stimx = stimulus.shape[1] stimy = stimulus.shape[2] ei_magnitude = np.sqrt(np.sum(ei**2, 0)) rfs = np.reshape(stimulus[:-4, :], [-1, stimulus.shape[1]*stimulus.shape[2]]).T.dot(responses[4:, :]) rfs = np.reshape(rfs, [stimulus.shape[1], stimulus.shape[2], -1]) rfs = clean_rfs(rfs, nbd=1) ei_embedding_matrix = embed_ei_grid(elec_loc, smoothness_sigma=15) eix, eiy, _ = ei_embedding_matrix.shape n_elec = 512 # compile data data = {'stimulus': stimulus, 'responses': responses, 'ei_magnitude': ei_magnitude, 'rfs': rfs, 'elec_loc': elec_loc, 'stimx': stimx, 'stimy': stimy, 'eix': eix, 'eiy': eiy, 'ei_embedding_matrix': ei_embedding_matrix, 'n_elec': n_elec, 'cell_type': cell_type} return data
def get_null_stimulus(resp_ret, subunit_fit_path, stimulus_test): """Project the sitmulus into null space.""" A = [] # Collect RF for all cells n_valid_cells = resp_ret['valid_cells'].sum() cell_ids = resp_ret['cellID_list'].squeeze() # time x cells x subunits Nsub = 1 for icell in range(n_valid_cells): # Get subunits. fit_file = os.path.join( subunit_fit_path, 'Cell_%d_Nsub_%d.pkl' % (cell_ids[icell], Nsub)) try: su_fit = pickle.load(gfile.Open(fit_file, 'r')) except: print('Cell %d not loaded ' % cell_ids[icell]) continue print(Nsub, '%d' % cell_ids[icell]) # Get window to extract stimulus around RF. windx = su_fit['windx'] windy = su_fit['windy'] sta = np.zeros((stimulus_test.shape[1], stimulus_test.shape[2])) sta[windx[0]:windx[1], windy[0]:windy[1]] = np.reshape( su_fit['K'].squeeze(), (windx[1] - windx[0], windy[1] - windy[0])) A += [sta] A = np.array(A) A_2d = np.reshape(A, (A.shape[0], -1)) stim_test_2d = np.reshape(stimulus_test, (stimulus_test.shape[0], -1)) stimulus_test_null = stim_test_2d.T - A_2d.T.dot( np.linalg.solve(A_2d.dot(A_2d.T), A_2d.dot(stim_test_2d.T))) stimulus_test_null = stimulus_test_null.T stimulus_test_null = np.reshape( stimulus_test_null, [-1, stimulus_test.shape[1], stimulus_test.shape[2]]) return stimulus_test_null
def main(unused_argv=()): # Load stimulus-response data datasets = gfile.ListDirectory(FLAGS.src_dir) responses = [] print(datasets) for icnt, idataset in enumerate([datasets]): #TODO(bhaishahster): HACK. fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key) stimulus, resp, dimx, dimy, num_cell_types = op responses += resp for idataset in range(len(responses)): k, b, ttf = fit_ln_population(responses[idataset]['responses'], stimulus) # Use FLAGS.taskID save_dict = {'k': k, 'b': b, 'ttf': ttf} save_analysis_filename = os.path.join(FLAGS.save_folder, responses[idataset]['piece'] + '_ln_model.pkl') pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w'))
def get_test_data(): # stimulus.astype('float32')[216000-1000: 216000-1, :] # response.astype('float32')[216000-1000: 216000-1, :] # length global chosen_mask global cells_choose stim_part = np.array([]).reshape(0, np.sum(chosen_mask)) resp_part = np.array([]).reshape(0, np.sum(cells_choose)) test_data_chunks = np.arange(FLAGS.n_chunks - 20, FLAGS.n_chunks + 1) for ichunk in test_data_chunks: filename = FLAGS.data_location + 'Off_par_data_' + str(ichunk) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) s = data['maskedMovdd_part'].T r = data['Y_part'].T print(np.shape(s)) print(np.shape(stim_part)) stim_part = np.append(stim_part, s[:, chosen_mask], axis=0) resp_part = np.append(resp_part, r[:, cells_choose], axis=0) test_len = stim_part.shape[0] return stim_part, resp_part, test_len
def main(argv): print('\nCode started') np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) ## Load data summary filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) cells = np.squeeze(data_summary['cells']) if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic': cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | ( cells == 3066) if FLAGS.model_id == 'poisson_full': cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool') n_cells = np.sum(cells_choose) tot_spks = np.squeeze(data_summary['tot_spks']) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T tot_spks_chosen_cells = tot_spks[cells_choose] chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0, dtype='bool') print(np.shape(chosen_mask)) print(np.sum(chosen_mask)) stim_dim = np.sum(chosen_mask) print('\ndataset summary loaded') # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells # decide the number of subunits to fit n_su = FLAGS.ratio_SU * n_cells # saving details #short_filename = 'data_model=ASM_pop_bg' # short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + # '_step_sz'+ str(FLAGS.step_sz)+'_bg') # saving details if FLAGS.model_id == 'poisson': short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'logistic': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'poisson_full': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' FLAGS.save_location = parent_folder + short_filename + '/' print(gfile.IsDirectory(FLAGS.save_location)) print(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename with tf.Session() as sess: # Learn population model! stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') # variables if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.1 * np.random.rand(n_cells, 1, n_su), dtype='float32')) if FLAGS.model_id == 'logistic': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.random.randn( n_cells ) #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0))) b = tf.Variable(b_init, dtype='float32') # get relevant files file_list = gfile.ListDirectory(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename print('\nLoading: ', save_filename) bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] #print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Could not load filename: ' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) # load tensorflow variables saver_var = tf.train.Saver(tf.all_variables()) restore_file = save_filename + '-' + str(int(iter_plot)) saver_var.restore(sess, restore_file) a_eval = a.eval() print(np.exp(np.squeeze(a_eval))) #print(np.shape(a_eval)) # get 2D region to plot mask2D = np.reshape(chosen_mask, [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) w_eval = w.eval() plt.figure() n_su = w_eval.shape[1] for isu in np.arange(n_su): xx = np.zeros((3200)) xx[chosen_mask] = w_eval[:, isu] fig = plt.subplot(np.ceil(np.sqrt(n_su)), np.ceil(np.sqrt(n_su)), isu + 1) plt.imshow(np.reshape(xx, [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) #if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge': # plt.title(str(a_eval[isu, :])) #else: # plt.title(str(np.squeeze(np.exp(a_eval[:, 0, isu]))), fontsize=12) plt.suptitle('Iteration:' + str(int(iter_plot)) + ' batchSz:' + str(FLAGS.batchsz) + ' step size:' + str(FLAGS.step_sz), fontsize=18) plt.show() plt.draw()
def subunit_discriminability(dataset_dict, stimuli, responses, sr_graph, num_examples=1000): ## compute distances between s-r pairs - pos and neg. ## negative_stim - if the negative is a stimulus or a response if num_examples % 100 != 0: raise ValueError('Only supports examples which are multiples of 100.') subunit_fit_loc = '/home/bhaishahster/stim-resp_collection_big_wn_retina_subunit_properties_train' subunits_datasets = gfile.ListDirectory(subunit_fit_loc) save_dict = {} datasets_log = {} for dat_key, datasets in dataset_dict.items(): distances_log = {} distances_retina_sr_log = [] distances_retina_rr_log = [] for iretina in range(len(datasets)): # Find the relevant subunit fit piece = responses[iretina]['piece'] matched_dataset = [ ifit for ifit in subunits_datasets if piece[:12] == ifit[:12] ] if matched_dataset == []: raise ValueError('Could not find subunit fit') subunit_fit_path = os.path.join(subunit_fit_loc, matched_dataset[0]) # Get predicted spikes. dat_resp_su = pickle.load( gfile.Open( os.path.join(subunit_fit_path, 'response_prediction.pkl'), 'r')) resp_su = dat_resp_su[ 'resp_su'] # it has non-rejected cells as well. # Remove some cells. select_cells = [ icell for icell in range(resp_su.shape[2]) if dat_resp_su['cell_ids'][icell] in responses[iretina] ['cellID_list'].squeeze() ] select_cells = np.array(select_cells) resp_su = resp_su[:, :, select_cells].astype(np.float32) # Get stimulus stimulus = stimuli[responses[iretina]['stimulus_key']] stimulus_test = stimulus[FLAGS.test_min:FLAGS.test_max, :, :] responses_recorded_test = responses[iretina]['responses'][ FLAGS.test_min:FLAGS.test_max, :] # Sample stimuli and responses. random_times = np.random.randint(40, stimulus_test.shape[0], num_examples) batch_size = 100 # Recorded response - predicted response distances. distances_retina = np.zeros((num_examples, 10)) + np.nan for Nsub in range(1, 11): for ibatch in range( np.floor(num_examples / batch_size).astype(np.int)): # construct stimulus tensor. stim_history = 30 resp_pred_batch = np.zeros((batch_size, resp_su.shape[2])) resp_rec_batch = np.zeros((batch_size, resp_su.shape[2])) for isample in range(batch_size): itime = random_times[batch_size * ibatch + isample] resp_pred_batch[isample, :] = resp_su[Nsub - 1, itime, :] resp_rec_batch[isample, :] = responses_recorded_test[ itime, :] # Embed predicted responses feed_dict = make_feed_dict(sr_graph, responses[iretina], responses=resp_pred_batch) embed_predicted = sr_graph.sess.run( sr_graph.anchor_model.responses_embed, feed_dict=feed_dict) # Embed recorded responses feed_dict = make_feed_dict(sr_graph, responses[iretina], responses=resp_rec_batch) embed_recorded = sr_graph.sess.run( sr_graph.anchor_model.responses_embed, feed_dict=feed_dict) dd = sr_graph.sess.run(sr_graph.distances_arbitrary, feed_dict={ sr_graph.arbitrary_embedding_1: embed_predicted, sr_graph.arbitrary_embedding_2: embed_recorded }) distances_retina[batch_size * ibatch:batch_size * (ibatch + 1), Nsub - 1] = dd print(iretina, Nsub, ibatch) distances_retina_rr_log += [distances_retina] # Stimulus - predicted response distances. distances_retina = np.zeros((num_examples, 10)) + np.nan for Nsub in range(1, 11): for ibatch in range( np.floor(num_examples / batch_size).astype(np.int)): # construct stimulus tensor. stim_history = 30 stim_batch = np.zeros( (batch_size, stimulus_test.shape[1], stimulus_test.shape[2], stim_history)) resp_batch = np.zeros((batch_size, resp_su.shape[2])) for isample in range(batch_size): itime = random_times[batch_size * ibatch + isample] stim_batch[isample, :, :, :] = np.transpose( stimulus_test[itime:itime - stim_history:-1, :, :], [1, 2, 0]) resp_batch[isample, :] = resp_su[Nsub - 1, itime, :] feed_dict = make_feed_dict(sr_graph, responses[iretina], resp_batch, stim_batch) # Get distances d_pos = sr_graph.sess.run(sr_graph.d_s_r_pos, feed_dict=feed_dict) distances_retina[batch_size * ibatch:batch_size * (ibatch + 1), Nsub - 1] = d_pos print(iretina, Nsub, ibatch) distances_retina_sr_log += [distances_retina] distances_log.update({'rr': distances_retina_rr_log}) distances_log.update({'sr': distances_retina_sr_log}) datasets_log.update({dat_key: distances_log}) save_dict.update({ 'datasets_log': datasets_log, 'dataset_dict': dataset_dict }) return save_dict
def main(argv): # global variables will be used for getting training data global cells_choose global chosen_mask global chunk_order # set random seeds: when same algorithm run with different FLAGS, # the sequence of random data is same. np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) # initial chunk order (will be re-shuffled everytime we go over a chunk) chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1)) # Load data summary data_filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(data_filename, 'r') data_summary = sio.loadmat(summary_file) cells = np.squeeze(data_summary['cells']) # which cells to train subunits for if FLAGS.all_cells == 'True': cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool') else: cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | ( cells == 3066) n_cells = np.sum(cells_choose) # number of cells # load spikes and relevant stimulus pixels for chosen cells tot_spks = np.squeeze(data_summary['tot_spks']) tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32') total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T # chosen_mask = which pixels to learn subunits over if FLAGS.masked_stimulus == 'True': chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0, dtype='bool') else: chosen_mask = np.array(np.ones(3200).astype('bool')) stim_dim = np.sum(chosen_mask) # stimulus dimensions print('\ndataset summary loaded') # print parameters print('Save folder name: ' + str(FLAGS.folder_name) + '\nmodel:' + str(FLAGS.model_id) + '\nLoss:' + str(FLAGS.loss) + '\nmasked stimulus:' + str(FLAGS.masked_stimulus) + '\nall_cells?' + str(FLAGS.all_cells) + '\nbatch size' + str(FLAGS.batchsz) + '\nstep size' + str(FLAGS.step_sz) + '\ntraining length: ' + str(FLAGS.train_len) + '\nn_cells: ' + str(n_cells)) # decide the number of subunits to fit n_su = FLAGS.ratio_SU * n_cells # filename for saving file short_filename = ('_masked_stim=' + str(FLAGS.masked_stimulus) + '_all_cells=' + str(FLAGS.all_cells) + '_loss=' + str(FLAGS.loss) + '_batch_sz=' + str(FLAGS.batchsz) + '_step_sz' + str(FLAGS.step_sz) + '_tlen=' + str(FLAGS.train_len) + '_bg') with tf.Session() as sess: # set up stimulus and response palceholders stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') if FLAGS.loss == 'poisson': b_init = np.array( 0.000001 * np.ones(n_cells) ) # a very small positive bias needed to avoid log(0) in poisson loss else: b_init = np.log( (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells) ) # log-odds, a good initialization for some losses (like logistic) # different firing rate models if FLAGS.model_id == 'exp_additive': # This model was implemented for earlier work. # firing rate for cell c: lam_c = sum_s exp(w_s.x + a_sc) # filename short_filename = ('model=' + str(FLAGS.model_id) + short_filename) # variables w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32'), name='w') a = tf.Variable(np.array(0.01 * np.random.rand(n_cells, 1, n_su), dtype='float32'), name='a') # firing rate model lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a), 2)) regularization = 0 vars_fit = [w, a] def proj( ): # called after every training step - to project to parameter constraints pass if FLAGS.model_id == 'relu': # firing rate for cell c: lam_c = a_c'.relu(w.x) + b # we know a>0 and for poisson loss, b>0 # for poisson loss: small b added to prevent lam_c going to 0 # filename short_filename = ('model=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_nsu=' + str(n_su) + short_filename) # variables w = tf.Variable(np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32'), name='w') a = tf.Variable(np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32'), name='a') b = tf.Variable(np.array(b_init, dtype='float32'), name='b') # firing rate model lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b vars_fit = [w, a] # which variables are learnt if not FLAGS.loss == 'poisson': # don't learn b for poisson loss vars_fit = vars_fit + [b] # regularization of parameters regularization = (FLAGS.lam_w * tf.reduce_sum(tf.abs(w)) + FLAGS.lam_a * tf.reduce_sum(tf.abs(a))) # projection to satisfy constraints a_pos = tf.assign(a, (a + tf.abs(a)) / 2) b_pos = tf.assign(b, (b + tf.abs(b)) / 2) def proj(): sess.run(a_pos) if FLAGS.loss == 'poisson': sess.run(b_pos) if FLAGS.model_id == 'relu_window': # firing rate for cell c: lam_c = a_c'.relu(w.x) + b, # where w_i are over a small window which are convolutionally related with each other. # we know a>0 and for poisson loss, b>0 # for poisson loss: small b added to prevent lam_c going to 0 # filename short_filename = ('model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + short_filename) mask_tf, dimx, dimy, n_pix = get_windows( ) # get convolutional windows # variables w = tf.Variable(np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32'), name='w') a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells), dtype='float32'), name='a') b = tf.Variable(np.array(b_init, dtype='float32'), name='b') vars_fit = [w, a] # which variables are learnt if not FLAGS.loss == 'poisson': # don't learn b for poisson loss vars_fit = vars_fit + [b] # stimulus filtered with convolutional windows stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3) stim_masked = tf.nn.conv2d( stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID") stim_wts = tf.nn.relu(tf.reduce_sum(tf.mul(stim_masked, w), 3)) # get firing rate lam = tf.matmul(tf.reshape(stim_wts, [-1, dimx * dimy]), a) + b # regularization regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w)) # projection to satisfy hard variable constraints a_pos = tf.assign(a, (a + tf.abs(a)) / 2) b_pos = tf.assign(b, (b + tf.abs(b)) / 2) def proj(): sess.run(a_pos) if FLAGS.loss == 'poisson': sess.run(b_pos) if FLAGS.model_id == 'relu_window_mother': # firing rate for cell c: lam_c = a_c'.relu(w.x) + b, # where w_i are over a small window which are convolutionally related with each other. # w_i = w_mother + w_del_i, # where w_mother is common accross all 'windows' and w_del is different for different windows. # we know a>0 and for poisson loss, b>0 # for poisson loss: small b added to prevent lam_c going to 0 # filename short_filename = ('model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + short_filename) mask_tf, dimx, dimy, n_pix = get_windows() # variables w_del = tf.Variable(np.array(0.05 * np.random.randn(dimx, dimy, n_pix), dtype='float32'), name='w_del') w_mother = tf.Variable(np.array(np.ones( (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)), dtype='float32'), name='w_mother') a = tf.Variable(np.array(np.random.rand(dimx * dimy, n_cells), dtype='float32'), name='a') b = tf.Variable(np.array(b_init, dtype='float32'), name='b') vars_fit = [w_mother, w_del, a] # which variables to learn if not FLAGS.loss == 'poisson': vars_fit = vars_fit + [b] # stimulus filtered with convolutional windows stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3) stim_convolved = tf.reduce_sum( tf.nn.conv2d(stim4D, w_mother, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID"), 3) stim_masked = tf.nn.conv2d( stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID") stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3) # activation of differnet subunits su_act = tf.nn.relu(stim_del + stim_convolved) # get firing rate lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a) + b # regularization regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del)) # projection to satisfy hard variable constraints a_pos = tf.assign(a, (a + tf.abs(a)) / 2) b_pos = tf.assign(b, (b + tf.abs(b)) / 2) def proj(): sess.run(a_pos) if FLAGS.loss == 'poisson': sess.run(b_pos) if FLAGS.model_id == 'relu_window_mother_sfm': # firing rate for cell c: lam_c = a_sfm_c'.relu(w.x) + b, # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well. # where w_i are over a small window which are convolutionally related with each other. # w_i = w_mother + w_del_i, # where w_mother is common accross all 'windows' and w_del is different for different windows. # we know a>0 and for poisson loss, b>0 # for poisson loss: small b added to prevent lam_c going to 0 # filename short_filename = ('model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + short_filename) mask_tf, dimx, dimy, n_pix = get_windows() # variables w_del = tf.Variable(np.array(0.05 * np.random.randn(dimx, dimy, n_pix), dtype='float32'), name='w_del') w_mother = tf.Variable(np.array(np.ones( (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)), dtype='float32'), name='w_mother') a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells), dtype='float32'), name='a') a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a))) b = tf.Variable(np.array(b_init, dtype='float32'), name='b') vars_fit = [w_mother, w_del, a] # which variables to fit if not FLAGS.loss == 'poisson': vars_fit = vars_fit + [b] # stimulus filtered with convolutional windows stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3) stim_convolved = tf.reduce_sum( tf.nn.conv2d(stim4D, w_mother, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID"), 3) stim_masked = tf.nn.conv2d( stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID") stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3) # activation of differnet subunits su_act = tf.nn.relu(stim_del + stim_convolved) # get firing rate lam = tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm) + b # regularization regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del)) # projection to satisfy hard variable constraints b_pos = tf.assign(b, (b + tf.abs(b)) / 2) def proj(): if FLAGS.loss == 'poisson': sess.run(b_pos) if FLAGS.model_id == 'relu_window_mother_sfm_exp': # firing rate for cell c: lam_c = exp(a_sfm_c'.relu(w.x)) + b, # a_sfm_c = softmax(a) : so a cell cannot be connected to all subunits equally well. # exponential output NL would cancel the log() in poisson and might get better estimation properties. # where w_i are over a small window which are convolutionally related with each other. # w_i = w_mother + w_del_i, # where w_mother is common accross all 'windows' and w_del is different for different windows. # we know a>0 and for poisson loss, b>0 # for poisson loss: small b added to prevent lam_c going to 0 # filename short_filename = ('model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + short_filename) # get windows mask_tf, dimx, dimy, n_pix = get_windows() # declare variables w_del = tf.Variable(np.array(0.05 * np.random.randn(dimx, dimy, n_pix), dtype='float32'), name='w_del') w_mother = tf.Variable(np.array(np.ones( (2 * FLAGS.window + 1, 2 * FLAGS.window + 1, 1, 1)), dtype='float32'), name='w_mother') a = tf.Variable(np.array(np.random.randn(dimx * dimy, n_cells), dtype='float32'), name='a') a_sfm = tf.transpose(tf.nn.softmax(tf.transpose(a))) b = tf.Variable(np.array(b_init, dtype='float32'), name='b') vars_fit = [w_mother, w_del, a] if not FLAGS.loss == 'poisson': vars_fit = vars_fit + [b] # filter stimulus stim4D = tf.expand_dims(tf.reshape(stim, (-1, 40, 80)), 3) stim_convolved = tf.reduce_sum( tf.nn.conv2d(stim4D, w_mother, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID"), 3) stim_masked = tf.nn.conv2d( stim4D, mask_tf, strides=[1, FLAGS.stride, FLAGS.stride, 1], padding="VALID") stim_del = tf.reduce_sum(tf.mul(stim_masked, w_del), 3) # get subunit activation su_act = tf.nn.relu(stim_del + stim_convolved) # get cell firing rates lam = tf.exp( tf.matmul(tf.reshape(su_act, [-1, dimx * dimy]), a_sfm)) + b # regularization regularization = FLAGS.lam_w * tf.reduce_sum(tf.nn.l2_loss(w_del)) # projection to satisfy hard variable constraints b_pos = tf.assign(b, (b + tf.abs(b)) / 2) def proj(): if FLAGS.loss == 'poisson': sess.run(b_pos) # different loss functions if FLAGS.loss == 'poisson': loss_inter = (tf.reduce_sum(lam) / 120. - tf.reduce_sum(resp * tf.log(lam))) / data_len if FLAGS.loss == 'logistic': loss_inter = tf.reduce_sum(tf.nn.softplus( -2 * (resp - 0.5) * lam)) / data_len if FLAGS.loss == 'hinge': loss_inter = tf.reduce_sum( tf.nn.relu(1 - 2 * (resp - 0.5) * lam)) / data_len loss = loss_inter + regularization # add regularization to get final loss function # training consists of calling training() # which performs a train step and # project parameters to model specific constraints using proj() train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize( loss, var_list=vars_fit) def training(inp_dict): sess.run(train_step, feed_dict=inp_dict) # one step of gradient descent proj() # model specific projection operations # evaluate loss on given data. def get_loss(inp_dict): ls = sess.run(loss, feed_dict=inp_dict) return ls # saving details # make a folder with name derived from parameters of the algorithm # - it saves checkpoint files and summaries used in tensorboard parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' # make folder if it does not exist if not gfile.IsDirectory(parent_folder): gfile.MkDir(parent_folder) FLAGS.save_location = parent_folder + short_filename + '/' if not gfile.IsDirectory(FLAGS.save_location): gfile.MkDir(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename # create summary writers # create histogram summary for all parameters which are learnt for ivar in vars_fit: tf.histogram_summary(ivar.name, ivar) # loss summary l_summary = tf.scalar_summary('loss', loss) # loss without regularization summary l_inter_summary = tf.scalar_summary('loss_inter', loss_inter) # Merge all the summary writer ops into one op (this way, # calling one op stores all summaries) merged = tf.merge_all_summaries() # training and testing has separate summary writers train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train', sess.graph) test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test') ## Fitting procedure print('Start fitting') sess.run(tf.initialize_all_variables()) saver_var = tf.train.Saver(tf.all_variables(), keep_checkpoint_every_n_hours=0.05) load_prev = False start_iter = 0 try: # restore previous fits if they are available # - useful when programs are preempted frequently on . latest_filename = short_filename + '_latest_fn' restore_file = tf.train.latest_checkpoint(FLAGS.save_location, latest_filename) # restore previous iteration count and start from there. start_iter = int(restore_file.split('/')[-1].split('-')[-1]) saver_var.restore(sess, restore_file) # restore variables load_prev = True except: print('No previous dataset') if load_prev: print('Previous results loaded') else: print('Variables initialized') # Finally, do fitting icnt = 0 # get test data and make test dictionary stim_test, resp_test, test_length = get_test_data() fd_test = {stim: stim_test, resp: resp_test, data_len: test_length} for istep in np.arange(start_iter, 400000): print(istep) # get training data and make test dictionary stim_train, resp_train, train_len = get_next_training_batch(istep) fd_train = { stim: stim_train, resp: resp_train, data_len: train_len } # take training step training(fd_train) if istep % 10 == 0: # compute training and testing losses ls_train = get_loss(fd_train) ls_test = get_loss(fd_test) latest_filename = short_filename + '_latest_fn' saver_var.save(sess, save_filename, global_step=istep, latest_filename=latest_filename) # add training summary summary = sess.run(merged, feed_dict=fd_train) train_writer.add_summary(summary, istep) # add testing summary summary = sess.run(merged, feed_dict=fd_test) test_writer.add_summary(summary, istep) print(istep, ls_train, ls_test) icnt += FLAGS.batchsz if icnt > 216000 - 1000: icnt = 0 tms = np.random.permutation(np.arange(216000 - 1000))
def get_next_training_batch(iteration): # Returns a new batch of training data : stimulus and response arrays # we will use global stimulus and response variables to permute training data # chunks and store where we are in list of training data # each chunk might have multiple training batches. # So go through all batches in a 'chunk' before moving on to the next chunk global stim_train_part global resp_train_part global chunk_order togo = True while togo: if (iteration % FLAGS.n_b_in_c == 0): # iteration is multiple of number of batches in a chunk means # finished going through a chunk, load new chunk of data ichunk = (iteration / FLAGS.n_b_in_c) % ( FLAGS.train_len - 1) # -1 as last one chunk used for testing if (ichunk == 0): # if starting over the chunks again, shuffle the chunks chunk_order = np.random.permutation(np.arange( FLAGS.train_len)) # remove first chunk - weired? if chunk_order[ ichunk] + 1 != 1: # 1st chunk was weired for the dataset used filename = FLAGS.data_location + 'Off_par_data_' + str( chunk_order[ichunk] + 1) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_train_part = data['maskedMovdd_part'] # stimulus resp_train_part = data['Y_part'] # response ichunk = chunk_order[ichunk] + 1 while stim_train_part.shape[1] < FLAGS.batchsz: # if the current loaded data is smaller than batch size, load more chunks if (ichunk > FLAGS.n_chunks): ichunk = 2 filename = FLAGS.data_location + 'Off_par_data_' + str( ichunk) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_train_part = np.append(stim_train_part, data['maskedMovdd_part'], axis=1) resp_train_part = np.append(resp_train_part, data['Y_part'], axis=1) ichunk = ichunk + 1 ibatch = iteration % FLAGS.n_b_in_c # which section of current chunk to use try: stim_train = np.array(stim_train_part[:, ibatch:ibatch + FLAGS.batchsz], dtype='float32').T resp_train = np.array(resp_train_part[:, ibatch:ibatch + FLAGS.batchsz], dtype='float32').T togo = False except: iteration = np.random.randint(1, 100000) print('Load exception iteration: ' + str(iteration) + 'chunk: ' + str(chunk_order[ichunk]) + 'batch: ' + str(ibatch)) togo = True stim_train = stim_train[:, chosen_mask] resp_train = resp_train[:, cells_choose] return stim_train, resp_train, FLAGS.batchsz
def linear_rank1_models(sr_model, sess, dimx, dimy, n_cells, center_locations, cell_masks, firing_rates, cell_type, time_window=30): """Learn a linear rank 1 (space time separable) stimulus response metric. The metric d(s, r) = -r' F s g. with r=response, s = space x time, g is time filter and F is (# cells x space) is spatial filter. g is initialized with precomputed average time filter across multiple retinas. F is initialized with signed gaussians at the center locations such that it forms a mosaic. g, F are learned by minimzing distance between positive examples and all the negatives in a batch. Since we think F must be spatially localized, we project to minimize locally normalized L1 regularization at each step. Args: sr_model : The variant of linear model ('lin_rank1' or 'lin_rank1_blind') sess : Tensorflow session. dimx: X dimension of the stimulus. dimy: Y dimension of the stimulus. n_cells : Number of cells center_locations : Location of centers for cell in the 2D grid (Dimx X Dimy X n_cells). cell_masks : Masks for spatial filter for each cell (Dimx x Dimy x n_cells). firing_rates : Mean firing rate of different cells ( # cells) cell_type : Cell type of each cell. (# cells) time_window : Length of stimulus (in time). Returns: sr_graph : Container of the embedding parameters and losses. Raises: ValueError : If the model type is not supported. """ # Embed stimulus. stim_tf = tf.placeholder(tf.float32, shape=[None, dimx, dimy, time_window ]) # batch x X x Y x time_window pos_resp = tf.placeholder(dtype=tf.float32, shape=[None, n_cells, 1], name='pos') neg_resp = tf.placeholder(dtype=tf.float32, shape=[None, n_cells, 1], name='neg') # n_cells x number of selected cells select_cells_mat = tf.placeholder(dtype=tf.float32, shape=[n_cells, None], name='selected_cells') # Declare spatial and temporal filter variables. k_init = approximate_rfs_from_centers(center_locations, cell_type, firing_rates, dimx, dimy, n_cells) ttf_data = pickle.load(gfile.Open(FLAGS.ttf_file, 'r')) ttf_init = ttf_data['ttf'].astype(np.float32) k_tf = tf.Variable(k_init.astype(np.float32)) ttf_tf = tf.Variable(ttf_init.astype(np.float32)) # Do filtering in space and time tfd = tf.expand_dims ttf_4d = tfd(tfd(tfd(ttf_tf, 0), 0), 3) stim_time_filtered = tf.nn.conv2d(stim_tf, ttf_4d, strides=[1, 1, 1, 1], padding='VALID') # Batch x X x Y x 1 stim_time_filtered_3d = stim_time_filtered[:, :, :, 0] stim_time_filtered_2d = tf.reshape(stim_time_filtered_3d, [-1, dimx * dimy]) k_tf_flat = tf.reshape(k_tf, [dimx * dimy, n_cells]) lam_raw = tf.matmul(stim_time_filtered_2d, k_tf_flat) # Batch x n_cells # select_cells - all outputs of size Batch x n_selected_cells pos_resp_sel = tf.matmul(pos_resp[:, :, 0], select_cells_mat) neg_resp_sel = tf.matmul(neg_resp[:, :, 0], select_cells_mat) lam_raw_sel = tf.matmul(lam_raw, select_cells_mat) # Loss beta = FLAGS.beta d_pos = -tf.reduce_mean(pos_resp_sel * lam_raw_sel, 1) # Batch d_pairwise_neg = -tf.reduce_mean( tfd(neg_resp_sel, 0) * tfd(lam_raw_sel, 1), 2) # Batch difference = (tf.expand_dims(d_pos / beta, 1) - d_pairwise_neg / beta ) # postives x negatives # Option 2 # log(1 + \sum_j(exp(d+ - dj-))) difference_padded = tf.pad(difference, [[0, 0], [0, 1]]) loss = tf.reduce_sum(beta * tf.reduce_logsumexp(difference_padded, 1), 0) accuracy_tf = tf.reduce_mean( tf.sign(-tf.expand_dims(d_pos, 1) + d_pairwise_neg)) ## Train model if sr_model == 'lin_rank1': train_op_part = tf.train.AdagradOptimizer( FLAGS.learning_rate).minimize(loss) # Project K with tf.control_dependencies([train_op_part]): # Locally reweighted L1 for sptial locality. # proj_k = project_spatially_locaized_l1(k_tf, dimx, dimy, n_cells, # lnl1_reg=0.00001, eps_neigh=0.01) # Project K to the mask cell_masks_tf = tf.constant(cell_masks.astype(np.float32)) proj_k = tf.assign(k_tf, k_tf * cell_masks_tf) train_op = tf.group(train_op_part, proj_k) elif sr_model == 'lin_rank1_blind': # Only time filter is trained in blind model. # train_op = tf.train.AdagradOptimizer(FLAGS.learning_rate # ).minimize(loss, var_list=[ttf_tf]) train_op = [] # Blind model not trained. else: raise ValueError('Only lin_rank1 and lin_rank1_blind supported') # Store everything in a graph. sr_graph = collections.namedtuple( 'SR_Graph', 'sess train_op ' 'select_cells_mat ' 'd_s_r_pos d_pairwise_s_rneg ' 'loss accuracy_tf stim_tf ' 'pos_resp neg_resp ttf_tf k_tf ') sr_graph = sr_graph(sess=sess, train_op=train_op, select_cells_mat=select_cells_mat, d_s_r_pos=d_pos, d_pairwise_s_rneg=d_pairwise_neg, loss=loss, accuracy_tf=accuracy_tf, stim_tf=stim_tf, pos_resp=pos_resp, neg_resp=neg_resp, ttf_tf=ttf_tf, k_tf=k_tf) return sr_graph
def test_encoding(training_datasets, testing_datasets, responses, stimuli, sr_graph, sess, file_name, sample_fcn): print('Testing for encoding model') tf.logging.info('Testing for encoding model') #from IPython import embed; embed() # saving filename. save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis_sample_resps') retina_ids = [resp['piece'] for resp in responses] # 4) Find `retina_embed` for a test retina # Find mean of embedding of training retinas latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2]) batch_sz = 500 ret_params_list = [] for idataset in training_datasets: if idataset >= len(responses): continue print(idataset) rng = numpy.random.RandomState(23) feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=batch_sz, batch_neg_sz=batch_sz, batch_type='test', if_continuous=False, rng=rng) op = sess.run(sr_graph.retina_params, feed_dict=feed_dict) ret_params_list += [op] ret_params_list = np.array(ret_params_list) ret_params_init = np.mean(ret_params_list, 0) # Use for initialization # Now optimize ret_params for each retina. loss_arbit_ret_params = sr_graph.loss_arbit_ret_params ret_params_grad = tf.gradients(loss_arbit_ret_params, sr_graph.retina_params_arbitrary) ret_params_dict = {} for idataset in testing_datasets: rng = numpy.random.RandomState(23) ret_params_new = np.copy(ret_params_init) lr = 0.001 ret_log = [] for iiter in range(100): feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=500, batch_neg_sz=0, batch_type='train', if_continuous=True, rng=rng) feed_dict.update( {sr_graph.retina_params_arbitrary: ret_params_new}) delta_ret_param, l_np = sess.run( [ret_params_grad, loss_arbit_ret_params], feed_dict=feed_dict) print('Retina: %d, step: %d, loss: %.3f, Ret_params: %s' % (idataset, iiter, l_np, ret_params_new)) ret_log += [np.copy(ret_params_new)] ret_params_new -= lr * delta_ret_param[0] dataset_log = { 'final_embedding': np.copy(ret_params_new), 'path': np.copy(ret_log) } ret_params_dict.update({idataset: dataset_log}) pickle.dump(ret_params_dict, gfile.Open((save_analysis_filename + '_test6.pkl'), 'w')) # Latent embedding of different retinas. # 3) Interpolate between latent representation and see how responses change. interpolation_retinas_log = [[63, 5], [73, 65], [41, 83], [38, 57], [72, 76], [48, 58], [2, 5]] # [[72, 76], [48, 58], [2, 5]] batch_sz = 500 save_dict_log = [] for interpolation_retinas in interpolation_retinas_log: retina_params_end_pts = [] cell_info = [] valid_cell_log = [] # 3a) Find latent representation of each retina. for idataset in interpolation_retinas: rng = numpy.random.RandomState(23) feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=batch_sz, batch_neg_sz=batch_sz, batch_type='test', if_continuous=False, rng=rng) op = sess.run([ sr_graph.retina_params, sr_graph.stim_tf, sr_graph.anchor_model.embed_locations_original, sr_graph.anchor_model.map_cell_grid_tf ], feed_dict=feed_dict) ret_params_np, stim_np, cell_locs, map_cell_grid = op retina_params_end_pts += [ret_params_np] rcct_log = [] for icell in range(map_cell_grid.shape[2]): r, c = np.where(map_cell_grid[:, :, icell] > 0) ct = np.where(np.squeeze(cell_locs[r, c, :]) > 0)[0] rcct_log += [[r[0], c[0], ct[0]]] cell_info += [np.array(rcct_log)] valid_cell_log += [responses[idataset]['valid_cells']] # 3b) Now, interpolate and check the responses. fr_interpolate_log = [] alpha_log = np.arange(0, 1.1, 0.1) for alpha in alpha_log: retina_params_interpolate = ( alpha * retina_params_end_pts[0] + (1 - alpha) * retina_params_end_pts[1]) feed_dict = { sr_graph.stim_tf: stim_np, sr_graph.retina_params_arbitrary: retina_params_interpolate } fr_interpolate = sess.run( sr_graph.response_pred_from_arbit_ret_params, feed_dict=feed_dict) fr_interpolate_log += [fr_interpolate] fr_interpolate_log = np.array(fr_interpolate_log) save_dict = { 'interpolation_retinas': interpolation_retinas, 'alpha_log': alpha_log, 'fr_interpolate_log': fr_interpolate_log, 'cell_info': cell_info, 'valid_cell_log': valid_cell_log, 'retina_params_end_pts': retina_params_end_pts } save_dict_log += [save_dict] pickle.dump(save_dict_log, gfile.Open((save_analysis_filename + '_test5.pkl'), 'w')) # 1) Predict responses of different retinas - training AND testing retinas. #from IPython import embed; embed() tag_list = ['training_datasets', 'testing_datasets'] results_tr_tst = {} for itr_tst, tr_test_datasets in enumerate( [training_datasets, testing_datasets]): results_log = {} for idataset in tr_test_datasets: if idataset >= len(responses): continue print(idataset) rng = numpy.random.RandomState(23) feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=200, batch_neg_sz=200, batch_type='test', if_continuous=True, rng=rng) op = sess.run([ sr_graph.fr_predicted, sr_graph.anchor_model.embed_locations_original, sr_graph.anchor_model.map_cell_grid_tf, sr_graph.anchor_model.responses_tf, sr_graph.retina_params ], feed_dict=feed_dict) fr_pred_np, cell_locs, map_cell_grid, responses_np, ret_params_np = op # r, c, z = np.where(cell_locs > 0) fr_pred_cell = np.squeeze(np.zeros_like(responses_np)) for icell in range(responses_np.shape[1]): r, c = np.where(map_cell_grid[:, :, icell] > 0) ct = np.where(np.squeeze(cell_locs[r, c, :]) > 0)[0] r = r[0] c = c[0] ct = ct[0] fr_pred_cell[:, icell] = fr_pred_np[:, r, c, ct] ''' tms = np.arange(responses_np.shape[0]) plt.stem(tms, responses_np[:, icell, 0]) plt.plot(3 * fr_pred_np[:, r, c, 0]) plt.plot(3 * fr_pred_np[:, r, c, 1]) ''' # Find the stimulus stim_np = feed_dict[sr_graph.stim_tf] t_len, dx, dy, t_depth = stim_np.shape stim_np_compressed = np.zeros( (t_len + t_depth - 1, dx, dy)) # 500, 80, 40, 30 stim_np_compressed[:t_depth] = np.transpose( stim_np[0, :, :, :], [2, 0, 1]) for itm in np.arange(1, t_len): stim_np_compressed[itm + t_depth - 1, :, :] = stim_np[itm, :, :, 0] save_dict = { 'fr_pred_cell': fr_pred_cell, 'responses_recorded': np.squeeze(responses_np), 'valid_cells': responses[idataset]['valid_cells'], 'ctype_1hot': responses[idataset]['ctype_1hot'], 'cell_locs': cell_locs, 'map_cell_grid': map_cell_grid, 'ret_params_np': ret_params_np, 'fr_pred_np': fr_pred_np, 'stimulus_key': responses[idataset]['stimulus_key'] } results_tr_tst.update( {responses[idataset]['stimulus_key']: stim_np_compressed}) results_log.update({idataset: save_dict}) results_tr_tst.update({tag_list[itr_tst]: results_log}) results_tr_tst.update({'retina_ids': retina_ids}) pickle.dump(results_tr_tst, gfile.Open((save_analysis_filename + '_test3.pkl'), 'w')) # 2) Is the latent representation consistent for each retina, across responses? latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2]) batch_sz_list = [500] n_repeats = 1 tag_list = ['training_datasets', 'testing_datasets'] results_tr_tst = {} for itr_tst, tr_test_datasets in enumerate( [training_datasets, testing_datasets]): results_log = {} for idataset in tr_test_datasets: if idataset >= len(responses): continue print(idataset) rng = numpy.random.RandomState(23) ret_params_np = np.zeros( (len(batch_sz_list), n_repeats, latent_dimensionality)) for ibatch_sz, batch_sz in enumerate(batch_sz_list): for iresp in range(n_repeats): print(idataset, ibatch_sz, iresp) feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=batch_sz, batch_neg_sz=batch_sz, batch_type='test', if_continuous=False, rng=rng) op = sess.run(sr_graph.retina_params, feed_dict=feed_dict) print(op) ret_params_np[ibatch_sz, iresp, :] = op save_dict = { 'ret_params_np': ret_params_np, 'batch_sz_list': batch_sz_list, 'valid_cells': responses[idataset]['valid_cells'] } results_log.update({idataset: save_dict}) results_tr_tst.update({tag_list[itr_tst]: results_log}) results_tr_tst.update({'retina_ids': retina_ids}) pickle.dump(results_tr_tst, gfile.Open((save_analysis_filename + '_test4.pkl'), 'w')) # 3) Embedding of EIs if hasattr(sr_graph, 'retina_params_from_ei'): latent_dimensionality = np.int(FLAGS.resp_layers.split(',')[-2]) batch_sz = 500 tag_list = ['training_datasets', 'testing_datasets'] results_tr_tst = {} for itr_tst, tr_test_datasets in enumerate( [training_datasets, testing_datasets]): results_log = {} for idataset in tr_test_datasets: if idataset >= len(responses): continue print(idataset) rng = numpy.random.RandomState(23) feed_dict = sample_fcn.batch(stimuli, responses, idataset, sr_graph, batch_pos_sz=batch_sz, batch_neg_sz=batch_sz, batch_type='test', if_continuous=False, rng=rng) op = sess.run( [sr_graph.retina_params_from_ei, sr_graph.retina_params], feed_dict=feed_dict) ret_params_from_ei_np, ret_params_np = op print(op) save_dict = { 'ret_params_np': ret_params_np, 'ret_params_from_ei_np': ret_params_from_ei_np, 'valid_cells': responses[idataset]['valid_cells'] } results_log.update({idataset: save_dict}) results_tr_tst.update({tag_list[itr_tst]: results_log}) results_tr_tst.update({'retina_ids': retina_ids}) pickle.dump( results_tr_tst, gfile.Open((save_analysis_filename + '_test4_ei.pkl'), 'w'))
def get_next_training_batch(iteration): # stimulus.astype('float32')[tms[icnt: icnt+FLAGS.batchsz], :], # response.astype('float32')[tms[icnt: icnt+FLAGS.batchsz], :] # FLAGS.batchsz # we will use global stimulus and response variables global stim_train_part global resp_train_part global chunk_order togo = True while togo: if (iteration % FLAGS.n_b_in_c == 0): # load new chunk of data ichunk = (iteration / FLAGS.n_b_in_c) % ( FLAGS.train_len - 1) # last one chunks used for testing if (ichunk == 0 ): # shuffle training chunks at start of training data chunk_order = np.random.permutation(np.arange( FLAGS.train_len)) # remove first chunk - weired? # if logfile != None : # logfile.write('\nTraining chunks shuffled') if chunk_order[ichunk] + 1 != 1: filename = FLAGS.data_location + 'Off_par_data_' + str( chunk_order[ichunk] + 1) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_train_part = data['maskedMovdd_part'] resp_train_part = data['Y_part'] ichunk = chunk_order[ichunk] + 1 while stim_train_part.shape[1] < FLAGS.batchsz: #print('Need to add extra chunk') if (ichunk > FLAGS.n_chunks): ichunk = 2 filename = FLAGS.data_location + 'Off_par_data_' + str( ichunk) + '.mat' file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) stim_train_part = np.append(stim_train_part, data['maskedMovdd_part'], axis=1) resp_train_part = np.append(resp_train_part, data['Y_part'], axis=1) #print(np.shape(stim_train_part), np.shape(resp_train_part)) ichunk = ichunk + 1 # if logfile != None: # logfile.write('\nNew training data chunk loaded at: '+ str(iteration) + ' chunk #: ' + str(chunk_order[ichunk])) ibatch = iteration % FLAGS.n_b_in_c try: stim_train = np.array(stim_train_part[:, ibatch:ibatch + FLAGS.batchsz], dtype='float32').T resp_train = np.array(resp_train_part[:, ibatch:ibatch + FLAGS.batchsz], dtype='float32').T togo = False except: iteration = np.random.randint(1, 100000) print('Load exception iteration: ' + str(iteration) + 'chunk: ' + str(chunk_order[ichunk]) + 'batch: ' + str(ibatch)) togo = True stim_train = stim_train[:, chosen_mask] resp_train = resp_train[:, cells_choose] return stim_train, resp_train, FLAGS.batchsz
def main(argv): print('\nCode started') global cells_choose global chosen_mask np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) global chunk_order chunk_order = np.random.permutation(np.arange(FLAGS.n_chunks - 1)) ## Load data summary filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) cells = np.squeeze(data_summary['cells']) if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge' or FLAGS.model_id == 'poisson_relu': cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | ( cells == 3066) if FLAGS.model_id == 'poisson_full': cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool') n_cells = np.sum(cells_choose) tot_spks = np.squeeze(data_summary['tot_spks']) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T tot_spks_chosen_cells = np.array(tot_spks[cells_choose], dtype='float32') chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0, dtype='bool') print(np.shape(chosen_mask)) print(np.sum(chosen_mask)) stim_dim = np.sum(chosen_mask) print(FLAGS.model_id) print('\ndataset summary loaded') # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells # decide the number of subunits to fit n_su = FLAGS.ratio_SU * n_cells # saving details if FLAGS.model_id == 'poisson': short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_tlen=' + str(FLAGS.train_len) + '_bg') else: short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_tlen=' + str(FLAGS.train_len) + '_bg') parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' if not gfile.IsDirectory(parent_folder): gfile.MkDir(parent_folder) FLAGS.save_location = parent_folder + short_filename + '/' print(gfile.IsDirectory(FLAGS.save_location)) if not gfile.IsDirectory(FLAGS.save_location): gfile.MkDir(FLAGS.save_location) print(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename with tf.Session() as sess: # Learn population model! stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full': # variables w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_cells, 1, n_su), dtype='float32')) lam = tf.transpose(tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a), 2)) #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len loss_inter = (tf.reduce_sum(lam) / 120. - tf.reduce_sum(resp * tf.log(lam))) / data_len loss = loss_inter train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize( loss, var_list=[w, a]) def training(inp_dict): sess.run(train_step, feed_dict=inp_dict) def get_loss(inp_dict): ls = sess.run(loss, feed_dict=inp_dict) return ls if FLAGS.model_id == 'poisson_relu': # variables w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.log( (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)) b = tf.Variable(b_init, dtype='float32') f = tf.matmul( tf.exp(tf.nn.relu(stim, w)), a ) + b #loss_inter = (tf.reduce_sum(lam/tot_spks_chosen_cells)/120. - tf.reduce_sum(resp*tf.log(lam)/tot_spks_chosen_cells)) / data_len loss_inter = (tf.reduce_sum(f) / 120. - tf.reduce_sum(resp * tf.log(f))) / data_len loss = loss_inter train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize( loss, var_list=[w, a]) def training(inp_dict): sess.run(train_step, feed_dict=inp_dict) def get_loss(inp_dict): ls = sess.run(loss, feed_dict=inp_dict) return ls if FLAGS.model_id == 'logistic': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.log( (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)) b = tf.Variable(b_init, dtype='float32') f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b loss_inter = tf.reduce_sum(tf.nn.softplus( -2 * (resp - 0.5) * f)) / data_len loss = loss_inter train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize( loss, var_list=[w, a, b]) a_pos = tf.assign(a, (a + tf.abs(a)) / 2) def training(inp_dict): sess.run(train_step, feed_dict=inp_dict) sess.run(a_pos) def get_loss(inp_dict): ls = sess.run(loss, feed_dict=inp_dict) return ls if FLAGS.model_id == 'hinge': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.log( (tot_spks_chosen_cells) / (216000. - tot_spks_chosen_cells)) b = tf.Variable(b_init, dtype='float32') f = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b loss_inter = tf.reduce_sum(tf.nn.relu(1 - 2 * (resp - 0.5) * f)) / data_len loss = loss_inter train_step = tf.train.AdagradOptimizer(FLAGS.step_sz).minimize( loss, var_list=[w, a, b]) a_pos = tf.assign(a, (a + tf.abs(a)) / 2) def training(inp_dict): sess.run(train_step, feed_dict=inp_dict) sess.run(a_pos) def get_loss(inp_dict): ls = sess.run(loss, feed_dict=inp_dict) return ls # summaries l_summary = tf.scalar_summary('loss', loss) l_inter_summary = tf.scalar_summary('loss_inter', loss_inter) # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.merge_all_summaries() train_writer = tf.train.SummaryWriter(FLAGS.save_location + 'train', sess.graph) test_writer = tf.train.SummaryWriter(FLAGS.save_location + 'test') print('\nStarting new code') sess.run(tf.initialize_all_variables()) saver_var = tf.train.Saver(tf.all_variables(), keep_checkpoint_every_n_hours=0.05) load_prev = False start_iter = 0 try: latest_filename = short_filename + '_latest_fn' restore_file = tf.train.latest_checkpoint(FLAGS.save_location, latest_filename) start_iter = int(restore_file.split('/')[-1].split('-')[-1]) saver_var.restore(sess, restore_file) load_prev = True except: print('No previous dataset') if load_prev: print('\nPrevious results loaded') else: print('\nVariables initialized') # Do the fitting icnt = 0 stim_test, resp_test, test_length = get_test_data() fd_test = {stim: stim_test, resp: resp_test, data_len: test_length} #logfile.close() for istep in np.arange(start_iter, 400000): print(istep) # get training data stim_train, resp_train, train_len = get_next_training_batch(istep) fd_train = { stim: stim_train, resp: resp_train, data_len: train_len } # take training step training(fd_train) if istep % 10 == 0: # compute training and testing losses ls_train = get_loss(fd_train) ls_test = get_loss(fd_test) latest_filename = short_filename + '_latest_fn' saver_var.save(sess, save_filename, global_step=istep, latest_filename=latest_filename) # add training summary summary = sess.run(merged, feed_dict=fd_train) train_writer.add_summary(summary, istep) # add testing summary summary = sess.run(merged, feed_dict=fd_test) test_writer.add_summary(summary, istep) print(istep, ls_train, ls_test) icnt += FLAGS.batchsz if icnt > 216000 - 1000: icnt = 0 tms = np.random.permutation(np.arange(216000 - 1000))
def main(argv): logfile = gfile.Open( FLAGS.save_location + 'log_bias=' + str(FLAGS.bias_ratio) + '_lam_W=' + str(FLAGS.lam_W) + '_lam_a=' + str(FLAGS.lam_a) + '.txt', "w") logfile.write('Starting new thread\n') logfile.flush() print('\nlog file written once') np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) #plt.ion() ## Load data file = h5py.File(FLAGS.data_location + 'Off_parasol.mat', 'r') logfile.write('\ndataset loaded') # Load Masked movie data = file.get('maskedMovdd') maskedMov = np.array(data) cells = file.get('cells') nCells = cells.shape[0] ttf_log = file.get('ttf_log') ttf_avg = file.get('ttf_avg') stimulus = maskedMov total_mask_log = file.get('totalMaskAccept_log') # Load spike Response of cells data = file.get('Y') biSpkResp_coll = np.array(data, dtype='float32') mask = np.array(np.ones(3200), dtype=bool) ## Nsub = FLAGS.ratio_SU * nCells StimDim = maskedMov.shape[1] # initialize subunits W_init = initialize_SU(nSU=Nsub) a_init = np.random.rand(Nsub, nCells) su_act = stimulus[10000:12000, :].dot(W_init) su_std = np.sqrt(np.diag(su_act.T.dot(su_act)) / stimulus.shape[0]) bias_init = FLAGS.bias_ratio * su_std print(bias_init) logfile.write('bias = ' + str(bias_init)) logfile.write('\nSU initialized') with tf.Session() as sess: stim = tf.placeholder(tf.float32, shape=[None, StimDim], name="stim") resp = tf.placeholder(tf.float32, name="resp") data_len = tf.placeholder(tf.float32, name="data_len") #W = tf.Variable(tf.random_uniform([StimDim,Nsub])) W = tf.Variable(np.array(W_init, dtype='float32')) a = tf.Variable(np.array(a_init, dtype='float32')) bias = tf.Variable(np.array(bias_init, dtype='float32')) #a = tf.Variable(np.identity(Nsub,dtype='float32')*0.01) lam = tf.matmul(tf.nn.relu(tf.matmul(stim, W) + bias), tf.nn.relu(a)) + 0.0001 # collapse a dimension loss = (tf.reduce_sum(lam) / 120. - tf.reduce_sum( resp * tf.log(lam))) / data_len + FLAGS.lam_W * tf.reduce_sum( tf.abs(W)) + FLAGS.lam_a * tf.reduce_sum(tf.abs(a)) train_step = tf.train.AdamOptimizer(1e-4).minimize( loss, var_list=[W, a, bias]) sess.run(tf.initialize_all_variables()) # Do the fitting batchSz = 100 icnt = 0 fd_test = { stim: stimulus.astype('float32')[216000 - 10000:216000 - 1, :], resp: biSpkResp_coll.astype('float32')[216000 - 10000:216000 - 1, :], data_len: 10000 } ls_train_log = np.array([]) ls_test_log = np.array([]) tms = np.random.permutation(np.arange(216000 - 1000)) for istep in range(100000): time_start = timeit.timeit() fd_train = { stim: stimulus.astype('float32')[tms[icnt:icnt + batchSz], :], resp: biSpkResp_coll.astype('float32')[tms[icnt:icnt + batchSz], :], data_len: batchSz } sess.run(train_step, feed_dict=fd_train) if istep % 10 == 0: ls_train = sess.run(loss, feed_dict=fd_train) ls_test = sess.run(loss, feed_dict=fd_test) ls_train_log = np.append(ls_train_log, ls_train) ls_test_log = np.append(ls_test_log, ls_test) logfile.write('\nIterations: ' + str(istep) + ' Training error: ' + str(ls_train) + ' Testing error: ' + str(ls_test)) logfile.flush() sio.savemat( FLAGS.save_location + 'data_bias=' + str(FLAGS.bias_ratio) + '_lam_W=' + str(FLAGS.lam_W) + '_lam_a' + str(FLAGS.lam_a) + '_ratioSU' + str(FLAGS.ratio_SU) + '_grid_spacing_' + str(FLAGS.su_grid_spacing) + '.mat', { 'bias_ratio': FLAGS.bias_ratio, 'bias_init': bias_init, 'bias': bias.eval(), 'W': W.eval(), 'a': a.eval(), 'W_init': W_init, 'a_init': a_init, 'ls_train_log': ls_train_log, 'ls_test_log': ls_test_log }) icnt = icnt + batchSz if icnt > 216000 - 10000: icnt = 0 tms = np.random.permutation(np.arange(216000 - 10000)) logfile.close()
def learn_lin_embedding(stimulus, responses, filename, lam_l1=0.01, beta=10, time_window=30, lr=0.01): num_cell_types = 2 dimx = 80 dimy = 40 leng = np.minimum(stimulus.shape[0], responses.shape[0]) resp_short = responses[:leng, :] stim_short = np.reshape(stimulus[:leng, :, :], [leng, -1]) init_A = stim_short[:-4, :].T.dot(resp_short[4:, :]) init_A_3d = np.reshape(init_A.T, [-1, stimulus.shape[1], stimulus.shape[2]]) n_cells = responses.shape[1] with tf.Session() as sess: stim_tf = tf.placeholder(tf.float32, shape=[None, dimx, dimy, time_window]) # batch x X x Y x time_window # Linear filter in time ttf = tf.Variable(np.ones((time_window, 1)).astype(np.float32)) stim_tf_2d = tf.reshape(stim_tf, [-1, time_window]) stim_filtered_2d = tf.matmul(stim_tf_2d, ttf) stim_filtered = tf.reshape(stim_filtered_2d, [-1, dimx, dimy, 1]) stim_filtered = tf.gather(tf.transpose(stim_filtered, [3, 0, 1, 2]), 0) # filter in space # A = tf.Variable(np.ones((n_cells, dimx, dimy)).astype(np.float32)) A = tf.Variable(init_A_3d.astype(np.float32)) A_2d = tf.reshape(A, [n_cells, dimx*dimy]) stim_filtered_perm_2d = tf.reshape(stim_filtered, [-1, dimx*dimy]) stim_space_filtered_2d = tf.matmul(stim_filtered_perm_2d, tf.transpose(A_2d)) stim_out = tf.expand_dims(stim_space_filtered_2d, 2) responses_anchor_tf = tf.placeholder(dtype=tf.float32, shape=[None, n_cells, 1], name='anchor') # batch x n_cells, 1 responses_neg_tf = tf.placeholder(dtype=tf.float32, shape=[None, n_cells, 1], name='anchor') # batch x n_cells, 1 from IPython import embed; embed() d_s_r_pos = - tf.reduce_sum((stim_out*responses_anchor_tf)**2, [1, 2]) # batch d_pairwise_s_rneg = - tf.reduce_sum((tf.expand_dims(stim_out, 1) * tf.expand_dims(responses_neg_tf, 0))**2, [2, 3]) # batch x batch_neg difference = (tf.expand_dims(d_s_r_pos/beta, 1) - d_pairwise_s_rneg/beta) # postives x negatives # # log(1 + \sum_j(exp(d+ - dj-))) difference_padded = tf.pad(difference, [[0, 0], [0, 1]]) loss = tf.reduce_sum(beta * tf.reduce_logsumexp(difference_padded, 1), 0) accuracy_tf = tf.reduce_mean(tf.sign(-tf.expand_dims(d_s_r_pos, 1) + d_pairwise_s_rneg)) lr = 0.01 train_op = tf.train.AdagradOptimizer(lr).minimize(loss) with tf.control_dependencies([train_op]): prox_op = tf.assign(A, tf.nn.relu(A - lam_l1) - tf.nn.relu(- A - lam_l1)) update = tf.group(train_op, prox_op) # Now train sess.run(tf.global_variables_initializer()) a_log = [] for iiter in range(200000): stim_batch, resp_batch, resp_batch_neg = sample_datasets.get_batch(stimulus, responses, batch_size=100, batch_neg_resp=100, stim_history=time_window, min_window=10, batch_type='train') feed_dict = {stim_tf: stim_batch.astype(np.float32), responses_anchor_tf: np.expand_dims(resp_batch, 2).astype(np.float32), responses_neg_tf: np.expand_dims(resp_batch_neg, 2).astype(np.float32)} _, l, a = sess.run([update, loss, accuracy_tf], feed_dict = feed_dict) a_log += [a] if iiter % 10 == 0: print(a) if iiter % 1000 == 0: save_dict = {'A': sess.run(A), 'ttf': sess.run(ttf)} pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder, filename), 'w')) return [sess.run(A), sess.run(ttf)]
def main(unused_argv=()): # set random seed np.random.seed(121) print('random seed reset') # Get details of stored model. model_savepath, model_filename = config.get_filepaths() # Load responses to two trials of long white noise. data_wn = du.DataUtilsMetric(os.path.join(FLAGS.data_path, FLAGS.data_test)) # Quadratic score function. with tf.Session() as sess: # Define and restore/initialize the model. tf.logging.info('Model : %s ' % FLAGS.model) met = config.get_model(sess, model_savepath, model_filename, data_wn, True) print('IS_TRAINING = TRUE!!! ') tf.logging.info('IS_TRAINING = TRUE!!! ') PrintModelAnalysis(tf.get_default_graph()) # get triplets # triplet A outputs = data_wn.get_triplets(batch_size=FLAGS.batch_size_test, time_window=FLAGS.time_window) anchor_test, pos_test, neg_test, _, _, _ = outputs triplet_a = (anchor_test, pos_test, neg_test) # triplet B outputs = data_wn.get_tripletsB(batch_size=FLAGS.batch_size_test, time_window=FLAGS.time_window) anchor_test, pos_test, neg_test, _, _, _ = outputs triplet_b = (anchor_test, pos_test, neg_test) triplets = [triplet_a, triplet_b] triplet_labels = ['triplet A', 'triplet B'] analysis_results = {} # collect analysis results in a dictionary # 1. Plot distances between positive and negative pairs. # analyse.plot_pos_neg_distances(met, anchor_test, pos_test, neg_test) # tf.logging.info('Distances plotted') # 2. Accuracy of triplet orderings - fraction of triplets where # distance with positive is smaller than distance with negative. triplet_dict = {} for iitriplet, itriplet in enumerate(triplets): dist_pos, dist_neg, accuracy = analyse.compute_distances(met, *itriplet) dist_analysis = {'pos': dist_pos, 'neg': dist_neg, 'accuracy': accuracy} triplet_dict.update({triplet_labels[iitriplet]: dist_analysis}) analysis_results.update({'distances': triplet_dict}) tf.logging.info('Accuracy computed') # 3. Precision-Recall analysis : declare positive if s(x,y)<t and # negative otherwise. Vary threshold t, and plot precision-recall and # ROC curves. triplet_dict = {} for iitriplet, itriplet in enumerate(triplets): output = analyse.precision_recall(met, *itriplet, toplot=False) precision_log, recall_log, f1_log, fpr_log, tpr_log, pr_data = output pr = {'precision': precision_log, 'recall': recall_log, 'pr_data': pr_data} roc = {'TPR': tpr_log, 'FPR': fpr_log} pr_results = {'PR': pr, 'F1': f1_log, 'ROC': roc} triplet_dict.update({triplet_labels[iitriplet]: pr_results}) analysis_results.update({'PR_analysis': triplet_dict}) tf.logging.info('Precision Recall, F1 score and ROC curves computed') # 4. Clustering analysis: How well clustered are responses for a stimulus? # Get all trials for a few (1000) stimuli and compute # distances between all pairs of points. # See how many of responses generated by same stimulus are actually # near to each other. n_tests = 10 p_log = [] r_log = [] s_log = [] resp_log = [] dist_log = [] embedding_log = [] for itest in range(n_tests): n_stims = 10 # previously 100 tf.logging.info('Number of random samples is : %d' % n_stims) resp_fcn = data_wn.get_response_all_trials resp_all_trials, stim_id = resp_fcn(n_stims, FLAGS.time_window, random_seed=itest) # TODO(bhaishahster) : Remove duplicates from resp_all_trials distance_pairs = analyse.get_pairwise_distances(met, resp_all_trials) k_log = [1, 2, 3, 4, 5, 10, 15, 20, 50, 75, 100, 200, 300, 400, 500] precision_log = [] recall_log = [] for k in k_log: precision, recall = analyse.topK_retrieval(distance_pairs, k, stim_id) precision_log += [precision] recall_log += [recall] p_log += [precision_log] r_log += [recall_log] s_log += [stim_id] resp_log += [resp_all_trials] dist_log += [distance_pairs] #tf.logging.info('Getting 2D t-SNE embedding') #model = manifold.TSNE(n_components=2) #tSNE_embedding = model.fit_transform(distance_pairs) #embedding_log += [tSNE_embedding] all_trials = {'distances': dist_log, 'K': k_log, 'precision': p_log, 'recall': r_log, 'probe_stim_idx': s_log, 'probes': resp_log, 'embedding': embedding_log} analysis_results_clustering = {'all_trials': all_trials} pickle_file_clustering = (os.path.join(model_savepath, model_filename) + '_' + FLAGS.data_test + '_analysis_clustering.pkl') pickle.dump(analysis_results_clustering, gfile.Open(pickle_file_clustering, 'w')) tf.logging.info('Clustering analysis done.') ''' # sample few/all repeats of stimuli which are continous. repeats = data_wn.get_repeats() n_samples_max = 10 samples = np.random.randint(0, repeats.shape[0], np.minimum(n_samples_max, repeats.shape[0])) n_start_times = 5 time_window = 15 resps_cont = np.zeros((n_start_times, n_samples_max, time_window, repeats.shape[-1])) from IPython import embed; embed() for istart in range(n_start_times): start_tm = np.random.randint(repeats.shape[1] - time_window) resps_cont[istart, :, :, :] = repeats[samples, start_tm: start_tm+time_window, :] resps_cont_2d = np.reshape(resps_cont, [-1, resps_cont.shape[-1]]) resps_cont_3d = np.expand_dims(resps_cont_2d, 2) distances_cont_resp = analyse.get_pairwise_distances(met, resps_cont_3d) n_components = 2 model = manifold.TSNE(n_components=n_components) ts = model.fit_transform(distances_cont_resp) tts = np.reshape(ts, [n_start_times, n_samples_max, time_window, n_components]) from IPython import embed; embed() plt.figure() for istart in [1]: # range(n_start_times): for isample in range(n_samples_max): pts = tts[istart, isample, :, :] plt.plot(pts[:, 0], pts[:, 1]) plt.show() ''' # 5. Store the parameters of the score function. score_params = met.get_parameters() analysis_results.update({'score_params': score_params}) tf.logging.info('Got interesting parameters of score') # 6. Retreival analysis on training data. # Retrieve the nearest responses in training data for a probe test response. # Load training data. # data_wn_train = du.DataUtilsMetric(os.path.join(FLAGS.data_path, # FLAGS.data_train)) # # out_data = data_wn_train.get_all_responses(FLAGS.time_window) # train_all_resp, train_stim_time = out_data # # # Get a few test stimuli. Here we use all repreats of a few stimuli. # n_stims = 100 # resp_all_trials, stim_id = data_wn.get_response_all_trials(n_stims, # FLAGS.time_window) # k = 1000 # retrieved, retrieved_stim = analyse.topK_retrieval_probes(train_all_resp, # train_stim_time, # resp_all_trials, # k, met) # retrieval_dict = {'probe': resp_all_trials, 'probe_stim_idx': stim_id, # 'retrieved': retrieved, # 'retrieved_stim_idx': retrieved_stim} # analysis_results.update({'retrieval': retrieval_dict}) # tf.logging.info('Retrieved nearest points in training data' # ' for some probes in test data') # TODO(bhaishahster) : Decode stimulus using retrieved responses. # 7. Learn encoding model. # Learn mapping from stimulus to response. # from IPython import embed; embed() ''' data_wn_train = du.DataUtilsMetric(os.path.join(FLAGS.data_path, 'example_long_wn_2rep_' 'ON_OFF_with_stim.mat')) data_wn_test = du.DataUtilsMetric(os.path.join(FLAGS.data_path, 'example_wn_30reps_ON_' 'OFF_with_stimulus.mat')) stimulus_test = data_wn_test.get_stimulus() response_test = data_wn_test.get_repeats() stimulus = data_wn_train.get_stimulus() response = data_wn_train.get_repeats() ttf = data_wn_train.ttf[::-1] encoding_fcn = encoding_model.learn_encoding_model_ln # Initialize ttf, RF using ttf and scale ttf to match firing rate RF_np, ttf_np, model = encoding_fcn(sess, met, stimulus, response, ttf_in=ttf, lr=0.1) firing_rate_pred = sess.run(model.firing_rate, feed_dict={model.stimulus: stimulus_test}) initialize_all = {'RF': RF_np, 'ttf': ttf, 'firing_rate_test': firing_rate_pred} # Initialize ttf and do no other initializations RF_np_noinit, ttf_np_noinit, model = encoding_fcn(sess,met, stimulus, response, ttf_in=ttf, initialize_RF_using_ttf=False, scale_ttf=False, lr=0.1) firing_rate_pred = sess.run(model.firing_rate, feed_dict={model.stimulus: stimulus_test}) initialize_only_ttf = {'RF': RF_np_noinit, 'ttf': ttf_np_noinit, 'firing_rate_test': firing_rate_pred} # Initialize ttf and do no other initializations RF_np_noinit2, ttf_np_noinit2, model = encoding_fcn(sess, met, stimulus, response, ttf_in=None, initialize_RF_using_ttf=False, scale_ttf=False, lr=0.1) firing_rate_pred = sess.run(model.firing_rate, feed_dict={model.stimulus: stimulus_test}) initialize_none = {'RF': RF_np_noinit2, 'ttf': ttf_np_noinit2, 'firing_rate_test': firing_rate_pred} encoding_models = {'Init_all': initialize_all, 'Init_ttf': initialize_only_ttf, 'Init_none': initialize_none, 'responses_test': response_test} analysis_results.update({'Encoding_models': encoding_models}) ''' # 8. Is similarity in images implicitly learnt in the metric ? # Reconstruction done in colab notebook ''' class StimulusMetric(object): """Compute MSE between stimuli.""" def get_distance(self, in1, in2): return np.sqrt(np.sum(np.sum((in1 - in2)**2, 2), 1)) # TODO(bhaishahster) : Filtering by time is remaining! stimuli_met = StimulusMetric() stim_distance, resp_distance, times, responses = analyse.compare_stimulus_score_similarity(data_wn, stimuli_met, met) compare_stim_mse_resp_met = {'stimulus_mse': stim_distance, 'response_metric': resp_distance, 'times': times, 'response_pairs': responses} analysis_results.update({'perception': compare_stim_mse_resp_met}) ''' # 9. Retrieve nearest responses from ALL possible response patterns # Retrieve the nearest responses in training data for a probe test response. ''' import itertools lst = list(map(list, itertools.product([0, 1], repeat=data_wn.n_cells))) all_resp = np.array(lst) all_resp = np.expand_dims(all_resp, 2) # Get a few test stimuli. Here we use all repreats of a few stimuli. n_stims = 100 probe_responses, stim_id = data_wn.get_response_all_trials(n_stims, FLAGS.time_window) distances_corpus = analyse.compute_all_distances(all_resp, probe_responses, met) retrieval_dict = {'probe': probe_responses, 'probe_stim_idx': stim_id, 'corpus': all_resp, 'distance_corpus': distances_corpus} analysis_results.update({'retrieval_ALL_responses': retrieval_dict}) tf.logging.info('Distance of probe to ALL possible response patterns') ''' # 10. Get embedding for all possible responses, # only if there are less than 15 cells if data_wn.n_cells < 15: import itertools lst = list(map(list, itertools.product([0, 1], repeat=data_wn.n_cells))) all_resp = np.expand_dims(np.array(lst), 2) # use time_window of 1. all_resp_embedding = met.get_embedding(all_resp) analysis_results.update({'all_resp_embedding': all_resp_embedding}) # save analysis in a pickle file # from IPython import embed; embed() pickle_file = (os.path.join(model_savepath, model_filename) + '_' + FLAGS.data_test + '_analysis.pkl') pickle.dump(analysis_results, gfile.Open(pickle_file, 'w')) # pickle.dump(analysis_results, file_io.FileIO(pickle_file, 'w')) tf.logging.info('File: ' + pickle_file) tf.logging.info('Analysis results saved') print('File: ' + pickle_file)
def main(argv): #plt.ion() # interactive plotting window = FLAGS.window n_pix = (2 * window + 1)**2 dimx = np.floor(1 + ((40 - (2 * window + 1)) / FLAGS.stride)).astype('int') dimy = np.floor(1 + ((80 - (2 * window + 1)) / FLAGS.stride)).astype('int') nCells = 107 # load model # load filename print(FLAGS.model_id) with tf.Session() as sess: if FLAGS.model_id == 'relu': # lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0 short_filename = ('data_model=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_normalized_bg') w = tf.Variable( np.array(np.random.randn(3200, 749), dtype='float32')) a = tf.Variable( np.array(np.random.randn(749, 107), dtype='float32')) if FLAGS.model_id == 'relu_window': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) # exp 5 a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_sfm': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_sfm_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.01 + 0.005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.02 + np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_mother_exp': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w_del = tf.Variable( np.array(0.1 + 0.05 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) w_mother = tf.Variable( np.array(np.ones((2 * window + 1, 2 * window + 1, 1, 1)), dtype='float32')) a = tf.Variable( np.array(np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'relu_window_a_support': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.002 * np.random.rand(dimx * dimy, nCells), dtype='float32')) if FLAGS.model_id == 'exp_window_a_support': short_filename = ('data_model=' + str(FLAGS.model_id) + '_window=' + str(FLAGS.window) + '_stride=' + str(FLAGS.stride) + '_lam_w=' + str(FLAGS.lam_w) + '_bg') w = tf.Variable( np.array(0.001 + 0.0005 * np.random.rand(dimx, dimy, n_pix), dtype='float32')) a = tf.Variable( np.array(0.002 * np.random.rand(dimx * dimy, nCells), dtype='float32')) parent_folder = FLAGS.save_location + FLAGS.folder_name + '/' FLAGS.save_location = parent_folder + short_filename + '/' # get relevant files file_list = gfile.ListDirectory(FLAGS.save_location) save_filename = FLAGS.save_location + short_filename print('\nLoading: ', save_filename) bin_files = [] meta_files = [] for file_n in file_list: if re.search(short_filename + '.', file_n): if re.search('.meta', file_n): meta_files += [file_n] else: bin_files += [file_n] #print(bin_files) print(len(meta_files), len(bin_files), len(file_list)) # get iteration numbers iterations = np.array([]) for file_name in bin_files: try: iterations = np.append( iterations, int(file_name.split('/')[-1].split('-')[-1])) except: print('Could not load filename: ' + file_name) iterations.sort() print(iterations) iter_plot = iterations[-1] print(int(iter_plot)) # load tensorflow variables saver_var = tf.train.Saver(tf.all_variables()) restore_file = save_filename + '-' + str(int(iter_plot)) saver_var.restore(sess, restore_file) # plot subunit - cell connections plt.figure() plt.cla() plt.imshow(a.eval(), cmap='gray', interpolation='nearest') print(np.shape(a.eval())) plt.title('Iteration: ' + str(int(iter_plot))) plt.show() plt.draw() # plot all subunits on 40x80 grid try: wts = w.eval() for isu in range(100): fig = plt.subplot(10, 10, isu + 1) plt.imshow(np.reshape(wts[:, isu], [40, 80]), interpolation='nearest', cmap='gray') plt.title('Iteration: ' + str(int(iter_plot))) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w full does not exist? ') # plot a few subunits - wmother + wdel try: wts = w.eval() print('wts shape:', np.shape(wts)) icnt = 1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2 * window + 1, 2 * window + 1)), interpolation='nearest', cmap='gray') icnt = icnt + 1 fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.show() plt.draw() except: print('w does not exist?') # plot wmother try: w_mot = np.squeeze(w_mother.eval()) print(w_mot) plt.imshow(w_mot, interpolation='nearest', cmap='gray') plt.title('Mother subunit') plt.show() plt.draw() except: print('w mother does not exist') # plot wmother + wdel try: w_mot = np.squeeze(w_mother.eval()) w_del = np.squeeze(w_del.eval()) wts = np.array(np.random.randn(dimx, dimy, (2 * window + 1)**2)) for idimx in np.arange(dimx): print(idimx) for idimy in np.arange(dimy): wts[idimx, idimy, :] = np.ndarray.flatten(w_mot) + w_del[idimx, idimy, :] except: print('w mother + w delta do not exist? ') ''' try: icnt=1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow(np.reshape(np.squeeze(wts[idimx, idimy, :]), (2*window+1,2*window+1)), interpolation='nearest', cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w mother + w delta plotting error? ') # plot wdel try: w_del = np.squeeze(w_del.eval()) icnt=1 for idimx in np.arange(dimx): for idimy in np.arange(dimy): fig = plt.subplot(dimx, dimy, icnt) plt.imshow( np.reshape(w_del[idimx, idimy, :], (2*window+1,2*window+1)), interpolation='nearest', cmap='gray') icnt = icnt+1 fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) except: print('w delta do not exist? ') plt.suptitle('Iteration: ' + str(int(iter_plot))) plt.show() plt.draw() ''' # select a cell, and show its subunits. #try: ## Load data summary, get mask filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T stas = data_summary['stas'] print(np.shape(total_mask)) # a is 2D a_eval = a.eval() print(np.shape(a_eval)) # get softmax numpy if FLAGS.model_id == 'relu_window_mother_sfm' or FLAGS.model_id == 'relu_window_mother_sfm_exp': b = np.exp(a_eval) / np.sum(np.exp(a_eval), 0) else: b = a_eval plt.figure() plt.imshow(b, interpolation='nearest', cmap='gray') plt.show() plt.draw() # plot subunits for multiple cells. n_cells = 10 n_plots_max = 20 plt.figure() for icell_cnt, icell in enumerate(np.arange(n_cells)): mask2D = np.reshape(total_mask[icell, :], [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) icnt = -1 a_thr = np.percentile(np.abs(b[:, icell]), 99.5) n_plots = np.sum(np.abs(b[:, icell]) > a_thr) nx = np.ceil(np.sqrt(n_plots)).astype('int') ny = np.ceil(np.sqrt(n_plots)).astype('int') ifig = 0 ww_sum = np.zeros((40, 80)) for idimx in np.arange(dimx): for idimy in np.arange(dimy): icnt = icnt + 1 if (np.abs(b[icnt, icell]) > a_thr): ifig = ifig + 1 fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + ifig + 2) ww = np.zeros((40, 80)) ww[idimx * FLAGS.stride:idimx * FLAGS.stride + (2 * window + 1), idimy * FLAGS.stride:idimy * FLAGS.stride + (2 * window + 1)] = b[icnt, icell] * (np.reshape( wts[idimx, idimy, :], (2 * window + 1, 2 * window + 1))) plt.imshow(ww, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) plt.title(b[icnt, icell]) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) ww_sum = ww_sum + ww fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + 2) plt.imshow(ww_sum, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('STA from model') fig = plt.subplot(n_cells, n_plots_max, icell_cnt * n_plots_max + 1) plt.imshow(np.reshape(stas[:, icell], [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('True STA') plt.show() plt.draw() #except: # print('a not 2D?') # using xlim and ylim, and plot the 'windows' which are relevant with their weights sq_flat = np.zeros((dimx, dimy)) icnt = 0 for idimx in np.arange(dimx): for idimy in np.arange(dimy): sq_flat[idimx, idimy] = icnt icnt = icnt + 1 n_cells = 1 n_plots_max = 10 plt.figure() for icell_cnt, icell in enumerate(np.array( [1, 2, 3, 4, 5])): #enumerate(np.arange(n_cells)): a_thr = np.percentile(np.abs(b[:, icell]), 99.5) mask2D = np.reshape(total_mask[icell, :], [40, 80]) nz_idx = np.nonzero(mask2D) np.shape(nz_idx) print(nz_idx) ylim = np.array([np.min(nz_idx[0]) - 1, np.max(nz_idx[0]) + 1]) xlim = np.array([np.min(nz_idx[1]) - 1, np.max(nz_idx[1]) + 1]) print(xlim, ylim) win_startx = np.ceil((xlim[0] - (2 * window + 1)) / FLAGS.stride) win_endx = np.floor((xlim[1] - 1) / FLAGS.stride) win_starty = np.ceil((ylim[0] - (2 * window + 1)) / FLAGS.stride) win_endy = np.floor((ylim[1] - 1) / FLAGS.stride) dimx_plot = win_endx - win_startx + 1 dimy_plot = win_endy - win_starty + 1 ww_sum = np.zeros((40, 80)) for irow, idimy in enumerate(np.arange(win_startx, win_endx + 1)): for icol, idimx in enumerate( np.arange(win_starty, win_endy + 1)): fig = plt.subplot(dimx_plot + 1, dimy_plot, (irow + 1) * dimy_plot + icol + 1) ww = np.zeros((40, 80)) ww[idimx * FLAGS.stride:idimx * FLAGS.stride + (2 * window + 1), idimy * FLAGS.stride:idimy * FLAGS.stride + (2 * window + 1)] = (np.reshape( wts[idimx, idimy, :], (2 * window + 1, 2 * window + 1))) plt.imshow(ww, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) if b[sq_flat[idimx, idimy], icell] > a_thr: plt.title(b[sq_flat[idimx, idimy], icell], fontsize=10, color='g') else: plt.title(b[sq_flat[idimx, idimy], icell], fontsize=10, color='r') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) ww_sum = ww_sum + ww * b[sq_flat[idimx, idimy], icell] fig = plt.subplot(dimx_plot + 1, dimy_plot, 2) plt.imshow(ww_sum, interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('STA from model') fig = plt.subplot(dimx_plot + 1, dimy_plot, 1) plt.imshow(np.reshape(stas[:, icell], [40, 80]), interpolation='nearest', cmap='gray') plt.ylim(ylim) plt.xlim(xlim) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.title('True STA') plt.show() plt.draw()
def __init__(self, data_location, batch_sz_in=1000, masked_stimulus=False, chosen_cells=None, total_samples=216000, storage_chunk_sz=1000, data_chunk_prefix='Off_par_data_', stimulus_dimension=3200, test_length=20000, sr_key={ 'stimulus': 'maskedMovdd_part', 'response': 'Y_part' }): # loads dataset, sets up variables. # data_location: where data is stored, after splitting into chunks, # with each chunk having prefix 'data_chunk_prefix' # having 'storage_chunk_sz' number of samples out of 'total_samples' # 'masked_stimulus': if we want to clip the stimulus to relevant pixels for selected cells # chosen_cells : which cells to include in 'response'. If None, include all cells # sr_key: what variables in stored data correspond to stimulus and response matrices # the stored matrices should have shape: # stimulus (# stimulus_dimension x time) and response (# chosen_cells x time) # batch_sz_in : number of examples in each batch of training data # test_length : number of examples in test data if chosen_cells is None: all_cells = True self.batch_sz = batch_sz_in print('Initializing datasets') # Load summary of datasets data_filename = data_location + 'data_details.mat' tf.logging.info('Loading summary file: ' + data_filename) summary_file = gfile.Open(data_filename, 'r') data_summary = sio.loadmat(summary_file) #data_summary = sio.loadmat(data_filename) cells = np.squeeze(data_summary['cells']) print('Dataset details loaded') # choose cells to build model for if all_cells: self.cells_choose = np.array(np.ones(np.shape(cells)), dtype='bool') else: cells_choose = np.zeros(len(cells)) for icell in chosen_cells: cells_choose += np.array(cells == icell).astype(np.int) self.cells_choose = cells_choose > 0 n_cells = np.sum(self.cells_choose) # number of cells self.stas = np.array(data_summary['stas']) self.stas = self.stas[:, self.cells_choose] print('Cells selected: %d' % n_cells) # count total number of spikes for chosen cells tot_spks = np.squeeze(data_summary['tot_spks']) tot_spks_chosen_cells = np.array(tot_spks[self.cells_choose], dtype='float32') self.total_spikes_chosen_cells = tot_spks_chosen_cells print('Total number of spikes loaded') # choose what stimulus mask to use # self.chosen_mask = which pixels to learn subunits over if masked_stimulus: total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T self.cell_mask = total_mask self.chosen_mask = np.array( np.sum(total_mask[self.cells_choose, :], 0) > 0, dtype='bool') else: total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T self.cell_mask = total_mask self.chosen_mask = np.array( np.ones(stimulus_dimension).astype('bool')) stim_dim = np.sum(self.chosen_mask) # stimulus dimensions print('Stimulus mask size: %d' % np.sum(stim_dim)) print('Stimulus mask loaded') # load stimulus and response from .mat files # python cant read too big .mat files, # so have broken it down into smaller pieces and stitch the data later self.stimulus = np.zeros((total_samples, np.sum(self.chosen_mask))) self.response = np.zeros((total_samples, n_cells)) n_chunks_load = np.int(total_samples / storage_chunk_sz) for ichunk in range(216): print('Loading %d' % ichunk) filename = data_location + data_chunk_prefix + str(ichunk + 1) + '.mat' tf.logging.info('Trying to load: ' + filename) file_r = gfile.Open(filename, 'r') data = sio.loadmat(file_r) #data = sio.loadmat(filename) X = data[sr_key['stimulus']].T Y = data[sr_key['response']].T self.stimulus[ichunk * storage_chunk_sz:(ichunk + 1) * storage_chunk_sz, :] = X[:, self.chosen_mask] self.response[ichunk * storage_chunk_sz:(ichunk + 1) * storage_chunk_sz, :] = Y[:, self.cells_choose] # set up training and testing chunk IDs n_chunks = np.int(total_samples / self.batch_sz) test_num_chunks = test_length / self.batch_sz self.test_chunks = np.arange(test_num_chunks) self.train_chunks = np.random.permutation( np.arange(test_num_chunks + 1, n_chunks)) self.ichunk_train = 0
def main(argv): print('\nCode started') np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) ## Load data file = h5py.File(FLAGS.data_location + 'Off_parasol.mat', 'r') print('\ndataset loaded') # Load Masked movie data = file.get('maskedMovdd') stimulus = np.array(data) cells = file.get('cells') nCells = cells.shape[0] total_mask_log = file.get('totalMaskAccept_log') Nsub = FLAGS.ratio_SU * nCells stim_dim = stimulus.shape[1] # Load spike Response of cells data = file.get('Y') response = np.array(data, dtype='float32') tot_spks = np.squeeze(np.sum(response, axis=0)) print(sys.getsizeof(file)) print(sys.getsizeof(stimulus)) print(sys.getsizeof(response)) with tf.Session() as sess: stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') if FLAGS.model_id == 0: # MODEL: lam_c(X) = sum_s(a_cs relu(k_s.x)) , a_cs>0 w_init = initialize_su(n_su=Nsub) a_init = np.random.rand(Nsub, nCells) w = tf.Variable(np.array(w_init, dtype='float32')) a = tf.Variable(np.array(a_init, dtype='float32')) lam = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + 0.0001 loss = (tf.reduce_sum(lam / tot_spks) / 120. - tf.reduce_sum(resp * tf.log(lam) / tot_spks)) / data_len loss_with_reg = loss + FLAGS.lam_w * tf.reduce_sum( tf.abs(w)) + FLAGS.lam_a * tf.reduce_sum(tf.abs(a)) # training steps for a. train_step_a = tf.train.AdagradOptimizer(FLAGS.eta_a).minimize( loss, var_list=[a]) # as 'a' is positive, this is op soft-thresholding for L1 and projecting to feasible set soft_th_a = tf.assign(a, tf.nn.relu(a - FLAGS.eta_a * FLAGS.lam_a)) # training steps for w train_step_w = tf.train.AdagradOptimizer(FLAGS.eta_w).minimize( loss, var_list=[w]) # do soft thresholding for 'w' soft_th_w = tf.assign( w, tf.nn.relu(w - FLAGS.eta_w * FLAGS.lam_w) - tf.nn.relu(-w - FLAGS.eta_w * FLAGS.lam_w)) save_filename = (FLAGS.save_location + 'data_model=' + str(FLAGS.model_id) + '_lam_w=' + str(FLAGS.lam_w) + '_lam_a=' + str(FLAGS.lam_a) + '_eta_w=' + str(FLAGS.eta_w) + '_eta_a=' + str(FLAGS.eta_a) + '_ratioSU=' + str(FLAGS.ratio_SU) + '_grid_spacing=' + str(FLAGS.su_grid_spacing) + '_proximal') # initialize the model logfile = gfile.Open(save_filename + '.txt', "w") logfile.write('Starting new code\n') logfile.flush() sess.run(tf.initialize_all_variables()) # Do the fitting batchsz = 100 icnt = 0 fd_test = { stim: stimulus.astype('float32')[216000 - 1000:216000 - 1, :], resp: response.astype('float32')[216000 - 1000:216000 - 1, :], data_len: 1000 } ls_train_log = np.array([]) ls_train_reg_log = np.array([]) ls_test_log = np.array([]) ls_test_reg_log = np.array([]) tms = np.random.permutation(np.arange(216000 - 1000)) for istep in range(100000): fd_train = { stim: stimulus.astype('float32')[tms[icnt:icnt + batchsz], :], resp: response.astype('float32')[tms[icnt:icnt + batchsz], :], data_len: batchsz } # gradient step for 'w' sess.run(train_step_w, feed_dict=fd_train) # soft thresholding for w sess.run(soft_th_w, feed_dict=fd_train) # gradient step for 'a' sess.run(train_step_a, feed_dict=fd_train) # soft thresholding for a, and project in constraint set sess.run(soft_th_a, feed_dict=fd_train) if istep % 10 == 0: # compute training and testing losses0.0 ls_train = sess.run(loss, feed_dict=fd_train) ls_train_log = np.append(ls_train_log, ls_train) ls_train_reg = sess.run(loss_with_reg, feed_dict=fd_train) ls_train_reg_log = np.append(ls_train_reg_log, ls_train_reg) ls_test = sess.run(loss, feed_dict=fd_test) ls_test_log = np.append(ls_test_log, ls_test) ls_test_reg = sess.run(loss_with_reg, feed_dict=fd_test) ls_test_reg_log = np.append(ls_test_reg_log, ls_test_reg) # log results logfile.write('\nIterations: ' + str(istep) + ' Training loss: ' + str(ls_train) + ' with reg: ' + str(ls_train_reg) + ' Testing loss: ' + str(ls_test) + ' with reg: ' + str(ls_test_reg) + ' w_l1_norm: ' + str(np.sum(np.abs(w.eval()))) + ' a_l1_norm: ' + str(np.sum(np.abs(a.eval())))) logfile.flush() sio.savemat( save_filename + '.mat', { 'w': w.eval(), 'a': a.eval(), 'w_init': w_init, 'a_init': a_init, 'ls_train_log': ls_train_log, 'ls_train_reg_log': ls_train_reg_log, 'ls_test_log': ls_test_log, 'ls_test_reg_log': ls_test_reg_log }) icnt += batchsz if icnt > 216000 - 1000: icnt = 0 tms = np.random.permutation(np.arange(216000 - 1000)) logfile.close()
def main(unused_argv=()): np.random.seed(23) tf.set_random_seed(1234) random.seed(50) # Load stimulus-response data. # Collect population response across retinas in the list 'responses'. # Stimulus for each retina is indicated by 'stim_id', # which is found in 'stimuli' dictionary. datasets = gfile.ListDirectory(FLAGS.src_dir) stimuli = {} responses = [] for icnt, idataset in enumerate(datasets): fullpath = os.path.join(FLAGS.src_dir, idataset) if gfile.IsDirectory(fullpath): key = 'stim_%d' % icnt op = data_util.get_stimulus_response(FLAGS.src_dir, idataset, key, boundary=FLAGS.valid_cells_boundary) stimulus, resp, dimx, dimy, _ = op stimuli.update({key: stimulus}) responses += resp taskid = FLAGS.taskid dat = responses[taskid] stimulus = stimuli[dat['stimulus_key']] # parameters window = 5 # Compute time course and non-linearity as two parameters which might be should be explored in embedded space. n_cells = dat['responses'].shape[1] T = np.minimum(stimulus.shape[0], dat['responses'].shape[0]) stim_short = stimulus[:T, :, :] resp_short = dat['responses'][:T, :].astype(np.float32) save_dict = {} # Find time course, non-linearity and RF parameters ######################################################################## # Separation between cell types ######################################################################## save_dict.update({'cell_type': dat['cell_type']}) save_dict.update({'dist_nn_cell_type': dat['dist_nn_cell_type']}) ######################################################################## # Find mean firing rate ######################################################################## mean_fr = dat['responses'].mean(0) mean_fr_1 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==1]) mean_fr_2 = np.mean(mean_fr[np.squeeze(dat['cell_type'])==2]) mean_fr_dict = {'mean_fr': mean_fr, 'mean_fr_1': mean_fr_1, 'mean_fr_2': mean_fr_2} save_dict.update({'mean_fr_dict': mean_fr_dict}) ######################################################################## # compute STAs ######################################################################## stas = np.zeros((n_cells, 80, 40, 30)) for icell in range(n_cells): print(icell) center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1]) for idelay in range(30): stas[icell, windx[0]: windx[1], windy[0]: windy[1], idelay] = np.reshape(resp_short[idelay:, icell].dot(stim_cell[:T-idelay, :]), [windx[1] - windx[0], windy[1] - windy[0]]) / np.sum(resp_short[idelay:, icell]) stas_dict = {'stas': stas} # save_dict.update({'stas_dict': stas_dict}) ######################################################################## # Find time courses for each cell ######################################################################## ttf_log = [] for icell in range(n_cells): center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] ll = stas[icell, windx[0]: windx[1], windy[0]: windy[1], :] ll_2d = np.reshape(ll, [-1, ll.shape[-1]]) u, s, v = np.linalg.svd(ll_2d) ttf_log += [v[0, :]] ttf_log = np.array(ttf_log) signs = [np.sign(ttf_log[icell, np.argmax(np.abs(ttf_log[icell, :]))]) for icell in range(ttf_log.shape[0])] ttf_corrected = np.expand_dims(np.array(signs), 1) * ttf_log ttf_corrected[np.squeeze(dat['cell_type'])==1, :] = ttf_corrected[np.squeeze(dat['cell_type'])==1, :] * -1 ttf_mean_1 = ttf_corrected[np.squeeze(dat['cell_type'])==1, :].mean(0) ttf_mean_2 = ttf_corrected[np.squeeze(dat['cell_type'])==2, :].mean(0) ttf_params_1 = get_times(ttf_mean_1) ttf_params_2 = get_times(ttf_mean_2) ttf_dict = {'ttf_log': ttf_log, 'ttf_mean_1': ttf_mean_1, 'ttf_mean_2': ttf_mean_2, 'ttf_params_1': ttf_params_1, 'ttf_params_2': ttf_params_2} save_dict.update({'ttf_dict': ttf_dict}) ''' plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==1, :].T, 'r', alpha=0.3) plt.plot(ttf_corrected[np.squeeze(dat['cell_type'])==2, :].T, 'k', alpha=0.3) plt.plot(ttf_mean_1, 'r--') plt.plot(ttf_mean_2, 'k--') ''' ######################################################################## ## Find non-linearity ######################################################################## f_nl = lambda x, p0, p1, p2, p3: p0 + p1*x + p2* np.power(x, 2) + p3* np.power(x, 3) nl_params_log = [] stim_resp_log = [] for icell in range(n_cells): print(icell) center = dat['centers'][icell, :].astype(np.int) windx = [np.maximum(center[0]-window, 0), np.minimum(center[0]+window, 80-1)] windy = [np.maximum(center[1]-window, 0), np.minimum(center[1]+window, 40-1)] stim_cell = np.reshape(stim_short[:, windx[0]: windx[1], windy[0]: windy[1]], [stim_short.shape[0], -1]) sta_cell = np.reshape(stas[icell, windx[0]: windx[1], windy[0]: windy[1], :], [-1, stas.shape[-1]]) stim_filter = np.zeros(stim_short.shape[0]) for idelay in range(30): stim_filter[idelay: ] += stim_cell[:T-idelay, :].dot(sta_cell[:, idelay]) # Normalize stim_filter stim_filter -= np.mean(stim_filter) stim_filter /= np.sqrt(np.var(stim_filter)) resp_cell = resp_short[:, icell] stim_nl = [] resp_nl = [] for ipercentile in range(3, 97, 1): lb = np.percentile(stim_filter, ipercentile-3) ub = np.percentile(stim_filter, ipercentile+3) tms = np.logical_and(stim_filter >= lb, stim_filter < ub) stim_nl += [np.mean(stim_filter[tms])] resp_nl += [np.mean(resp_cell[tms])] stim_nl = np.array(stim_nl) resp_nl = np.array(resp_nl) popt, pcov = scipy.optimize.curve_fit(f_nl, stim_nl, resp_nl, p0=[1, 0, 0, 0]) nl_params_log += [popt] stim_resp_log += [[stim_nl, resp_nl]] nl_params_log = np.array(nl_params_log) np_params_mean_1 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==1, :], 0) np_params_mean_2 = np.mean(nl_params_log[np.squeeze(dat['cell_type'])==2, :], 0) nl_params_dict = {'nl_params_log': nl_params_log, 'np_params_mean_1': np_params_mean_1, 'np_params_mean_2': np_params_mean_2, 'stim_resp_log': stim_resp_log} save_dict.update({'nl_params_dict': nl_params_dict}) ''' # Visualize Non-linearities for icell in range(n_cells): stim_in = np.arange(-3, 3, 0.1) fr = f_nl(stim_in, *nl_params_log[icell, :]) if np.squeeze(dat['cell_type'])[icell] == 1: c = 'r' else: c = 'k' plt.plot(stim_in, fr, c, alpha=0.2) fr = f_nl(stim_in, *np_params_mean_1) plt.plot(stim_in, fr, 'r--') fr = f_nl(stim_in, *np_params_mean_2) plt.plot(stim_in, fr, 'k--') ''' pickle.dump(save_dict, gfile.Open(os.path.join(FLAGS.save_folder , dat['piece']), 'w')) pickle.dump(stas_dict, gfile.Open(os.path.join(FLAGS.save_folder , 'stas' + dat['piece']), 'w'))
def main(argv): print('\nCode started') np.random.seed(FLAGS.np_randseed) random.seed(FLAGS.randseed) global chosen_mask global cells_choose ## Load data summary filename = FLAGS.data_location + 'data_details.mat' summary_file = gfile.Open(filename, 'r') data_summary = sio.loadmat(summary_file) cells = np.squeeze(data_summary['cells']) cells_choose = (cells == 3287) | (cells == 3318) | (cells == 3155) | ( cells == 3066) n_cells = np.sum(cells_choose) tot_spks = np.squeeze(data_summary['tot_spks']) total_mask = np.squeeze(data_summary['totalMaskAccept_log']).T tot_spks_chosen_cells = tot_spks[cells_choose] chosen_mask = np.array(np.sum(total_mask[cells_choose, :], 0) > 0, dtype='bool') print(np.shape(chosen_mask)) print(np.sum(chosen_mask)) stim_dim = np.sum(chosen_mask) # get test data stim_test, resp_test, test_length = get_test_data() print('\ndataset summary loaded') # use stim_dim, chosen_mask, cells_choose, tot_spks_chosen_cells, n_cells # decide the number of subunits to fit n_su = FLAGS.ratio_SU * n_cells # saving details #short_filename = 'data_model=ASM_pop_bg' # short_filename = ('data_model=ASM_pop_batch_sz='+ str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + # '_step_sz'+ str(FLAGS.step_sz)+'_bg') # saving details batchsz = np.array([1000, 1000, 1000], dtype='int') n_b_in_c = np.array([1, 1, 1], dtype='int') step_sz = np.array([1, 1, 1], dtype='float32') folder_names = ['experiment25', 'experiment22', 'experiment23'] roc_data = [[]] * n_cells for icnt, FLAGS.model_id in enumerate(['hinge', 'poisson', 'logistic']): # restore file FLAGS.batchsz = batchsz[icnt] FLAGS.n_b_in_c = n_b_in_c[icnt] FLAGS.step_sz = step_sz[icnt] folder_name = folder_names[icnt] if FLAGS.model_id == 'poisson': short_filename = ('data_model=ASM_pop_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'poisson_full': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge': short_filename = ('data_model=' + str(FLAGS.model_id) + '_batch_sz=' + str(FLAGS.batchsz) + '_n_b_in_c' + str(FLAGS.n_b_in_c) + '_step_sz' + str(FLAGS.step_sz) + '_bg') print(FLAGS.model_id) parent_folder = FLAGS.save_location + folder_name + '/' save_location = parent_folder + short_filename + '/' restore_file = get_latest_file(save_location, short_filename) tf.reset_default_graph() with tf.Session() as sess: # Learn population model! stim = tf.placeholder(tf.float32, shape=[None, stim_dim], name='stim') resp = tf.placeholder(tf.float32, name='resp') data_len = tf.placeholder(tf.float32, name='data_len') # define models if FLAGS.model_id == 'poisson' or FLAGS.model_id == 'poisson_full': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_cells, 1, n_su), dtype='float32')) z = tf.transpose( tf.reduce_sum(tf.exp(tf.matmul(stim, w) + a), 2)) if FLAGS.model_id == 'logistic' or FLAGS.model_id == 'hinge': w = tf.Variable( np.array(0.01 * np.random.randn(stim_dim, n_su), dtype='float32')) a = tf.Variable( np.array(0.01 * np.random.rand(n_su, n_cells), dtype='float32')) b_init = np.random.randn( n_cells ) #np.log((np.sum(response,0))/(response.shape[0]-np.sum(response,0))) b = tf.Variable(b_init, dtype='float32') z = tf.matmul(tf.nn.relu(tf.matmul(stim, w)), a) + b # restore variables # load tensorflow variables print(tf.all_variables()) saver_var = tf.train.Saver(tf.all_variables()) saver_var.restore(sess, restore_file) fd_test = {stim: stim_test, resp: resp_test, data_len: test_length} z_eval = sess.run(z, feed_dict=fd_test) print(z_eval[0:20, :]) print(resp_test[0:20, :]) for roc_cell in np.arange(n_cells): roc = get_roc(z_eval[:, roc_cell], resp_test[:, roc_cell]) roc_data[roc_cell] = roc_data[roc_cell] + [roc] print(roc_cell) plt.figure() for icell in range(n_cells): plt.subplot(1, n_cells, icell + 1) for icnt in np.arange(3): plt.plot(roc_data[icell][icnt][0], roc_data[icell][icnt][1]) plt.hold(True) plt.xlabel('recall') plt.ylabel('precision') plt.legend(['hinge', 'poisson', 'logistic']) cells_ch = cells[cells_choose] plt.title(cells_ch[icell]) plt.show() plt.draw()
def test_metric(training_datasets, testing_datasets, responses, stimuli, sr_graph, sess, file_name): print('Testing for metric learning') tf.logging.info('Testing for metric learning') # saving filename. save_analysis_filename = os.path.join(FLAGS.save_folder, file_name + '_analysis_sample_resps') save_dict = {} #if FLAGS.taskid in [2, 8, 14, 20]: # return retina_ids = [resp['piece'] for resp in responses] ''' # NULL + SUBUNITS: Test subunits on null stimulus # Responses with different number of subunits - which is closer to stimulus? # Take embedding of population responses to same stimulus but # different number of subunits and see which is closer to the stimulus. rand_number = np.random.rand(2) dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets} su_analysis = experiments.subunit_discriminability_null(dataset_dict, stimuli, responses, sr_graph, num_examples=10000) su_analysis.update({'rand_key': rand_number}) su_analysis.update({'retina_ids': retina_ids}) pickle.dump(su_analysis, gfile.Open(save_analysis_filename + '_test14_su_analysis_null.pkl', 'w')) # SUBUNITS: Responses with different number of subunits - which is closer to stimulus? # Take embedding of population responses to same stimulus but # different number of subunits and see which is closer to the stimulus. rand_number = np.random.rand(2) dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets} su_analysis = experiments.subunit_discriminability(dataset_dict, stimuli, responses, sr_graph, num_examples=10000) su_analysis.update({'rand_key': rand_number}) su_analysis.update({'retina_ids': retina_ids}) pickle.dump(su_analysis, gfile.Open(save_analysis_filename + '_test12_su_analysis.pkl', 'w')) ''' # Embed stimuli-response + decode # WN for stim_key in stimuli.keys(): expt = experiments.stimulus_response_embedding_expt stim_resp_embedding = expt(stimuli, responses, sr_graph, stim_key=stim_key, n_samples=1000) save_dict.update({'stim_resp_embedding_wn': stim_resp_embedding}) save_dict.update({'retina_ids': retina_ids}) pickle.dump( stim_resp_embedding, gfile.Open((save_analysis_filename + '_test3_%s.pkl') % stim_key, 'w')) ''' # stimulus + response embed for stim_key in stimuli.keys(): expt = experiments.stimulus_response_embedding_expt stim_resp_embedding = expt(stimuli, responses, sr_graph, stim_key=stim_key, n_samples=1000, if_continuous=True) save_dict.update({'stim_resp_embedding_wn': stim_resp_embedding}) save_dict.update({'retina_ids': retina_ids}) pickle.dump(stim_resp_embedding, gfile.Open((save_analysis_filename + '_test3_continuous_%s.pkl') % stim_key, 'w')) ''' # ROC analysis on training and testing datasets rand_number = np.random.rand(2) dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets} roc_analysis = experiments.roc_analysis(dataset_dict, stimuli, responses, sr_graph, num_examples=10000) roc_analysis.update({'rand_key': rand_number}) roc_analysis.update({'retina_ids': retina_ids}) pickle.dump(roc_analysis, gfile.Open(save_analysis_filename + '_test4.pkl', 'w')) # ROC analysis with subsampling of cells # dataset_dict = {'train_sr': training_datasets, 'test_sr': testing_datasets} dataset_dict = {'test_sr': testing_datasets} if len(training_datasets) == 1: dataset_dict.update({'train_sr': training_datasets}) roc_frac_cells_dict = {} for frac_cells in [0.05, 0.1, 0.15, 0.2, 0.5, 0.8, 1.0]: print('frac_cells : %.3f' % frac_cells) roc_analysis = experiments.roc_analysis(dataset_dict, stimuli, responses, sr_graph, num_examples=10000, frac_cells=frac_cells) roc_frac_cells_dict.update({frac_cells: roc_analysis}) roc_frac_cells_dict.update({'retina_ids': retina_ids}) pickle.dump( roc_frac_cells_dict, gfile.Open(save_analysis_filename + '_test4_frac_cells_new.pkl', 'w')) print(save_analysis_filename + '_test4_frac_cells_new.pkl') # ROC analysis - RSS triplets - on training and testing datasets rand_number = np.random.rand(2) dataset_dict = {'test_sr': testing_datasets, 'train_sr': training_datasets} roc_analysis = experiments.roc_analysis(dataset_dict, stimuli, responses, sr_graph, num_examples=10000, negative_stim=True) roc_analysis.update({'rand_key': rand_number}) roc_analysis.update({'retina_ids': retina_ids}) pickle.dump(roc_analysis, gfile.Open(save_analysis_filename + '_test4_rss.pkl', 'w')) ## Response transformations expt = experiments.response_transformation_increase_nl nl_expt = expt(stimuli, responses, sr_graph, time_start_list=[100, 4000, 10000], time_len=100, alpha_list=[1.5, 1.25, 0.8, 0.6]) nl_expt.update({'retina_ids': retina_ids}) pickle.dump( nl_expt, gfile.Open(save_analysis_filename + '_test7_resp_transform_wn.pkl', 'w')) ## Embed all stimuli. ''' stim_embedding_dict = experiments.stimulus_embedding_expt(stimuli, sr_graph, n_stims_per_type=1000) save_dict.update({'stimulus_embedding': stim_embedding_dict}) pickle.dump(stim_embedding_dict, gfile.Open(save_analysis_filename + '_test1.pkl', 'w')) print('Saved after test 1') ''' ## Explore invariances of the stimulus embedding. # Change luminance, contrast, geometrical changes (translate, rotate) and see # how they are reflected in embedded space. # # Get stimuli which are transformed stim_transformations_dict = experiments.stimulus_transformations_expt( stimuli, sr_graph, n_stims_per_type_transform=50, n_stims_per_type_bkg=500) # 2000 save_dict.update({'stimulus_transformations': stim_transformations_dict}) pickle.dump(stim_transformations_dict, gfile.Open(save_analysis_filename + '_test2.pkl', 'w')) print('Saved after test 2') # NSEM ''' stim_resp_embedding = expt(stimuli, responses, sr_graph, stim_key='stim_2', n_samples=100) save_dict.update({'stim_resp_embedding_nsem': stim_resp_embedding}) save_dict.update({'retina_ids': retina_ids}) pickle.dump(stim_resp_embedding, gfile.Open(save_analysis_filename + '_test3_nsem.pkl', 'w')) print('Saved after test 3') ''' # Drop different cell types expt = experiments.resp_drop_cells_expt for stim_key in stimuli.keys(): resp_drop_cells = expt(stimuli, responses, sr_graph, stim_key=stim_key, n_samples=100) save_dict.update({'resp_drop_cells': resp_drop_cells}) save_dict.update({'retina_ids': retina_ids}) pickle.dump( resp_drop_cells, gfile.Open(save_analysis_filename + '_test8_%s.pkl' % stim_key, 'w')) ## TODO(bhaishahster): joint-auto-embed model response prediction for some stimuli # Rasters across different retinas embedded in same space. expt = experiments.resp_wn_repeats_multiple_retina resp_repeats = expt(sr_graph, n_samples=100) pickle.dump(resp_repeats, gfile.Open(save_analysis_filename + '_test9_repeats.pkl', 'w')) # Predict responses by embedding from rasters of different retina. expt = experiments.resp_wn_repeats_multiple_retina_encoding resp_repeats, response_embeddings, cell_geometry_log = expt( sr_graph, n_repeats=10) # n_repeats=None for using all repeats pickle.dump( resp_repeats, gfile.Open(save_analysis_filename + '_test10_repeats_prediction.pkl', 'w')) # Interpolate between retinas expt = experiments.resp_wn_repeats_interpolate_embeddings interpolate_dict = expt( sr_graph, [ri[:7, 600:, :, :, :] for ri in response_embeddings]) pickle.dump([interpolate_dict, cell_geometry_log], gfile.Open( save_analysis_filename + '_test11_repeats_pred_interpolation.pkl', 'w')) # Interpolate between retinas - mean of embedding across repeats expt = experiments.resp_wn_repeats_interpolate_embeddings interpolate_dict = expt(sr_graph, [ np.expand_dims(ri[:, 600:, :, :, :].mean(0), 0) for ri in response_embeddings ]) pickle.dump([interpolate_dict, cell_geometry_log], gfile.Open( save_analysis_filename + '_test11_repeats_pred_interpolation_mean.pkl', 'w')) return ''' # Load multiple check points and analyse accuracy # /retina/response_model/python/metric_learning/end_to_end/stimulus_response_embedding --logtostderr --mode=1 --taskid=2 --save_suffix='_stim-resp_wn_nsem' --stim_layers='1, 5, 1, 3, 64, 1, 3, 64, 1, 3, 64, 1, 3, 64, 2, 3, 64, 2, 3, 1, 1' --resp_layers='3, 64, 1, 3, 64, 1, 3, 64, 1, 3, 64, 2, 3, 64, 2, 3, 1, 1' --batch_norm=True --save_folder='//home/bhaishahster/end_to_end_feb_5_6PM' --learning_rate=0.001 --batch_train_sz=100 --batch_neg_train_sz=100 --sr_model='convolutional_embedding' for frac_cells in [0.2, 1.0, 0.5]: dataset_dict = {'train_sr': training_datasets} saver = tf.train.Saver() filename = bookkeeping.get_filename(training_datasets, testing_datasets, FLAGS.beta, FLAGS.sr_model) long_filename = os.path.join(FLAGS.save_folder, filename) checkpoints = gfile.Glob(long_filename + '*.meta') checkpoints = [cpts[:-5] for cpts in checkpoints] roc_dict = {} for cpts in checkpoints: try: print(cpts) saver.restore(sess, cpts) iteration = int(cpts.split('/')[-1].split('-')[-1]) roc_analysis = experiments.roc_analysis(dataset_dict, stimuli, responses, sr_graph, num_examples=10000, frac_cells=frac_cells) roc_dict.update({iteration: roc_analysis}) except: pass pickle.dump(roc_dict, gfile.Open((save_analysis_filename + '_test6_frac_cells_%.2f.pkl') % frac_cells, 'w')) # ROC analysis with negatives generated at small difference to positives. dataset_dict = {'train_sr': training_datasets} roc_delta_t_dict = {} for delta_t in [1, 2, 3, 4, 5]: print('Delta t : %d' % delta_t) roc_analysis = experiments.roc_analysis(dataset_dict, stimuli, responses, sr_graph, num_examples=10000, delta_t=delta_t) roc_delta_t_dict.update({delta_t: roc_analysis}) pickle.dump(roc_delta_t_dict, gfile.Open(save_analysis_filename + '_test4_delta_t.pkl', 'w')) ''' ## ROC curves of responses from repeats - dataset 1 ''' repeats_datafile = '/home/bhaishahster/metric_learning/datasets/2015-09-23-7.mat' repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r')); repeats_data['cell_type'] = repeats_data['cell_type'].T # process repeats data num_cell_types = 2 dimx = 80 dimy = 40 data_util.process_dataset(repeats_data, dimx, dimy, num_cell_types) # analyse and store the result test_reps = analyse_response_repeats(repeats_data, sr_graph.anchor_model, sr_graph.neg_model, sr_graph.sess) save_dict.update({'test_reps_2015-09-23-7': test_reps}) pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w')) ''' ## ROC curves of responses from repeats - dataset 2 ''' repeats_datafile = '/home/bhaishahster/metric_learning/examples_pc2005_08_03_0/data005_test.mat' repeats_data = sio.loadmat(gfile.Open(repeats_datafile, 'r')); data_util.process_dataset(repeats_data, dimx, dimy, num_cell_types) # analyse and store the result test_clustering = analyse_response_repeats_all_trials(repeats_data, sr_graph.anchor_model, sr_graph.neg_model, sr_graph.sess) save_dict.update({'test_reps_2005_08_03_0': test_clustering}) pickle.dump(save_dict, gfile.Open(save_analysis_filename, 'w')) ''' # # get model params model_pars_dict = { 'model_pars': sr_graph.sess.run(tf.trainable_variables()) } pickle.dump(model_pars_dict, gfile.Open(save_analysis_filename + '_test5.pkl', 'w')) print(save_analysis_filename)
def main(argv): np.random.seed(23) # Figure out dictionary path. dict_list = gfile.ListDirectory(FLAGS.src_dict) dict_path = os.path.join(FLAGS.src_dict, dict_list[FLAGS.taskid]) # Load the dictionary if dict_path[-3:] == 'pkl': data = pickle.load(gfile.Open(dict_path, 'r')) if dict_path[-3:] == 'mat': data = sio.loadmat(gfile.Open(dict_path, 'r')) #FLAGS.save_dir = '/home/bhaishahster/stimulation_algos/dictionaries/' + dict_list[FLAGS.taskid][:-4] FLAGS.save_dir = FLAGS.save_dir + dict_list[FLAGS.taskid][:-4] if not gfile.Exists(FLAGS.save_dir): gfile.MkDir(FLAGS.save_dir) # S_collection = data['S'] # Target A = data['A'] # Decoder D = data['D'].T # Dictionary # clean dictionary thr_list = np.arange(0, 1, 0.01) dict_val = [] for thr in thr_list: dict_val += [np.sum(np.sum(D.T > thr, 1) != 0)] plt.ion() plt.figure() plt.plot(thr_list, dict_val) plt.xlabel('Threshold') plt.ylabel( 'Number of dictionary elements with \n atleast one element above threshold' ) plt.title('Please choose threshold') thr_use = float(input('What threshold to use?')) plt.axvline(thr_use) plt.title('Using threshold: %.5f' % thr_use) dict_valid = np.sum(D.T > thr_use, 1) > 0 D = D[:, dict_valid] D = np.append(D, np.zeros((D.shape[0], 1)), 1) print( 'Appending a "dummy" dictionary element that does not activate any cell' ) # Vary stimulus resolution for itarget in range(20): n_targets = 1 for stix_resolution in [32, 64, 16, 8]: # Get the target x_dim = int(640 / stix_resolution) y_dim = int(320 / stix_resolution) targets = (np.random.rand(y_dim, x_dim, n_targets) < 0.5) - 0.5 upscale = stix_resolution / 8 targets = np.repeat(np.repeat(targets, upscale, axis=0), upscale, axis=1) targets = np.reshape(targets, [-1, targets.shape[-1]]) S_actual = targets[:, 0] # Remove null component of A from S S = A.dot(np.linalg.pinv(A).dot(S_actual)) # Run Greedy first to initialize x_greedy = greedy_stimulation(S, A, D, max_stims=FLAGS.t_max * FLAGS.delta, file_suffix='%d_%d' % (stix_resolution, itarget), save=True, save_dir=FLAGS.save_dir) # Load greedy output from previous run #data_greedy = pickle.load(gfile.Open('/home/bhaishahster/greedy_2000_32_0.pkl', 'r')) #x_greedy = data_greedy['x_chosen'] # Plan for multiple time points x_init = np.zeros((x_greedy.shape[0], FLAGS.t_max)) for it in range(FLAGS.t_max): print((it + 1) * FLAGS.delta - 1) x_init[:, it] = x_greedy[:, (it + 1) * FLAGS.delta - 1] #simultaneous_planning(S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, # delta=FLAGS.delta, normalization=FLAGS.normalization, # file_suffix='%d_%d_normal' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir) from IPython import embed embed() simultaneous_planning_interleaved_discretization( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_pgd' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=np.inf, steps_max=500 * 20 - 1) # Interleaved discretization. simultaneous_planning_interleaved_discretization( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_pgd_od' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=500, steps_max=500 * 20 - 1) # Exponential weighing. simultaneous_planning_interleaved_discretization_exp_gradient( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_ew' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=np.inf, steps_max=500 * 20 - 1) # Exponential weighing with interleaved discretization. simultaneous_planning_interleaved_discretization_exp_gradient( S, A, D, t_max=FLAGS.t_max, lr=FLAGS.learning_rate, delta=FLAGS.delta, normalization=FLAGS.normalization, file_suffix='%d_%d_ew_od' % (stix_resolution, itarget), x_init=x_init, save_dir=FLAGS.save_dir, freeze_freq=500, steps_max=500 * 20 - 1) ''' # Plot results data_fractional = pickle.load(gfile.Open('/home/bhaishahster/2012-09-24-0_SAD_fr1/pgd_20_5.000000_C_100_32_0_normal.pkl', 'r')) plt.ion() start = 0 end = 20 error_target_frac = np.linalg.norm(S - data_fractional['S']) #error_target_greedy = np.linalg.norm(S - data_greedy['S']) print('Did target change? \n Err from fractional: %.3f ' % error_target_frac) #'\n Err from greedy %.3f' % (error_target_frac, # error_target_greedy)) # from IPython import embed; embed() normalize = np.sum(data_fractional['S'] ** 2) plt.ion() plt.plot(np.arange(120, len(data_fractional['f_log'])), data_fractional['f_log'][120:] / normalize, 'b') plt.axhline(data_fractional['errors'][5:].mean() / normalize, color='m') plt.axhline(data_fractional['errors_ht_discrete'][5:].mean() / normalize, color='y') plt.axhline(data_fractional['errors_rr_discrete'][5:].mean() / normalize, color='r') plt.axhline(data_greedy['error_curve'][120:].mean() / normalize, color='k') plt.legend(['fraction_curve', 'fractional error', 'HT', 'RR', 'Greedy']) plt.pause(1.0) ''' from IPython import embed embed()