def plot_divisive_normalization_weber_law(data_flags,
                                          axes_to_plot=[0, 1],
                                          projected_variable_components=dict()
                                          ):

    # Define the plot indices
    assert len(data_flags) % 2 == 0, \
     "Need command line arguments to be normalization_idxs*2, " \
     "first half Weber law and second half non-Weber law."

    # Ready the plotting window; colormaps; colors; signals to plot
    cmaps = [cm.Reds, cm.Blues]
    shades = sp.linspace(0.3, 0.85, len(data_flags) / 2)

    # Plot success error figures
    for data_flag_idx, data_flag in enumerate(data_flags):

        Weber_idx = int(data_flag_idx / (len(data_flags) / 2))
        normalization_idx = int(data_flag_idx % (len(data_flags) / 2))

        # Decoding accuracy subfigures
        fig = divisive_normalization_subfigures()

        # Blue for non-adapted; red for adapted
        cmap = cmaps[Weber_idx]

        # Shade darkens for successive plots
        shade = shades[normalization_idx]

        list_dict = read_specs_file(data_flag)
        iter_vars = list_dict['iter_vars']
        Nn = list_dict['params']['Nn']
        iter_plot_var = iter_vars.keys()[axes_to_plot[0]]
        x_axis_var = iter_vars.keys()[axes_to_plot[1]]

        data = load_signal_decoding_weber_law(data_flag)
        successes = data['successes']

        nAxes = len(successes.shape)
        if nAxes > 2:
            successes = project_tensor(successes, iter_vars,
                                       projected_variable_components,
                                       axes_to_plot)

        # Switch axes if necessary
        if axes_to_plot[0] > axes_to_plot[1]:
            successes = successes.T

        # Plot successes, averaged over second axis of successes array
        avg_successes = sp.average(successes, axis=1) * 100.0
        plt.plot(iter_vars[iter_plot_var],
                 avg_successes,
                 color=cmap(shade),
                 zorder=normalization_idx,
                 lw=4.0)

        # Save same plot in both Weber Law and non-Weber Law folders
        for Weber_idx in range(2):
            data_flag = data_flags[data_flag_idx]
            save_decoding_accuracy_fig(fig, data_flag)
Ejemplo n.º 2
0
def plot_signal_decoding_vs_Kk(data_flag,
                               nonzero_bounds=[0.75, 1.25],
                               zero_bound=1. / 10.,
                               threshold_pct_nonzero=100.0,
                               threshold_pct_zero=100.0):

    list_dict = read_specs_file(data_flag)
    iter_vars = list_dict['iter_vars']
    Nn = list_dict['params']['Nn']

    # Decoding accuracy subfigures
    fig = signal_decoding_vs_Kk_subfigures()

    assert len(iter_vars) == 3, "Need 3 iter_vars"
    iter_plot_var = iter_vars.keys()[0]
    x_axis_var = iter_vars.keys()[1]
    Kk_axis_var = iter_vars.keys()[2]

    data = load_signal_decoding_weber_law(data_flag)
    successes = data['successes']

    # Plot successes, averaged over second axis of successes array
    avg_successes = sp.average(successes, axis=1) * 100.0
    plt.pcolormesh(avg_successes.T,
                   cmap=plt.cm.hot,
                   rasterized=True,
                   vmin=-5,
                   vmax=110)
    plt.colorbar()
    save_decoding_accuracy_vs_Kk_fig(fig, data_flag)
def plot_tuning_curves(data_flags):

    # First entries are for mu_dSs, second are for tuning_width
    #plot_vars = [[0, 9, 18], [7, 11, 19]]
    plot_vars = [[0, 1, 2], [0, 1, 2]]
    cmaps = [[cm.Greys, cm.Purples, cm.Blues],
             [cm.Greys, cm.Purples, cm.Blues]]

    for data_idx, data_flag in enumerate(data_flags):

        list_dict = read_specs_file(data_flag)
        for key in list_dict:
            exec("%s = list_dict[key]" % key)

        tuning_curve_data = load_tuning_curve(data_flag)
        tuning_curve = tuning_curve_data['tuning_curve']
        epsilons = tuning_curve_data['epsilons']

        if data_idx == 0:
            fig, plot_dims, axes_tuning, axes_eps, axes_signal = \
              tuning_curve_plot_epsilon(plot_vars, iter_vars, params)

        for idx, idx_var in enumerate(plot_vars[0]):
            for idy, idy_var in enumerate(plot_vars[1]):

                colors = cmaps[data_idx][idx](sp.linspace(
                    0.75, 0.3, params['Mm']))

                for iM in range(params['Mm']):
                    axes_tuning[idx,
                                idy].plot(sp.arange(params['Nn'] / 2),
                                          sp.sort(tuning_curve[idx_var,
                                                               idy_var, ::2,
                                                               iM]),
                                          color=colors[iM],
                                          linewidth=0.7,
                                          zorder=params['Mm'] - iM)
                    axes_tuning[idx,
                                idy].plot(sp.arange(params['Nn'] / 2 - 1,
                                                    params['Nn'] - 1),
                                          sp.sort(tuning_curve[idx_var,
                                                               idy_var, 1::2,
                                                               iM])[::-1],
                                          color=colors[iM],
                                          linewidth=0.7,
                                          zorder=params['Mm'] - iM)

                axes_eps[idy].plot(range(params['Mm']),
                                   epsilons[idx_var, idy_var],
                                   color=colors[4],
                                   linewidth=1.5,
                                   zorder=0)
                for iM in range(params['Mm']):
                    axes_eps[idy].scatter(iM,
                                          epsilons[idx_var, idy_var][iM],
                                          c=colors[iM],
                                          s=3)

    save_tuning_curve_fig(fig, data_flag)
Ejemplo n.º 4
0
def plot_activities(data_flags,
                    axes_to_plot=[0, 1],
                    projected_variable_components=dict()):

    # Define the plot indices
    assert len(data_flags) % 2 == 0, \
     "Need command line arguments to be normalization_idxs*2, " \
     "alternating Weber law and non-Weber law."

    # Ready the plotting window; colormaps; colors; signals to plot
    cmaps = [cm.Reds, cm.Blues]
    cmaps_r = [cm.Reds_r, cm.Blues_r]

    # Plot success error figures
    for data_flag_idx, data_flag in enumerate(data_flags):

        # Blue or red depends on this index
        Weber_idx = data_flag_idx % 2

        list_dict = read_specs_file(data_flag)
        for key in list_dict:
            exec("%s = list_dict[key]" % key)

        iter_plot_var = iter_vars.keys()[axes_to_plot[0]]
        x_axis_var = iter_vars.keys()[axes_to_plot[1]]

        data = load_signal_decoding_weber_law(data_flag)
        activities = data['Yys']

        for iter_var_idx in range(len(iter_vars[iter_plot_var])):
            if iter_var_idx == 0:
                bins = sp.linspace(-4, 1.05, 20000)
                activities_data = sp.zeros(
                    (len(iter_vars[iter_plot_var]), len(bins) - 1))

            # Histogram takes all activities for all stimuli choices, for
            #  this particular background mean. Plot absolute value of dYy
            hist, bin_edges = sp.histogram(
                 sp.log(abs(activities[iter_var_idx, :, :]))\
                 /sp.log(10), bins=200)

            # Interpolate to identical scale for all columns.
            interp_hist = sp.interp(bins, bin_edges[:-1], hist)[:-1]
            activities_data[iter_var_idx, :] = interp_hist

        normed_activities = activities_data.T / sp.amax(activities_data.T,
                                                        axis=0)
        fig = activities_subfigures()
        plt.pcolormesh(iter_vars[iter_plot_var],
                       10.**bins,
                       normed_activities,
                       cmap=cmaps_r[Weber_idx],
                       rasterized=True,
                       vmin=0,
                       vmax=1.05)

        save_activities_fig(fig, data_flag)
