Example #1
0
def print_cluster_data(args):

    """loading precomputed fits"""
    from util import depickle_fits

    res = depickle_fits(args.file, suffix="cfits")

    """retrieving weightmatrix from fits dic"""
    W = res["map"]["W"]
    dic_W = res["map"]

    """get fit and cluster data"""
    cl = res["dic_cluster"]
    num_dead = cl["num_dead_rf"]
    alive = W.shape[0] - num_dead

    import pprint as pp

    print

    print "*****+ Trained map data"
    print "visible units", W.shape[1]
    print "hidden units", W.shape[0]
    print "alive hidden units", alive, "(", N.round(alive / float(W.shape[0]), 2), "% )"
    print "*****+ Map Parameter"
    print pp.pprint([(key, dic_W[key]) if key != "W" else () for key in dic_W.keys()])

    print "*****+ Cluster data"
    print "num cluster", cl["num_clusters"]
    print "num rf in each cluster", cl["num_each_cluster"]
    print "prototype weights each cluster", N.round(cl["prototype_cluster"], 2)
    print "*****+ Cluster Parameters"
    print pp.pprint(cl["args"])
Example #2
0
def rec_map_from_pickled_result(
	map_file,
	odir
	):
	from util import depickle_fits
	dic = depickle_fits(map_file)

	W_rec = __rec_map(dic['map'], dic['fits'], dic['model'])

	from ..base import make_working_dir_sub
	work_dir = make_working_dir_sub(odir, 'rec')

	from ..base import make_filename
	writefilename = make_filename(map_file,
		'rec_'+dic['model']+'_rgc',
		'.png', odir=work_dir)

	from ..base.plots import save_w_as_image
	save_w_as_image(W_rec, dic['map']['vis'], dic['map']['vis'], dic['map']['hid'], dic['map']['hid'], 
		mode=dic['map']['mode'], 
		outfile=writefilename, 
		verbose=False)

	print 
	print 'reconstruction ', writefilename+'.png', ' written.'
Example #3
0
def plot_cluster_cov(args):
    """loading precomputed fits"""
    from util import depickle_fits

    res = depickle_fits(args.file, suffix="cfits")

    import os

    mod_odir = os.path.join(args.odir, "../")
    from ..base import make_working_dir_sub

    cluster_dir = make_working_dir_sub(mod_odir, "cl")
    if args.plain:
        plain_str = "plain_"
    else:
        plain_str = ""
    from ..base import make_filename

    filepath = make_filename(args.file, "image_clustered_rgc_" + plain_str, ".png", odir=cluster_dir)

    __plot_clusters_coverage(
        filepath=filepath,
        res=res,
        args_plain=args.plain,
        args_scale=args.scale,
        args_a=args.a,
        args_pad=args.pad,
        args_dpi=args.dpi,
    )
Example #4
0
def print_pickled_fits(
	map_file,
	complete=False,
	):
	from util import depickle_fits
	dic = depickle_fits(map_file, exit_on_errer=False)
	if not dic:
		dic = depickle_fits(map_file, suffix='cfits')

	from pprint import pprint
	if not complete:
		keys = dic.keys()
		for key in keys:
			print key
			if type(dic[key]) == dict:
				if key=='map' or key=='dic_cluster':
					for k in dic[key].keys():
						print '\t', k, ':', dic[key][k]
				else:
					print '\t', dic[key].keys()
			else:
				print '\t', dic[key]
	else:	
		print pprint(dic)
