def get_link(pkl_p='', selected_p=''): if pkl_p: try: link = load_pickle(pkl_p) except FileNotFoundError: print(f'No {pkl_p}') else: return link if selected_p: return img_search(selected_p)
def compute_prediction(fn, model): border = int((h_neigh-1)/2) img = cv2.imread(fn, cv2.IMREAD_COLOR) cached_fn = os.path.join(cache_dir, fn.split( os.path.sep)[-1].split('.')[0]+".pkl") if os.path.isfile(cached_fn): features = load_pickle(cached_fn) else: features = train.create_features(img) dump2pickle(features, cached_fn) scaler = load_pickle(os.path.join(model_dir, scaler_fn)) features = features.reshape(-1, features.shape[1]) features = scaler.transform(features) model_predictions = model.predict_proba(features) model_predictions = prob2class(model_predictions) predictions_image = model_predictions.reshape( [img.shape[0]-2*border, img.shape[1]-2*border, -1]) return predictions_image
def weird_load(fp): container = {} with open(fp, 'rb') as file: while True: chunk = file.read(3408) if not chunk: break with open('_', 'wb') as temp: temp.write(chunk) data = load_pickle('_') container.update(data) return container
def __init__(self): #SpaCy Initializers self.nlp = spacy.load('de_core_news_sm') self.doc = self.nlp('Ich schicke meiner Mutter ein Geschenk') self.matcher = Matcher(self.nlp.vocab) #Dictionaries self.worter = load_pickle("Worter.p") #Patterns self.patterns = get_patterns() self.pattern_akk = [[{ 'TAG': 'ART' }, { 'TAG': 'ADJA', 'OP': '*' }, { 'DEP': 'oa' }]] self.pattern_nom = [[{ 'TAG': 'ART' }, { 'TAG': 'ADJA', 'OP': '*' }, { 'DEP': 'sb' }]] self.pattern_dat = [[{ 'TAG': 'ART' }, { 'TAG': 'ADJA', 'OP': '*' }, { 'DEP': 'da' }]] #Add patterns self.matcher.add('Akk', self.pattern_akk) self.matcher.add('Nom', self.pattern_nom) self.matcher.add('Dat', self.pattern_dat) #Variables self.found_matches = None
def __init__(self, model_dir): """Analyze based on units.""" data_type = 'rule' fname = os.path.join(model_dir, 'variance_' + data_type + '.pkl') res = tools.load_pickle(fname) h_var_all = res['h_var_all'] keys = res['keys'] rules = ['contextdm1', 'contextdm2'] ind_rules = [keys.index(rule) for rule in rules] h_var_all = h_var_all[:, ind_rules] # First only get active units. Total variance across tasks larger than 1e-3 ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] # ind_active = np.where(h_var_all.sum(axis=1) > 1e-1)[0] # TEMPORARY h_var_all = h_var_all[ind_active, :] # Normalize by the total variance across tasks h_normvar_all = (h_var_all.T/np.sum(h_var_all, axis=1)).T group_ind = dict() group_ind['1'] = np.where(h_normvar_all[:,0] > 0.9)[0] group_ind['2'] = np.where(h_normvar_all[:, 0] < 0.1)[0] group_ind['12'] = np.where(np.logical_and(h_normvar_all[:,0] > 0.4, h_normvar_all[:,0] < 0.6))[0] group_ind['1+2'] = np.concatenate((group_ind['1'], group_ind['2'])) group_ind_orig = {key: ind_active[val] for key, val in group_ind.items()} self.model_dir = model_dir self.group_ind = group_ind self.group_ind_orig = group_ind_orig self.h_normvar_all = h_normvar_all self.rules = rules self.ind_active = ind_active colors = ['xkcd:'+c for c in ['orange', 'green', 'pink', 'sky blue']] self.colors = dict(zip([None, '1', '2', '12'], colors)) self.lesion_group_names = {None : 'Intact', '1' : 'Lesion group 1', '2' : 'Lesion group 2', '12' : 'Lesion group 12', '1+2': 'Lesion group 1 & 2'}
def _compute_hist_varprop(model_dir, rule_pair, random_rotation=False): data_type = 'rule' assert len(rule_pair) == 2 assert data_type == 'rule' fname = os.path.join(model_dir, 'variance_' + data_type) if random_rotation: fname += '_rr' fname += '.pkl' if not os.path.isfile(fname): # If not computed, compute now compute_variance(model_dir, random_rotation=random_rotation) res = tools.load_pickle(fname) h_var_all = res['h_var_all'] keys = res['keys'] ind_rules = [keys.index(rule) for rule in rule_pair] h_var_all = h_var_all[:, ind_rules] # First only get active units. Total variance across tasks larger than 1e-3 ind_active = np.where(h_var_all.sum(axis=1) > 1e-3)[0] # Temporary: Mimicking biased sampling. Notice the free parameter though. # print('Mimicking selective sampling') # ind_active = np.where((h_var_all.sum(axis=1) > 1e-3)*(h_var_all[:,0]>1*1e-2))[0] h_var_all = h_var_all[ind_active, :] # Normalize by the total variance across tasks h_normvar_all = (h_var_all.T / np.sum(h_var_all, axis=1)).T # Plot the proportion of variance for the first rule # data_plot = h_normvar_all[:, 0] data_plot = (h_var_all[:, 0] - h_var_all[:, 1]) / ( (h_var_all[:, 0] + h_var_all[:, 1])) hist, bins_edge = np.histogram(data_plot, bins=20, range=(-1, 1)) # # Plot the percentage instead of the total count # hist = hist/np.sum(hist) return hist, bins_edge
def compute_replacerule_performance(model_dir, setup, restore=False): """Compute the performance of one task given a replaced rule input.""" if setup == 1: rule = 'delayanti' replace_rule = np.array(['delayanti', 'fdanti', 'delaygo', 'fdgo']) rule_strengths = \ [[1,0,0,0], [0,1,0,0], [0,1,1,0], [0,1,1,-1]] elif setup == 2: rule = 'contextdelaydm1' replace_rule = np.array( ['contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2']) rule_strengths = \ [[1,0,0,0], [0,1,0,0], [0,1,1,0], [0,0,1,0], [0,1,1,-1]] elif setup == 3: rule = 'dmsgo' replace_rule = np.array(['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo']) rule_strengths = \ [[1,0,0,0], [0,1,0,0], [0,1,1,0], [0,1,1,-1]] else: raise ValueError('Unknown setup value') fname = 'taskset{:d}_perf'.format(setup) + '.pkl' fname = os.path.join(model_dir, fname) if restore and os.path.isfile(fname): print('Reloading results from ' + fname) r = tools.load_pickle(fname) perfs, rule, names = r['perfs'], r['rule'], r['names'] else: perfs = list() names = list() for rule_strength in rule_strengths: perf, _ = run_network_replacerule(model_dir, rule, replace_rule, rule_strength) perfs.append(perf) names.append(replace_rule_name(replace_rule, rule_strength)) perfs = np.array(perfs) print(perfs) results = {'perfs': perfs, 'rule': rule, 'names': names} with open(fname, 'wb') as f: pickle.dump(results, f) print('Results stored at : ' + fname) return perfs, rule, names
def generate_neuron_info(self, epochs, trial_list = None, rules = None, norm = True, p_value = 0.05, abs_active_thresh = 1e-3,): self.neuron_info = OrderedDict() if trial_list is None: trial_list = self.trial_list if rules is None: rules = self.rules for rule in rules: for epoch in epochs: if epoch not in self.epoch_info[rule].keys(): raise KeyError('Rule ',rule,' dose not have epoch ',epoch,'!') for trial_num in trial_list: self.neuron_info[trial_num] = OrderedDict() for rule in rules: H = tools.load_pickle(self.model_dir+'/'+str(trial_num)+'/'+'H_'+rule+'.pkl') self.neuron_info[trial_num][rule] = OrderedDict() for epoch in epochs: self.neuron_info[trial_num][rule][epoch] = OrderedDict() for info_type in ['selective_neurons','active_neurons','exh_neurons','inh_neurons','mix_neurons',\ 'firerate_loc_order','firerate_max_central']: self.neuron_info[trial_num][rule][epoch][info_type] = list() for neuron in range(self.neurons): neuron_data_abs = OrderedDict() neuron_data_norm = OrderedDict() neuron_data = OrderedDict() firerate_abs = list() firerate_norm = list() firerate = list() for loc in self.in_loc_set[rule]: fix_level = H[self.epoch_info[rule]['fix1'][0]:self.epoch_info[rule]['fix1'][1], \ self.in_loc[rule] == loc, neuron].mean(axis=1).mean(axis=0) neuron_data_abs[loc] = H[self.epoch_info[rule][epoch][0]:self.epoch_info[rule][epoch][1],\ self.in_loc[rule] == loc, neuron].mean(axis=0) #axis = 1 for trial-wise mean, 0 for time-wise mean neuron_data_norm[loc] = neuron_data_abs[loc]/fix_level-1 if norm: neuron_data[loc] = neuron_data_norm[loc] else: neuron_data[loc] = neuron_data_abs[loc] firerate_abs.append(neuron_data_abs[loc].mean()) firerate_norm.append(neuron_data_norm[loc].mean()) firerate.append(neuron_data[loc].mean()) data_frame = pd.DataFrame(neuron_data) data_frame_melt = data_frame.melt() data_frame_melt.columns = ['Location','Fire_rate'] model = ols('Fire_rate~C(Location)',data=data_frame_melt).fit() anova_table = anova_lm(model, typ = 2) if max(firerate_abs) > abs_active_thresh: self.neuron_info[trial_num][rule][epoch]['active_neurons'].append(neuron) if anova_table['PR(>F)'][0] <= p_value: self.neuron_info[trial_num][rule][epoch]['selective_neurons'].append(neuron) if max(firerate_norm) < 0: self.neuron_info[trial_num][rule][epoch]['inh_neurons'].append(neuron) elif min(firerate_norm) >= 0: #else: self.neuron_info[trial_num][rule][epoch]['exh_neurons'].append(neuron) else: self.neuron_info[trial_num][rule][epoch]['mix_neurons'].append(neuron) max_index = firerate.index(max(firerate)) temp_len = len(firerate) if temp_len%2 == 0: mc_len = temp_len + 1 else: mc_len = temp_len firerate_max_central = np.zeros(mc_len) for i in range(temp_len): new_index = (i-max_index+temp_len//2)%temp_len firerate_max_central[new_index] = firerate[i] if temp_len%2 == 0: firerate_max_central[-1] = firerate_max_central[0] self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'].append(firerate) self.neuron_info[trial_num][rule][epoch]['firerate_max_central'].append(firerate_max_central) self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'] = \ np.array(self.neuron_info[trial_num][rule][epoch]['firerate_loc_order']) self.neuron_info[trial_num][rule][epoch]['firerate_max_central'] = \ np.array(self.neuron_info[trial_num][rule][epoch]['firerate_max_central'])
926, 928, 1455, 1811, 1816, 1819, 2372, 2765, 3726, 7086, 7358, 7504, 7516, 8231, 8235, 8381, 8393, 9108 ]] input_dims = [ 538, 529, 358, 13, 7, 2, 306, 7, 876, 7, 175, 7, 272, 20, 926, 2, 527, 356, 5, 3, 553, 393, 961, 3360, 272, 146, 12, 715, 4, 146, 12, 715, 4 ] batch_size = 256 data_path = 'feature/' samp = 'new_hard' nb_epoch = 20 job_train = load_npz(data_path + samp + '_train_job.npz') res_train = load_npz(data_path + samp + '_train_res.npz') gt_train = np.asarray(load_pickle(data_path + samp + '_train_label.pkl')) job_val = load_npz(data_path + samp + '_val_job.npz') res_val = load_npz(data_path + samp + '_val_res.npz') gt_val = np.asarray(load_pickle(data_path + samp + '_val_label.pkl')) print('[INFO] Compiling model...') model = iPNN(input_dims=input_dims, embedding_dim=128, prod_dim=128, hidden_dims=[128, 64, 64, 32]) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc', 'mae']) model.summary() print('[INFO] Initiating training')
def compute_taskspace(model_dir, setup, restore=False, representation='rate'): if setup == 1: rules = ['fdgo', 'fdanti', 'delaygo', 'delayanti'] elif setup == 2: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'contextdm1', 'contextdm2' ] elif setup == 3: rules = ['dmsgo', 'dmcgo', 'dmsnogo', 'dmcnogo'] elif setup == 4: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'contextdm1', 'contextdm2', 'multidm' ] elif setup == 5: rules = [ 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm', 'delaydm1', 'delaydm2', 'contextdm1', 'contextdm2', 'multidm', 'dm1', 'dm2', ] elif setup == 6: rules = ['fdgo', 'delaygo', 'contextdm1', 'contextdelaydm1'] if representation == 'rate': fname = 'taskset{:d}_space'.format(setup) + '.pkl' fname = os.path.join(model_dir, fname) if restore and os.path.isfile(fname): print('Reloading results from ' + fname) h_trans = tools.load_pickle(fname) else: tsa = TaskSetAnalysis(model_dir, rules=rules) h_trans = tsa.compute_taskspace(rules=rules, epochs=['stim1'], dim_reduction_type='PCA', setup=setup) with open(fname, 'wb') as f: pickle.dump(h_trans, f) print('Results stored at : ' + fname) elif representation == 'weight': from task import get_rule_index model = Model(model_dir) hp = model.hp n_hidden = hp['n_rnn'] n_output = hp['n_output'] with tf.Session() as sess: model.restore() w_in = sess.run(model.w_in).T rule_indices = [get_rule_index(r, hp) for r in rules] w_rules = w_in[:, rule_indices] from sklearn.decomposition import PCA model = PCA(n_components=2) # Transform data data_trans = model.fit_transform(w_rules.T) # Turn into dictionary, and consistent with previous code h_trans = OrderedDict() for i, r in enumerate(rules): # shape will be (1,2), and the key is added an epoch value only for consistency h_trans[(r, 'stim1')] = np.array([data_trans[i]]) else: raise ValueError() return h_trans
def process_trajectories(file_name, data_dir): data = load_pickle(DATA_DIR, file_name) import ipdb ipdb.set_trace()
def __init__(self, model_dir, data_type, normalization_method='max'): hp = tools.load_hp(model_dir) # If not computed, use variance.py fname = os.path.join(model_dir, 'variance_' + data_type + '.pkl') res = tools.load_pickle(fname) h_var_all_ = res['h_var_all'] self.keys = res['keys'] # First only get active units. Total variance across tasks larger than 1e-3 # ind_active = np.where(h_var_all_.sum(axis=1) > 1e-2)[0] ind_active = np.where(h_var_all_.sum(axis=1) > 1e-3)[0] h_var_all = h_var_all_[ind_active, :] # Normalize by the total variance across tasks if normalization_method == 'sum': h_normvar_all = (h_var_all.T / np.sum(h_var_all, axis=1)).T elif normalization_method == 'max': h_normvar_all = (h_var_all.T / np.max(h_var_all, axis=1)).T elif normalization_method == 'none': h_normvar_all = h_var_all else: raise NotImplementedError() ################################## Clustering ################################ from sklearn import metrics X = h_normvar_all # Clustering from sklearn.cluster import AgglomerativeClustering, KMeans # Choose number of clusters that maximize silhouette score n_clusters = range(2, 30) scores = list() labels_list = list() for n_cluster in n_clusters: # clustering = AgglomerativeClustering(n_cluster, affinity='cosine', linkage='average') clustering = KMeans(n_cluster, algorithm='full', n_init=20, random_state=0) clustering.fit( X) # n_samples, n_features = n_units, n_rules/n_epochs labels = clustering.labels_ # cluster labels score = metrics.silhouette_score(X, labels) scores.append(score) labels_list.append(labels) scores = np.array(scores) # Heuristic elbow method # Choose the number of cluster when Silhouette score first falls # Choose the number of cluster when Silhouette score is maximum if data_type == 'rule': #i = np.where((scores[1:]-scores[:-1])<0)[0][0] i = np.argmax(scores) else: # The more rigorous method doesn't work well in this case i = n_clusters.index(10) labels = labels_list[i] n_cluster = n_clusters[i] print('Choosing {:d} clusters'.format(n_cluster)) # Sort clusters by its task preference (important for consistency across nets) if data_type == 'rule': label_prefs = [ np.argmax(h_normvar_all[labels == l].sum(axis=0)) for l in set(labels) ] elif data_type == 'epoch': ## TODO: this may no longer work! label_prefs = [ self.keys[np.argmax(h_normvar_all[labels == l].sum(axis=0))][0] for l in set(labels) ] ind_label_sort = np.argsort(label_prefs) label_prefs = np.array(label_prefs)[ind_label_sort] # Relabel labels2 = np.zeros_like(labels) for i, ind in enumerate(ind_label_sort): labels2[labels == ind] = i labels = labels2 # # Sort data by labels and by input connectivity # model = Model(save_name) # hp = model.hp # with tf.Session() as sess: # model.restore(sess) # var_list = sess.run(model.var_list) # # # Get connectivity # w_out = var_list[0].T # b_out = var_list[1] # w_in = var_list[2][:n_input, :].T # w_rec = var_list[2][n_input:, :].T # b_rec = var_list[3] # # # nx, nh, ny = hp['shape'] # nr = hp['n_eachring'] # # sort_by = 'w_in' # if sort_by == 'w_in': # w_in_mod1 = w_in[ind_active, :][:, 1:nr+1] # w_in_mod2 = w_in[ind_active, :][:, nr+1:2*nr+1] # w_in_modboth = w_in_mod1 + w_in_mod2 # w_prefs = np.argmax(w_in_modboth, axis=1) # elif sort_by == 'w_out': # w_prefs = np.argmax(w_out[1:, ind_active], axis=0) # # ind_sort = np.lexsort((w_prefs, labels)) # sort by labels then by prefs ind_sort = np.argsort(labels) labels = labels[ind_sort] self.h_normvar_all = h_normvar_all[ind_sort, :] self.ind_active = ind_active[ind_sort] self.n_clusters = n_clusters self.scores = scores self.n_cluster = n_cluster self.h_var_all = h_var_all self.normalization_method = normalization_method self.labels = labels self.unique_labels = np.unique(labels) self.model_dir = model_dir self.hp = hp self.data_type = data_type self.rules = hp['rules']
def plot_choicefamily_varytime(model_dir, rule): import seaborn as sns assert rule in ['dm1', 'dm2', 'contextdm1', 'contextdm2', 'multidm'] savename = os.path.join(model_dir, 'varytime_' + rule + '.pkl') try: result = tools.load_pickle(savename) except FileNotFoundError: raise FileNotFoundError( 'Run performance.compute_choicefamily_varytime first.') xdatas = result['xdatas'] ydatas = result['ydatas'] cohs = result['cohs'] stim_times = xdatas[0] n_coh = len(xdatas) # Plot how the threshold varies with stimulus duration weibull = lambda x, a, b: 1 - 0.5 * np.exp(-(x / a)**b) xdata = cohs alpha_fits = list() for i in range(len(stim_times)): ydata = ydatas[:, i] res = minimize(lambda param: np.sum( (weibull(xdata, param[0], param[1]) - ydata)**2), [0.1, 1], bounds=([1e-3, 1], [1e-5, 10]), method='L-BFGS-B') alpha, beta = res.x alpha_fits.append(alpha) perfect_int = lambda x, b: -0.5 * x + b b, _ = curve_fit(perfect_int, np.log10(stim_times), np.log10(alpha_fits)) fs = 7 fig = plt.figure(figsize=(2.5, 1.5)) ax = fig.add_axes([0.2, 0.25, 0.4, 0.6]) ax.plot(np.log10(stim_times), np.log10(alpha_fits), 'o-', color='black', label='model', markersize=3) ax.plot(np.log10(stim_times), -0.5 * np.log10(stim_times) + b, color='red', label='perfect int.') ax.set_xlabel('Stimulus duration (ms)', fontsize=fs) ax.set_ylabel('Discrim. thr. (x0.01)', fontsize=fs) ax.set_xticks(np.log10(np.array([200, 400, 800, 1600]))) ax.set_xticklabels(['200', '400', '800', '1600']) ax.set_yticks(np.log10(np.array([0.005, 0.01, 0.02, 0.04]))) ax.set_yticklabels(['0.5', '1', '2', '4']) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') ax.tick_params(axis='both', which='major', labelsize=fs) ax.set_title(rule_name[rule], fontsize=fs) leg = plt.legend(fontsize=fs, frameon=False, bbox_to_anchor=[1, 1], loc=2) # plt.locator_params(axis='y', nbins=5) figname = 'varytime2_' + rule_name[rule].replace(' ', '') # figname = figname + model_dir if save: plt.savefig('figure/' + figname + '.pdf', transparent=True) # Chronometric curve figname = 'varytime_' + rule_name[rule].replace(' ', '') # figname = figname + model_dir plot_psychometric_varytime(xdatas, ydatas, figname, labels=['{:0.3f}'.format(t) for t in 2 * cohs], colors=sns.dark_palette("light blue", n_coh, input="xkcd"), legtitle='Stim. 1 - Stim. 2', rule=rule)
def plot_PSTH(self, epochs, rules=None, trial_list=None, neuron_types=[('exh_neurons','mix_neurons')], norm = True, separate_plot = False, fuse_rules = False): if trial_list is None: trial_list = self.trial_list if rules is None: rules = self.rules psth_to_plot = OrderedDict() for trial_num in trial_list: psth_to_plot[trial_num] = OrderedDict() for rule in rules: H = tools.load_pickle(self.model_dir+'/'+str(trial_num)+'/'+'H_'+rule+'.pkl') psth_to_plot[trial_num][rule] = OrderedDict() for epoch in epochs: psth_to_plot[trial_num][rule][epoch] = OrderedDict() for type_pair in neuron_types: psth_to_plot[trial_num][rule][epoch][type_pair] = OrderedDict() psth_neuron = list() anti_dir_psth = list() for n_type in type_pair: for neuron in self.neuron_info[trial_num][rule][epoch][n_type]: sel_loc = np.argmax(self.neuron_info[trial_num][rule][epoch]['firerate_loc_order'][neuron]) anti_loc = (sel_loc+len(self.in_loc_set[rule])//2)%len(self.in_loc_set[rule]) psth_temp = H[:,self.in_loc[rule] == sel_loc, neuron].mean(axis=1) fix_level = H[self.epoch_info[rule]['fix1'][0]:self.epoch_info[rule]['fix1'][1], \ self.in_loc[rule] == sel_loc, neuron].mean(axis=1).mean(axis=0) if len(self.in_loc_set[rule])%2: anti_dir_psth_temp = (H[:,self.in_loc[rule] == anti_loc, neuron].mean(axis=1)+\ H[:,self.in_loc[rule] == (anti_loc+1), neuron].mean(axis=1))/2.0 else: anti_dir_psth_temp = H[:,self.in_loc[rule] == anti_loc, neuron].mean(axis=1) anti_dir_psth_norm = anti_dir_psth_temp/fix_level-1 psth_norm = psth_temp/fix_level-1 if norm: psth_neuron.append(psth_norm) anti_dir_psth.append(anti_dir_psth_norm) else: psth_neuron.append(psth_temp) anti_dir_psth.append(anti_dir_psth_temp) try: psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'] = np.array(psth_neuron).mean(axis=0) except: pass try: psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'] = np.array(anti_dir_psth).mean(axis=0) except: pass for rule in rules: for epoch in epochs: for type_pair in neuron_types: if not separate_plot: fig_psth = plt.figure() for trial_num in trial_list: if separate_plot: fig_psth = plt.figure() color = None else: color = kelly_colors[(trial_list.index(trial_num)+1)%len(kelly_colors)] try: plt.plot(np.arange(len(psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir']))*self.hp['dt']/1000, psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir'],label=str(trial_num)+'sel',color=color) except: pass try: plt.plot(np.arange(len(psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir']))*self.hp['dt']/1000, psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir'],linestyle = '--',label=str(trial_num)+'anti',color=color) except: pass if separate_plot: plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair)) plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0) type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/' tools.mkdir_p(type_pair_folder) plt.tight_layout() plt.savefig(type_pair_folder+'PSTH-'+str(trial_num)+'.png',transparent=False,bbox_inches='tight') plt.close(fig_psth) if not separate_plot: plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair)) plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0) type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/' tools.mkdir_p(type_pair_folder) plt.tight_layout() plt.savefig(type_pair_folder+'PSTH_all_'+str(trial_list[0])+'to'+str(trial_list[-1])+'.png', transparent=False,bbox_inches='tight') plt.close(fig_psth) plot_by_growth = dict() for rule in rules: plot_by_growth[rule] = dict() for epoch in epochs: plot_by_growth[rule][epoch] = dict() for type_pair in neuron_types: plot_by_growth[rule][epoch][type_pair] = dict() plot_by_growth[rule][epoch][type_pair]['less_than_I'] = dict() plot_by_growth[rule][epoch][type_pair]['I_to_Y'] = dict() plot_by_growth[rule][epoch][type_pair]['elder_than_Y'] = dict() for key in plot_by_growth[rule][epoch][type_pair].keys(): plot_by_growth[rule][epoch][type_pair][key]['sel'] = list() plot_by_growth[rule][epoch][type_pair][key]['anti'] = list() plot_by_growth[rule][epoch][type_pair][key]['growth'] = list() for trial_num in trial_list: growth = self.log['perf_'+rule][trial_num//self.log['trials'][1]] if growth <= self.hp['infancy_target_perf']: plot_by_growth[rule][epoch][type_pair]['less_than_I']['sel'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir']) plot_by_growth[rule][epoch][type_pair]['less_than_I']['anti'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir']) plot_by_growth[rule][epoch][type_pair]['less_than_I']['growth'].append(growth) elif growth <= self.hp['young_target_perf']: plot_by_growth[rule][epoch][type_pair]['I_to_Y']['sel'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir']) plot_by_growth[rule][epoch][type_pair]['I_to_Y']['anti'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir']) plot_by_growth[rule][epoch][type_pair]['I_to_Y']['growth'].append(growth) else: plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['sel'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['sel_dir']) plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['anti'].append(\ psth_to_plot[trial_num][rule][epoch][type_pair]['anti_sel_dir']) plot_by_growth[rule][epoch][type_pair]['elder_than_Y']['growth'].append(growth) for growth_key in plot_by_growth[rule][epoch][type_pair].keys(): for type_key, value in plot_by_growth[rule][epoch][type_pair][growth_key].items(): try: plot_by_growth[rule][epoch][type_pair][growth_key][type_key] = np.array(value).mean(axis=0) except: pass for rule in rules: for epoch in epochs: for type_pair in neuron_types: fig_psth = plt.figure() for growth_key,value in plot_by_growth[rule][epoch][type_pair].items(): if growth_key == 'less_than_I': color = 'green' elif growth_key == 'I_to_Y': color = 'blue' else: color = 'red' try: plt.plot(np.arange(len(value['sel']))*self.hp['dt']/1000,value['sel'], label = growth_key+'_'+str(value['growth'])[:4], color = color) except: pass try: plt.plot(np.arange(len(value['anti']))*self.hp['dt']/1000,value['anti'], linestyle = '--',label = growth_key+'_'+str(value['growth'])[:4], color = color) except: pass plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair)) plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0) type_pair_folder = 'figure/figure_'+self.model_dir+'/'+rule+'/'+epoch+'/'+'_'.join(type_pair)+'/' tools.mkdir_p(type_pair_folder) plt.tight_layout() plt.savefig(type_pair_folder+'PSTH_bygrowth_'+str(trial_list[0])+'to'+str(trial_list[-1])+'.png', transparent=False,bbox_inches='tight') plt.close(fig_psth) if fuse_rules: ls_list = ['-','--','-.',':'] for type_pair in neuron_types: for epoch in epochs: fig_psth_fr = plt.figure() for rule, linestyle in zip(rules,ls_list[0:len(rules)]): for key,value in plot_by_growth[rule][epoch][type_pair].items(): if key == 'less_than_I': color = 'green' elif key == 'I_to_Y': color = 'blue' else: color = 'red' try: plt.plot(np.arange(len(value['sel']))*self.hp['dt']/1000,value['sel'], label = rule+'_'+key+'_'+str(value['growth'])[:4], color = color, linestyle=linestyle) except: pass plt.legend() plt.title("Rule:"+rule+" Epoch:"+epoch+" Neuron_type:"+"_".join(type_pair)) plt.legend(bbox_to_anchor=(1.05, 0), loc=3, borderaxespad=0) plt.tight_layout() plt.savefig('figure/figure_'+self.model_dir+'/'+'-'.join(rules)+'_'+epoch+'_'+'-'.join(type_pair)\ +str(trial_list[0])+'to'+str(trial_list[-1])+'.png', transparent=False,bbox_inches='tight') plt.close(fig_psth_fr)
(X_train, y_train), (X_test, y_test) = mnist.load_data() X_train = X_train.astype('float32') X_train = np.reshape(X_train, (X_train.shape[0], 784)) X_train /= 255. model = WAE_MMD(dims=[784, 512, 64], loss='L2', activation='lrelu', z_dim=8, phase='train', scale=1, batch_size=50) print('Pre-train') model.fit(data=X_train, nb_epoch=20, w=[0, 0.05]) #print('Pre-train with regularizer') #model.fit(data = X_train, nb_epoch = 30, w = [5, 0]) model.save_weight('pre_weight.pkl') model = WAE_MMD(dims=[784, 512, 64], loss='L2', activation='lrelu', z_dim=8, phase='train', scale=1, batch_size=50, params=load_pickle('pre_weight.pkl')) model.fit(data=X_train, nb_epoch=300, w=[5000, 0], sample_path='samples/mnist_')
def create_dataset(image_dict, label_dict): print('[INFO] Creating training dataset on %d image(s).' % len(image_dict.keys())) X = [] y = [] for fn in image_dict.keys(): img = image_dict[fn] label = label_dict[fn] cached_fn = os.path.join( cache_dir, fn.split(os.path.sep)[-1].split('.')[0] + ".pkl") if os.path.isfile(cached_fn): features = load_pickle(cached_fn) else: features = create_features(img) dump2pickle(features, cached_fn) label = label[h_ind:-h_ind, h_ind:-h_ind] labels = label.reshape(label.shape[0] * label.shape[1], -1) X.append(features) y.append(labels) X = np.array(X) X = X.reshape(X.shape[0] * X.shape[1], -1) y = np.array(y) y = (y > 0) * 1 y = y.astype(dtype=np.uint8) y = y.reshape(y.shape[0] * y.shape[1], -1) # delete ambigous samples indices = np.where(y.sum(axis=1) != 1) y = np.delete(y, indices, axis=0) X = np.delete(X, indices, axis=0) # unused def add_background_label(y): expanded = np.c_[y, np.zeros(y.shape[0])].astype(np.uint8) bg_label = np.zeros(len(classes) + 1, dtype=np.uint8) bg_label[-1] = 1 expanded[np.where(~expanded.any(axis=1))] = bg_label return expanded y_int = onehot2int(y) from collections import Counter print("Original dataset shape {}".format(Counter(y_int))) if oversample: from imblearn.over_sampling import RandomOverSampler ros = RandomOverSampler(random_state=random_state) X, y_int = ros.fit_resample(X, y_int) print("Resampled dataset shape {}".format(Counter(y_int))) y = int2onehot(y_int) scaler = MinMaxScaler() scaler.fit(X) X = scaler.transform(X) dump2pickle(scaler, os.path.join(model_dir, scaler_fn)) return X, y
cached_fn = os.path.join(cache_dir, fn.split( os.path.sep)[-1].split('.')[0]+".pkl") if os.path.isfile(cached_fn): features = load_pickle(cached_fn) else: features = train.create_features(img) dump2pickle(features, cached_fn) scaler = load_pickle(os.path.join(model_dir, scaler_fn)) features = features.reshape(-1, features.shape[1]) features = scaler.transform(features) model_predictions = model.predict_proba(features) model_predictions = prob2class(model_predictions) predictions_image = model_predictions.reshape( [img.shape[0]-2*border, img.shape[1]-2*border, -1]) return predictions_image if __name__ == '__main__': filelist = glob(os.path.join(infere_dir, '*.jpg')) print('[INFO] Running inference on %s test images' % len(filelist)) model = load_pickle(os.path.join(model_dir, model_fn)) for fn in filelist: print('[INFO] Processing images:', fn.split(os.path.sep)[-1]) inference_img = compute_prediction(fn, model) mapped_img = class2bgr(inference_img, palette_bgr) output_fn = os.path.join(output_dir, fn.split( os.path.sep)[-1].split('.')[0]+".png") cv2.imwrite(output_fn, mapped_img)