Ejemplo n.º 5
0
def est_VA(data_flag, init_seed):

    # Load specifications from file; pass to single_cell_FRET object
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict)
    scF = single_cell_FRET(**vars_to_pass)

    # If stim and meas were not imported, then data was saved as data_flag
    if scF.stim_file is None:
        scF.stim_file = data_flag
    if scF.meas_file is None:
        scF.meas_file = data_flag
    scF.set_stim()
    scF.set_meas_data()

    # Initalize estimation; set the estimation and prediction windows
    scF.init_seed = init_seed
    scF.set_init_est()
    scF.set_est_pred_windows()

    # Initalize annealer class
    annealer = va_ode.Annealer()
    annealer.set_model(scF.df_estimation, scF.nD)
    annealer.set_data(scF.meas_data[scF.est_wind_idxs, :],
                      stim=scF.stim[scF.est_wind_idxs],
                      t=scF.Tt[scF.est_wind_idxs])

    # Set Rm as inverse covariance; all parameters measured for now
    Rm = 1.0 / sp.asarray(scF.meas_noise)**2.0
    P_idxs = sp.arange(scF.nP)

    # Estimate
    BFGS_options = {
        'gtol': 1.0e-8,
        'ftol': 1.0e-8,
        'maxfun': 1000000,
        'maxiter': 1000000
    }
    tstart = time.time()
    annealer.anneal(scF.x_init[scF.est_wind_idxs],
                    scF.p_init,
                    scF.alpha,
                    scF.beta_array,
                    Rm,
                    scF.Rf0,
                    scF.L_idxs,
                    P_idxs,
                    dt_model=None,
                    init_to_data=True,
                    bounds=scF.bounds,
                    disc='trapezoid',
                    method='L-BFGS-B',
                    opt_args=BFGS_options,
                    adolcID=init_seed)
    print("\nADOL-C annealing completed in %f s." % (time.time() - tstart))

    save_estimates(scF, annealer, data_flag)
def plot_tuning_curves(data_flags):

    # First entries are for mu_dSs, second are for Kk2 diversity
    plot_vars = [[0, 9, 19], [8, 15, 19]]
    cmaps = [[cm.Greys, cm.Purples, cm.Blues],
             [cm.Greys, cm.Purples, cm.Blues]]

    for data_idx, data_flag in enumerate(data_flags):

        list_dict = read_specs_file(data_flag)
        for key in list_dict:
            exec("%s = list_dict[key]" % key)

        tuning_curve_data = load_tuning_curve(data_flag)
        tuning_curve = tuning_curve_data['tuning_curve']
        epsilons = tuning_curve_data['epsilons']
        Kk2s = tuning_curve_data['Kk2s']

        if data_idx == 0:
            fig, plot_dims, axes_tuning, axes_Kk2, axes_signal = \
              tuning_curve_plot_Kk2(plot_vars, iter_vars, params)

        for idx, idx_var in enumerate(plot_vars[0]):
            for idy, idy_var in enumerate(plot_vars[1]):

                colors = cmaps[data_idx][idx](sp.linspace(
                    0.75, 0.3, params['Mm']))

                for iM in range(params['Mm']):
                    axes_tuning[idx,
                                idy].plot(sp.arange(params['Nn'] / 2),
                                          sp.sort(tuning_curve[idx_var,
                                                               idy_var, ::2,
                                                               iM]),
                                          color=colors[iM],
                                          linewidth=0.7,
                                          zorder=params['Mm'] - iM)
                    axes_tuning[idx,
                                idy].plot(sp.arange(params['Nn'] / 2 - 1,
                                                    params['Nn'] - 1),
                                          sp.sort(tuning_curve[idx_var,
                                                               idy_var, 1::2,
                                                               iM])[::-1],
                                          color=colors[iM],
                                          linewidth=0.7,
                                          zorder=params['Mm'] - iM)

                if idx == 0:
                    sorted_idxs = sp.argsort(
                        sp.std(Kk2s[0, idy_var, :, :], axis=1))
                    axes_Kk2[idy].imshow(Kk2s[0, idy_var, sorted_idxs, :].T,
                                         aspect=0.3,
                                         cmap='bone',
                                         rasterized=True)

    save_tuning_curve_fig(fig, data_flag)
def calculate_signal_discrimination_weber_law(data_flag, 
										nonzero_bounds=[0.5, 1.5], 
										zero_bound=1./10., 
										threshold_pct_nonzero=75.0, 
										threshold_pct_zero=75.0):
	
	list_dict = read_specs_file(data_flag)
	for key in list_dict:
		exec("%s = list_dict[key]" % key)

	iter_vars_dims = []
	for iter_var in iter_vars:
		iter_vars_dims.append(len(iter_vars[iter_var]))		
	
	print ('Loading object list...'),
	CS_object_array = load_aggregated_object_list(iter_vars_dims, data_flag)
	print ('...loaded.')

	#Code not written for Kk_split = 0
	assert CS_object_array[0, 0].Kk_split != 0, "Need nonzero Kk_split"
	
	# Data structures 
	errors_zero = sp.zeros(iter_vars_dims)
	errors_nonzero_2 = sp.zeros(iter_vars_dims)
	errors_nonzero = sp.zeros(iter_vars_dims)
	successes_2 = sp.zeros(iter_vars_dims)
	successes = sp.zeros(iter_vars_dims)
	
	# Calculate binary errors
	it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])	
	while not it.finished:
		errors = binary_errors_dual_odor(CS_object_array[it.multi_index], 
								nonzero_bounds=nonzero_bounds,
								zero_bound=zero_bound)
		
		errors_nonzero[it.multi_index] = errors['errors_nonzero']
		errors_nonzero_2[it.multi_index] = errors['errors_nonzero_2']
		errors_zero[it.multi_index] = errors['errors_zero']
		it.iternext()
	
	# Calculate success ratios from binary errors
	it = sp.nditer(sp.zeros(iter_vars_dims), flags = ['multi_index'])
	while not it.finished:
		successes_2[it.multi_index] = binary_success(
					errors_nonzero_2[it.multi_index], 
					errors_zero[it.multi_index], 
					threshold_pct_nonzero=threshold_pct_nonzero,
					threshold_pct_zero=threshold_pct_zero)
		successes[it.multi_index] = binary_success(
					errors_nonzero[it.multi_index], 
					errors_zero[it.multi_index], 
					threshold_pct_nonzero=threshold_pct_nonzero,
					threshold_pct_zero=threshold_pct_zero)
		it.iternext()
		
	save_signal_discrimination_weber_law(successes, successes_2, data_flag)