Example #5
0
def make_size_hist(
	args_file,
	args_odir,
	odir_nosubdir = False,
	luminosity_channels = False,
	):
	import numpy as N


	plt.rcParams['xtick.major.pad']='12'
	plt.rcParams['ytick.major.pad']='12'
	plt.rc('axes', linewidth = 0)


	'''loading precomputed fits'''
	from util import depickle_fits
	res = depickle_fits(args_file, suffix='cfits')
	

	'''get fit and cluster data'''
	fits = res['fits']
	fit_keys = fits.keys()
	cl = res['dic_cluster']
	num_channel = cl['num_clusters']
	colors = cl['prototype_cluster']

	perc_cluster = N.round(N.array(cl['num_each_cluster'])/len(fit_keys)*100)


	model = res['model']
	assert model=='edog'

	'''y axis: mean size of RF'''
	'''x axis: anzahl der RF'''

	areas_by_cluster = [[] for c in [None]*num_channel]

	keys = sorted(fit_keys)
	for i, key in enumerate(keys):
		clid = fits[key]['cl']

		'''params: 0: mu_x  1: mu_y  2: sigma_x  3: sigma_y  4: c_to_s_ratio
			   5: theta (rotation) 6: k_s (k_c is implicitly fixed as 1)
			   7: cdir_a, 8: cdir_b,  9: cdir_c
			   10: bias_a, 11: bias_b, 12: bias_c
		'''		
		p = fits[key]['p']
		sigma = p[2:4]
		c_to_s_ratio = p[4]
		''''RF size'''
		c_area = N.pi * sigma[0] * sigma[1] 
		s_area = c_area * c_to_s_ratio
		# apprarea = N.max([c_area, s_area])
		apprarea = N.mean([c_area, s_area])
		
		arr_area = areas_by_cluster[clid]
		arr_area.append(apprarea) 


	data = []
	for i, a in enumerate(areas_by_cluster):
		data.append((i, cl['num_each_cluster'][i], N.mean(a), int(perc_cluster[i])))
	sdata = sorted(data, key=lambda x: x[1])

	fig = plt.figure()
	ax = fig.add_subplot(111)


	perc_sum = 0
	for i, a in enumerate(sdata):
		ax.bar(perc_sum, a[2], a[3], color=colors[a[0]], alpha=.9, linewidth=0)
		perc_sum += a[3]
	ax.get_yaxis().set_ticks(N.arange(0,9,2))
	ax.get_xaxis().set_ticks(N.arange(0,101,25))


	# labels = []
	# for i, a in enumerate(sdata):
	# 	ax.bar([i], [a[2]], 1, color=colors[a[0]], alpha=1., linewidth=0, align='center')
	# 	labels.append(int(a[1]))
	# 	for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): 
	# 		m.set_fontsize(18)
	# ax.get_xaxis().set_ticklabels(labels)
	# ax.get_xaxis().set_ticks(N.arange(0,8,1))
	# ax.get_yaxis().set_ticks(N.arange(0,8,1))
	# # ax.get_yaxis().set_ticks([])
	# ax.set_aspect(1)
	# # ax.set_aspect(float(7./8))


	from ..base import make_filename	
	fname = make_filename(args_file,'hist_RF_sizes','.png', './')

	plt.savefig(fname, bbox_inches='tight', dpi=144)
	plt.close(fig)









	exit()



	if num_channel == 3:	labels = list('rgb')
	if num_channel == 5:	labels = list('wbrgb')
	if num_channel == 6:	labels = list('wbrgcm') if luminosity_channels else list('rgbcmy')



	data = N.zeros((N.max(cl['num_each_cluster']), num_channel))
	print data.shape
	for i in xrange(0, data.shape[0]):
		for j in xrange(0, data.shape[1]):
			if i < len(areas_by_cluster[j]):
				data[i,j] = areas_by_cluster[j][i]
	

	fig = plt.figure()
	ax = fig.add_subplot(111)





	# data = N.random.lognormal(size=(37, 6), mean=1.5, sigma=1.75)
	# labels = list('rgbcmy')
	ax.boxplot(data, labels=labels)

	# ax.set_yscale('log')
	# ax.set_yticklabels([])


	from ..base import make_filename	
	fname = make_filename(args_file,'hist_RF_sizes','.png', './')

	plt.savefig(fname, bbox_inches='tight', dpi=288)
	plt.close(fig)




	exit()

	


	assert num_channel == 3 or num_channel == 5 or num_channel == 6
	if num_channel == 3:	rows, cols = 1, 3
	if num_channel == 5:	rows, cols = 2, 3
	if num_channel == 6:	rows, cols = 2, 3
	fig, subplots = plt.subplots(nrows=rows, ncols=cols, sharex=False, sharey=False, squeeze=False)

	for r in range(0, rows):
		for c in range(0, cols):
			ax = subplots[r,c]
			ax.set_aspect(1)
			for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): m.set_fontsize(10)

			i = r*cols+c
			if i >= num_channel: break
			areas = areas_by_cluster[i]

			areas_max = len(keys)/num_channel/2
			index = N.arange(0, areas_max+1, int(areas_max/4))
			print len(index), index

			X, Y = N.histogram(areas, bins=len(index))
			print X, Y
			# ax.bar(index, Y[0:len(index)], 1, color=colors[i], alpha=1.)





			# ax.hist(areas, bins=10, normed=False, color=colors[i], alpha=1., histtype='stepfilled')
			# num_max = len(keys)/num_channel/1.3  #len(areas)/2
			# ax.get_yaxis().set_ticks(N.arange(0, num_max+1, int(num_max/4)))
			
			# size_max = 20 #N.max(areas)			
			# ax.get_xaxis().set_ticks(N.arange(0, size_max+1, int(N.ceil(size_max/5))))
			# ax.set_aspect(float(size_max/num_max))
			# for m in [ax.title] + ax.get_xticklabels() + ax.get_yticklabels(): m.set_fontsize(8.5)




	exit()

	if not odir_nosubdir:
		import os
		mod_odir = os.path.join(args_odir, '../')
		from ..base import make_working_dir_sub
		work_dir = make_working_dir_sub(mod_odir, 'hist')
	else:
		work_dir = args_odir

	from ..base import make_filename
	fname = make_filename(args_file,'hist_RF_sizes','.png', work_dir)

	fig.subplots_adjust(hspace=0.25, wspace=0.25)
	plt.savefig(fname, bbox_inches='tight', dpi=288)
	plt.close(fig)
Example #6
0
def make_spatal_hist(args_file, args_odir, abs_weight_threshold=0.7):

    """loading precomputed fits"""
    from util import depickle_fits

    res = depickle_fits(args_file, suffix="cfits")

    """retrieving weightmatrix from fits dic"""
    map_as_dict = res["map"]
    W = map_as_dict["W"]
    vis = map_as_dict["vis"]
    ch_width = vis ** 2

    mode = map_as_dict["mode"]
    assert mode == "rgb"

    """get fit and cluster data"""
    fits = res["fits"]
    fit_keys = fits.keys()
    cl = res["dic_cluster"]
    num_cluster = cl["num_clusters"]

    spat_hist = []

    def make_record(loc, key, cl, ch, value):
        if abs(value) > abs_weight_threshold:
            return (loc, {"key": key, "cl": cl, "ch": ch, "value": value})
        else:
            return None

    keys = sorted(fit_keys)
    for i, key in enumerate(keys):
        cl = fits[key]["cl"]
        rf = W[key]
        for j in xrange(0, ch_width):
            sploc = (j / vis, j % vis)

            vr = make_record(loc=sploc, key=key, cl=cl, ch=0, value=rf[j])
            if vr:
                spat_hist.append(vr)
            vg = make_record(loc=sploc, key=key, cl=cl, ch=1, value=rf[ch_width + j])
            if vg:
                spat_hist.append(vg)
            vb = make_record(loc=sploc, key=key, cl=cl, ch=2, value=rf[2 * ch_width + j])
            if vb:
                spat_hist.append(vb)

            # import pprint as pp
            # # print pp.pprint(spat_hist)

            # # print pp.pprint( __filter_ch(spat_hist, 0) )
            # # print pp.pprint( __filter_value(spat_hist, 0.2) )
            # print pp.pprint(__filter_cl(
            # 					__filter_ch(
            # 						__filter_loc_fromto(spat_hist, (1,1), (3,3)),
            # 					1),
            # 				0) )

    def hist_2d(ch, cl, onoff=None):
        """make a 2d histogramm"""
        import numpy as N

        hist_2d = N.zeros((vis, vis), dtype=float)

        def sign_ok(on, val):
            if on == None:
                return True
            else:
                if val > 0 and on:
                    return True
                elif val < 0 and not on:
                    return True
                else:
                    return False

        for i in xrange(0, len(spat_hist)):
            loc = spat_hist[i][0]
            if spat_hist[i][1]["ch"] == ch and spat_hist[i][1]["cl"] == cl and sign_ok(onoff, spat_hist[i][1]["value"]):
                # hist_2d[loc[0],loc[1]] += abs(spat_hist[i][1]['value']) #1
                hist_2d[loc[0], loc[1]] += 1
        return hist_2d

    import os

    mod_odir = os.path.join(args_odir, "../")
    from ..base import make_working_dir_sub

    work_dir = make_working_dir_sub(mod_odir, "hist")
    from ..base import make_filename

    def make_plots_of_channel(ch, ison=None, ch_name=None):
        plots = []
        str_name = ch_name + " " if ch_name != None else ""
        for cl in xrange(0, num_cluster):
            plots += [
                {
                    "name": str_name + " cluster: " + str(cl),
                    "value": hist_2d(ch, cl, ison),
                    "maxmin": False,
                    "cmap": "Greys",
                    "balance": True,
                    "invert": True,
                    "patch_width": vis,
                    "interp": "nearest",
                }
            ]
            #'catrom'}]
        return plots

    def hist_input(ch, ison=None, ch_name=None):
        plots = make_plots_of_channel(ch, ison, ch_name)

        if ison == None:
            miscstr = ""
        elif ison:
            miscstr = "_ON_"
        else:
            miscstr = "_OFF_"
        fname = make_filename(
            args_file, "thresh" + str(abs_weight_threshold) + "_hist_inputch_" + str(ch) + miscstr, ".png", work_dir
        )

        from ..base.plots import write_row_col_fig

        write_row_col_fig(plots, rows=2, cols=3, filepath=fname + ".png", dpi=144, alpha=1.0, fontsize=5.5)
        print "file:", fname + ".png", "written."

        # hist_input(2, ch_name='Blue')
        # hist_input(1, ch_name='Green')
        # hist_input(0, ch_name='Red')

    hist_input(2, ison=True, ch_name="input Blue ON")
    hist_input(2, ison=False, ch_name="input Blue OFF")
    hist_input(1, ison=True, ch_name="input Green ON")
    hist_input(1, ison=False, ch_name="input Green OFF")
    hist_input(0, ison=True, ch_name="input Red ON")
    hist_input(0, ison=False, ch_name="input Red OFF")

    def hist_complete(onoff=None):
        plots = []

        for inch in xrange(0, 3):
            plots += make_plots_of_channel(inch)
            if onoff != None:
                plots += make_plots_of_channel(inch, onoff)
                plots += make_plots_of_channel(inch, not onoff)

        fname = make_filename(args_file, "thresh" + str(abs_weight_threshold) + "_hist_", ".png", work_dir)

        from ..base.plots import numplots_to_rowscols

        rows, cols = numplots_to_rowscols(num_cluster * 9)

        print "plen", len(plots)
        print rows, cols

        from ..base.plots import write_row_col_fig

        write_row_col_fig(plots, rows=rows, cols=cols, filepath=fname + ".png", dpi=144, alpha=1.0, fontsize=12.5)
        print "file:", fname + ".png", "written."
