def print_cluster_data(args): """loading precomputed fits""" from util import depickle_fits res = depickle_fits(args.file, suffix="cfits") """retrieving weightmatrix from fits dic""" W = res["map"]["W"] dic_W = res["map"] """get fit and cluster data""" cl = res["dic_cluster"] num_dead = cl["num_dead_rf"] alive = W.shape[0] - num_dead import pprint as pp print print "*****+ Trained map data" print "visible units", W.shape[1] print "hidden units", W.shape[0] print "alive hidden units", alive, "(", N.round(alive / float(W.shape[0]), 2), "% )" print "*****+ Map Parameter" print pp.pprint([(key, dic_W[key]) if key != "W" else () for key in dic_W.keys()]) print "*****+ Cluster data" print "num cluster", cl["num_clusters"] print "num rf in each cluster", cl["num_each_cluster"] print "prototype weights each cluster", N.round(cl["prototype_cluster"], 2) print "*****+ Cluster Parameters" print pp.pprint(cl["args"])
def rec_map_from_pickled_result( map_file, odir ): from util import depickle_fits dic = depickle_fits(map_file) W_rec = __rec_map(dic['map'], dic['fits'], dic['model']) from ..base import make_working_dir_sub work_dir = make_working_dir_sub(odir, 'rec') from ..base import make_filename writefilename = make_filename(map_file, 'rec_'+dic['model']+'_rgc', '.png', odir=work_dir) from ..base.plots import save_w_as_image save_w_as_image(W_rec, dic['map']['vis'], dic['map']['vis'], dic['map']['hid'], dic['map']['hid'], mode=dic['map']['mode'], outfile=writefilename, verbose=False) print print 'reconstruction ', writefilename+'.png', ' written.'
def plot_cluster_cov(args): """loading precomputed fits""" from util import depickle_fits res = depickle_fits(args.file, suffix="cfits") import os mod_odir = os.path.join(args.odir, "../") from ..base import make_working_dir_sub cluster_dir = make_working_dir_sub(mod_odir, "cl") if args.plain: plain_str = "plain_" else: plain_str = "" from ..base import make_filename filepath = make_filename(args.file, "image_clustered_rgc_" + plain_str, ".png", odir=cluster_dir) __plot_clusters_coverage( filepath=filepath, res=res, args_plain=args.plain, args_scale=args.scale, args_a=args.a, args_pad=args.pad, args_dpi=args.dpi, )
def print_pickled_fits( map_file, complete=False, ): from util import depickle_fits dic = depickle_fits(map_file, exit_on_errer=False) if not dic: dic = depickle_fits(map_file, suffix='cfits') from pprint import pprint if not complete: keys = dic.keys() for key in keys: print key if type(dic[key]) == dict: if key=='map' or key=='dic_cluster': for k in dic[key].keys(): print '\t', k, ':', dic[key][k] else: print '\t', dic[key].keys() else: print '\t', dic[key] else: print pprint(dic)
def make_size_hist( args_file, args_odir, odir_nosubdir = False, luminosity_channels = False, ): import numpy as N plt.rcParams['xtick.major.pad']='12' plt.rcParams['ytick.major.pad']='12' plt.rc('axes', linewidth = 0) '''loading precomputed fits''' from util import depickle_fits res = depickle_fits(args_file, suffix='cfits') '''get fit and cluster data''' fits = res['fits'] fit_keys = fits.keys() cl = res['dic_cluster'] num_channel = cl['num_clusters'] colors = cl['prototype_cluster'] perc_cluster = N.round(N.array(cl['num_each_cluster'])/len(fit_keys)*100) model = res['model'] assert model=='edog' '''y axis: mean size of RF''' '''x axis: anzahl der RF''' areas_by_cluster = [[] for c in [None]*num_channel] keys = sorted(fit_keys) for i, key in enumerate(keys): clid = fits[key]['cl'] '''params: 0: mu_x 1: mu_y 2: sigma_x 3: sigma_y 4: c_to_s_ratio 5: theta (rotation) 6: k_s (k_c is implicitly fixed as 1) 7: cdir_a, 8: cdir_b, 9: cdir_c 10: bias_a, 11: bias_b, 12: bias_c ''' p = fits[key]['p'] sigma = p[2:4] c_to_s_ratio = p[4] ''''RF size''' c_area = N.pi * sigma[0] * sigma[1] s_area = c_area * c_to_s_ratio # apprarea = N.max([c_area, s_area]) apprarea = N.mean([c_area, s_area]) arr_area = areas_by_cluster[clid] arr_area.append(apprarea) data = [] for i, a in enumerate(areas_by_cluster): data.append((i, cl['num_each_cluster'][i], N.mean(a), int(perc_cluster[i]))) sdata = sorted(data, key=lambda x: x[1]) fig = plt.figure() ax = fig.add_subplot(111) perc_sum = 0 for i, a in enumerate(sdata): ax.bar(perc_sum, a[2], a[3], color=colors[a[0]], alpha=.9, linewidth=0) perc_sum += a[3] ax.get_yaxis().set_ticks(N.arange(0,9,2)) ax.get_xaxis().set_ticks(N.arange(0,101,25)) # labels = [] # for i, a in enumerate(sdata): # ax.bar([i], [a[2]], 1, color=colors[a[0]], alpha=1., linewidth=0, align='center') # labels.append(int(a[1])) # for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): # m.set_fontsize(18) # ax.get_xaxis().set_ticklabels(labels) # ax.get_xaxis().set_ticks(N.arange(0,8,1)) # ax.get_yaxis().set_ticks(N.arange(0,8,1)) # # ax.get_yaxis().set_ticks([]) # ax.set_aspect(1) # # ax.set_aspect(float(7./8)) from ..base import make_filename fname = make_filename(args_file,'hist_RF_sizes','.png', './') plt.savefig(fname, bbox_inches='tight', dpi=144) plt.close(fig) exit() if num_channel == 3: labels = list('rgb') if num_channel == 5: labels = list('wbrgb') if num_channel == 6: labels = list('wbrgcm') if luminosity_channels else list('rgbcmy') data = N.zeros((N.max(cl['num_each_cluster']), num_channel)) print data.shape for i in xrange(0, data.shape[0]): for j in xrange(0, data.shape[1]): if i < len(areas_by_cluster[j]): data[i,j] = areas_by_cluster[j][i] fig = plt.figure() ax = fig.add_subplot(111) # data = N.random.lognormal(size=(37, 6), mean=1.5, sigma=1.75) # labels = list('rgbcmy') ax.boxplot(data, labels=labels) # ax.set_yscale('log') # ax.set_yticklabels([]) from ..base import make_filename fname = make_filename(args_file,'hist_RF_sizes','.png', './') plt.savefig(fname, bbox_inches='tight', dpi=288) plt.close(fig) exit() assert num_channel == 3 or num_channel == 5 or num_channel == 6 if num_channel == 3: rows, cols = 1, 3 if num_channel == 5: rows, cols = 2, 3 if num_channel == 6: rows, cols = 2, 3 fig, subplots = plt.subplots(nrows=rows, ncols=cols, sharex=False, sharey=False, squeeze=False) for r in range(0, rows): for c in range(0, cols): ax = subplots[r,c] ax.set_aspect(1) for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): m.set_fontsize(10) i = r*cols+c if i >= num_channel: break areas = areas_by_cluster[i] areas_max = len(keys)/num_channel/2 index = N.arange(0, areas_max+1, int(areas_max/4)) print len(index), index X, Y = N.histogram(areas, bins=len(index)) print X, Y # ax.bar(index, Y[0:len(index)], 1, color=colors[i], alpha=1.) # ax.hist(areas, bins=10, normed=False, color=colors[i], alpha=1., histtype='stepfilled') # num_max = len(keys)/num_channel/1.3 #len(areas)/2 # ax.get_yaxis().set_ticks(N.arange(0, num_max+1, int(num_max/4))) # size_max = 20 #N.max(areas) # ax.get_xaxis().set_ticks(N.arange(0, size_max+1, int(N.ceil(size_max/5)))) # ax.set_aspect(float(size_max/num_max)) # for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): m.set_fontsize(8.5) exit() if not odir_nosubdir: import os mod_odir = os.path.join(args_odir, '../') from ..base import make_working_dir_sub work_dir = make_working_dir_sub(mod_odir, 'hist') else: work_dir = args_odir from ..base import make_filename fname = make_filename(args_file,'hist_RF_sizes','.png', work_dir) fig.subplots_adjust(hspace=0.25, wspace=0.25) plt.savefig(fname, bbox_inches='tight', dpi=288) plt.close(fig)
def make_spatal_hist(args_file, args_odir, abs_weight_threshold=0.7): """loading precomputed fits""" from util import depickle_fits res = depickle_fits(args_file, suffix="cfits") """retrieving weightmatrix from fits dic""" map_as_dict = res["map"] W = map_as_dict["W"] vis = map_as_dict["vis"] ch_width = vis ** 2 mode = map_as_dict["mode"] assert mode == "rgb" """get fit and cluster data""" fits = res["fits"] fit_keys = fits.keys() cl = res["dic_cluster"] num_cluster = cl["num_clusters"] spat_hist = [] def make_record(loc, key, cl, ch, value): if abs(value) > abs_weight_threshold: return (loc, {"key": key, "cl": cl, "ch": ch, "value": value}) else: return None keys = sorted(fit_keys) for i, key in enumerate(keys): cl = fits[key]["cl"] rf = W[key] for j in xrange(0, ch_width): sploc = (j / vis, j % vis) vr = make_record(loc=sploc, key=key, cl=cl, ch=0, value=rf[j]) if vr: spat_hist.append(vr) vg = make_record(loc=sploc, key=key, cl=cl, ch=1, value=rf[ch_width + j]) if vg: spat_hist.append(vg) vb = make_record(loc=sploc, key=key, cl=cl, ch=2, value=rf[2 * ch_width + j]) if vb: spat_hist.append(vb) # import pprint as pp # # print pp.pprint(spat_hist) # # print pp.pprint( __filter_ch(spat_hist, 0) ) # # print pp.pprint( __filter_value(spat_hist, 0.2) ) # print pp.pprint(__filter_cl( # __filter_ch( # __filter_loc_fromto(spat_hist, (1,1), (3,3)), # 1), # 0) ) def hist_2d(ch, cl, onoff=None): """make a 2d histogramm""" import numpy as N hist_2d = N.zeros((vis, vis), dtype=float) def sign_ok(on, val): if on == None: return True else: if val > 0 and on: return True elif val < 0 and not on: return True else: return False for i in xrange(0, len(spat_hist)): loc = spat_hist[i][0] if spat_hist[i][1]["ch"] == ch and spat_hist[i][1]["cl"] == cl and sign_ok(onoff, spat_hist[i][1]["value"]): # hist_2d[loc[0],loc[1]] += abs(spat_hist[i][1]['value']) #1 hist_2d[loc[0], loc[1]] += 1 return hist_2d import os mod_odir = os.path.join(args_odir, "../") from ..base import make_working_dir_sub work_dir = make_working_dir_sub(mod_odir, "hist") from ..base import make_filename def make_plots_of_channel(ch, ison=None, ch_name=None): plots = [] str_name = ch_name + " " if ch_name != None else "" for cl in xrange(0, num_cluster): plots += [ { "name": str_name + " cluster: " + str(cl), "value": hist_2d(ch, cl, ison), "maxmin": False, "cmap": "Greys", "balance": True, "invert": True, "patch_width": vis, "interp": "nearest", } ] #'catrom'}] return plots def hist_input(ch, ison=None, ch_name=None): plots = make_plots_of_channel(ch, ison, ch_name) if ison == None: miscstr = "" elif ison: miscstr = "_ON_" else: miscstr = "_OFF_" fname = make_filename( args_file, "thresh" + str(abs_weight_threshold) + "_hist_inputch_" + str(ch) + miscstr, ".png", work_dir ) from ..base.plots import write_row_col_fig write_row_col_fig(plots, rows=2, cols=3, filepath=fname + ".png", dpi=144, alpha=1.0, fontsize=5.5) print "file:", fname + ".png", "written." # hist_input(2, ch_name='Blue') # hist_input(1, ch_name='Green') # hist_input(0, ch_name='Red') hist_input(2, ison=True, ch_name="input Blue ON") hist_input(2, ison=False, ch_name="input Blue OFF") hist_input(1, ison=True, ch_name="input Green ON") hist_input(1, ison=False, ch_name="input Green OFF") hist_input(0, ison=True, ch_name="input Red ON") hist_input(0, ison=False, ch_name="input Red OFF") def hist_complete(onoff=None): plots = [] for inch in xrange(0, 3): plots += make_plots_of_channel(inch) if onoff != None: plots += make_plots_of_channel(inch, onoff) plots += make_plots_of_channel(inch, not onoff) fname = make_filename(args_file, "thresh" + str(abs_weight_threshold) + "_hist_", ".png", work_dir) from ..base.plots import numplots_to_rowscols rows, cols = numplots_to_rowscols(num_cluster * 9) print "plen", len(plots) print rows, cols from ..base.plots import write_row_col_fig write_row_col_fig(plots, rows=rows, cols=cols, filepath=fname + ".png", dpi=144, alpha=1.0, fontsize=12.5) print "file:", fname + ".png", "written."
def prune_map( args_file, args_odir, sort_row_then_col=True, normalize_weights=False ): import numpy as N '''loading precomputed fits''' from util import depickle_fits res = depickle_fits(args_file, suffix='cfits') '''retrieving weightmatrix from fits dic''' W = res['map']['W'] dic_W = res['map'] patch_w = dic_W['vis'] print N.max(W), N.min(W) '''normalize weights''' if normalize_weights: W = W / N.max(N.abs(W)) print N.max(W), N.min(W) '''get fit and cluster data''' fits = res['fits'] cl = res['dic_cluster'] num_channel = cl['num_clusters'] from ..base.receptivefield import is_close_to_zero '''move RF ids in an appropiate data struc, ignore zeros''' num_zeros = W.shape[0] - len(fits.keys()) clusters_by_index = [[] for c in [None]*num_channel] keys = sorted(res['fits'].keys()) for i, key in enumerate(keys): rf = W[key] abs_rf = N.abs(rf) abs_max = N.max( abs_rf ) abs_min = N.min( abs_rf ) value_spectrum = abs_max - abs_min if value_spectrum > 0.2 and \ not is_close_to_zero(rf, verbose=False, atol=1e-02): cl = res['fits'][key]['cl'] clusters_by_index[cl].append(key) '''vectorize fit center coords clusterwise''' sorted_clusters_by_index = [[] for s in [None]*num_channel] for c in xrange(0,len(clusters_by_index)): cl_ids = clusters_by_index[c] dtype = [('coord', float), ('id', int)] values = [] for i in xrange(0, len(cl_ids)): p = fits[cl_ids[i]]['p'] p01 = (int(N.round(p[0])), int(N.round(p[1]))) if sort_row_then_col: '''vectorize col wise''' values.append( (p01[1]+p01[0]*patch_w, cl_ids[i]) ) else: '''vectorize row wise''' values.append( (p01[0]+p01[1]*patch_w, cl_ids[i]) ) vectorized_coords = N.array(values, dtype) for pair in N.sort(vectorized_coords): sorted_clusters_by_index[c].append( pair[1] ) not_zero = W.shape[0] - num_zeros '''new build pruned map''' if dic_W['mode'] == 'rgb': n_ch = 3 elif dic_W['mode'] == 'rg_vs_b': n_ch = 2 elif dic_W['mode'] == 'rg': n_ch = 2 else: n_ch = 1 pruned_patch_h = int(N.ceil(not_zero**.5))+1#2 pruned_W_shape = (pruned_patch_h**2, patch_w**2*n_ch) rec_W = N.zeros(pruned_W_shape) rec_W_channel_dim = [(0,0) for dim in [None]*num_channel] rec_W_index = 0 # rec_W_channel_coords = [] for c in xrange(0,len(sorted_clusters_by_index)): cl_ids = sorted_clusters_by_index[c] rec_W_channel_dim[c] = (rec_W_index, rec_W_index+len(cl_ids)-1) to_nextrow = rec_W_index % pruned_patch_h if to_nextrow != 0: rec_W_index += pruned_patch_h - to_nextrow '''coords: store fitted center coords for each id''' # rec_W_channel_coords.append([]) for i in xrange(0, len(cl_ids)): p = fits[cl_ids[i]]['p'] rec_W[rec_W_index] = W[cl_ids[i]] rec_W_index += 1 # pair = (int(N.round(p[0])), int(N.round(p[1]))) # rec_W_channel_coords[c].append(pair) '''write pruned map''' d = res['map'] W_rec_args = { 'hid': pruned_patch_h, 'vis': d['vis'], 'mode': d['mode'], 'k': d['k'], 'p': d['p'], 'lr': d['lr'], 'clip': d['clip'], 'version': d['version'], 'epochs_done': d['epochs_done'], 'chdim': rec_W_channel_dim, # 'chcoords': rec_W_channel_coords, } from pprint import pprint as pp print pp(W_rec_args) from ..base import rgc_filename_str str_mapfile = rgc_filename_str( mode=d['mode'], clip=d['clip'], k=d['k'], p=d['p'], lr=d['lr'], vis=d['vis'], hid=d['hid'], ) import os mod_odir = os.path.join(args_odir, '../') from ..base import make_working_dir_sub work_dir = make_working_dir_sub(mod_odir, 'pr') from ..base import make_filename mapfile = make_filename(args_file, 'pruned_and_sorted_'+str_mapfile, '.map', odir=work_dir) from ..base.weightmatrix import save_2 save_2(filepath=mapfile+'.map', W=rec_W, W_args=W_rec_args, version=2, verbose=True) imagefile = make_filename(args_file, 'pruned_and_sorted_'+str_mapfile, '.png', odir=work_dir) '''write reconstructed map''' from ..base.plots import save_w_as_image save_w_as_image(X=rec_W, in_w=dic_W['vis'], in_h=dic_W['vis'], out_w=pruned_patch_h, out_h=pruned_patch_h, outfile=imagefile+'.png', mode=dic_W['mode'], dpi=288)
def cluster_map(args): import numpy as N """load precomputed fits""" from util import depickle_fits res = depickle_fits(args.file) """retrieving weightmatrix from fits dic""" W = res["map"]["W"] patch_w = res["map"]["vis"] channel_w = res["map"]["vis"] ** 2 W_mode = res["map"]["mode"] W_clip = res["map"]["clip"] rf_total = W.shape[0] rf_non_zeros = len(res["fits"].keys()) rf_zeros = rf_total - rf_non_zeros rf_dead_perc = N.round(rf_zeros / float(rf_total), 2) print print rf_zeros, "of", rf_total, "RF are zeros." print "ratio dead/total:", rf_dead_perc print """generate cluster data""" cluster_data = __collect_cluster_data( W_mode, patch_w, channel_w, W, W_clip, depickled_fits=res, reconstr=args.rec, csp=args.csp, chrm=args.chr, err=args.err, nz=args.nz, surround=args.surr, ) """cluster obs""" idx, args.n = __apply_cluster_alg(cluster_data=cluster_data, alg=args.alg, prior_cluster_num=args.n, t=args.t) """assign cluster id to fits""" num_types = N.zeros(args.n) proto_color_arr = [[] for prot in [None] * args.n] keys = sorted(res["fits"].keys()) for i, key in enumerate(keys): fit = res["fits"][key] num_types[idx[i]] += 1 proto_color_arr[idx[i]].append(fit["color_rgb"]) prototype_color = [N.mean(prot, axis=0) for prot in proto_color_arr] from pprint import pprint as pp # print 'proto colors' # print pp(N.round(prototype_color,1)) """sort prototype colors ON / OFF super ugly""" val_proto_on, val_proto_off = [], [] for pr in prototype_color: fmax = N.max(pr) on = fmax > 0.7 # can fail on not enough converged maps if on: val_proto_on.append((pr[0], pr[1], pr[2], N.argmax(pr))) else: val_proto_off.append((pr[0], pr[1], pr[2], N.argmin(pr))) p_dtype = [("r", float), ("g", float), ("b", float), ("id", int)] proto_on = N.sort(N.array(val_proto_on, dtype=p_dtype), order=["id"]) proto_on_flat = [[pr[0], pr[1], pr[2]] for pr in proto_on] proto_off = N.sort(N.array(val_proto_off, dtype=p_dtype), order=["id"]) proto_off_flat = [[pr[0], pr[1], pr[2]] for pr in proto_off] if proto_on_flat == []: s_prototype_color = proto_off_flat elif proto_off_flat == []: s_prototype_color = proto_on_flat else: s_prototype_color = N.concatenate( [[[pr[0], pr[1], pr[2]] for pr in proto_on], [[pr[0], pr[1], pr[2]] for pr in proto_off]] ) sorted_ids = [] sort_ch = [None] * args.n for i, pr in enumerate(prototype_color): sorted_ids.append(N.where(s_prototype_color == pr)[0][0]) for i, sid in enumerate(sorted_ids): sort_ch[sid] = i s_num_types = [num_types[i] for i in sort_ch] s_idx = N.copy(idx) for i in xrange(0, len(sort_ch)): s_idx[idx == sort_ch[i]] = i s_num_types[i] = num_types[sort_ch[i]] """fold opposing channels""" if args.fold: assert args.n % 2 == 0, "number of clusters need to be even in order to fold opposing clusters." """leaving the prototype colors intact - just reducing the s_idx entries by half. so magnitude |prototype_color| = 6 is not touched but entries in s_idx originally ranging from 0 to 5 will be reduced to 0 to 2. so all entries in prototype_color are invalid in terms of real clusters contents. ... ugly""" fold_n = args.n / 2 fold_num_types = [0 for i in xrange(0, args.n)] fold_idx = N.copy(s_idx) from math import ceil for i in xrange(0, fold_n): fold_i = int(ceil(fold_n + i)) # fold_idx[fold_idx==i] = i fold_idx[fold_idx == fold_i] = i fold_num_types[i] = s_num_types[i] + s_num_types[fold_i] s_idx = N.copy(fold_idx) s_num_types = N.copy(fold_num_types) args.n = fold_n print "sorted proto colors" print pp(N.round(s_prototype_color, 1)) for i in xrange(0, args.n): print i, "\t", int(s_num_types[i]), "\t" print sort_ch print s_idx """write cluster membership into fit data""" keys = sorted(res["fits"].keys()) for i, key in enumerate(keys): res["fits"][key]["cl"] = s_idx[i] from ..base import make_working_dir_sub cluster_dir = make_working_dir_sub(args.odir, "cl") """write clustered data""" fname = __write_clustered_fits( args, num_types_list=N.copy(s_num_types), prototype_color=N.copy(s_prototype_color), idx_list=s_idx, depickled_fits=res, odir=cluster_dir, num_dead_rf=rf_zeros, per_dead_rf=rf_dead_perc, ) if args.plot: from ..base import make_filename filepath = make_filename(args.file, "image__" + fname, ".png", odir=cluster_dir) __plot_clusters_coverage( filepath=filepath, res=res, args_plain=False, args_scale=1.2, args_a=1, args_pad=0.01, args_dpi=288 ) if args.pr: import os from prune import prune_map prune_map(args_file=os.path.join(cluster_dir, fname + ".cfits"), args_odir=cluster_dir)
def plot_clusters_colorspace(args): """loading precomputed fits""" from util import depickle_fits res = depickle_fits(args.file, suffix="cfits") args_n = res["dic_cluster"]["num_clusters"] idx = res["dic_cluster"]["cluster_index_list"] prototype_color = res["dic_cluster"]["prototype_cluster"] cluster_data = [] keys = sorted(res["fits"].keys()) for i, key in enumerate(keys): fit = res["fits"][key] cluster_data.append(fit["color_rgb"]) """normalize cluster data""" cluster_data -= N.min(cluster_data) cluster_data /= N.max(cluster_data) """plot""" import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D Axes3D(plt.figure()) import matplotlib.gridspec as gridspec if args.single: gs = gridspec.GridSpec(1, 1) plt.rcParams["axes.labelsize"] = 18 else: gs = gridspec.GridSpec(3, 1) plt.rcParams["axes.linewidth"] = 0.5 plt.rcParams["axes.labelsize"] = 5 plt.rcParams["lines.linewidth"] = 0.1 fig = plt.figure() markers = args_n * ["."] markers[0] = "o" markers[1] = "h" markers[2] = "p" markers[3] = "^" markers[4] = "v" if args_n > 5: markers[5] = "d" def make_plot(fig, pos, proj="3d", fontsize=9): if proj == "3d": ax = fig.add_subplot(pos, aspect="equal", projection=proj, xmargin=0) ax.set_xlim([0.05, 0.95]) ax.set_ylim([0.05, 0.95]) ax.set_zlim([0.05, 0.95]) ax.set_xlabel("red") ax.set_ylabel("green") ax.set_zlabel("blue") ax.grid(True) # ax.xaxis.set_rotate_label(False) # ax.yaxis.set_rotate_label(False) # ax.zaxis.set_rotate_label(False) ax.xaxis._axinfo["tick"]["inward_factor"] = 0 ax.xaxis._axinfo["tick"]["outward_factor"] = 0.4 ax.yaxis._axinfo["tick"]["inward_factor"] = 0 ax.yaxis._axinfo["tick"]["outward_factor"] = 0.4 ax.zaxis._axinfo["tick"]["inward_factor"] = 0 ax.zaxis._axinfo["tick"]["outward_factor"] = 0.4 ax.zaxis._axinfo["tick"]["outward_factor"] = 0.4 [t.set_va("center") for t in ax.get_yticklabels()] [t.set_ha("left") for t in ax.get_yticklabels()] [t.set_va("center") for t in ax.get_xticklabels()] [t.set_ha("right") for t in ax.get_xticklabels()] [t.set_va("center") for t in ax.get_zticklabels()] [t.set_ha("left") for t in ax.get_zticklabels()] ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis._axinfo["label"]["space_factor"] = 2.8 ax.yaxis._axinfo["label"]["space_factor"] = 2.8 ax.zaxis._axinfo["label"]["space_factor"] = 3.5 ax.xaxis._axinfo["axisline"]["line_width"] = 3.0 if args.single: pass else: ax.tick_params(axis="x", labelsize=3.5) ax.tick_params(axis="y", labelsize=3.5) ax.tick_params(axis="z", labelsize=3.5) ax.set_xticks([0.2, 0.5, 0.8]) ax.set_yticks([0.2, 0.5, 0.8]) ax.set_zticks([0.2, 0.5, 0.8]) else: ax = fig.add_subplot(pos, projection=proj) ax.set_xlim([0.0, 1]) ax.set_ylim([0.0, 1]) ax.xaxis.labelpad = 2 ax.yaxis.labelpad = 2 ax.tick_params(axis="x", labelsize=4, pad=2, length=3, width=0.05) ax.tick_params(axis="y", labelsize=4, pad=2, length=3, width=0.05) ax.set_xticks([0, 0.2, 0.5, 0.8, 1]) ax.set_yticks([0, 0.2, 0.5, 0.8, 1]) return ax def make_3d_plot(fig, grid_spec): ax = make_plot(fig, grid_spec, proj="3d") if args.single: s = 84 lw = 0.2 else: s = 8 lw = 0.1 # ax.view_init(15, -60) for k in xrange(0, args_n): ax.scatter( cluster_data[idx == k, 0], cluster_data[idx == k, 1], cluster_data[idx == k, 2], c=prototype_color[k], marker=".", linewidth=lw, s=s, alpha=1.0, ) def make_2d_plot(fig, grid_spec, axis_id, labels=("red", "green"), fontsize=9): ax = make_plot(fig, grid_spec, proj=None) if args.plain: ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) ax.set_title("") ax.axis("off") else: ax.set_xlabel(labels[0]) ax.set_ylabel(labels[1]) ax.set_aspect(1) for k in xrange(0, args_n): ax.scatter( cluster_data[idx == k, axis_id[0]], cluster_data[idx == k, axis_id[1]], c=prototype_color[k], marker=markers[k], linewidth=0.01, s=6.0, alpha=1.0, ) ax.grid = True return ax if args.single: make_3d_plot(fig, gs[0]) else: make_2d_plot(fig, gs[0, 0], (2, 0), ("blue", "red"), fontsize=12) make_2d_plot(fig, gs[1, 0], (0, 1), ("red", "green"), fontsize=12) make_2d_plot(fig, gs[2, 0], (1, 2), ("green", "blue"), fontsize=12) if not args.odirnosub: import os mod_odir = os.path.join(args.odir, "../") from ..base import make_working_dir_sub cluster_dir = make_working_dir_sub(mod_odir, "cl") else: cluster_dir = args.odir if args.plain: plain_str = "plain_" else: plain_str = "" from ..base import make_filename filepath = make_filename(args.file, "image_clustered_rgc_colorspace_" + plain_str, ".png", odir=cluster_dir) if args.single: plt.savefig(filepath, dpi=args.dpi, pad_inches=args.pad) else: plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=0.5) plt.savefig(filepath, dpi=args.dpi, pad_inches=args.pad, bbox_inches="tight") plt.close(fig) print "plotted cluster data colorspace, file", filepath, "written."
def make_proto_filter( args_file, args_odir, args_model, synthetic_pos=True, mean_rec=True, max_dist_to_center=2, meanlen=None, min_err_metric=True, abs_weight_threshold=0.7, indices=None, debug=False, odir_nosubdir=False, ): hand_automatic = indices == None '''loading precomputed fits''' from util import depickle_fits res = depickle_fits(args_file, suffix='cfits') '''retrieving weightmatrix from fits dic''' W = res['map']['W'] vis = res['map']['vis'] map_as_dict = res['map'] assert map_as_dict['mode'] == 'rgb' '''get fit and cluster data''' fits = res['fits'] fit_keys = fits.keys() model = res['model'] cl = res['dic_cluster'] num_channel = cl['num_clusters'] '''move RF ids in an appropiate data struc, ignore zeros''' clusters_by_index = [[] for c in [None]*num_channel] keys = sorted(fit_keys) for i, key in enumerate(keys): cl = fits[key]['cl'] clusters_by_index[cl].append(key) from ..base.receptivefield import convert_rfvector_to_rgbmatrix from ..base.plots import colormap_for_mode cmap = colormap_for_mode(map_as_dict['mode']) from ..base import make_filename from cluster import __transpose_zero_to_one if mean_rec: '''of every channel mean the fitted parameters''' import numpy as N num_param = len(fits[fit_keys[-1]]['p']) clusters_p_mean = [N.zeros(num_param) for c in [None]*num_channel] for c in xrange(0,len(clusters_by_index)): cl_ids = clusters_by_index[c] if meanlen is None: len_cl_id = len(cl_ids) else: len_cl_id = meanlen p_mean_tmp = [] for i in xrange(0, len_cl_id): p = fits[cl_ids[i]]['p'] p_mean_tmp.append(p) p_mean = N.mean(p_mean_tmp, axis=0) clusters_p_mean[c] = N.copy(p_mean) if model=='dog': from dog import reconstruct else: from edog import reconstruct else: import numpy as N hand_by_index = [[] for c in [None]*num_channel] zpx = vis/2. zpy = vis/2. max_absmax = 0 if not hand_automatic: '''use provided indices ... handchosen''' for i, index in enumerate(indices): if i == num_channel: break hand_by_index[i] = index else: '''chose automatically: for each channel find the RF with smallest error (and max value > .7), in distance close to the center vis field.''' for c in xrange(0,len(clusters_by_index)): cl_ids = clusters_by_index[c] len_cl_id = len(cl_ids) min_err, min_err_id, min_dist, min_id = N.inf, -1, N.inf, -1 for i in xrange(0, len_cl_id): p = fits[cl_ids[i]]['real_pix_center_coords'] n = fits[cl_ids[i]]['n'] err = fits[cl_ids[i]]['err'] dist = N.sqrt((zpy - p[0])**2 + (zpx - p[1])**2) absmax = N.max(N.abs(W[n])) if err < min_err and absmax > abs_weight_threshold and dist < max_dist_to_center: min_err = err min_err_id = n if dist < min_dist and absmax > abs_weight_threshold: min_dist = dist min_id = n '''statistic''' if absmax > max_absmax: max_absmax = absmax if min_err_metric: hand_by_index[c] = min_err_id else: hand_by_index[c] = min_id '''store distance of fitted center point to center of visual field''' trans_rf = [] for n in hand_by_index: if type(n) == list: print 'Not enough RF indices for all channels given.\nnum channels:', num_channel, 'num indices:', len(indices), '\n' exit() if n == -1: print 'No prototype RF found. \nTry lowering wheight treshold\nparameter -thr is:', abs_weight_threshold, 'RF max abs weight:', absmax, '\n' exit() p = fits[n]['real_pix_center_coords'] dist_y = int(N.floor(zpy - p[0])) dist_x = int(N.floor(zpx - p[1])) trans_rf.append( (dist_y, dist_x) ) filters = [] plots = [] for c in xrange(0,num_channel): if mean_rec: p = clusters_p_mean[c] if synthetic_pos: p = N.concatenate([[map_as_dict['vis']/2., map_as_dict['vis']/2.], p[2:]]) proto_filter = reconstruct(p, map_as_dict['mode'], map_as_dict['vis']**2, map_as_dict['vis'], map_as_dict['W'][-1].shape) else: proto_filter = W[hand_by_index[c]] '''convert RF vector to matrix and normalize values''' proto_filter_matr = convert_rfvector_to_rgbmatrix( proto_filter, map_as_dict['vis'], map_as_dict['vis'], map_as_dict['mode']) proto_filter_matr = __transpose_zero_to_one(proto_filter_matr) '''move RF to visual fields center''' if not mean_rec: proto_filter_matr = N.roll(proto_filter_matr, trans_rf[c][0], axis=0) proto_filter_matr = N.roll(proto_filter_matr, trans_rf[c][1], axis=1) plots.append({ # 'name':('RF: '+str(hand_by_index[c])+' ' if not hand_automatic else '') + 'ch: '+str(c), 'name':'RF: '+str(hand_by_index[c])+' ' + ' ch: '+str(c), 'value':proto_filter_matr, 'maxmin':True, 'cmap':cmap, 'balance':False, 'invert':False, }) filters.append(N.copy(proto_filter_matr)) if not mean_rec: if hand_automatic: misc_str = 'auto_' else: misc_str = 'hand_' if min_err_metric: metr_str = 'err_' else: metr_str = 'dist_' else: misc_str = 'mean_' metr_str = '' if not odir_nosubdir: import os mod_odir = os.path.join(args_odir, '../') from ..base import make_working_dir_sub work_dir = make_working_dir_sub(mod_odir, 'proto') else: work_dir = args_odir fname = make_filename(args_file,misc_str+metr_str+'conv_proto','.png', work_dir) def numplots_to_rowscols(num): sq = int(num**.5)+1 return sq, sq from ..base.plots import write_row_col_fig row, col = numplots_to_rowscols(num_channel) write_row_col_fig(plots, row, col, fname+'.png', dpi=144, alpha=1.0, fontsize=6) dic = { 'mode': res['map']['mode'], 'num_chn': num_channel, 'filters': filters } writefilename = make_filename(args_file,misc_str+metr_str+'conv_proto','.kern', work_dir) from util import pickle_fits pickle_fits(writefilename+'.kern', dic) if debug: from ..base.plots import write_rf_fit_debug_fig for i, fmat in enumerate(filters): fmat = N.swapaxes(fmat, 0, 1) fvec = fmat.reshape(fmat.shape[0]*fmat.shape[1], fmat.shape[2]) fvecflat = N.copy(N.concatenate([fvec.T[0].T, fvec.T[1].T, fvec.T[2].T])) fvecflat -= .5 fvecflat *= 2. p = fits[hand_by_index[i]]['p'] p[0] = p[0] + trans_rf[i][0] p[1] = p[1] + trans_rf[i][1] if res['map']['mode'] == 'dog': from dog import reconstruct else: from edog import reconstruct rec = reconstruct(p, res['map']['mode'], vis**2, vis, fvecflat.shape) err = (fvecflat - rec)**2 # err = None # fname = make_filename(args_file,str(i)+'_debug','.png', work_dir) fname = work_dir + '/' + str(i)+'_debug'+'.png' write_rf_fit_debug_fig(fname, fvecflat, vis, 'rgb', p, rec, err, model, scale=2.8, s_scale=3.8, alpha=.5, dpi=300, draw_ellipsoid=True, no_title=True, ellipsoid_line_width=1.2)
def convole_image_with_filter( args_file, input_image, odir = None, lum_channels = False, odir_lum_channels = False, odir_nosub = False, ): from util import depickle_fits res = depickle_fits(args_file, suffix='kern') kernels = res['filters'] import os if odir == None: odir = os.path.dirname(args_file) if not odir_nosub: mod_odir = os.path.join(odir, '../') from ..base import make_working_dir_sub work_dir = make_working_dir_sub(mod_odir, 'conv') else: work_dir = odir if input_image == None: from scipy import misc img = misc.lena()*-1 else: from ..base.images import read_image img = read_image(input_image, verbose=True) import numpy as N from scipy import signal def filter_rgb(image, kernel): r = signal.fftconvolve(image.T[0].T, kernel.T[0].T, mode='same')*.5 # full, same, valid r += signal.fftconvolve(image.T[0].T, kernel.T[0], mode='same')*.5 g = signal.fftconvolve(image.T[1].T, kernel.T[1].T, mode='same')*.5 g += signal.fftconvolve(image.T[1].T, kernel.T[1], mode='same')*.5 b = signal.fftconvolve(image.T[2].T, kernel.T[2].T, mode='same')*.5 b += signal.fftconvolve(image.T[2].T, kernel.T[2], mode='same')*.5 # r = signal.convolve2d(image.T[0].T, kernel.T[0].T, boundary='symm', mode='same') # g = signal.convolve2d(image.T[1].T, kernel.T[1].T, boundary='symm', mode='same') # b = signal.convolve2d(image.T[2].T, kernel.T[2].T, boundary='symm', mode='same') return r, g, b def write_image(arr, name, suffix='.png', mode=None): from ..base.plots import colormap_for_mode if mode==None: cmap = colormap_for_mode(res['mode']) else: cmap = colormap_for_mode(mode) from ..base import make_filename imagefile = make_filename(args_file, name, suffix, work_dir) import matplotlib.image as pltimg pltimg.imsave(imagefile, arr, dpi=144, cmap=cmap) def wr_rgb(a,b,c,name): from ..base.plots import normalize_color print name, ' \t\t', N.max(N.dstack([a, b, c]))-N.min(N.dstack([a, b, c])) write_image(normalize_color(N.dstack([a, b, c])), name) def wr_lum(a,name,invert=False): from ..base.plots import normalize_color if invert: write_image(normalize_color(a)*-1, name, mode='lum') else: write_image(normalize_color(a), name, mode='lum') if len(img.shape) > 2: assert len(kernels) >= 5 ch0_r, ch0_g, ch0_b = filter_rgb(img, kernels[0]-N.mean(kernels[0], axis=0)) ch1_r, ch1_g, ch1_b = filter_rgb(img, kernels[1]-N.mean(kernels[1], axis=0)) ch2_r, ch2_g, ch2_b = filter_rgb(img, kernels[2]-N.mean(kernels[2], axis=0)) ch3_r, ch3_g, ch3_b = filter_rgb(img, kernels[3]-N.mean(kernels[3], axis=0)) ch4_r, ch4_g, ch4_b = filter_rgb(img, kernels[4]-N.mean(kernels[4], axis=0)) if len(kernels) == 5: wr_rgb(ch1_r, ch1_g, ch1_b, '0_red') wr_rgb(ch0_r, ch0_g, ch0_b, '1_green') wr_rgb(ch2_r, ch2_g, ch2_b, '2_blue') wr_rgb(ch3_r, ch3_g, ch3_b, '3_white') wr_rgb(ch4_r, ch4_g, ch4_b, '4_black') wr_rgb(ch1_r-ch0_r, ch1_g-ch0_g, ch1_b-ch0_b, '5_red_green') wr_rgb(ch0_r-ch1_r, ch0_g-ch1_g, ch0_b-ch1_b, '6_green_red') wr_rgb(ch2_r-(ch0_r+ch1_r)/2., ch2_g-(ch0_g+ch1_g)/2., ch2_b-(ch0_b+ch1_b)/2., '7_blue_yellow') if len(kernels) == 6: ch5_r, ch5_g, ch5_b = filter_rgb(img, kernels[5]-N.mean(kernels[5], axis=0)) if not odir_lum_channels: if not lum_channels: wr_rgb(ch0_r, ch0_g, ch0_b, '0_red_on') wr_rgb(ch1_r, ch1_g, ch1_b, '1_green_on') wr_rgb(ch2_r, ch2_g, ch2_b, '2_blue_on') wr_rgb(ch4_r, ch4_g, ch4_b, '3_red_off') wr_rgb(ch3_r, ch3_g, ch3_b, '4_green_off') wr_rgb(ch5_r, ch5_g, ch5_b, '5_blue_off') wr_rgb(ch0_r+ch4_r, ch0_g+ch4_g, ch0_b+ch4_b, '6_red_on_green_off') wr_rgb(ch3_r-ch1_r, ch3_g-ch1_g, ch3_b-ch1_b, '6_green_on_red_off') else: wr_rgb(ch0_r, ch0_g, ch0_b, '0_luminosity_ON') wr_rgb(ch1_r, ch1_g, ch1_b, '1_luminosity_OFF') wr_rgb(ch2_r, ch2_g, ch2_b, '2_red_on') wr_rgb(ch4_r, ch4_g, ch4_b, '3_green_on') wr_rgb(ch3_r, ch3_g, ch3_b, '4_red_off') wr_rgb(ch5_r, ch5_g, ch5_b, '5_green_off') wr_rgb(ch2_r+ch5_r, ch2_g+ch5_g, ch2_b+ch5_b, '6_red_on_green_off') wr_rgb(ch4_r-ch3_r, ch4_g-ch3_g, ch4_b-ch3_b, '6_green_on_red_off') else: if not lum_channels: write_image((ch0_r+ch0_g+ch0_b)*-1, 'lum_0_red_on', mode='lum') write_image((ch1_r+ch1_g+ch1_b)*-1, 'lum_1_green_on', mode='lum') write_image((ch2_r+ch2_g+ch2_b)*-1, 'lum_2_blue_on', mode='lum') write_image((ch4_r+ch4_g+ch4_b)*-1, 'lum_3_red_off', mode='lum') write_image((ch3_r+ch3_g+ch3_b)*-1, 'lum_4_green_off', mode='lum') write_image((ch5_r+ch5_g+ch5_b)*-1, 'lum_5_blue_off', mode='lum') write_image((ch0_r+ch4_r+ch0_g+ch4_g+ch0_b+ch4_b)*-1, 'lum_6_red_on_green_off', mode='lum') write_image((ch3_r+ch1_r+ch3_g+ch1_g+ch3_b+ch1_b)*-1, 'lum_6_green_on_red_off', mode='lum') else: write_image((ch0_r+ch0_g+ch0_b)*-1, 'lum_0_luminosity_ON', mode='lum') write_image((ch1_r+ch1_g+ch1_b)*-1, 'lum_1_luminosity_OFF', mode='lum') write_image((ch2_r+ch2_g+ch2_b)*-1, 'lum_2_red_on', mode='lum') write_image((ch4_r+ch4_g+ch4_b)*-1, 'lum_3_green_on', mode='lum') write_image((ch3_r+ch3_g+ch3_b)*-1, 'lum_4_red_off', mode='lum') write_image((ch5_r+ch5_g+ch5_b)*-1, 'lum_5_green_off', mode='lum') write_image((ch2_r+ch5_r+ch2_g+ch5_g+ch2_b+ch5_b)*-1, 'lum_6_red_on_green_off', mode='lum') write_image((ch4_r+ch3_r+ch4_g+ch3_g+ch4_b+ch3_b)*-1, 'lum_6_green_on_red_off', mode='lum') else: on_filterd = signal.convolve2d(img, kernels[0], boundary='symm', mode='same') off_filterd = signal.convolve2d(img, kernels[1], boundary='symm', mode='same') img_filterd = on_filterd - off_filterd write_image(img_filterd*-1, 'lum_convole') write_image(on_filterd, '0_lum_on') write_image(off_filterd*-1, '1_lum_off') # import os # working_dir = os.path.split(args_file)[0] # original = os.path.join(working_dir, 'org.png') # if not os.path.exists(original): write_image(img, 'org')