Ejemplo n.º 8
0
def calculate_signal_decoding_weber_law(data_flag,
                                        nonzero_bounds=[0.5, 1.5],
                                        zero_bound=1. / 10.,
                                        threshold_pct_nonzero=85.0,
                                        threshold_pct_zero=85.0):

    list_dict = read_specs_file(data_flag)
    for key in list_dict:
        exec("%s = list_dict[key]" % key)

    iter_vars_dims = []
    for iter_var in iter_vars:
        iter_vars_dims.append(len(iter_vars[iter_var]))

    print('Loading object list...'),
    CS_object_array = load_aggregated_object_list(iter_vars_dims, data_flag)
    print('...loaded.')

    # Data structures
    errors_nonzero = sp.zeros(iter_vars_dims)
    errors_zero = sp.zeros(iter_vars_dims)
    epsilons = sp.zeros((iter_vars_dims[0], iter_vars_dims[1], params['Mm']))
    gains = sp.zeros(
        (iter_vars_dims[0], iter_vars_dims[1], params['Mm'], params['Nn']))
    successes = sp.zeros(iter_vars_dims)

    # Calculate binary errors
    it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])
    while not it.finished:
        errors = binary_errors(CS_object_array[it.multi_index],
                               nonzero_bounds=nonzero_bounds,
                               zero_bound=zero_bound)

        errors_nonzero[it.multi_index] = errors['errors_nonzero']
        errors_zero[it.multi_index] = errors['errors_zero']
        it.iternext()

    # Calculate success ratios from binary errors
    it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])
    while not it.finished:
        successes[it.multi_index] = binary_success(
            errors_nonzero[it.multi_index],
            errors_zero[it.multi_index],
            threshold_pct_nonzero=threshold_pct_nonzero,
            threshold_pct_zero=threshold_pct_zero)
        epsilons[it.multi_index] = CS_object_array[it.multi_index].eps
        gains[it.multi_index] = CS_object_array[it.multi_index].Rr
        it.iternext()

    save_signal_decoding_weber_law(successes, gains, epsilons, data_flag)
def nn_run(data_flag, iter_var_idxs):
    """
	Run a supervised learning classification for a single specs file.

	Data is read from a specifications file in the data_dir/specs/ 
	folder, with proper formatting given in read_specs_file.py. The
	specs file indicates the full range of the iterated variable; this
	script only produces output from one of those indices, so multiple
	runs can be performed in parallel.
	"""

    # Aggregate all run specifications from the specs file; instantiate model
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict, iter_var_idxs)
    obj = nn(**vars_to_pass)

    # Need this to save tensor flow objects on iterations
    obj.data_flag = data_flag

    # Set the signals and free energy, depending if adaptive or not.
    if 'run_type' in list(list_dict['run_specs'].keys()):
        val = list_dict['run_specs']['run_type']
        if val[0] == 'nn':
            obj.init_nn_frontend()
        elif val[0] == 'nn_adapted':
            obj.init_nn_frontend_adapted()
        else:
            print('`%s` run type not accepted for '
                  'supervised learning calculation' % val[0])
            quit()
    else:
        print ('No learning calculation run type specified, proceeding with' \
          'unadapted learning calculation')
        obj.init_nn_frontend()

    # Set the network variables, learning algorithm
    obj.set_AL_MB_connectome()
    obj.set_ORN_response_array()
    obj.set_PN_response_array()
    obj.init_tf()

    # Train and test performance
    obj.set_tf_class_labels()
    obj.train_and_test_tf()

    # Delete tensorflow variables to allow saving
    obj.del_tf_vars()
    dump_objects(obj, iter_var_idxs, data_flag)

    return obj
Ejemplo n.º 10
0
def calculate_tuning_curves(data_flag):

    list_dict = read_specs_file(data_flag)
    for key in list_dict:
        exec("%s = list_dict[key]" % key)

    # Get the iterated variable dimensions
    iter_vars_dims = []
    for iter_var in iter_vars:
        iter_vars_dims.append(len(iter_vars[iter_var]))
    it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])

    # Set up array to hold tuning curve curves
    tuning_curve = sp.zeros(
        (iter_vars_dims[0], iter_vars_dims[1], params['Nn'], params['Mm']))

    # Set array to hold epsilons and Kk2
    epsilons = sp.zeros((iter_vars_dims[0], iter_vars_dims[1], params['Mm']))
    Kk2s = sp.zeros(
        (iter_vars_dims[0], iter_vars_dims[1], params['Mm'], params['Nn']))

    # Iterate tuning curve calculation over all iterable variables
    while not it.finished:
        iter_var_idxs = it.multi_index

        vars_to_pass = dict()
        vars_to_pass = parse_iterated_vars(iter_vars, iter_var_idxs,
                                           vars_to_pass)
        vars_to_pass = parse_relative_vars(rel_vars, iter_vars, vars_to_pass)
        vars_to_pass = merge_two_dicts(vars_to_pass, fixed_vars)
        vars_to_pass = merge_two_dicts(vars_to_pass, params)

        # Calculate tuning curve
        for iN in range(vars_to_pass['Nn']):
            vars_to_pass['manual_dSs_idxs'] = sp.array([iN])
            obj = single_encode_CS(vars_to_pass, run_specs)
            tuning_curve[iter_var_idxs[0], iter_var_idxs[1], iN, :] = obj.dYy

        epsilons[it.multi_index] = obj.eps
        Kk2s[it.multi_index] = obj.Kk2

        it.iternext()

    save_tuning_curve(tuning_curve, epsilons, Kk2s, data_flag)
Ejemplo n.º 11
0
def gen_twin_data(data_flag):

    # Load specifications from file; to be passed to single_cell_FRET object
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict)
    scF = single_cell_FRET(**vars_to_pass)

    assert scF.meas_file is None, "For generating twin data manually, cannot "\
     "import a measurement file; remove meas_file var in specs file"

    scF.set_stim()
    scF.gen_true_states()
    scF.set_meas_data()

    # Save the newly generated data; don't save stimulus if imported
    if scF.stim_file is None:
        save_stim(scF, data_flag)
    save_true_states(scF, data_flag)
    save_meas_data(scF, data_flag)
def CS_run(data_flag, iter_var_idxs):
    """
	Run a CS decoding run for one given index of a set of iterated
	variables. 

	Data is read from a specifications file in the data_dir/specs/ 
	folder, with proper formatting given in read_specs_file.py. The
	specs file indicates the full range of the iterated variable; this
	script only produces output from one of those indices, so multiple
	runs can be performed in parallel.
	"""

    # Aggregate all run specifications from the specs file; instantiate model
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict, iter_var_idxs)
    obj = four_state_receptor_CS(**vars_to_pass)

    # Encode and decode
    obj = single_encode_CS(obj, list_dict['run_specs'])
    obj.decode()

    dump_objects(obj, iter_var_idxs, data_flag)