Example #7
0
def prune_map(
	args_file,
	args_odir,
	sort_row_then_col=True,
	normalize_weights=False
	):
	import numpy as N


	'''loading precomputed fits'''
	from util import depickle_fits
	res = depickle_fits(args_file, suffix='cfits')


	'''retrieving weightmatrix from fits dic'''
	W = res['map']['W']
	dic_W = res['map']
	patch_w = dic_W['vis']

	print N.max(W), N.min(W)

	'''normalize weights'''
	if normalize_weights:
		W = W / N.max(N.abs(W))
	
	print N.max(W), N.min(W)

	'''get fit and cluster data'''
	fits = res['fits']
	cl = res['dic_cluster']


	num_channel = cl['num_clusters']


	from ..base.receptivefield import is_close_to_zero
	'''move RF ids in an appropiate data struc, ignore zeros'''
	num_zeros = W.shape[0] - len(fits.keys())
	clusters_by_index = [[] for c in [None]*num_channel]	
	keys = sorted(res['fits'].keys())
	for i, key in enumerate(keys):
		rf = W[key]
		abs_rf = N.abs(rf)
		abs_max = N.max( abs_rf )
		abs_min = N.min( abs_rf )
		value_spectrum = abs_max - abs_min
		if value_spectrum > 0.2 and \
		not is_close_to_zero(rf, verbose=False, atol=1e-02):
			cl = res['fits'][key]['cl']
			clusters_by_index[cl].append(key)



	'''vectorize fit center coords clusterwise'''
	sorted_clusters_by_index = [[] for s in [None]*num_channel]
	for c in xrange(0,len(clusters_by_index)):
		cl_ids = clusters_by_index[c]
		dtype = [('coord', float), ('id', int)]
		values = []
		for i in xrange(0, len(cl_ids)):
			p = fits[cl_ids[i]]['p']
			p01 = (int(N.round(p[0])), int(N.round(p[1])))
			if sort_row_then_col:
				'''vectorize col wise'''
				values.append( (p01[1]+p01[0]*patch_w, cl_ids[i]) )
			else:
				'''vectorize row wise'''
				values.append( (p01[0]+p01[1]*patch_w, cl_ids[i]) )
		vectorized_coords = N.array(values, dtype)
		for pair in N.sort(vectorized_coords):
			sorted_clusters_by_index[c].append( pair[1] )


	not_zero = W.shape[0] - num_zeros
	'''new build pruned map'''
	if dic_W['mode'] == 'rgb': n_ch = 3
	elif dic_W['mode'] == 'rg_vs_b': n_ch = 2
	elif dic_W['mode'] == 'rg': n_ch = 2
	else: n_ch = 1
	pruned_patch_h = int(N.ceil(not_zero**.5))+1#2
	pruned_W_shape = (pruned_patch_h**2, patch_w**2*n_ch)
	rec_W = N.zeros(pruned_W_shape)
	rec_W_channel_dim = [(0,0) for dim in [None]*num_channel]
	rec_W_index = 0
	# rec_W_channel_coords = []
	for c in xrange(0,len(sorted_clusters_by_index)):
		cl_ids = sorted_clusters_by_index[c]
		rec_W_channel_dim[c] = (rec_W_index, rec_W_index+len(cl_ids)-1)
		to_nextrow = rec_W_index % pruned_patch_h
		if to_nextrow != 0:
			rec_W_index += pruned_patch_h - to_nextrow
		'''coords: store fitted center coords for each id'''
		# rec_W_channel_coords.append([])
		for i in xrange(0, len(cl_ids)):
			p = fits[cl_ids[i]]['p']
			rec_W[rec_W_index] = W[cl_ids[i]]
			rec_W_index += 1
			# pair = (int(N.round(p[0])), int(N.round(p[1])))
			# rec_W_channel_coords[c].append(pair)
			


	'''write pruned map'''
	d = res['map']
	W_rec_args = {
		'hid': pruned_patch_h,
		'vis': d['vis'],
		'mode': d['mode'],
		'k': d['k'],
		'p': d['p'], 
		'lr': d['lr'], 
		'clip': d['clip'],
		'version': d['version'], 
		'epochs_done': d['epochs_done'], 
		'chdim': rec_W_channel_dim,
		# 'chcoords': rec_W_channel_coords,
	}
	from pprint import pprint as pp
	print pp(W_rec_args)

	from ..base import rgc_filename_str
	str_mapfile = rgc_filename_str(
		mode=d['mode'], clip=d['clip'], 
		k=d['k'], p=d['p'], lr=d['lr'], 
		vis=d['vis'], hid=d['hid'],
		)


	import os
	mod_odir = os.path.join(args_odir, '../')
	from ..base import make_working_dir_sub
	work_dir = make_working_dir_sub(mod_odir, 'pr')
	from ..base import make_filename
	mapfile = make_filename(args_file, 'pruned_and_sorted_'+str_mapfile, '.map', odir=work_dir)

	from ..base.weightmatrix import save_2
	save_2(filepath=mapfile+'.map', 
		W=rec_W, 
		W_args=W_rec_args, 
		version=2, 
		verbose=True)


	imagefile = make_filename(args_file, 'pruned_and_sorted_'+str_mapfile, 
		'.png', odir=work_dir)

	'''write reconstructed map'''
	from ..base.plots import save_w_as_image
	save_w_as_image(X=rec_W, 
		in_w=dic_W['vis'], in_h=dic_W['vis'],
		out_w=pruned_patch_h, out_h=pruned_patch_h,
		outfile=imagefile+'.png',
		mode=dic_W['mode'],
		dpi=288)
