def main(argv): # copy data dst = os.path.join(FLAGS.tmp_dir, 'Off_parasol.mat') if not gfile.Exists(dst): print('Started Copy') src = os.path.join(FLAGS.src_dir, 'Off_parasol.mat') if not gfile.IsDirectory(FLAGS.tmp_dir): gfile.MkDir(FLAGS.tmp_dir) gfile.Copy(src, dst) print('File copied to destination') else: print('File exists') # load stimulus file_data = h5py.File(dst, 'r') # Load Masked movie data = file_data.get('maskedMovdd') stimulus = np.array(data) # load cell response cells = file_data.get('cells') cells = np.array(cells) cells = np.squeeze(cells) ttf_log = file_data.get('ttf_log') ttf_avg = file_data.get('ttf_avg') # Load spike Response of cells data = file_data.get('Y') responses = np.array(data) # get mask total_mask_log = np.array(file_data.get('totalMaskAccept_log')) print('Got data') # read line corresponding to task with gfile.Open(FLAGS.task_params_file, 'r') as f: for _ in range(FLAGS.taskid + 1): line = f.readline() line = line[:-1] # Remove \n from end. print(line) # get task parameters by parsing the lines line_split = line.split(';') cell_idx = line_split[0] cell_idx = cell_idx[1:-1].split(',') cell_idx = [int(i) for i in cell_idx] nsub = int(line_split[1]) projection_type = line_split[2] lam_proj = float(line_split[3]) partitions_fit = line_split[4] partitions_fit = partitions_fit[1:-1].split(',') partitions_fit = [int(i) for i in partitions_fit] if len(line_split) == 5: cell_idx_mask = cell_idx else: cell_idx_mask = line_split[5] cell_idx_mask = cell_idx_mask[1:-1].split(',') cell_idx_mask = [int(i) for i in cell_idx] ## print(cell_idx) print(nsub) print(cell_idx_mask) mask = (total_mask_log[cell_idx_mask, :].sum(0) != 0) mask_matrix = np.reshape(mask != 0, [40, 80]) # make mask bigger - add one row one left/right r, c = np.where(mask_matrix) mask_matrix[r.min() - 1:r.max() + 1, c.min() - 1:c.max() + 1] = True neighbor_mat = su_model.get_neighbormat(mask_matrix, nbd=1) mask = np.ndarray.flatten(mask_matrix) stim_use = stimulus[:, mask] resp_use = responses[:, cell_idx] print('Prepared data') # get last 10% as test data np.random.seed(23) frac_test = 0.1 tms_test = np.arange(np.floor(stim_use.shape[0] * (1 - frac_test)), 1 * np.floor(stim_use.shape[0])).astype(np.int) # Random partitions n_partitions = 10 tms_train_validate = np.arange( 0, np.floor(stim_use.shape[0] * (1 - frac_test))).astype(np.int) frac_validate = 0.1 partitions = [] for _ in range(n_partitions): perm = np.random.permutation(tms_train_validate) tms_train = perm[0:np.int(np.floor((1 - frac_validate) * perm.shape[0]))] tms_validate = perm[np. int(np.floor((1 - frac_validate) * perm.shape[0])):np.int(perm.shape[0])] partitions += [{ 'tms_train': tms_train, 'tms_validate': tms_validate, 'tms_test': tms_test }] print('Made partitions') # Do fitting # tms_train = np.arange(0, np.floor(stim_use.shape[0] * 0.8)).astype(np.int) # tms_test = np.arange(np.floor(stim_use.shape[0] * 0.8), # 1 * np.floor(stim_use.shape[0] * 0.9)).astype(np.int) ss = '_'.join([str(cells[ic]) for ic in cell_idx]) for ipartition in partitions_fit: save_filename = os.path.join( FLAGS.save_path, 'Cell_%s_nsub_%d_%s_%.3f_part_%d_jnt.pkl' % (ss, nsub, projection_type, lam_proj, ipartition)) save_filename_partial = os.path.join( FLAGS.save_path_partial, 'Cell_%s_nsub_%d_%s_%.3f_part_%d_jnt.p' 'kl' % (ss, nsub, projection_type, lam_proj, ipartition)) if not gfile.Exists(save_filename): print('Fitting started') op = su_model.Flat_clustering_jnt( stim_use, resp_use, nsub, partitions[ipartition]['tms_train'], partitions[ipartition]['tms_validate'], steps_max=10000, eps=1e-9, projection_type=projection_type, neighbor_mat=neighbor_mat, lam_proj=lam_proj, eps_proj=0.01, save_filename_partial=save_filename_partial) k, b, _, lam_log, lam_log_test, fitting_phase, fit_params = op tms_test = partitions[ipartition]['tms_test'] lam_test_on_test_data = su_model.compute_fr_loss( fit_params[2][0], fit_params[2][1], stim_use[tms_test, :], resp_use[tms_test, :], nl_params=fit_params[2][2]) print('Fitting done') save_dict = { 'K': k, 'b': b, 'lam_log': lam_log, 'lam_log_validate': lam_log_test, 'lam_test': lam_test_on_test_data, 'fitting_phase': fitting_phase, 'fit_params': fit_params } pickle.dump(save_dict, gfile.Open(save_filename, 'w')) print('Saved results')
def main(argv): # copy WN data dst = os.path.join(FLAGS.tmp_dir, 'Off_parasol.mat') if not gfile.Exists(dst): print('Started Copy') src = os.path.join(FLAGS.src_dir, 'Off_parasol.mat') if not gfile.IsDirectory(FLAGS.tmp_dir): gfile.MkDir(FLAGS.tmp_dir) gfile.Copy(src, dst) print('File copied to destination') else: print('File exists') # load stimulus file = h5py.File(dst, 'r') # Load Masked movie data = file.get('maskedMovdd') stimulus = np.array(data) # load cell response cells = file.get('cells') cells = np.array(cells) cells = np.squeeze(cells) ttf_log = file.get('ttf_log') ttf_avg = file.get('ttf_avg') # Load spike Response of cells data = file.get('Y') responses = np.array(data) # get mask total_mask_log = np.array(file.get('totalMaskAccept_log')) print('Got WN data') # Get NULL data dat_null = sio.loadmat( gfile.Open( os.path.join(FLAGS.src_dir, 'OFF_parasol_trial' '_resp_data.mat'), 'r')) # Load Masked movie cids = np.squeeze(np.array(dat_null['cids'])) condition_idx = FLAGS.datarun stimulus_null = dat_null['condMov'][condition_idx, 0] stimulus_null = np.transpose(stimulus_null, [2, 1, 0]) stimulus_null = np.reshape(stimulus_null, [stimulus_null.shape[0], -1]) resp_cell_log = dat_null['resp_cell_log'] print('Got Null data') # read line corresponding to task with gfile.Open(FLAGS.task_params_file, 'r') as f: for itask in range(FLAGS.taskid + 1): line = f.readline() line = line[:-1] # Remove \n from end. print(line) # get task parameters by parsing the lines line_split = line.split(';') cell_idx = line_split[0] cell_idx = cell_idx[1:-1].split(',') cell_idx = [int(i) for i in cell_idx] Nsub = int(line_split[1]) projection_type = line_split[2] lam_proj = float(line_split[3]) #ipartition = int(line_split[4]) #cell_idx_mask = cell_idx partitions_fit = line_split[4] partitions_fit = partitions_fit[1:-1].split(',') partitions_fit = [int(i) for i in partitions_fit] if len(line_split) == 5: cell_idx_mask = cell_idx else: cell_idx_mask = line_split[5] cell_idx_mask = cell_idx_mask[1:-1].split(',') cell_idx_mask = [int(i) for i in cell_idx] ## ## print(cell_idx) print(Nsub) print(cell_idx_mask) mask = (total_mask_log[cell_idx_mask, :].sum(0) != 0) mask_matrix = np.reshape(mask != 0, [40, 80]) # make mask bigger - add one row one left/right r, c = np.where(mask_matrix) mask_matrix[r.min() - 1:r.max() + 1, c.min() - 1:c.max() + 1] = True neighbor_mat = su_model.get_neighbormat(mask_matrix, nbd=1) mask = np.ndarray.flatten(mask_matrix) ## WN preprocess stim_use_wn = stimulus[:, mask] resp_use_wn = responses[:, cell_idx] # get last 10% as test data np.random.seed(23) frac_test = 0.1 tms_test = np.arange(np.floor(stim_use_wn.shape[0] * (1 - frac_test)), 1 * np.floor(stim_use_wn.shape[0])).astype(np.int) # Random partitions n_partitions = 10 tms_train_validate = np.arange( 0, np.floor(stim_use_wn.shape[0] * (1 - frac_test))).astype(np.int) frac_validate = 0.1 partitions_wn = [] for _ in range(n_partitions): perm = np.random.permutation(tms_train_validate) tms_train = perm[0:np.int(np.floor((1 - frac_validate) * perm.shape[0]))] tms_validate = perm[np. int(np.floor((1 - frac_validate) * perm.shape[0])):np.int(perm.shape[0])] partitions_wn += [{ 'tms_train': tms_train, 'tms_validate': tms_validate, 'tms_test': tms_test }] print('Made partitions') print('WN data preprocessed') ## NULL preprocess stimulus_use_null = stimulus_null[:, mask] ttf_use = np.array(ttf_log[cell_idx, :]).astype(np.float32).squeeze() stimulus_use_null = filterMov_time(stimulus_use_null, ttf_use) if len(cell_idx) > 1: print('More than 1 cell not supported!') try: resp_use_null = np.array(resp_cell_log[cell_idx[0], 0][condition_idx, 0]).T.astype(np.float32) except: resp_use_null = np.array(resp_cell_log[cell_idx[0], 0][0, condition_idx].T).astype( np.float32) # Remove first 30 frames due to convolution artifact. stimulus_use_null = stimulus_use_null[30:, :] resp_use_null = resp_use_null[30:, :] n_trials = resp_use_null.shape[1] t_null = resp_use_null.shape[0] tms_train_1tr_null = np.arange(np.floor(t_null / 2)).astype(np.int) tms_test_1tr_null = np.arange(np.ceil(t_null / 2), t_null).astype(np.int) # repeat in time dimension, divide into training and testing. stimulus_use_null = np.tile(stimulus_use_null.T, n_trials).T resp_use_null = np.ndarray.flatten(resp_use_null.T) resp_use_null = np.expand_dims(resp_use_null, 1) tms_train_null = np.array([]) tms_test_null = np.array([]) for itrial in range(n_trials): tms_train_null = np.append(tms_train_null, tms_train_1tr_null + itrial * t_null) tms_test_null = np.append(tms_test_null, tms_test_1tr_null + itrial * t_null) tms_train_null = tms_train_null.astype(np.int) tms_test_null = tms_test_null.astype(np.int) print('NULL data preprocessed') ss = '_'.join([str(cells[ic]) for ic in cell_idx]) for ipartition in partitions_fit: save_filename = os.path.join( FLAGS.save_path, 'Cell_%s_nsub_%d_%s_%.3f_part_%d_jnt.pkl' % (ss, Nsub, projection_type, lam_proj, ipartition)) save_filename_partial = os.path.join( FLAGS.save_path_partial, 'Cell_%s_nsub_%d_%s_%.3f_part_%d_' 'jnt.pkl' % (ss, Nsub, projection_type, lam_proj, ipartition)) ## Do fitting # Fit SU on WN if not gfile.Exists(save_filename): print('Fitting started on WN') op = su_model.Flat_clustering_jnt( stim_use_wn, resp_use_wn, Nsub, partitions_wn[ipartition]['tms_train'], partitions_wn[ipartition]['tms_validate'], steps_max=10000, eps=1e-9, projection_type=projection_type, neighbor_mat=neighbor_mat, lam_proj=lam_proj, eps_proj=0.01, save_filename_partial=save_filename_partial, fitting_phases=[1]) _, _, alpha, lam_log_wn, lam_log_test_wn, fitting_phase, fit_params_wn = op print('Fitting done on WN') # Fit on NULL op = su_model.fit_scales(stimulus_use_null[tms_train_null, :], resp_use_null[tms_train_null, :], stimulus_use_null[tms_test_null, :], resp_use_null[tms_test_null, :], Ns=Nsub, K=fit_params_wn[0][0], b=fit_params_wn[0][1], params=fit_params_wn[0][2], lr=0.01, eps=1e-9) k_null, b_null, nl_params_null, lam_log_null, lam_log_test_null = op # Collect results and save fit_params = fit_params_wn + [[k_null, b_null, nl_params_null]] lam_log = [lam_log_wn, np.array(lam_log_null)] lam_log_test = [lam_log_test_wn, np.array(lam_log_test_null)] save_dict = { 'lam_log': lam_log, 'lam_log_test': lam_log_test, 'fit_params': fit_params } pickle.dump(save_dict, gfile.Open(save_filename, 'w')) print('Saved results')
def main(argv): # read line corresponding to task with gfile.Open(FLAGS.task_params_file, 'r') as f: for _ in range(FLAGS.taskid + 1): line = f.readline() line = line[:-1] # Remove \n from end. print(line) # get task parameters by parsing the lines line_split = line.split(';') cell_idx = line_split[0] cell_idx = cell_idx[1:-1].split(',') cell_idx = int(cell_idx[0]) file_list = gfile.ListDirectory(FLAGS.src_dir) cell_file = file_list[cell_idx] print('Cell file %s' % cell_file) nsub = int(line_split[1]) projection_type = line_split[2] lam_proj = float(line_split[3]) # copy data dst = os.path.join(FLAGS.tmp_dir, cell_file) if not gfile.Exists(dst): print('Started Copy') src = os.path.join(FLAGS.src_dir, cell_file) if not gfile.IsDirectory(FLAGS.tmp_dir): gfile.MkDir(FLAGS.tmp_dir) gfile.Copy(src, dst) print('File copied to destination') else: print('File exists') # load stimulus, response data try: data = sio.loadmat(dst) trainMov_filterNSEM = data['trainMov_filterNSEM'] testMov_filterNSEM = data['testMov_filterNSEM'] trainSpksNSEM = data['trainSpksNSEM'] testSpksNSEM = data['testSpksNSEM'] mask = data['mask'] neighbor_mat = su_model.get_neighbormat(mask, nbd=1) trainMov_filterWN = data['trainMov_filterWN'] testMov_filterWN = data['testMov_filterWN'] trainSpksWN = data['trainSpksWN'] testSpksWN = data['testSpksWN'] # get NSEM stimulus and resposne stimulus_WN = np.array(trainMov_filterWN.transpose(), dtype='float32') response_WN = np.array(np.squeeze(trainSpksWN), dtype='float32') stimulus_NSEM = np.array(trainMov_filterNSEM.transpose(), dtype='float32') response_NSEM = np.array(np.squeeze(trainSpksNSEM), dtype='float32') print('Prepared data') # Do fitting # set random seed. np.random.seed(23) print('Made partitions') # Do fitting # WN data ifrac = 0.8 tms_train_WN = np.arange(0, np.floor(stimulus_WN.shape[0] * ifrac)).astype(np.int) tms_test_WN = np.arange(np.floor(stimulus_WN.shape[0] * ifrac), 1 * np.floor(stimulus_WN.shape[0] * 1)).astype(np.int) # NSEM data ifrac = 0.8 tms_train_NSEM = np.arange(0, np.floor(stimulus_NSEM.shape[0] * ifrac)).astype(np.int) tms_test_NSEM = np.arange(np.floor(stimulus_NSEM.shape[0] * ifrac), 1 * np.floor(stimulus_NSEM.shape[0] * 1)).astype(np.int) # Give filename ss = str(cell_idx) save_filename = os.path.join(FLAGS.save_path, 'Cell_%s_nsub_%d_%s_%.3f_jnt.pkl' % (ss, nsub, projection_type, lam_proj)) save_filename_partial = os.path.join(FLAGS.save_path_partial, 'Cell_%s_nsub_%d_%s_%.3f_jnt.pkl' % (ss, nsub, projection_type, lam_proj)) ## Do fitting if not gfile.Exists(save_filename): # Fit SU on WN print('Fitting started on WN') op = su_model.Flat_clustering_jnt(stimulus_WN, np.expand_dims(response_WN, 1), nsub, tms_train_WN, tms_test_WN, steps_max=10000, eps=1e-9, projection_type=projection_type, neighbor_mat=neighbor_mat, lam_proj=lam_proj, eps_proj=0.01, save_filename_partial= save_filename_partial, fitting_phases=[1]) _, _, alpha, lam_log_wn, lam_log_test_wn, fitting_phase, fit_params_wn = op print('WN fit done') # Fit on NSEM op = su_model.fit_scales(stimulus_NSEM[tms_train_NSEM, :], np.expand_dims(response_NSEM[tms_train_NSEM], 1), stimulus_NSEM[tms_test_NSEM, :], np.expand_dims(response_NSEM[tms_test_NSEM], 1), Ns=nsub, K=fit_params_wn[0][0], b=fit_params_wn[0][1], params=fit_params_wn[0][2], lr=0.01, eps=1e-9) k_nsem, b_nsem, nl_params_nsem, lam_log_nsem, lam_log_test_nsem = op # Collect results and save fit_params = fit_params_wn + [[k_nsem, b_nsem, nl_params_nsem]] lam_log = [lam_log_wn, np.array(lam_log_nsem)] lam_log_test = [lam_log_test_wn, np.array(lam_log_test_nsem)] save_dict = {'lam_log': lam_log, 'lam_log_test': lam_log_test, 'fit_params': fit_params, 'mask': mask} pickle.dump(save_dict, gfile.Open(save_filename, 'w')) print('Saved results') except: print('Error')
def main(argv): # Copy data. dst = os.path.join(FLAGS.tmp_dir, 'Off_parasol.mat') if not gfile.Exists(dst): print('Started Copy') src = os.path.join(FLAGS.src_dir, 'Off_parasol.mat') if not gfile.IsDirectory(FLAGS.tmp_dir): gfile.MkDir(FLAGS.tmp_dir) gfile.Copy(src, dst) print('File copied to destination') else: print('File exists') # load stimulus file_data = h5py.File(dst, 'r') # Load Masked movie data = file_data.get('maskedMovdd') stimulus = np.array(data) # load cell response cells = file_data.get('cells') cells = np.array(cells) cells = np.squeeze(cells) ttf_log = file_data.get('ttf_log') ttf_avg = file_data.get('ttf_avg') # Load spike Response of cells data = file_data.get('Y') responses = np.array(data) # get mask total_mask_log = np.array(file_data.get('totalMaskAccept_log')) print('Got data') # read line corresponding to task with gfile.Open(FLAGS.task_params_file, 'r') as f: for _ in range(FLAGS.taskid + 1): line = f.readline() line = line[:-1] # Remove \n from end. print(line) # get task parameters by parsing the lines line_split = line.split(';') cell_idx = line_split[0] cell_idx = cell_idx[1:-1].split(',') cell_idx = [int(i) for i in cell_idx] model = line_split[1] if model == 'su': nsub = int(line_split[2]) projection_type = line_split[3] lam_proj = float(line_split[4]) if model == 'conv': dim_filters = int(line_split[2]) strides = int(line_split[3]) num_filters = int(line_split[4]) frac_train = float(line_split[5]) mask = (total_mask_log[cell_idx, :].sum(0) != 0) mask_matrix = np.reshape(mask != 0, [40, 80]) # make mask bigger - add one row one left/right r, c = np.where(mask_matrix) mask_matrix[r.min() - 1:r.max() + 1, c.min() - 1:c.max() + 1] = True neighbor_mat = su_model.get_neighbormat(mask_matrix, nbd=1) mask = np.ndarray.flatten(mask_matrix) stimulus_2d = np.reshape(stimulus, [-1, 40, 80]) stim_use_2d = stimulus_2d[:, r.min() - 1:r.max() + 1, c.min() - 1:c.max() + 1] stim_use = stimulus[:, mask] resp_use = responses[:, cell_idx] print('Prepared data') # get last 10% as test data np.random.seed(23) frac_test = 0.1 tms_test = np.arange(np.floor(stim_use.shape[0] * (1 - frac_test)), 1 * np.floor(stim_use.shape[0])).astype(np.int) # Random partitions n_partitions = 10 tms_train_validate = np.arange( 0, np.floor(stim_use.shape[0] * (1 - frac_test))).astype(np.int) frac_validate = 0.1 # 'frac_train' needs to be < 0.9 partitions = [] for _ in range(n_partitions): perm = np.random.permutation(tms_train_validate) tms_train = perm[0:np.floor(frac_train * perm.shape[0])] tms_validate = perm[np.floor((1 - frac_validate) * perm.shape[0]):perm.shape[0]] partitions += [{ 'tms_train': tms_train, 'tms_validate': tms_validate, 'tms_test': tms_test }] ipartition = 0 print('Made partitions') # Do fitting # tms_train = np.arange(0, np.floor(stim_use.shape[0] * 0.8)).astype(np.int) # tms_test = np.arange(np.floor(stim_use.shape[0] * 0.8), # 1 * np.floor(stim_use.shape[0] * 0.9)).astype(np.int) ss = '_'.join([str(cells[ic]) for ic in cell_idx]) if model == 'su': save_filename = os.path.join( FLAGS.save_path, 'Cell_%s_su_nsub_%d_%s_%.3f_frac_train' '_%.4f_jnt.pkl' % (ss, nsub, projection_type, lam_proj, frac_train)) save_filename_partial = os.path.join( FLAGS.save_path_partial, 'Cell_%s_su_nsub_%d_%s_%.3f_frac_' 'train_%.4f_jnt.pkl' % (ss, nsub, projection_type, lam_proj, frac_train)) if not gfile.Exists(save_filename): print('Fitting started for SU') op = su_model.Flat_clustering_jnt( stim_use, resp_use, nsub, partitions[ipartition]['tms_train'], partitions[ipartition]['tms_validate'], steps_max=10000, eps=1e-9, projection_type=projection_type, neighbor_mat=neighbor_mat, lam_proj=lam_proj, eps_proj=0.01, save_filename_partial=save_filename_partial) k, b, _, lam_log, lam_log_test, fitting_phase, fit_params = op print('Fitting done') save_dict = { 'K': k, 'b': b, 'lam_log': lam_log, 'lam_log_test': lam_log_test, 'fitting_phase': fitting_phase, 'fit_params': fit_params } pickle.dump(save_dict, gfile.Open(save_filename, 'w')) print('Saved results') if model == 'conv': save_filename = os.path.join( FLAGS.save_path, 'Cell_%s_conv_dim_filter_%d_strides_' '%d_num_filters_%d_frac_train_%.4f_jnt.pkl' % (ss, dim_filters, strides, num_filters, frac_train)) if not gfile.Exists(save_filename): print('Fitting started for CONV') op = conv_model.convolutional_1layer( stim_use_2d, np.squeeze(resp_use), partitions[ipartition]['tms_train'], partitions[ipartition]['tms_validate'], dim_filters=dim_filters, num_filters=num_filters, lr=0.1, num_steps_max=100000, strides=strides, eps=1e-9) loss_train_log, loss_test_log, model_params = op print('Convolutional model fitting done') save_dict = { 'lam_log': loss_train_log, 'lam_log_test': loss_test_log, 'model_params': model_params } pickle.dump(save_dict, gfile.Open(save_filename, 'w')) print('Saved results')
def get_su_nsub(stimulus, response, mask_matrix, cell_string, nsub, projection_type, lam_proj, ipartition): """Get 'nsub' subunits.""" np.random.seed(95) # 23 for _jnt.pkl, 46 for _jnt_2.pkl, 93 for _nov # Get a few (5) training, testing, validation partitions # continuous partitions # ifrac = 0.8 # tms_train = np.arange(0, np.floor(stimulus.shape[0]*ifrac)).astype(np.int) # Random partitions # get last 10% as test data frac_test = 0.1 tms_test = np.arange(np.floor(stimulus.shape[0] * (1 - frac_test)), 1 * np.floor(stimulus.shape[0])).astype(np.int) # Random partitions n_partitions = 10 tms_train_validate = np.arange( 0, np.floor(stimulus.shape[0] * (1 - frac_test))).astype(np.int) frac_validate = 0.1 partitions = [] for _ in range(n_partitions): perm = np.random.permutation(tms_train_validate) tms_train = perm[0:np.floor((1 - frac_validate) * perm.shape[0])] tms_validate = perm[np.floor((1 - frac_validate) * perm.shape[0]):perm.shape[0]] partitions += [{ 'tms_train': tms_train, 'tms_validate': tms_validate, 'tms_test': tms_test }] print('Made partitions') # do fitting for different lambdas # from IPython import embed; embed() neighbor_mat = su_model.get_neighbormat(mask_matrix, nbd=1) save_name = os.path.join( FLAGS.save_path, 'Cell_%s_nsub_%d_%s_%.6f_part_%d_%s.pkl' % (cell_string, nsub, projection_type, lam_proj, ipartition, FLAGS.save_suffix)) save_name_partial = os.path.join( FLAGS.save_path_partial, 'Cell_%s_nsub_%d_%s_%.6f_part_%d_%s.pkl' % (cell_string, nsub, projection_type, lam_proj, ipartition, FLAGS.save_suffix)) if not gfile.Exists(save_name): print(cell_string, nsub, projection_type, lam_proj, ipartition) op = su_model.Flat_clustering_jnt( stimulus, response, nsub, partitions[ipartition]['tms_train'], partitions[ipartition]['tms_validate'], steps_max=10000, eps=1e-9, projection_type=projection_type, neighbor_mat=neighbor_mat, lam_proj=lam_proj, eps_proj=0.01, save_filename_partial=save_name_partial) k_f, b_f, _, loss_log_f, loss_log_test_f, fitting_phase_f, fit_params_f = op print('Fitting done') save_dict = { 'K': k_f, 'b': b_f, 'loss_log': loss_log_f, 'loss_log_test': loss_log_test_f, 'fitting_phase': fitting_phase_f, 'fit_params': fit_params_f } pickle.dump(save_dict, gfile.Open(save_name, 'w')) print('Saved results')