def aggregate_objects(data_flags, skip_missing=False):
    """
	Aggregate CS objects from separate .pklz files to a single .pklz file.
	
	Args:
		data_flags: Identifiers for saving and loading.
	"""

    if skip_missing == True:
        print("Skipping missing files...will populate with `None`")

    if isinstance(data_flags, str):
        data_flags = [data_flags]

    for data_flag in data_flags:
        list_dict = read_specs_file(data_flag)
        iter_vars = list_dict['iter_vars']
        iter_vars_dims = []
        for iter_var in iter_vars:
            iter_vars_dims.append(len(iter_vars[iter_var]))
        it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])

        obj_list = []
        while not it.finished:
            sys.stdout.flush()
            print(it.multi_index)
            if skip_missing == False:
                CS_obj = load_objects(list(it.multi_index), data_flag)
            else:
                try:
                    CS_obj = load_objects(list(it.multi_index), data_flag)
                except (IOError, OSError):
                    print('Skipping item %s...' % list(it.multi_index))
                    CS_obj = None

            obj_list.append(CS_obj)
            it.iternext()

        save_aggregated_object_list(obj_list, data_flag)
def plot_signal_decoding_weber_law(data_flags, axes_to_plot=[0, 1], 
				projected_variable_components=dict()):
	
	data_idxs = len(data_flags)
	assert data_idxs % 2 == 0, \
		"Need even number of command line arguments, alternating" \
		"Weber law and non-Weber law"
	
	# Ready the plotting window
	fig, ax = single_decoding_weber_law_plot(data_idxs=data_idxs)
	cmaps = [cm.Reds, cm.Blues]
	cmaps_r = [cm.Reds_r, cm.Blues_r]
	color_cycle_errors = sp.linspace(0.25, 0.7, data_idxs/2)
	linewidths_errors = sp.linspace(6.0, 3.0, data_idxs/2)
	
	# Plot for each command line argument
	for data_idx, data_flag in enumerate(data_flags):
	
		data_flag = str(data_flag)
		list_dict = read_specs_file(data_flag)
		for key in list_dict:
			exec("%s = list_dict[key]" % key)
			
		iter_plot_var = iter_vars.keys()[axes_to_plot[0]]
		x_axis_var = iter_vars.keys()[axes_to_plot[1]]
		Nn = list_dict['params']['Nn']
		
		data = load_signal_decoding_weber_law(data_flag)
		successes = data['successes']
		epsilons = data['epsilons']
		gains = data['gains']
		
		nAxes = len(successes.shape)
		if nAxes > 2:
			successes = project_tensor(successes, 
									iter_vars, projected_variable_components,
									axes_to_plot)
			
		# Switch axes if necessary
		if axes_to_plot[0] > axes_to_plot[1]:    
			successes = successes.T
		
		# Plot average errors
		average_successes = sp.average(successes, axis=1)*100.0
		cmap = cmaps[data_idx % 2]
		color = cmap(color_cycle_errors[data_idx / 2])
		linewidth = linewidths_errors[data_idx / 2]
		ax['successes'].plot(iter_vars[iter_plot_var], average_successes, 
					color=color, linewidth=linewidth)

		# Plot gains, averaged over odorant, and binned over different signals
		#   and over receptors
		if data_idx % 2 == 0:
			for iter_var_idx in range(len(iter_vars[iter_plot_var])):
				if iter_var_idx == 0:
					min = -8.0
					max = -1.0
					bins = sp.linspace(min, max, 1000)
					gains_data = sp.zeros((len(iter_vars[iter_plot_var]), 
											len(bins) - 1))
				odorant_averaged_gains = sp.average(gains[iter_var_idx], 
													axis=-1)
				hist, bin_edges = sp.histogram(
									sp.log(odorant_averaged_gains.flatten()) \
									/sp.log(10), bins=bins)
				gains_data[iter_var_idx, :] = hist 
			ax['gains_%s' % data_idx].pcolormesh(iter_vars[iter_plot_var],
										10.**bins, gains_data.T, cmap=cmaps_r[0],
										rasterized=True) 
		elif data_idx % 2 == 1:
			for iter_var_idx in range(len(iter_vars[iter_plot_var])):
				if iter_var_idx == 0:
					min = -2.0 
					max = 1.0
					bins = sp.linspace(min, max, 200)
					gains_data = sp.zeros((len(iter_vars[iter_plot_var]), 
											len(bins) - 1))
				odorant_averaged_gains = sp.average(gains[iter_var_idx], 
													axis=-1)
				hist, bin_edges = sp.histogram(
									sp.log(odorant_averaged_gains.flatten()) \
									/sp.log(10), bins=bins)
				gains_data[iter_var_idx, :] = hist 
			ax['gains_%s' % data_idx].pcolormesh(iter_vars[iter_plot_var],
										10.**bins, gains_data.T, cmap=cmaps_r[1], 
										rasterized=True)
			
	save_signal_decoding_weber_law_fig(fig)	
	
	# Plot estimated signals
	samples_to_plot = [0, 1]
	bkgrnds_to_plot = [40, 76]
	stimuli_to_plot = [1, 2]
	for data_idx, data_flag_idx in enumerate(samples_to_plot):
		data_flag = data_flags[data_flag_idx]
		iter_vars_dims = []
		for iter_var in iter_vars:
			iter_vars_dims.append(len(iter_vars[iter_var]))		

		print ('Loading object list...'),
		CS_object_array = load_aggregated_object_list(iter_vars_dims, data_flag)
		print ('...loaded.')
		
		cmap = cmaps[data_idx % 2]
		color = cmap(color_cycle_errors[-1])
		
		for bkgrnd_idx, bkgrnd_val in enumerate(bkgrnds_to_plot):
			ax['est_%s' % bkgrnd_idx].bar(
				sp.arange(Nn), CS_object_array[bkgrnd_val, 
				stimuli_to_plot[bkgrnd_idx]].dSs, edgecolor='black', 
				zorder=100, width=1.0, lw=1.0, fill=False)
			ax['est_%s' % bkgrnd_idx].bar(
				sp.arange(Nn), CS_object_array[bkgrnd_val, 
				stimuli_to_plot[bkgrnd_idx]].dSs_est, color=color, 
				zorder=2+data_idx, width=1.0)
		
		save_signal_decoding_weber_law_fig(fig)	