Example #8
0
def cluster_map(args):
    import numpy as N

    """load precomputed fits"""
    from util import depickle_fits

    res = depickle_fits(args.file)

    """retrieving weightmatrix from fits dic"""
    W = res["map"]["W"]
    patch_w = res["map"]["vis"]
    channel_w = res["map"]["vis"] ** 2
    W_mode = res["map"]["mode"]
    W_clip = res["map"]["clip"]

    rf_total = W.shape[0]
    rf_non_zeros = len(res["fits"].keys())
    rf_zeros = rf_total - rf_non_zeros
    rf_dead_perc = N.round(rf_zeros / float(rf_total), 2)
    print
    print rf_zeros, "of", rf_total, "RF are zeros."
    print "ratio dead/total:", rf_dead_perc
    print

    """generate cluster data"""
    cluster_data = __collect_cluster_data(
        W_mode,
        patch_w,
        channel_w,
        W,
        W_clip,
        depickled_fits=res,
        reconstr=args.rec,
        csp=args.csp,
        chrm=args.chr,
        err=args.err,
        nz=args.nz,
        surround=args.surr,
    )

    """cluster obs"""
    idx, args.n = __apply_cluster_alg(cluster_data=cluster_data, alg=args.alg, prior_cluster_num=args.n, t=args.t)

    """assign cluster id to fits"""
    num_types = N.zeros(args.n)
    proto_color_arr = [[] for prot in [None] * args.n]
    keys = sorted(res["fits"].keys())
    for i, key in enumerate(keys):
        fit = res["fits"][key]
        num_types[idx[i]] += 1
        proto_color_arr[idx[i]].append(fit["color_rgb"])
    prototype_color = [N.mean(prot, axis=0) for prot in proto_color_arr]

    from pprint import pprint as pp

    # print 'proto colors'
    # print pp(N.round(prototype_color,1))

    """sort prototype colors ON / OFF super ugly"""
    val_proto_on, val_proto_off = [], []
    for pr in prototype_color:
        fmax = N.max(pr)
        on = fmax > 0.7  # can fail on not enough converged maps
        if on:
            val_proto_on.append((pr[0], pr[1], pr[2], N.argmax(pr)))
        else:
            val_proto_off.append((pr[0], pr[1], pr[2], N.argmin(pr)))

    p_dtype = [("r", float), ("g", float), ("b", float), ("id", int)]
    proto_on = N.sort(N.array(val_proto_on, dtype=p_dtype), order=["id"])
    proto_on_flat = [[pr[0], pr[1], pr[2]] for pr in proto_on]
    proto_off = N.sort(N.array(val_proto_off, dtype=p_dtype), order=["id"])
    proto_off_flat = [[pr[0], pr[1], pr[2]] for pr in proto_off]

    if proto_on_flat == []:
        s_prototype_color = proto_off_flat
    elif proto_off_flat == []:
        s_prototype_color = proto_on_flat
    else:
        s_prototype_color = N.concatenate(
            [[[pr[0], pr[1], pr[2]] for pr in proto_on], [[pr[0], pr[1], pr[2]] for pr in proto_off]]
        )

    sorted_ids = []
    sort_ch = [None] * args.n
    for i, pr in enumerate(prototype_color):
        sorted_ids.append(N.where(s_prototype_color == pr)[0][0])
    for i, sid in enumerate(sorted_ids):
        sort_ch[sid] = i

    s_num_types = [num_types[i] for i in sort_ch]
    s_idx = N.copy(idx)
    for i in xrange(0, len(sort_ch)):
        s_idx[idx == sort_ch[i]] = i
        s_num_types[i] = num_types[sort_ch[i]]

    """fold opposing channels"""
    if args.fold:
        assert args.n % 2 == 0, "number of clusters need to be even in order to fold opposing clusters."
        """leaving the prototype colors intact - just reducing the s_idx entries by half.
		so magnitude |prototype_color| = 6 is not touched but entries in s_idx originally ranging from 0 to 5
		will be reduced to 0 to 2. so all entries in prototype_color are invalid in terms of real clusters contents.
		... ugly"""
        fold_n = args.n / 2
        fold_num_types = [0 for i in xrange(0, args.n)]
        fold_idx = N.copy(s_idx)
        from math import ceil

        for i in xrange(0, fold_n):
            fold_i = int(ceil(fold_n + i))
            # fold_idx[fold_idx==i] = i
            fold_idx[fold_idx == fold_i] = i
            fold_num_types[i] = s_num_types[i] + s_num_types[fold_i]

        s_idx = N.copy(fold_idx)
        s_num_types = N.copy(fold_num_types)
        args.n = fold_n

    print "sorted proto colors"
    print pp(N.round(s_prototype_color, 1))
    for i in xrange(0, args.n):
        print i, "\t", int(s_num_types[i]), "\t"
    print sort_ch
    print s_idx

    """write cluster membership into fit data"""
    keys = sorted(res["fits"].keys())
    for i, key in enumerate(keys):
        res["fits"][key]["cl"] = s_idx[i]

    from ..base import make_working_dir_sub

    cluster_dir = make_working_dir_sub(args.odir, "cl")

    """write clustered data"""
    fname = __write_clustered_fits(
        args,
        num_types_list=N.copy(s_num_types),
        prototype_color=N.copy(s_prototype_color),
        idx_list=s_idx,
        depickled_fits=res,
        odir=cluster_dir,
        num_dead_rf=rf_zeros,
        per_dead_rf=rf_dead_perc,
    )

    if args.plot:
        from ..base import make_filename

        filepath = make_filename(args.file, "image__" + fname, ".png", odir=cluster_dir)

        __plot_clusters_coverage(
            filepath=filepath, res=res, args_plain=False, args_scale=1.2, args_a=1, args_pad=0.01, args_dpi=288
        )

    if args.pr:
        import os
        from prune import prune_map

        prune_map(args_file=os.path.join(cluster_dir, fname + ".cfits"), args_odir=cluster_dir)