Ejemplo n.º 15
0
def temporal_CS_run(data_flag,
                    iter_var_idxs,
                    mu_dSs_offset=0,
                    mu_dSs_multiplier=1. / 3.,
                    sigma_dSs_offset=0,
                    sigma_dSs_multiplier=1. / 9.,
                    signal_window=None,
                    save_data=True,
                    decode=True):
    """
	Run a CS decoding run for a full temporal signal trace.

	Data is read from a specifications file in the data_dir/specs/ 
	folder, with proper formatting given in read_specs_file.py. The
	specs file indicates the full range of the iterated variable; this
	script only produces output from one of those indices, so multiple
	runs can be performed in parallel.
	"""

    assert mu_dSs_offset >= 0, "mu_dSs_offset kwarg must be >= 0"
    assert sigma_dSs_offset >= 0, "sigma_dSs_offset kwarg must be >= 0"

    # Aggregate all run specifications from the specs file; instantiate model
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict, iter_var_idxs)
    obj = four_state_receptor_CS(**vars_to_pass)

    # Set the temporal signal array from file; truncate to signal window
    obj.set_signal_trace()
    assert sp.sum(obj.signal_trace <= 0) == 0, \
     "Signal contains negative values; increase signal_trace_offset"
    if signal_window is not None:
        obj.signal_trace_Tt = obj.signal_trace_Tt[signal_window[0]: \
                   signal_window[1]]
        obj.signal_trace = obj.signal_trace[signal_window[0]:signal_window[1]]

    # Load dual odor dSs from file (this is considered non-adapted fluctuation
    # and should have a shorter timescale than the first odor). Can also use
    # Kk_1 and Kk_2 for separate complexities of odor 1 and 2, respectively.
    if (obj.Kk_1 is not None) and (obj.Kk_2 is not None):
        obj.Kk = obj.Kk_1 + obj.Kk_2
        obj.Kk_split = obj.Kk_2

    if (obj.Kk_split is not None) and (obj.Kk_split != 0):
        try:
            obj.signal_trace_2
        except AttributeError:
            print('Need to assign signal_trace_2 if setting Kk_split or ' \
              'Kk_1 and Kk_2 nonzero')
            quit()
        assert sp.sum(obj.signal_trace_2 <= 0) == 0, \
          "Signal_2 contains neg values; increase signal_trace_offset_2"
        if signal_window is not None:
            obj.signal_trace_2 = obj.signal_trace_2[signal_window[0]: \
                      signal_window[1]]

    obj_list = []
    for iT, dt in enumerate(obj.signal_trace_Tt):
        print('%s/%s' % (iT + 1, len(obj.signal_trace)), end=' ')
        sys.stdout.flush()

        # Set mu_Ss0 from signal trace, if desired
        if obj.set_mu_Ss0_temporal_signal == True:
            obj.mu_Ss0 = obj.signal_trace[iT]

        # Set estimation dSs values from signal trace and kwargs
        signal = obj.signal_trace[iT]
        obj.mu_dSs = mu_dSs_offset + signal * mu_dSs_multiplier
        obj.sigma_dSs = sigma_dSs_offset + signal * sigma_dSs_multiplier

        # Set estimation dSs values for dual odor if needed
        if (obj.Kk_split is not None) and (obj.Kk_split != 0):
            signal_2 = obj.signal_trace_2[iT]
            obj.mu_dSs_2 = mu_dSs_offset + signal_2 * mu_dSs_multiplier
            obj.sigma_dSs_2 = sigma_dSs_offset + signal_2 * sigma_dSs_multiplier

        # Encode / decode fully first time; then just update eps and responses
        if iT == 0:
            obj = single_encode_CS(obj, list_dict['run_specs'])

            # Spread adaptation rates over the system
            if obj.temporal_adaptation_rate_sigma != 0:
                obj.set_ordered_temporal_adaptation_rate()
        else:
            obj.set_sparse_signals()
            obj.set_temporal_adapted_epsilon()
            obj.set_measured_activity()
            obj.set_linearized_response()

        # Estimate signal at point iT
        if decode == True:
            obj.decode()

        # Deep copy to take all aspects of the object but not update it
        obj_list.append(copy.deepcopy(obj))

    if save_data == True:
        dump_objects(obj_list, iter_var_idxs, data_flag)

    return obj_list
def calculate_signal_decoding_weber_law(data_flags,
                                        nonzero_bounds=[0.75, 1.25],
                                        zero_bound=1. / 10.,
                                        threshold_pct_nonzero=100.0,
                                        threshold_pct_zero=100.0):

    for data_flag in data_flags:

        list_dict = read_specs_file(data_flag)
        for key in list_dict:
            exec("%s = list_dict[key]" % key)

        iter_vars_dims = []
        for iter_var in iter_vars:
            iter_vars_dims.append(len(iter_vars[iter_var]))

        print('Loading object list...'),
        CS_object_array = load_aggregated_object_list(iter_vars_dims,
                                                      data_flag)
        print('...loaded.')

        # Data structures
        data = dict()
        errors_nonzero = sp.zeros(iter_vars_dims)
        errors_zero = sp.zeros(iter_vars_dims)

        Mm_shape = iter_vars_dims + [params['Mm']]
        Mm_Nn_shape = iter_vars_dims + [params['Mm'], params['Nn']]
        data['epsilons'] = sp.zeros(Mm_shape)
        data['dYys'] = sp.zeros(Mm_shape)
        data['Yys'] = sp.zeros(Mm_shape)
        data['gains'] = sp.zeros(Mm_Nn_shape)
        data['Kk2s'] = sp.zeros(Mm_Nn_shape)
        data['successes'] = sp.zeros(iter_vars_dims)

        # Calculate binary errors
        it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])
        while not it.finished:
            errors = binary_errors(CS_object_array[it.multi_index],
                                   nonzero_bounds=nonzero_bounds,
                                   zero_bound=zero_bound)

            errors_nonzero[it.multi_index] = errors['errors_nonzero']
            errors_zero[it.multi_index] = errors['errors_zero']
            it.iternext()

        # Calculate success ratios from binary errors
        it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])
        while not it.finished:
            data['successes'][it.multi_index] = binary_success(
                errors_nonzero[it.multi_index],
                errors_zero[it.multi_index],
                threshold_pct_nonzero=threshold_pct_nonzero,
                threshold_pct_zero=threshold_pct_zero)
            data['epsilons'][it.multi_index] = CS_object_array[
                it.multi_index].eps
            data['gains'][it.multi_index] = CS_object_array[it.multi_index].Rr
            data['dYys'][it.multi_index] = CS_object_array[it.multi_index].dYy
            data['Yys'][it.multi_index] = CS_object_array[it.multi_index].Yy
            data['Kk2s'][it.multi_index] = CS_object_array[it.multi_index].Kk2
            it.iternext()

        save_signal_decoding_weber_law(data, data_flag)