Example #9
0
def plot_clusters_colorspace(args):
    """loading precomputed fits"""
    from util import depickle_fits

    res = depickle_fits(args.file, suffix="cfits")

    args_n = res["dic_cluster"]["num_clusters"]
    idx = res["dic_cluster"]["cluster_index_list"]
    prototype_color = res["dic_cluster"]["prototype_cluster"]

    cluster_data = []

    keys = sorted(res["fits"].keys())
    for i, key in enumerate(keys):
        fit = res["fits"][key]
        cluster_data.append(fit["color_rgb"])

    """normalize cluster data"""
    cluster_data -= N.min(cluster_data)
    cluster_data /= N.max(cluster_data)

    """plot"""
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    Axes3D(plt.figure())
    import matplotlib.gridspec as gridspec

    if args.single:
        gs = gridspec.GridSpec(1, 1)
        plt.rcParams["axes.labelsize"] = 18

    else:
        gs = gridspec.GridSpec(3, 1)
        plt.rcParams["axes.linewidth"] = 0.5
        plt.rcParams["axes.labelsize"] = 5
        plt.rcParams["lines.linewidth"] = 0.1

    fig = plt.figure()

    markers = args_n * ["."]
    markers[0] = "o"
    markers[1] = "h"
    markers[2] = "p"
    markers[3] = "^"
    markers[4] = "v"
    if args_n > 5:
        markers[5] = "d"

    def make_plot(fig, pos, proj="3d", fontsize=9):
        if proj == "3d":
            ax = fig.add_subplot(pos, aspect="equal", projection=proj, xmargin=0)
            ax.set_xlim([0.05, 0.95])
            ax.set_ylim([0.05, 0.95])
            ax.set_zlim([0.05, 0.95])

            ax.set_xlabel("red")
            ax.set_ylabel("green")
            ax.set_zlabel("blue")
            ax.grid(True)
            # ax.xaxis.set_rotate_label(False)
            # ax.yaxis.set_rotate_label(False)
            # ax.zaxis.set_rotate_label(False)
            ax.xaxis._axinfo["tick"]["inward_factor"] = 0
            ax.xaxis._axinfo["tick"]["outward_factor"] = 0.4
            ax.yaxis._axinfo["tick"]["inward_factor"] = 0
            ax.yaxis._axinfo["tick"]["outward_factor"] = 0.4
            ax.zaxis._axinfo["tick"]["inward_factor"] = 0
            ax.zaxis._axinfo["tick"]["outward_factor"] = 0.4
            ax.zaxis._axinfo["tick"]["outward_factor"] = 0.4
            [t.set_va("center") for t in ax.get_yticklabels()]
            [t.set_ha("left") for t in ax.get_yticklabels()]
            [t.set_va("center") for t in ax.get_xticklabels()]
            [t.set_ha("right") for t in ax.get_xticklabels()]
            [t.set_va("center") for t in ax.get_zticklabels()]
            [t.set_ha("left") for t in ax.get_zticklabels()]
            ax.xaxis.pane.fill = False
            ax.yaxis.pane.fill = False
            ax.zaxis.pane.fill = False
            ax.xaxis._axinfo["label"]["space_factor"] = 2.8
            ax.yaxis._axinfo["label"]["space_factor"] = 2.8
            ax.zaxis._axinfo["label"]["space_factor"] = 3.5
            ax.xaxis._axinfo["axisline"]["line_width"] = 3.0
            if args.single:
                pass
            else:
                ax.tick_params(axis="x", labelsize=3.5)
                ax.tick_params(axis="y", labelsize=3.5)
                ax.tick_params(axis="z", labelsize=3.5)

            ax.set_xticks([0.2, 0.5, 0.8])
            ax.set_yticks([0.2, 0.5, 0.8])
            ax.set_zticks([0.2, 0.5, 0.8])

        else:
            ax = fig.add_subplot(pos, projection=proj)
            ax.set_xlim([0.0, 1])
            ax.set_ylim([0.0, 1])

            ax.xaxis.labelpad = 2
            ax.yaxis.labelpad = 2
            ax.tick_params(axis="x", labelsize=4, pad=2, length=3, width=0.05)
            ax.tick_params(axis="y", labelsize=4, pad=2, length=3, width=0.05)

            ax.set_xticks([0, 0.2, 0.5, 0.8, 1])
            ax.set_yticks([0, 0.2, 0.5, 0.8, 1])

        return ax

    def make_3d_plot(fig, grid_spec):
        ax = make_plot(fig, grid_spec, proj="3d")
        if args.single:
            s = 84
            lw = 0.2
        else:
            s = 8
            lw = 0.1
            # ax.view_init(15, -60)
        for k in xrange(0, args_n):
            ax.scatter(
                cluster_data[idx == k, 0],
                cluster_data[idx == k, 1],
                cluster_data[idx == k, 2],
                c=prototype_color[k],
                marker=".",
                linewidth=lw,
                s=s,
                alpha=1.0,
            )

    def make_2d_plot(fig, grid_spec, axis_id, labels=("red", "green"), fontsize=9):
        ax = make_plot(fig, grid_spec, proj=None)
        if args.plain:
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])
            ax.set_title("")
            ax.axis("off")
        else:
            ax.set_xlabel(labels[0])
            ax.set_ylabel(labels[1])
        ax.set_aspect(1)
        for k in xrange(0, args_n):
            ax.scatter(
                cluster_data[idx == k, axis_id[0]],
                cluster_data[idx == k, axis_id[1]],
                c=prototype_color[k],
                marker=markers[k],
                linewidth=0.01,
                s=6.0,
                alpha=1.0,
            )
        ax.grid = True
        return ax

    if args.single:
        make_3d_plot(fig, gs[0])
    else:
        make_2d_plot(fig, gs[0, 0], (2, 0), ("blue", "red"), fontsize=12)
        make_2d_plot(fig, gs[1, 0], (0, 1), ("red", "green"), fontsize=12)
        make_2d_plot(fig, gs[2, 0], (1, 2), ("green", "blue"), fontsize=12)

    if not args.odirnosub:
        import os

        mod_odir = os.path.join(args.odir, "../")
        from ..base import make_working_dir_sub

        cluster_dir = make_working_dir_sub(mod_odir, "cl")
    else:
        cluster_dir = args.odir

    if args.plain:
        plain_str = "plain_"
    else:
        plain_str = ""
    from ..base import make_filename

    filepath = make_filename(args.file, "image_clustered_rgc_colorspace_" + plain_str, ".png", odir=cluster_dir)

    if args.single:
        plt.savefig(filepath, dpi=args.dpi, pad_inches=args.pad)
    else:
        plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=0.5)
        plt.savefig(filepath, dpi=args.dpi, pad_inches=args.pad, bbox_inches="tight")

    plt.close(fig)
    print "plotted cluster data colorspace, file", filepath, "written."
Example #10
0
def make_proto_filter(
	args_file,
	args_odir,
	args_model,
	synthetic_pos=True,
	mean_rec=True,
	max_dist_to_center=2,
	meanlen=None,
	min_err_metric=True,
	abs_weight_threshold=0.7,
	indices=None,
	debug=False,
	odir_nosubdir=False,
	):
	hand_automatic = indices == None

	'''loading precomputed fits'''
	from util import depickle_fits
	res = depickle_fits(args_file, suffix='cfits')

	'''retrieving weightmatrix from fits dic'''
	W = res['map']['W']
	vis = res['map']['vis']
	map_as_dict = res['map']
	assert map_as_dict['mode'] == 'rgb'

	'''get fit and cluster data'''
	fits = res['fits']
	fit_keys = fits.keys()
	model = res['model']

	cl = res['dic_cluster']
	num_channel = cl['num_clusters']

	'''move RF ids in an appropiate data struc, ignore zeros'''
	clusters_by_index = [[] for c in [None]*num_channel]

	keys = sorted(fit_keys)
	for i, key in enumerate(keys):
		cl = fits[key]['cl']
		clusters_by_index[cl].append(key)

	from ..base.receptivefield import convert_rfvector_to_rgbmatrix
	from ..base.plots import colormap_for_mode
	cmap = colormap_for_mode(map_as_dict['mode'])	
	from ..base import make_filename
	from cluster import __transpose_zero_to_one

	if mean_rec:
		'''of every channel mean the fitted parameters'''
		import numpy as N
		num_param = len(fits[fit_keys[-1]]['p'])
		clusters_p_mean = [N.zeros(num_param) for c in [None]*num_channel]
		for c in xrange(0,len(clusters_by_index)):
			cl_ids = clusters_by_index[c]
			if meanlen is None:
				len_cl_id = len(cl_ids)
			else:
				len_cl_id = meanlen
			p_mean_tmp = []
			for i in xrange(0, len_cl_id):
				p = fits[cl_ids[i]]['p']
				p_mean_tmp.append(p)
			p_mean = N.mean(p_mean_tmp, axis=0)
			clusters_p_mean[c] = N.copy(p_mean)
		if model=='dog': from dog  import reconstruct
		else:			 from edog import reconstruct

	else:
		import numpy as N
		hand_by_index = [[] for c in [None]*num_channel]
		zpx = vis/2.
		zpy = vis/2.
		max_absmax = 0

		if not hand_automatic:
			'''use provided indices ... handchosen'''
			for i, index in enumerate(indices):
				if i == num_channel:
					break
				hand_by_index[i] = index
		else:
			'''chose automatically:
			for each channel find the RF with smallest error (and max value > .7), 
			in distance close to the center vis field.'''		
			for c in xrange(0,len(clusters_by_index)):
				cl_ids = clusters_by_index[c]
				len_cl_id = len(cl_ids)
				min_err, min_err_id, min_dist, min_id = N.inf, -1, N.inf, -1
				for i in xrange(0, len_cl_id):
					p = fits[cl_ids[i]]['real_pix_center_coords']
					n = fits[cl_ids[i]]['n']
					err = fits[cl_ids[i]]['err']
					dist = N.sqrt((zpy - p[0])**2 + (zpx - p[1])**2)
					absmax = N.max(N.abs(W[n]))
					if err < min_err and absmax > abs_weight_threshold and dist < max_dist_to_center:
						min_err = err
						min_err_id = n
					if dist < min_dist and absmax > abs_weight_threshold:
						min_dist = dist
						min_id = n
					'''statistic'''
					if absmax > max_absmax:
						max_absmax = absmax
				if min_err_metric:
					hand_by_index[c] = min_err_id
				else:
					hand_by_index[c] = min_id

		'''store distance of fitted center point to center of visual field'''
		trans_rf = []
		for n in hand_by_index:
			if type(n) == list:
				print 'Not enough RF indices for all channels given.\nnum channels:', num_channel, 'num indices:',  len(indices), '\n'
				exit()
			if n == -1:
				print 'No prototype RF found. \nTry lowering wheight treshold\nparameter -thr is:', abs_weight_threshold, 'RF max abs weight:', absmax, '\n'
				exit()
			
			p = fits[n]['real_pix_center_coords']
			dist_y = int(N.floor(zpy - p[0]))
			dist_x = int(N.floor(zpx - p[1]))
			trans_rf.append( (dist_y, dist_x) )
	


	filters = []
	plots = []
	for c in xrange(0,num_channel):
		if mean_rec:
			p = clusters_p_mean[c]
			if synthetic_pos:
				p = N.concatenate([[map_as_dict['vis']/2., map_as_dict['vis']/2.], p[2:]])
			proto_filter = reconstruct(p, 
									 map_as_dict['mode'], 
									 map_as_dict['vis']**2, 
									 map_as_dict['vis'], 
									 map_as_dict['W'][-1].shape)
		else:
			proto_filter = W[hand_by_index[c]]

		'''convert RF vector to matrix and normalize values'''
		proto_filter_matr = convert_rfvector_to_rgbmatrix(
			proto_filter, 
			map_as_dict['vis'],
			map_as_dict['vis'], 
			map_as_dict['mode'])		
		proto_filter_matr = __transpose_zero_to_one(proto_filter_matr)

		'''move RF to visual fields center'''
		if not mean_rec:
			proto_filter_matr = N.roll(proto_filter_matr, trans_rf[c][0], axis=0)
			proto_filter_matr = N.roll(proto_filter_matr, trans_rf[c][1], axis=1)

		plots.append({
			# 'name':('RF: '+str(hand_by_index[c])+' ' if not hand_automatic else '') + 'ch: '+str(c), 
			'name':'RF: '+str(hand_by_index[c])+' ' + ' ch: '+str(c), 
			'value':proto_filter_matr,
			'maxmin':True, 
			'cmap':cmap, 
			'balance':False,
			'invert':False,
			})
		filters.append(N.copy(proto_filter_matr))

	if not mean_rec:
		if hand_automatic:	misc_str = 'auto_'
		else:				misc_str = 'hand_'
		if min_err_metric:	metr_str = 'err_'
		else:				metr_str = 'dist_'
	else:
		misc_str = 'mean_'
		metr_str = ''


	if not odir_nosubdir:
		import os
		mod_odir = os.path.join(args_odir, '../')
		from ..base import make_working_dir_sub
		work_dir = make_working_dir_sub(mod_odir, 'proto')
	else:
		work_dir = args_odir
		
	fname = make_filename(args_file,misc_str+metr_str+'conv_proto','.png', work_dir)
	def numplots_to_rowscols(num):
		sq = int(num**.5)+1
		return sq, sq
	from ..base.plots import write_row_col_fig
	row, col = numplots_to_rowscols(num_channel)
	write_row_col_fig(plots, row, col, fname+'.png', dpi=144, alpha=1.0, fontsize=6)


	dic = {
		'mode': res['map']['mode'],
		'num_chn': num_channel,
		'filters': filters
	}

	writefilename = make_filename(args_file,misc_str+metr_str+'conv_proto','.kern', work_dir)
	from util import pickle_fits
	pickle_fits(writefilename+'.kern', dic)


	if debug:
		from ..base.plots import write_rf_fit_debug_fig
	
		for i, fmat in enumerate(filters):
			fmat = N.swapaxes(fmat, 0, 1)			

			fvec = fmat.reshape(fmat.shape[0]*fmat.shape[1], fmat.shape[2])
			fvecflat = N.copy(N.concatenate([fvec.T[0].T, fvec.T[1].T, fvec.T[2].T]))
			fvecflat -= .5
			fvecflat *= 2.

			p = fits[hand_by_index[i]]['p']
			p[0] = p[0] + trans_rf[i][0]
			p[1] = p[1] + trans_rf[i][1]

			if res['map']['mode'] == 'dog': from dog  import reconstruct
			else: 					  		from edog import reconstruct

			rec = reconstruct(p, res['map']['mode'], vis**2, vis, fvecflat.shape)
			err = (fvecflat - rec)**2
			# err = None

			# fname = make_filename(args_file,str(i)+'_debug','.png', work_dir)
			fname = work_dir + '/' + str(i)+'_debug'+'.png'
			write_rf_fit_debug_fig(fname, fvecflat, vis, 'rgb', p, rec, err, model, 
				scale=2.8, s_scale=3.8, alpha=.5, dpi=300, draw_ellipsoid=True,
				no_title=True, ellipsoid_line_width=1.2)