def temporal_entropy_run(data_flag, iter_var_idxs, 
					mu_dSs_offset=0, mu_dSs_multiplier=1./3., 
					sigma_dSs_offset=0, sigma_dSs_multiplier=1./9., 
					signal_window=None, save_data=True):
	
	assert mu_dSs_offset >= 0, "mu_dSs_offset kwarg must be >= 0"
	assert sigma_dSs_offset >= 0, "sigma_dSs_offset kwarg must be >= 0"
	
	# Aggregate all run specifications from the specs file; instantiate model
	list_dict = read_specs_file(data_flag)
	if 'run_type' in list_dict['run_specs'].keys():
		print ('!!\n\nrun_spec %s passed in specs file. run_specs are not '
				'accepted for temporal entropy calculations at this time. '
				'Ignoring...\n\n!!\n' % list_dict['run_specs']['run_type'])
	vars_to_pass = compile_all_run_vars(list_dict, iter_var_idxs)
	obj = response_entropy(**vars_to_pass)
	obj.encode_power_Kk()	
	
	# Set the temporal signal array from file; truncate to signal window
	obj.set_signal_trace()
	
	assert sp.sum(obj.signal_trace <= 0) == 0, \
		"Signal contains negative values; increase signal_trace_offset"
	if signal_window is not None:
		obj.signal_trace_Tt = obj.signal_trace_Tt[signal_window[0]: \
													signal_window[1]]
		obj.signal_trace = obj.signal_trace[signal_window[0]: signal_window[1]]
	
	# Load dual odor dSs from file (this is considered non-adapted fluctuation
	# and should have a shorter timescale than the first odor). Can also use
	# Kk_1 and Kk_2 for separate complexities of odor 1 and 2, respectively.
	if (obj.Kk_1 is not None) and (obj.Kk_2 is not None):
		obj.Kk = obj.Kk_1 + obj.Kk_2
		obj.Kk_split = obj.Kk_2
	
	if (obj.Kk_split is not None) and (obj.Kk_split != 0):
		try: 
			obj.signal_trace_2
		except AttributeError:
			print('Need to assign signal_trace_2 if setting Kk_split or ' \
					'Kk_1 and Kk_2 nonzero') 
			quit()
		assert sp.sum(obj.signal_trace_2 <= 0) == 0, \
				"Signal_2 contains neg values; increase signal_trace_offset_2"
		if signal_window is not None:
			obj.signal_trace_2 = obj.signal_trace_2[signal_window[0]: \
													signal_window[1]]
	
	obj_list = []
	
	for iT, dt in enumerate(obj.signal_trace_Tt):
		print('%s/%s' % (iT + 1, len(obj.signal_trace)), end=' ')
		sys.stdout.flush()
		
		# Set mu_Ss0 from signal trace, if desired
		if obj.set_mu_Ss0_temporal_signal == True:
			obj.mu_Ss0 = obj.signal_trace[iT]
		
		# Set estimation dSs values from signal trace and kwargs
		signal = obj.signal_trace[iT]
		obj.mu_dSs = mu_dSs_offset + signal*mu_dSs_multiplier
		obj.sigma_dSs = sigma_dSs_offset + signal*sigma_dSs_multiplier
		
		# Set estimation dSs values for dual odor if needed
		if (obj.Kk_split is not None) and (obj.Kk_split != 0):
			signal_2 = obj.signal_trace_2[iT]
			obj.mu_dSs_2 = mu_dSs_offset + signal_2*mu_dSs_multiplier
			obj.sigma_dSs_2 = sigma_dSs_offset + signal_2*sigma_dSs_multiplier
		
		# Set the full signal array from the above signal parameters
		obj.set_ordered_dual_signal_array()
			
		# At first step, set energy; from then on it is dynamical.
		
		if iT == 0:
			obj.set_normal_free_energy()
			
			# Spread adaptation rates over the system
			if obj.temporal_adaptation_rate_sigma != 0:
				obj.set_ordered_temporal_adaptation_rate()
		else:
			obj.set_temporal_adapted_epsilon()
		
		# Calculate MI
		obj.set_mean_response_array()
		obj.set_ordered_dual_response_pdf()
		obj.calc_MI_fore_only()
		
		print (sp.mean(obj.entropy))
		
		# Deep copy to take all aspects of the object but not update it
		obj_list.append(copy.deepcopy(obj))
	
	if save_data == True:
		dump_objects(obj_list, iter_var_idxs, data_flag)
	
	return obj_list
Ejemplo n.º 18
0
def plot_signal_estimation_weber_law(data_flags, axes_to_plot=[0, 1], 
				projected_variable_components=dict()):
	
	# Define the plot indices
	assert len(data_flags) == 2, \
		"Need command line arguments to be two, alternating " \
		"non-Weber law and Weber law."
			
	cmaps = [cm.Blues, cm.Reds]
	alphas = [1.0, 0.8]
	mu_dSs_to_plot = sp.arange(0, 100, 5)
	seed_Kk2_to_plot = sp.arange(0, 100, 5)
	true_signal_lw = 1.0
	true_signal_color = 'black'
	CS_object_array = []
	
	# Load both object arrays only once
	for Weber_idx, data_flag in enumerate(data_flags):
	
		list_dict = read_specs_file(data_flag)
		iter_vars = list_dict['iter_vars']
		Nn = list_dict['params']['Nn']
		Kk = list_dict['params']['Kk']
		
		data = load_signal_decoding_weber_law(data_flag)
		
		# Load CS objects for single stimuli plotting
		iter_vars_dims = []
		for iter_var in iter_vars:
			iter_vars_dims.append(len(iter_vars[iter_var]))		
		print ('Loading object list for single stimulus plot...'),
		CS_object_array.append(load_aggregated_object_list(iter_vars_dims, 
								data_flag))
		print ('...loaded.')
	
	# Nonzero components
	for dSs_idx in mu_dSs_to_plot:
		for Kk2_idx in seed_Kk2_to_plot:
			
			# Ready the plotting window; colormaps; colors; signals to plot
			fig = signal_estimation_subfigures(nonzero=True)
			for Weber_idx, data_flag in enumerate(data_flags):

				# Blue for non-adapted; red for adapted
				color = cmaps[Weber_idx](0.6)

				# Plot the bar graphs for true signal and estimates, top and 
				#  bottom; order by signal strength
				true_signal = CS_object_array[Weber_idx][dSs_idx, Kk2_idx].dSs
				est_signal = CS_object_array[Weber_idx][dSs_idx, Kk2_idx].dSs_est
				sorted_idxs = sp.argsort(true_signal)[::-1]
				plt.bar(sp.arange(Kk), true_signal[sorted_idxs][:Kk]*(-1)**Weber_idx, 
					lw=true_signal_lw, edgecolor=true_signal_color,
					zorder=100, width=1.0, fill=False)
				plt.bar(sp.arange(Kk), est_signal[sorted_idxs][:Kk]*(-1)**Weber_idx, 
					color=color, zorder=2, width=1.0)
				
			for data_flag in data_flags:
				save_signal_estimation_nonzeros_fig(fig, data_flag, dSs_idx, Kk2_idx)
			plt.close()
		
	# Zero components
	for dSs_idx in mu_dSs_to_plot:
		for Kk2_idx in seed_Kk2_to_plot:
			
			# Ready the plotting window; colormaps; colors; signals to plot
			fig = signal_estimation_subfigures(nonzero=False)
			for Weber_idx, data_flag in enumerate(data_flags):
			
				# Blue for non-adapted; red for adapted
				color = cmaps[Weber_idx](0.6)
				alpha = alphas[Weber_idx]
				
				# Ordering here is actually not necessary
				true_signal = CS_object_array[Weber_idx][dSs_idx, Kk2_idx].dSs
				est_signal = CS_object_array[Weber_idx][dSs_idx, Kk2_idx].dSs_est
				sorted_idxs = sp.argsort(true_signal)[::-1]
				est_signal_nonzeros = est_signal[sorted_idxs][Kk:]
				est_signal_sorted_idxs = sp.argsort(est_signal_nonzeros)
				plt.fill_between(sp.arange(Nn - Kk), 0, est_signal_nonzeros\
					[est_signal_sorted_idxs], color=color, alpha=alpha)
			
			for data_flag in data_flags:
				save_signal_estimation_zeros_fig(fig, data_flag, dSs_idx, Kk2_idx)
			plt.close()
Ejemplo n.º 19
0
def plot_signal_discrimination_weber_law(data_flags,
                                         axes_to_plot=[0, 1],
                                         projected_variable_components=dict()):

    # Function to plot signal and inset; odor 2 is overlaid in darker color.
    def signal_plot(ax):

        ax.bar(sp.arange(Nn),
               CS_object_array[mu_dSs_to_plot,
                               seed_Kk2_to_plot[Kk_split_idx]].dSs *
               (-1)**Weber_idx,
               lw=true_signal_lw,
               edgecolor=true_signal_color,
               zorder=100,
               width=1.0,
               fill=False)
        ax.bar(sp.arange(Nn),
               CS_object_array[mu_dSs_to_plot,
                               seed_Kk2_to_plot[Kk_split_idx]].dSs_est *
               (-1)**Weber_idx,
               color=cmap(dual_odor_color_shades[0]),
               zorder=2,
               width=1.0)

        # Generate just odor 2 signal and plot over first plot
        signal_2 = sp.zeros(Nn)
        idxs_2 = CS_object_array[mu_dSs_to_plot, \
          seed_Kk2_to_plot[Kk_split_idx]].idxs_2
        for idx_2 in idxs_2:
            signal_2[idx_2] = CS_object_array[
                mu_dSs_to_plot, seed_Kk2_to_plot[Kk_split_idx]].dSs_est[idx_2]
        ax.bar(sp.arange(Nn),
               signal_2 * (-1)**Weber_idx,
               color=cmap(dual_odor_color_shades[1]),
               zorder=3,
               width=1.0)

        return ax

    # Define the plot indices
    Kk_split_idxs = len(data_flags) / 2
    Weber_idxs = 2
    assert len(data_flags) % 2 == 0, \
     "Need command line arguments to be Kk_split*2, alternating " \
     "Weber law and non-Weber law."

    # Ready the plotting window; colormaps; colors; signals to plot
    fig, ax = signal_discrimination_weber_law_plot(Kk_split_idxs=Kk_split_idxs)
    cmaps = [cm.Reds, cm.Blues]
    cmaps_r = [cm.Reds_r, cm.Blues_r]
    dual_odor_color_shades = [0.7, 0.3]
    success_plots_linewidths = [4.0, 6.0]
    true_signal_color = 'black'
    true_signal_lw = 1.0
    mu_dSs_to_plot = 27
    seed_Kk2_to_plot = [9, 64, 25]

    # Plot
    for data_idx, data_flag in enumerate(data_flags):

        Weber_idx = data_idx % Weber_idxs
        Kk_split_idx = data_idx / Weber_idxs
        cmap = cmaps[Weber_idx]

        list_dict = read_specs_file(data_flag)
        iter_vars = list_dict['iter_vars']
        Nn = list_dict['params']['Nn']
        iter_plot_var = iter_vars.keys()[axes_to_plot[0]]
        x_axis_var = iter_vars.keys()[axes_to_plot[1]]

        data = load_signal_discrimination_weber_law(data_flag)
        successes = data['successes']
        successes_2 = data['successes_2']

        nAxes = len(successes.shape)
        if nAxes > 2:
            successes = project_tensor(successes, iter_vars,
                                       projected_variable_components,
                                       axes_to_plot)
            successes_2 = project_tensor(successes_2, iter_vars,
                                         projected_variable_components,
                                         axes_to_plot)

        # Switch axes if necessary
        if axes_to_plot[0] > axes_to_plot[1]:
            successes = successes.T
            successes_2 = successes_2.T

        # Plot successes, averaged over second axis of successes array
        avg_successes = sp.average(successes, axis=1) * 100.0
        avg_successes_2 = sp.average(successes_2, axis=1) * 100.0
        ax['successes_%s' % Kk_split_idx].plot(iter_vars[iter_plot_var],
                                               avg_successes,
                                               color=cmap(
                                                   dual_odor_color_shades[0]),
                                               zorder=2,
                                               lw=success_plots_linewidths[0])
        ax['successes_%s' % Kk_split_idx].plot(iter_vars[iter_plot_var],
                                               avg_successes_2,
                                               color=cmap(
                                                   dual_odor_color_shades[1]),
                                               zorder=1,
                                               lw=success_plots_linewidths[1])

        # Load CS objects for single stimuli plotting
        iter_vars_dims = []
        for iter_var in iter_vars:
            iter_vars_dims.append(len(iter_vars[iter_var]))
        print('Loading object list for single stimulus plot...'),
        CS_object_array = load_aggregated_object_list(iter_vars_dims,
                                                      data_flag)
        print('...loaded.')

        # Plot signal and inset
        ax['signal_%s' % Kk_split_idx] = \
         signal_plot(ax['signal_%s' % Kk_split_idx])
        ax['signal_insert_%s' % Kk_split_idx] = \
         signal_plot(ax['signal_insert_%s' % Kk_split_idx])
        if Weber_idx == 1:
            mark_inset(ax['signal_%s' % Kk_split_idx],
                       ax['signal_insert_%s' % Kk_split_idx],
                       loc1=3,
                       loc2=4,
                       fc="none",
                       ec="0.5")

        save_signal_discrimination_weber_law_fig(fig)
def plot_signal_decoding_weber_law(data_flags,
                                   axes_to_plot=[0, 1],
                                   projected_variable_components=dict()):

    # Define the plot indices
    diversity_idxs = len(data_flags) / 2
    assert len(data_flags) % 2 == 0, \
     "Need command line arguments to be diversity_idxs*2, alternating " \
     "Weber law and non-Weber law."

    # Ready the plotting window; colormaps; colors; signals to plot
    cmaps = [cm.Reds, cm.Blues]
    shades = sp.linspace(0.7, 0.3, diversity_idxs)
    success_plot_lws = sp.linspace(4.0, 3.0, diversity_idxs)

    # Decoding accuracy subfigures
    fig = decoding_accuracy_subfigures()

    # Plot success error figures
    for diversity_idx in range(diversity_idxs):

        shade = shades[diversity_idx]
        lw = success_plot_lws[diversity_idx]

        for Weber_idx in range(2):

            data_flag_idx = Weber_idx + diversity_idx * 2
            data_flag = data_flags[data_flag_idx]

            # Blue for non-adapted; red for adapted
            cmap = cmaps[Weber_idx]

            list_dict = read_specs_file(data_flag)
            iter_vars = list_dict['iter_vars']
            Nn = list_dict['params']['Nn']
            iter_plot_var = iter_vars.keys()[axes_to_plot[0]]
            x_axis_var = iter_vars.keys()[axes_to_plot[1]]

            data = load_signal_decoding_weber_law(data_flag)
            successes = data['successes']

            nAxes = len(successes.shape)
            if nAxes > 2:
                successes = project_tensor(successes, iter_vars,
                                           projected_variable_components,
                                           axes_to_plot)

            # Switch axes if necessary
            if axes_to_plot[0] > axes_to_plot[1]:
                successes = successes.T

            # Plot successes, averaged over second axis of successes array
            avg_successes = sp.average(successes, axis=1) * 100.0
            plt.plot(iter_vars[iter_plot_var],
                     avg_successes,
                     color=cmap(shade),
                     zorder=diversity_idx,
                     lw=lw)

        # Save same plot in both Weber Law and non-Weber Law folders
        for Weber_idx in range(2):
            data_flag = data_flags[Weber_idx + diversity_idx * 2]
            save_decoding_accuracy_fig(fig, data_flag)
        plt.close()

    # Plot Kk2 of index [0, 0], sorted
    for data_flag in data_flags:

        list_dict = read_specs_file(data_flag)
        iter_vars = list_dict['iter_vars']

        data = load_signal_decoding_weber_law(data_flag)
        Kk2s = data['Kk2s']
        reshape_idxs = sp.hstack((-1, Kk2s.shape[-2:]))
        Kk2 = Kk2s.reshape(reshape_idxs)[0]

        means = sp.average(Kk2, axis=1)
        stdevs = sp.std(Kk2, axis=1)
        sorted_idxs = sp.argsort(means)
        sorted_Kk2 = Kk2[sorted_idxs, :]

        fig = Kk2_subfigures()
        plt.imshow(sp.log(sorted_Kk2.T) / sp.log(10),
                   interpolation='nearest',
                   cmap=plt.cm.inferno,
                   vmin=-1.51,
                   vmax=0.01)
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=14)
        save_Kk2_fig(fig, data_flag)