Example #11
0
def convole_image_with_filter(
	args_file,
	input_image,
	odir = None,
	lum_channels = False,
	odir_lum_channels = False,
	odir_nosub = False,
	):
	from util import depickle_fits
	res = depickle_fits(args_file, suffix='kern')
	kernels = res['filters']

	import os
	if odir == None: 
		odir = os.path.dirname(args_file)
	if not odir_nosub:
		mod_odir = os.path.join(odir, '../')
		from ..base import make_working_dir_sub
		work_dir = make_working_dir_sub(mod_odir, 'conv')
	else:
		work_dir = odir

	if input_image == None:
		from scipy import misc
		img = misc.lena()*-1
	else:
		from ..base.images import read_image
		img = read_image(input_image, verbose=True)


	import numpy as N
	from scipy import signal

	def filter_rgb(image, kernel):
		r = signal.fftconvolve(image.T[0].T, kernel.T[0].T, mode='same')*.5 # full, same, valid
		r += signal.fftconvolve(image.T[0].T, kernel.T[0], mode='same')*.5
		g = signal.fftconvolve(image.T[1].T, kernel.T[1].T, mode='same')*.5
		g += signal.fftconvolve(image.T[1].T, kernel.T[1], mode='same')*.5
		b = signal.fftconvolve(image.T[2].T, kernel.T[2].T, mode='same')*.5
		b += signal.fftconvolve(image.T[2].T, kernel.T[2], mode='same')*.5

		# r = signal.convolve2d(image.T[0].T, kernel.T[0].T, boundary='symm', mode='same')
		# g = signal.convolve2d(image.T[1].T, kernel.T[1].T, boundary='symm', mode='same')
		# b = signal.convolve2d(image.T[2].T, kernel.T[2].T, boundary='symm', mode='same')
		return r, g, b

	def write_image(arr, name, suffix='.png', mode=None):
		from ..base.plots import colormap_for_mode
		if mode==None:
			cmap = colormap_for_mode(res['mode'])
		else:
			cmap = colormap_for_mode(mode)
		from ..base import make_filename
		imagefile = make_filename(args_file, name, suffix, work_dir)
		import matplotlib.image as pltimg
		pltimg.imsave(imagefile, arr, dpi=144, cmap=cmap)

	def wr_rgb(a,b,c,name):
		from ..base.plots import normalize_color
		print name, '   \t\t', N.max(N.dstack([a, b, c]))-N.min(N.dstack([a, b, c]))
		write_image(normalize_color(N.dstack([a, b, c])), name)

	def wr_lum(a,name,invert=False):
		from ..base.plots import normalize_color
		if invert:
			write_image(normalize_color(a)*-1, name, mode='lum')
		else:
			write_image(normalize_color(a), name, mode='lum')



	if len(img.shape) > 2:

		assert len(kernels) >= 5
		ch0_r, ch0_g, ch0_b = filter_rgb(img, kernels[0]-N.mean(kernels[0], axis=0))
		ch1_r, ch1_g, ch1_b = filter_rgb(img, kernels[1]-N.mean(kernels[1], axis=0))
		ch2_r, ch2_g, ch2_b = filter_rgb(img, kernels[2]-N.mean(kernels[2], axis=0))
		ch3_r, ch3_g, ch3_b = filter_rgb(img, kernels[3]-N.mean(kernels[3], axis=0))
		ch4_r, ch4_g, ch4_b = filter_rgb(img, kernels[4]-N.mean(kernels[4], axis=0))

		if len(kernels) == 5:
			wr_rgb(ch1_r, ch1_g, ch1_b, '0_red')
			wr_rgb(ch0_r, ch0_g, ch0_b, '1_green')
			wr_rgb(ch2_r, ch2_g, ch2_b, '2_blue')
			wr_rgb(ch3_r, ch3_g, ch3_b, '3_white')
			wr_rgb(ch4_r, ch4_g, ch4_b, '4_black')
	
			wr_rgb(ch1_r-ch0_r, ch1_g-ch0_g, ch1_b-ch0_b, '5_red_green')
			wr_rgb(ch0_r-ch1_r, ch0_g-ch1_g, ch0_b-ch1_b, '6_green_red')
			wr_rgb(ch2_r-(ch0_r+ch1_r)/2., ch2_g-(ch0_g+ch1_g)/2., ch2_b-(ch0_b+ch1_b)/2., '7_blue_yellow')

		if len(kernels) == 6:
			ch5_r, ch5_g, ch5_b = filter_rgb(img, kernels[5]-N.mean(kernels[5], axis=0))

			if not odir_lum_channels:
				if not lum_channels:
					wr_rgb(ch0_r, ch0_g, ch0_b, '0_red_on')
					wr_rgb(ch1_r, ch1_g, ch1_b, '1_green_on')
					wr_rgb(ch2_r, ch2_g, ch2_b, '2_blue_on')
					wr_rgb(ch4_r, ch4_g, ch4_b, '3_red_off')
					wr_rgb(ch3_r, ch3_g, ch3_b, '4_green_off')
					wr_rgb(ch5_r, ch5_g, ch5_b, '5_blue_off')
					wr_rgb(ch0_r+ch4_r, ch0_g+ch4_g, ch0_b+ch4_b, '6_red_on_green_off')
					wr_rgb(ch3_r-ch1_r, ch3_g-ch1_g, ch3_b-ch1_b, '6_green_on_red_off')

				else:
					wr_rgb(ch0_r, ch0_g, ch0_b, '0_luminosity_ON')
					wr_rgb(ch1_r, ch1_g, ch1_b, '1_luminosity_OFF')
					wr_rgb(ch2_r, ch2_g, ch2_b, '2_red_on')
					wr_rgb(ch4_r, ch4_g, ch4_b, '3_green_on')
					wr_rgb(ch3_r, ch3_g, ch3_b, '4_red_off')
					wr_rgb(ch5_r, ch5_g, ch5_b, '5_green_off')
					wr_rgb(ch2_r+ch5_r, ch2_g+ch5_g, ch2_b+ch5_b, '6_red_on_green_off')
					wr_rgb(ch4_r-ch3_r, ch4_g-ch3_g, ch4_b-ch3_b, '6_green_on_red_off')

			else:
				if not lum_channels:
					write_image((ch0_r+ch0_g+ch0_b)*-1, 'lum_0_red_on', mode='lum')
					write_image((ch1_r+ch1_g+ch1_b)*-1, 'lum_1_green_on', mode='lum')
					write_image((ch2_r+ch2_g+ch2_b)*-1, 'lum_2_blue_on', mode='lum')
					write_image((ch4_r+ch4_g+ch4_b)*-1, 'lum_3_red_off', mode='lum')
					write_image((ch3_r+ch3_g+ch3_b)*-1, 'lum_4_green_off', mode='lum')
					write_image((ch5_r+ch5_g+ch5_b)*-1, 'lum_5_blue_off', mode='lum')
					write_image((ch0_r+ch4_r+ch0_g+ch4_g+ch0_b+ch4_b)*-1, 'lum_6_red_on_green_off', mode='lum')
					write_image((ch3_r+ch1_r+ch3_g+ch1_g+ch3_b+ch1_b)*-1, 'lum_6_green_on_red_off', mode='lum')

				else:
					write_image((ch0_r+ch0_g+ch0_b)*-1, 'lum_0_luminosity_ON', mode='lum')
					write_image((ch1_r+ch1_g+ch1_b)*-1, 'lum_1_luminosity_OFF', mode='lum')
					write_image((ch2_r+ch2_g+ch2_b)*-1, 'lum_2_red_on', mode='lum')
					write_image((ch4_r+ch4_g+ch4_b)*-1, 'lum_3_green_on', mode='lum')
					write_image((ch3_r+ch3_g+ch3_b)*-1, 'lum_4_red_off', mode='lum')
					write_image((ch5_r+ch5_g+ch5_b)*-1, 'lum_5_green_off', mode='lum')
					write_image((ch2_r+ch5_r+ch2_g+ch5_g+ch2_b+ch5_b)*-1, 'lum_6_red_on_green_off', mode='lum')
					write_image((ch4_r+ch3_r+ch4_g+ch3_g+ch4_b+ch3_b)*-1, 'lum_6_green_on_red_off', mode='lum')				

	else:
		on_filterd = signal.convolve2d(img, kernels[0], boundary='symm', mode='same')
		off_filterd = signal.convolve2d(img, kernels[1], boundary='symm', mode='same')
		img_filterd = on_filterd - off_filterd

		write_image(img_filterd*-1, 'lum_convole')
		write_image(on_filterd, '0_lum_on')
		write_image(off_filterd*-1, '1_lum_off')


	# import os
	# working_dir = os.path.split(args_file)[0]
	# original = os.path.join(working_dir, 'org.png')
	# if not os.path.exists(original):
	write_image(img, 'org')