def pred_plot(data_flag):

    # Load specs file data and object
    list_dict = read_specs_file(data_flag)
    vars_to_pass = compile_all_run_vars(list_dict)
    scF = single_cell_FRET(**vars_to_pass)

    # If stim and meas were not imported, then data was saved as data_flag
    if scF.stim_file is None:
        scF.stim_file = data_flag
    if scF.meas_file is None:
        scF.meas_file = data_flag
    scF.set_stim()
    scF.set_meas_data()

    # Initalize estimation; set the estimation and prediction windows
    scF.set_est_pred_windows()

    # Load all of the prediction data and estimation object and dicts
    pred_dict = load_pred_data(data_flag)
    opt_IC = sp.nanargmin(pred_dict['errors'])
    opt_pred_path = pred_dict['pred_path'][:, :, opt_IC]
    est_path = pred_dict['est_path'][:, :, opt_IC]
    est_params = pred_dict['params'][:, opt_IC]
    est_range = scF.est_wind_idxs
    pred_range = scF.pred_wind_idxs
    full_range = sp.arange(scF.est_wind_idxs[0], scF.pred_wind_idxs[-1])
    est_Tt = scF.Tt[est_range]
    pred_Tt = scF.Tt[pred_range]
    full_Tt = scF.Tt[full_range]

    # Load true data if using synthetic data
    true_states = None
    try:
        true_states = load_true_file(data_flag)[:, 1:]
    except:
        pass

    num_plots = scF.nD + 1

    # Plot the stimulus
    plt.subplot(num_plots, 1, 1)
    plt.plot(full_Tt, scF.stim[full_range], color='r', lw=2)
    plt.xlim(full_Tt[0], full_Tt[-1])
    plt.ylim(80, 160)

    # Plot the estimates
    iL_idx = 0
    for iD in range(scF.nD):
        plt.subplot(num_plots, 1, iD + 2)
        plt.xlim(full_Tt[0], full_Tt[-1])

        if iD in scF.L_idxs:

            # Plot measured data
            plt.plot(est_Tt,
                     scF.meas_data[scF.est_wind_idxs, iL_idx],
                     color='g')
            plt.plot(pred_Tt,
                     scF.meas_data[scF.pred_wind_idxs, iL_idx],
                     color='g')

            # Plot estimation and prediction
            plt.plot(est_Tt, est_path[:, iD], color='r', lw=3)
            plt.plot(pred_Tt, opt_pred_path[:, iD], color='r', lw=3)

            # Plot true states if this uses fake data
            if true_states is not None:
                plt.plot(scF.Tt, true_states[:, iD], color='k')

            iL_idx += 1
        else:
            plt.plot(est_Tt, est_path[:, iD], color='r', lw=3)
            plt.plot(pred_Tt, opt_pred_path[:, iD], color='r', lw=3)
            if true_states is not None:
                plt.plot(scF.Tt, true_states[:, iD], color='k')
    save_opt_pred_plots(data_flag)
    plt.show()

    # Save all the optimal predictions, measurement and stimuli to txt files
    stim_to_save = sp.vstack((full_Tt.T, scF.stim[full_range].T)).T
    meas_to_save = sp.vstack((full_Tt.T, scF.meas_data[full_range].T)).T
    est_to_save = sp.vstack((est_Tt.T, est_path.T)).T
    pred_to_save = sp.vstack((pred_Tt.T, opt_pred_path.T)).T
    params_to_save = sp.vstack((scF.model.param_names, est_params)).T
    save_opt_pred_data(data_flag, stim_to_save, meas_to_save, est_to_save,
                       pred_to_save, params_to_save)
def aggregate_temporal_entropy_objects(data_flags):
    """
	Aggregate CS objects from separate .pklz files of temporal runs to a single
	.pklz object.
	
	Args:
		data_flags: Identifiers for saving and loading.
	"""

    temporal_structs_to_save = ['entropy']

    if isinstance(data_flags, str):
        data_flags = [data_flags]

    for data_flag in data_flags:
        list_dict = read_specs_file(data_flag)
        iter_vars = list_dict['iter_vars']
        iter_vars_dims = []
        for iter_var in iter_vars:
            iter_vars_dims.append(len(iter_vars[iter_var]))
        it = sp.nditer(sp.zeros(iter_vars_dims), flags=['multi_index'])

        CS_init_array = load_objects(list(it.multi_index), data_flag)

        # Dictionary to save all object at time 0; this will contain all
        # non-temporal info for each iterated variable.
        data = dict()
        data['init_objs'] = []
        nT = len(CS_init_array[0].signal_trace_Tt)

        # Assign data structures of appropriate shape for the temporal variable
        structs = dict()
        for struct_name in temporal_structs_to_save:
            try:
                tmp_str = 'structs[struct_name] = CS_init_array[0].%s' \
                   % struct_name
                exec(tmp_str)
            except:
                print('%s not an attribute of the CS object' % struct_name)
                continue

            # shape is (num timesteps, iterated var ranges, variable shape);
            # if a float or integer, shape is just time and iter vars.
            struct_shape = (nT, ) + tuple(iter_vars_dims)
            if hasattr(structs[struct_name], 'shape'):
                struct_shape += (structs[struct_name].shape)
            data['%s' % struct_name] = sp.zeros(struct_shape)

        # Iterate over all objects to be aggregated
        structs = dict()
        while not it.finished:

            print('Loading index:', it.multi_index)
            temporal_CS_array = load_objects(list(it.multi_index), data_flag)

            # Save full object at time 0, contains non-temporal data.
            data['init_objs'].append(temporal_CS_array[0])

            # Grab all the temporal structures, timepoint-by-timepoint
            for iT in range(nT):

                full_idx = (iT, ) + it.multi_index

                for struct_name in temporal_structs_to_save:
                    tmp_str = 'structs[struct_name] = temporal_CS_array[iT].%s' \
                       % struct_name
                    exec(tmp_str)
                    data[struct_name][full_idx] = structs[struct_name]

            it.iternext()

        save_aggregated_temporal_objects(data, data_flag)