def compute_optimal_rcscale(cls, M, R, population_code_type='conjunctive'):
   '''
           Compute the optimal rcscale, depending on the type of code we use.
       '''
   if population_code_type == 'conjunctive':
     # We use the optimum heuristic for the rc_scale: try to cover the space fully, assuming uniform coverage with squares of size 2*(2*utils.kappa_to_stddev(kappa)). We assume that 2*stddev gives a good approximation to the appropriate coverage required.
     return utils.stddev_to_kappa(2. * np.pi / int(M**(1. / R)))
   elif population_code_type == 'feature':
     return utils.stddev_to_kappa(2. * np.pi / int(M / R))
  def create_full_features(cls,
                           M,
                           R=2,
                           scale=0.3,
                           ratio=40.,
                           autoset_parameters=False,
                           response_maxout=False):
    '''
            Create a RandomFactorialNetwork instance, using a pure conjunctive code
        '''
    print "create feature network, R=%d, M=%d, autoset: %d" % (
        R, M, autoset_parameters)

    rn = HighDimensionNetwork(M, R=R, response_maxout=response_maxout)

    if autoset_parameters:
      # Use optimal values for the parameters. Be careful, this assumes M/2 and coverage of full 2 pi space
      # Assume one direction should cover width = pi, the other should cover M/2 * width/2. = 2pi
      # width = utils.stddev_to_kappa(stddev)

      scale = utils.stddev_to_kappa(np.pi)
      scale2 = cls.compute_optimal_rcscale(
          M, R, population_code_type='feature')
      ratio = scale2 / scale
    else:
      if ratio < 0.0:
        # Setting ratio < 0 cause some mid-automatic parameter setting.
        # Assume that only one scale is really desired, and the other automatically set.
        scale_fixed = utils.stddev_to_kappa(np.pi)
        ratio = np.max((scale / scale_fixed, scale_fixed / scale))

        print "Semi auto ratio: %f %f %f" % (scale, scale_fixed, ratio)

    M_sub = M / rn.R
    # Assign centers
    rn.assign_prefered_stimuli(tiling_type='features', reset=True)

    # Now assign scales
    resetted = False
    for r in xrange(rn.R):
      rn.assign_aligned_eigenvectors(
          scale=scale,
          ratio=ratio,
          scaled_dimension=r,
          specific_neurons=np.arange(r * M_sub, (r + 1) * M_sub),
          reset=not resetted)

      resetted = True

    rn.population_code_type = 'feature'

    return rn
예제 #3
0
    def create_mixed(cls, M, R=2, ratio_feature_conjunctive=0.5, conjunctive_parameters=None, feature_parameters=None, autoset_parameters=False, response_maxout=False):
        '''
            Create a RandomFactorialNetwork instance, using a pure conjunctive code
        '''
        print "Create mixed network, R=%d autoset: %d" % (R, autoset_parameters)

        conj_scale = 1.0
        feat_scale = 0.3
        feat_ratio = 40.0

        if conjunctive_parameters is not None:
            # Heavily refactored, but keeps compatibility...
            conj_scale = conjunctive_parameters['scale']

        if feature_parameters is not None:
            feat_scale = feature_parameters['scale']
            feat_ratio = feature_parameters['ratio']

            # nb_feature_centers = feature_parameters.get('nb_feature_centers', 1)

        rn = HighDimensionNetwork(M, R=R, response_maxout=response_maxout)

        rn.conj_subpop_size = int(M*ratio_feature_conjunctive)
        rn.feat_subpop_size = M - rn.conj_subpop_size

        if autoset_parameters:
            # Use optimal values for the parameters. Be careful, this assumes M/2 and coverage of full 2 pi space
            # Assume one direction should cover width = pi, the other should cover M/2 * width/2. = 2pi
            # width = utils.stddev_to_kappa(stddev)
            if rn.conj_subpop_size > 0:
                conj_scale = cls.compute_optimal_rcscale(rn.conj_subpop_size, R, population_code_type='conjunctive')
            if rn.feat_subpop_size > 0:
                feat_scale = utils.stddev_to_kappa(np.pi)
                feat_ratio = cls.compute_optimal_rcscale(rn.feat_subpop_size, R, population_code_type='feature')/feat_scale

        print "Population sizes: ratio: %.1f conj: %d, feat: %d, autoset: %d" % (ratio_feature_conjunctive, rn.conj_subpop_size, rn.feat_subpop_size, autoset_parameters)

        # Create the conjunctive subpopulation
        rn.assign_prefered_stimuli(tiling_type='conjunctive', reset=True, specific_neurons=np.arange(rn.conj_subpop_size))
        rn.assign_aligned_eigenvectors(scale=conj_scale, ratio=1.0, specific_neurons=np.arange(rn.conj_subpop_size), reset=True)

        # Create the feature subpopulation
        # Assign centers
        rn.assign_prefered_stimuli(tiling_type='features', specific_neurons=np.arange(rn.conj_subpop_size, M))
        # Now assign scales
        feat_sub_M = rn.feat_subpop_size/rn.R
        for r in xrange(rn.R):
            rn.assign_aligned_eigenvectors(scale=feat_scale, ratio=feat_ratio, scaled_dimension=r, specific_neurons=np.arange(rn.conj_subpop_size + r*feat_sub_M, rn.conj_subpop_size + (r+1)*feat_sub_M))

        rn.population_code_type = 'mixed'
        rn.ratio_conj = ratio_feature_conjunctive

        return rn
def plots_ratioMscaling(data_pbs, generator_module=None):
    '''
        Reload and plot precision/fits of a Mixed code.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plots_pcolor_all = False
    plots_effect_M_target_kappa = False

    plots_kappa_fi_comparison = False
    plots_multiple_fisherinfo = False
    specific_plot_effect_R = True


    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_precisions_mean = (utils.nanmean(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_all_precisions_std = (utils.nanstd(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_em_fits_mean = (utils.nanmean(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_em_fits_std = (utils.nanstd(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_fisherinfo_mean = (utils.nanmean(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))
    result_fisherinfo_std = (utils.nanstd(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))

    all_args = data_pbs.loaded_data['args_list']

    result_em_fits_kappa = result_em_fits_mean[..., 0]
    result_em_fits_target = result_em_fits_mean[..., 1]
    result_em_fits_kappa_valid = np.ma.masked_where(result_em_fits_target < 0.9, result_em_fits_kappa)

    M_space = data_pbs.loaded_data['parameters_uniques']['M'].astype(int)
    ratio_space = data_pbs.loaded_data['parameters_uniques']['ratio_conj']
    R_space = data_pbs.loaded_data['parameters_uniques']['R'].astype(int)
    num_repetitions = generator_module.num_repetitions

    print M_space
    print ratio_space
    print R_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape

    dataio = DataIO.DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    MAX_DISTANCE = 100.

    if plots_pcolor_all:
        # Do one pcolor for M and ratio per R
        for R_i, R in enumerate(R_space):
            # Check evolution of precision given M and ratio
            # utils.pcolor_2d_data(result_all_precisions_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='precision, R=%d' % R)
            # if savefigs:
            #     dataio.save_current_figure('pcolor_precision_R%d_log_{label}_{unique_id}.pdf' % R)

            # Show kappa
            try:
                utils.pcolor_2d_data(result_em_fits_kappa_valid[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='kappa, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_kappa_R%d_log_{label}_{unique_id}.pdf' % R)
            except ValueError:
                pass

            # Show probability on target
            # utils.pcolor_2d_data(result_em_fits_target[..., R_i], log_scale=False, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='target, R=%d' % R)
            # if savefigs:
            #     dataio.save_current_figure('pcolor_target_R%d_{label}_{unique_id}.pdf' % R)

            # # Show Fisher info
            # utils.pcolor_2d_data(result_fisherinfo_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='fisher info, R=%d' % R)
            # if savefigs:
            #     dataio.save_current_figure('pcolor_fisherinfo_R%d_log_{label}_{unique_id}.pdf' % R)

            plt.close('all')

    if plots_effect_M_target_kappa:
        def plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R):
            f, ax = plt.subplots()
            ax.plot(M_space, ratio_target_kappa_given_M)
            ax.set_xlabel('M')
            ax.set_ylabel('Optimal ratio')
            ax.set_title('Optimal Ratio for kappa %d, R=%d' % (target_kappa, R))

            if savefigs:
                dataio.save_current_figure('optratio_M_targetkappa%d_R%d_{label}_{unique_id}.pdf' % (target_kappa, R))

        target_kappas = np.array([100, 200, 300, 500, 1000, 3000])
        for R_i, R in enumerate(R_space):
            for target_kappa in target_kappas:
                dist_to_target_kappa = (result_em_fits_kappa[..., R_i] - target_kappa)**2.
                best_dist_to_target_kappa = np.argmin(dist_to_target_kappa, axis=1)
                ratio_target_kappa_given_M = np.ma.masked_where(dist_to_target_kappa[np.arange(dist_to_target_kappa.shape[0]), best_dist_to_target_kappa] > MAX_DISTANCE, ratio_space[best_dist_to_target_kappa])

                # replot
                plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R)

            plt.close('all')


    if plots_kappa_fi_comparison:

        # result_em_fits_kappa and fisher info
        if True:
            for R_i, R in enumerate(R_space):
                for M_tot_selected_i, M_tot_selected in enumerate(M_space[::2]):

                    # M_conj_space = ((1.-ratio_space)*M_tot_selected).astype(int)
                    # M_feat_space = M_tot_selected - M_conj_space

                    f, axes = plt.subplots(2, 1)
                    axes[0].plot(ratio_space, result_em_fits_kappa[2*M_tot_selected_i, ..., R_i])
                    axes[0].set_xlabel('ratio')
                    axes[0].set_title('Fitted kappa')

                    axes[1].plot(ratio_space, utils.stddev_to_kappa(1./result_fisherinfo_mean[2*M_tot_selected_i, ..., R_i]**0.5))
                    axes[1].set_xlabel('ratio')
                    axes[1].set_title('kappa_FI')

                    f.suptitle('M_tot %d' % M_tot_selected, fontsize=15)
                    f.set_tight_layout(True)

                    if savefigs:
                        dataio.save_current_figure('comparison_kappa_fisher_R%d_M%d_{label}_{unique_id}.pdf' % (R, M_tot_selected))

                    plt.close(f)

        if plots_multiple_fisherinfo:
            target_fisherinfos = np.array([100, 200, 300, 500, 1000])
            for R_i, R in enumerate(R_space):
                for target_fisherinfo in target_fisherinfos:
                    dist_to_target_fisherinfo = (result_fisherinfo_mean[..., R_i] - target_fisherinfo)**2.

                    utils.pcolor_2d_data(dist_to_target_fisherinfo, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='Fisher info, R=%d' % R)
                    if savefigs:
                        dataio.save_current_figure('pcolor_distfi%d_R%d_log_{label}_{unique_id}.pdf' % (target_fisherinfo, R))

                plt.close('all')

    if specific_plot_effect_R:
        # Choose a M, find which ratio gives best fit to a given kappa
        M_target = 228
        M_target_i = np.argmin(np.abs(M_space - M_target))

        # target_kappa = np.ma.mean(result_em_fits_kappa_valid[M_target_i])
        # target_kappa = 5*1e3
        target_kappa = 1e3

        dist_target_kappa = (result_em_fits_kappa_valid[M_target_i] - target_kappa)**2.

        utils.pcolor_2d_data(dist_target_kappa, log_scale=True, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='Kappa dist %.2f, M %d' % (target_kappa, M_target))
        if savefigs:
            dataio.save_current_figure('pcolor_distkappa%d_M%d_log_{label}_{unique_id}.pdf' % (target_kappa, M_target))




    all_args = data_pbs.loaded_data['args_list']
    variables_to_save = []

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)

        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='higher_dimensions_R')

    plt.show()

    return locals()
def plots_ratioMscaling(data_pbs, generator_module=None):
    '''
        Reload and plot precision/fits of a Mixed code.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plots_pcolor_all = False
    plots_effect_M_target_kappa = False

    plots_kappa_fi_comparison = False
    plots_multiple_fisherinfo = False
    specific_plot_effect_R = True
    specific_plot_effect_ratio_M = False

    convert_M_realsizes = False

    plots_pcolor_realsizes_Msubs = True
    plots_pcolor_realsizes_Mtot = True


    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    # interpolation_method = 'linear'
    interpolation_method = 'nearest'
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_precisions_mean = (utils.nanmean(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_all_precisions_std = (utils.nanstd(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_em_fits_mean = (utils.nanmean(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_em_fits_std = (utils.nanstd(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_fisherinfo_mean = (utils.nanmean(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))
    result_fisherinfo_std = (utils.nanstd(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))

    all_args = data_pbs.loaded_data['args_list']

    result_em_fits_kappa = result_em_fits_mean[..., 0]
    result_em_fits_target = result_em_fits_mean[..., 1]
    result_em_fits_kappa_valid = np.ma.masked_where(result_em_fits_target < 0.8, result_em_fits_kappa)

    # flat versions
    result_parameters_flat = np.array(data_pbs.dict_arrays['result_all_precisions']['parameters_flat'])
    result_all_precisions_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_all_precisions']['results_flat']), axis=-1)
    result_em_fits_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_em_fits']['results_flat']), axis=-1)
    result_fisherinfor_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_fisher_info']['results_flat']), axis=-1)
    result_em_fits_kappa_flat = result_em_fits_mean_flat[..., 0]
    result_em_fits_target_flat = result_em_fits_mean_flat[..., 1]
    result_em_fits_kappa_valid_flat = np.ma.masked_where(result_em_fits_target_flat < 0.8, result_em_fits_kappa_flat)



    M_space = data_pbs.loaded_data['parameters_uniques']['M'].astype(int)
    ratio_space = data_pbs.loaded_data['parameters_uniques']['ratio_conj']
    R_space = data_pbs.loaded_data['parameters_uniques']['R'].astype(int)
    num_repetitions = generator_module.num_repetitions
    T = generator_module.T

    print M_space
    print ratio_space
    print R_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape

    dataio = DataIO.DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    MAX_DISTANCE = 100.

    if convert_M_realsizes:
        # alright, currently M*ratio_conj gives the conjunctive subpopulation,
        # but only floor(M_conj**1/R) neurons are really used. So we should
        # convert to M_conj_real and M_feat_real instead of M and ratio
        result_parameters_flat_subM_converted = []
        result_parameters_flat_Mtot_converted = []

        for params in result_parameters_flat:
            M = params[0]; ratio_conj = params[1]; R = int(params[2])

            M_conj_prior = int(M*ratio_conj)
            M_conj_true = int(np.floor(M_conj_prior**(1./R))**R)
            M_feat_true = int(np.floor((M-M_conj_prior)/R)*R)

            # result_parameters_flat_subM_converted contains (M_conj, M_feat, R)
            result_parameters_flat_subM_converted.append(np.array([M_conj_true, M_feat_true, R]))
            # result_parameters_flat_Mtot_converted contains (M_tot, ratio_conj, R)
            result_parameters_flat_Mtot_converted.append(np.array([float(M_conj_true+M_feat_true), float(M_conj_true)/float(M_conj_true+M_feat_true), R]))

        result_parameters_flat_subM_converted = np.array(result_parameters_flat_subM_converted)
        result_parameters_flat_Mtot_converted = np.array(result_parameters_flat_Mtot_converted)

    if plots_pcolor_all:
        if convert_M_realsizes:
            def plot_interp(points, data, currR_indices, title='', points_label='', xlabel='', ylabel=''):
                utils.contourf_interpolate_data_interactive_maxvalue(points[currR_indices][..., :2], data[currR_indices], xlabel=xlabel, ylabel=ylabel, title='%s, R=%d' % (title, R), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False)

                if savefigs:
                    dataio.save_current_figure('pcolortrueM%s_%s_R%d_log_%s_{label}_{unique_id}.pdf' % (points_label, title, R, interpolation_method))

            all_datas = [dict(name='precision', data=result_all_precisions_mean_flat), dict(name='kappa', data=result_em_fits_kappa_flat), dict(name='kappavalid', data=result_em_fits_kappa_valid_flat), dict(name='target', data=result_em_fits_target_flat), dict(name='fisherinfo', data=result_fisherinfor_mean_flat)]
            all_points = []
            if plots_pcolor_realsizes_Msubs:
                all_points.append(dict(name='sub', data=result_parameters_flat_subM_converted, xlabel='M_conj', ylabel='M_feat'))
            if plots_pcolor_realsizes_Mtot:
                all_points.append(dict(name='tot', data=result_parameters_flat_Mtot_converted, xlabel='Mtot', ylabel='ratio_conj'))

            for curr_points in all_points:
                for curr_data in all_datas:
                    for R_i, R in enumerate(R_space):
                        currR_indices = curr_points['data'][:, 2] == R

                        plot_interp(curr_points['data'], curr_data['data'], currR_indices, title=curr_data['name'], points_label=curr_points['name'], xlabel=curr_points['xlabel'], ylabel=curr_points['ylabel'])


        else:
            # Do one pcolor for M and ratio per R
            for R_i, R in enumerate(R_space):
                # Check evolution of precision given M and ratio
                utils.pcolor_2d_data(result_all_precisions_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='precision, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_precision_R%d_log_{label}_{unique_id}.pdf' % R)

                # Show kappa
                try:
                    utils.pcolor_2d_data(result_em_fits_kappa_valid[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='kappa, R=%d' % R)
                    if savefigs:
                        dataio.save_current_figure('pcolor_kappa_R%d_log_{label}_{unique_id}.pdf' % R)
                except ValueError:
                    pass

                # Show probability on target
                utils.pcolor_2d_data(result_em_fits_target[..., R_i], log_scale=False, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='target, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_target_R%d_{label}_{unique_id}.pdf' % R)

                # # Show Fisher info
                utils.pcolor_2d_data(result_fisherinfo_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='fisher info, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_fisherinfo_R%d_log_{label}_{unique_id}.pdf' % R)

                plt.close('all')

    if plots_effect_M_target_kappa:
        def plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R):
            f, ax = plt.subplots()
            ax.plot(M_space, ratio_target_kappa_given_M)
            ax.set_xlabel('M')
            ax.set_ylabel('Optimal ratio')
            ax.set_title('Optimal Ratio for kappa %d, R=%d' % (target_kappa, R))

            if savefigs:
                dataio.save_current_figure('optratio_M_targetkappa%d_R%d_{label}_{unique_id}.pdf' % (target_kappa, R))

        target_kappas = np.array([100, 200, 300, 500, 1000, 3000])
        for R_i, R in enumerate(R_space):
            for target_kappa in target_kappas:
                dist_to_target_kappa = (result_em_fits_kappa[..., R_i] - target_kappa)**2.
                best_dist_to_target_kappa = np.argmin(dist_to_target_kappa, axis=1)
                ratio_target_kappa_given_M = np.ma.masked_where(dist_to_target_kappa[np.arange(dist_to_target_kappa.shape[0]), best_dist_to_target_kappa] > MAX_DISTANCE, ratio_space[best_dist_to_target_kappa])

                # replot
                plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R)

            plt.close('all')


    if plots_kappa_fi_comparison:

        # result_em_fits_kappa and fisher info
        if True:
            for R_i, R in enumerate(R_space):
                for M_tot_selected_i, M_tot_selected in enumerate(M_space[::2]):

                    # M_conj_space = ((1.-ratio_space)*M_tot_selected).astype(int)
                    # M_feat_space = M_tot_selected - M_conj_space

                    f, axes = plt.subplots(2, 1)
                    axes[0].plot(ratio_space, result_em_fits_kappa[2*M_tot_selected_i, ..., R_i])
                    axes[0].set_xlabel('ratio')
                    axes[0].set_title('Fitted kappa')

                    axes[1].plot(ratio_space, utils.stddev_to_kappa(1./result_fisherinfo_mean[2*M_tot_selected_i, ..., R_i]**0.5))
                    axes[1].set_xlabel('ratio')
                    axes[1].set_title('kappa_FI')

                    f.suptitle('M_tot %d' % M_tot_selected, fontsize=15)
                    f.set_tight_layout(True)

                    if savefigs:
                        dataio.save_current_figure('comparison_kappa_fisher_R%d_M%d_{label}_{unique_id}.pdf' % (R, M_tot_selected))

                    plt.close(f)

        if plots_multiple_fisherinfo:
            target_fisherinfos = np.array([100, 200, 300, 500, 1000])
            for R_i, R in enumerate(R_space):
                for target_fisherinfo in target_fisherinfos:
                    dist_to_target_fisherinfo = (result_fisherinfo_mean[..., R_i] - target_fisherinfo)**2.

                    utils.pcolor_2d_data(dist_to_target_fisherinfo, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='Fisher info, R=%d' % R)
                    if savefigs:
                        dataio.save_current_figure('pcolor_distfi%d_R%d_log_{label}_{unique_id}.pdf' % (target_fisherinfo, R))

                plt.close('all')

    if specific_plot_effect_R:
        M_target = 356
        if convert_M_realsizes:
            M_tot_target = M_target
            delta_around_target = 30

            filter_points_totM_indices = np.abs(result_parameters_flat_Mtot_converted[:, 0] - M_tot_target) < delta_around_target

            # first check landscape
            utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_Mtot_converted[filter_points_totM_indices][..., 1:], result_em_fits_kappa_flat[filter_points_totM_indices], xlabel='ratio_conj', ylabel='R', title='kappa, M_target=%d +- %d' % (M_tot_target, delta_around_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, show_slider=False)

            if savefigs:
                dataio.save_current_figure('specific_pcolortrueMtot_kappa_M%d_log_%s_{label}_{unique_id}.pdf' % (M_tot_target, interpolation_method))

            # Then plot distance to specific kappa
            # target_kappa = 1.2e3
            target_kappa = 580
            dist_target_kappa_flat = np.abs(result_em_fits_kappa_flat - target_kappa)
            mask_greater_than = 5e3

            utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_Mtot_converted[filter_points_totM_indices][..., 1:], dist_target_kappa_flat[filter_points_totM_indices], xlabel='ratio_conj', ylabel='R', title='dist kappa %d, M_target=%d +- %d' % (target_kappa, M_tot_target, delta_around_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, mask_greater_than=mask_greater_than, mask_smaller_than=0)

            if savefigs:
                dataio.save_current_figure('specific_pcolortrueMtot_distkappa%d_M%d_log_%s_{label}_{unique_id}.pdf' % (target_kappa, M_tot_target, interpolation_method))

        else:
            # Choose a M, find which ratio gives best fit to a given kappa
            M_target_i = np.argmin(np.abs(M_space - M_target))

            utils.pcolor_2d_data(result_em_fits_kappa[M_target_i], log_scale=True, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='Kappa, M %d' % (M_target))
            plt.gcf().set_tight_layout(True)
            plt.gcf().canvas.draw()
            if savefigs:
                dataio.save_current_figure('specific_Reffect_pcolor_kappa_M%dT%d_log_{label}_{unique_id}.pdf' % (M_target, T))
            # target_kappa = np.ma.mean(result_em_fits_kappa_valid[M_target_i])
            # target_kappa = 5*1e3
            # target_kappa = 1.2e3
            target_kappa = 580

            # dist_target_kappa = np.abs(result_em_fits_kappa_valid[M_target_i] - target_kappa)
            dist_target_kappa = result_em_fits_kappa[M_target_i]/target_kappa
            dist_target_kappa[dist_target_kappa > 2.0] = 2.0
            # dist_target_kappa[dist_target_kappa < 0.5] = 0.5

            utils.pcolor_2d_data(dist_target_kappa, log_scale=False, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='Kappa dist %.2f, M %d' % (target_kappa, M_target), cmap='RdBu_r')
            plt.gcf().set_tight_layout(True)
            plt.gcf().canvas.draw()
            if savefigs:
                dataio.save_current_figure('specific_Reffect_pcolor_distkappa%d_M%dT%d_log_{label}_{unique_id}.pdf' % (target_kappa, M_target, T))

            # Plot the probability of being on-target
            utils.pcolor_2d_data(result_em_fits_target[M_target_i], log_scale=False, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='target mixture proportion, M %d' % (M_target), vmin=0.0, vmax=1.0)  #cmap='RdBu_r'
            plt.gcf().set_tight_layout(True)
            plt.gcf().canvas.draw()
            if savefigs:
                dataio.save_current_figure('specific_Reffect_pcolor_target_M%dT%d_{label}_{unique_id}.pdf' % (M_target, T))



    if specific_plot_effect_ratio_M:
        # try to do the same plots as in ratio_scaling_M but with the current data
        R_target = 2
        interpolation_method = 'cubic'
        mask_greater_than = 1e3

        if convert_M_realsizes:
            # filter_points_targetR_indices = np.abs(result_parameters_flat_Mtot_converted[:, -1] - R_target) == 0
            filter_points_targetR_indices = np.abs(result_parameters_flat_subM_converted[:, -1] - R_target) == 0

            # utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., :-1], result_em_fits_kappa_flat[filter_points_targetR_indices], xlabel='M', ylabel='ratio_conj', title='kappa, R=%d' % (R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False)
            utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_subM_converted[filter_points_targetR_indices][..., :-1], result_em_fits_kappa_flat[filter_points_targetR_indices], xlabel='M_conj', ylabel='M_feat', title='kappa, R=%d' % (R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, show_slider=False)

            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolorsubM_kappa_R%d_log_%s_{label}_{unique_id}.pdf' % (R_target, interpolation_method))

            # Then plot distance to specific kappa
            target_kappa = 580
            # target_kappa = 2000
            dist_target_kappa_flat = np.abs(result_em_fits_kappa_flat - target_kappa)
            dist_target_kappa_flat = result_em_fits_kappa_flat/target_kappa
            dist_target_kappa_flat[dist_target_kappa_flat > 1.45] = 1.45
            dist_target_kappa_flat[dist_target_kappa_flat < 0.5] = 0.5

            # utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., :-1], dist_target_kappa_flat[filter_points_targetR_indices], xlabel='M', ylabel='ratio_conj', title='kappa, R=%d' % (R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, mask_smaller_than=0, show_slider=False, mask_greater_than =mask_greater_than)

            # utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat[filter_points_targetR_indices][..., :-1], dist_target_kappa_flat[filter_points_targetR_indices], xlabel='M', ylabel='ratio_conj', title='kappa, R=%d' % (R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, mask_smaller_than=0, show_slider=False, mask_greater_than =mask_greater_than)

            utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_subM_converted[filter_points_targetR_indices][..., :-1], dist_target_kappa_flat[filter_points_targetR_indices], xlabel='M_conj', ylabel='M_feat', title='Kappa dist %.2f, R=%d' % (target_kappa, R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, mask_greater_than=mask_greater_than, mask_smaller_than=0, show_slider=False)

            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolorsubM_distkappa%d_R%d_log_%s_{label}_{unique_id}.pdf' % (target_kappa, R_target, interpolation_method))


            ## Scatter plot, works better...

            f, ax = plt.subplots()
            ss = ax.scatter(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0], result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1], 500, c=(dist_target_kappa_flat[filter_points_targetR_indices]),
                norm=matplotlib.colors.LogNorm())
            plt.colorbar(ss)
            ax.set_xlabel('M')
            ax.set_ylabel('ratio')
            ax.set_xlim((result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0].min()*0.8, result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0].max()*1.03))
            ax.set_ylim((-0.05, result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1].max()*1.05))
            ax.set_title('Kappa dist %.2f, R=%d' % (target_kappa, R_target))

            if savefigs:
                dataio.save_current_figure('specific_ratioM_scattertotM_distkappa%d_R%d_log_%s_{label}_{unique_id}.pdf' % (target_kappa, R_target, interpolation_method))


            ## Spline interpolation
            distkappa_spline_int_params = spint.bisplrep(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0], result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1], np.log(dist_target_kappa_flat[filter_points_targetR_indices]), kx=3, ky=3, s=1)
            if True:
                M_interp_space = np.linspace(100, 740, 100)
                ratio_interp_space = np.linspace(0.0, 1.0, 100)
                # utils.pcolor_2d_data(spint.bisplev(M_interp_space, ratio_interp_space, distkappa_spline_int_params), y=ratio_interp_space, x=M_interp_space, ylabel='ratio', xlabel='M', xlabel_format="%d", title='Kappa dist %.2f, R %d' % (target_kappa, R_target), ticks_interpolate=11)
                utils.pcolor_2d_data(np.exp(spint.bisplev(M_interp_space, ratio_interp_space, distkappa_spline_int_params)), y=ratio_interp_space, x=M_interp_space, ylabel='ratio conjunctivity', xlabel='M', xlabel_format="%d", title='Ratio kappa/%.2f, R %d' % (target_kappa, R_target), ticks_interpolate=11, log_scale=False, cmap='RdBu_r')
                # plt.scatter(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0], result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1], marker='o', c='b', s=5)
                # plt.scatter(np.argmin(np.abs(M_interp_space[:, np.newaxis] - result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0]), axis=0), np.argmin(np.abs(ratio_interp_space[:, np.newaxis] - result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1]), axis=0), marker='o', c='b', s=5)
            else:
                M_interp_space = M_space
                ratio_interp_space = ratio_space
                utils.pcolor_2d_data(spint.bisplev(M_interp_space, ratio_interp_space, distkappa_spline_int_params), y=ratio_interp_space, x=M_interp_space, ylabel='ratio', xlabel='M', xlabel_format="%d", title='Kappa dist %.2f, R %d' % (target_kappa, R_target))


            plt.gcf().set_tight_layout(True)
            plt.gcf().canvas.draw()
            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolorsplinetotM_distkappa%d_R%d_log_%s_{label}_{unique_id}.pdf' % (target_kappa, R_target, interpolation_method))

            ### Distance to Fisher Info

            target_fi = 2*target_kappa
            dist_target_fi_flat = np.abs(result_fisherinfor_mean_flat - target_fi)

            dist_target_fi_flat = result_fisherinfor_mean_flat/target_fi
            dist_target_fi_flat[dist_target_fi_flat > 1.45] = 1.45
            dist_target_fi_flat[dist_target_fi_flat < 0.5] = 0.5

            # utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., :-1], dist_target_kappa_flat[filter_points_targetR_indices], xlabel='M', ylabel='ratio_conj', title='kappa, R=%d' % (R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False)
            utils.contourf_interpolate_data_interactive_maxvalue(result_parameters_flat_subM_converted[filter_points_targetR_indices][..., :-1], dist_target_fi_flat[filter_points_targetR_indices], xlabel='M_conj', ylabel='M_feat', title='FI dist %.2f, R=%d' % (target_fi, R_target), interpolation_numpoints=200, interpolation_method=interpolation_method, log_scale=False, mask_greater_than=mask_greater_than, show_slider=False)

            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolorsubM_distfi%d_R%d_log_%s_{label}_{unique_id}.pdf' % (target_fi, R_target, interpolation_method))


            distFI_spline_int_params = spint.bisplrep(result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 0], result_parameters_flat_Mtot_converted[filter_points_targetR_indices][..., 1], np.log(dist_target_fi_flat[filter_points_targetR_indices]), kx=3, ky=3, s=1)

            M_interp_space = np.linspace(100, 740, 100)
            ratio_interp_space = np.linspace(0.0, 1.0, 100)
            # utils.pcolor_2d_data(spint.bisplev(M_interp_space, ratio_interp_space, spline_int_params), y=ratio_interp_space, x=M_interp_space, ylabel='ratio', xlabel='M', xlabel_format="%d", title='Kappa dist %.2f, R %d' % (target_kappa, R_target), ticks_interpolate=11)
            utils.pcolor_2d_data(np.exp(spint.bisplev(M_interp_space, ratio_interp_space, distFI_spline_int_params)), y=ratio_interp_space, x=M_interp_space, ylabel='ratio conjunctivity', xlabel='M', xlabel_format="%d", title='Ratio FI/%.2f, R %d' % (target_fi, R_target), ticks_interpolate=11, log_scale=False, cmap='RdBu_r')
            plt.gcf().set_tight_layout(True)
            plt.gcf().canvas.draw()
            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolorsplinetotM_distfi%d_R%d_log_%s_{label}_{unique_id}.pdf' % (target_fi, R_target, interpolation_method))



        else:
            R_target_i = np.argmin(np.abs(R_space - R_target))

            utils.pcolor_2d_data(result_em_fits_kappa_valid[..., R_target_i], log_scale=True, y=ratio_space, x=M_space, ylabel='ratio', xlabel='M', xlabel_format="%d", title='Kappa, R %d' % (R_target))
            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolor_kappa_R%d_log_{label}_{unique_id}.pdf' % (R_target))
            # target_kappa = np.ma.mean(result_em_fits_kappa_valid[R_target_i])
            # target_kappa = 5*1e3
            target_kappa = 580

            dist_target_kappa = np.ma.masked_greater(np.abs(result_em_fits_kappa_valid[..., R_target_i] - target_kappa), mask_greater_than*5)

            utils.pcolor_2d_data(dist_target_kappa, log_scale=True, y=ratio_space, x=M_space, ylabel='ratio', xlabel='M', xlabel_format="%d", title='Kappa dist %.2f, R %d' % (target_kappa, R_target))
            if savefigs:
                dataio.save_current_figure('specific_ratioM_pcolor_distkappa%d_R%d_log_{label}_{unique_id}.pdf' % (target_kappa, R_target))






    all_args = data_pbs.loaded_data['args_list']
    variables_to_save = []

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)

        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='higher_dimensions_R')

    plt.show()

    return locals()
def plots_specific_stimuli_hierarchical(data_pbs, generator_module=None):
    '''
        Reload and plot behaviour of mixed population code on specific Stimuli
        of 3 items.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plot_per_min_dist_all = False
    specific_plots_paper = False
    plots_emfit_allitems = False
    plot_min_distance_effect = True

    should_fit_allitems_model = True
    # caching_emfit_filename = None
    caching_emfit_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_emfitallitems_uniquekappa.pickle')


    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_precisions_mean = utils.nanmean(np.squeeze(data_pbs.dict_arrays['result_all_precisions']['results']), axis=-1)
    result_all_precisions_std = utils.nanstd(np.squeeze(data_pbs.dict_arrays['result_all_precisions']['results']), axis=-1)
    result_em_fits_mean = utils.nanmean(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results']), axis=-1)
    result_em_fits_std = utils.nanstd(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results']), axis=-1)
    result_em_kappastddev_mean = utils.nanmean(utils.kappa_to_stddev(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results'])[..., 0, :]), axis=-1)
    result_em_kappastddev_std = utils.nanstd(utils.kappa_to_stddev(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results'])[..., 0, :]), axis=-1)
    result_responses_all = np.squeeze(data_pbs.dict_arrays['result_responses']['results'])
    result_target_all = np.squeeze(data_pbs.dict_arrays['result_target']['results'])
    result_nontargets_all = np.squeeze(data_pbs.dict_arrays['result_nontargets']['results'])

    all_args = data_pbs.loaded_data['args_list']

    nb_repetitions = np.squeeze(data_pbs.dict_arrays['result_em_fits']['results']).shape[-1]
    print nb_repetitions
    nb_repetitions = result_responses_all.shape[-1]
    print nb_repetitions
    K = result_nontargets_all.shape[-2]
    N = result_responses_all.shape[-2]


    enforce_min_distance_space = data_pbs.loaded_data['parameters_uniques']['enforce_min_distance']
    sigmax_space = data_pbs.loaded_data['parameters_uniques']['sigmax']

    MMlower_valid_space = data_pbs.loaded_data['datasets_list'][0]['MMlower_valid_space']
    ratio_space = MMlower_valid_space[:, 0]/float(np.sum(MMlower_valid_space[0]))

    print enforce_min_distance_space
    print sigmax_space
    print MMlower_valid_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape

    dataio = DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    # Relaod cached emfitallitems
    if caching_emfit_filename is not None:
        if os.path.exists(caching_emfit_filename):
            # Got file, open it and try to use its contents
            try:
                with open(caching_emfit_filename, 'r') as file_in:
                    # Load and assign values
                    cached_data = pickle.load(file_in)
                    result_emfitallitems = cached_data['result_emfitallitems']
                    should_fit_allitems_model = False

            except IOError:
                print "Error while loading ", caching_emfit_filename, "falling back to computing the EM fits"

    if plot_per_min_dist_all:
        # Do one plot per min distance.
        for min_dist_i, min_dist in enumerate(enforce_min_distance_space):
            # Show log precision
            utils.pcolor_2d_data(result_all_precisions_mean[min_dist_i].T, x=ratio_space, y=sigmax_space, xlabel='ratio layer two', ylabel='sigma_x', title='Precision, min_dist=%.3f' % min_dist)
            if savefigs:
                dataio.save_current_figure('precision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Show log precision
            utils.pcolor_2d_data(result_all_precisions_mean[min_dist_i].T, x=ratio_space, y=sigmax_space, xlabel='ratio layer two', ylabel='sigma_x', title='Precision, min_dist=%.3f' % min_dist, log_scale=True)
            if savefigs:
                dataio.save_current_figure('logprecision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)


            # Plot estimated model precision (kappa)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 0].T, x=ratio_space, y=sigmax_space, xlabel='ratio layer two', ylabel='sigma_x', title='EM precision, min_dist=%.3f' % min_dist, log_scale=False)
            if savefigs:
                dataio.save_current_figure('logemprecision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Plot estimated Target, nontarget and random mixture components, in multiple subplots
            _, axes = plt.subplots(1, 3, figsize=(18, 6))
            plt.subplots_adjust(left=0.05, right=0.97, wspace = 0.3, bottom=0.15)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 1].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Target, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[0], ticks_interpolate=5)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 2].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Nontarget, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[1], ticks_interpolate=5)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 3].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Random, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[2], ticks_interpolate=5)

            if savefigs:
                dataio.save_current_figure('em_mixtureprobs_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Plot Log-likelihood of Mixture model, sanity check
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., -1].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='EM loglik, min_dist=%.3f' % min_dist, log_scale=False)
            if savefigs:
                dataio.save_current_figure('em_loglik_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

    if specific_plots_paper:
        # We need to choose 3 levels of min_distances
        target_sigmax = 0.25
        target_mindist_low = 0.09
        target_mindist_medium = 0.36
        target_mindist_high = 1.5

        sigmax_level_i = np.argmin(np.abs(sigmax_space - target_sigmax))
        min_dist_level_low_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_low))
        min_dist_level_medium_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_medium))
        min_dist_level_high_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_high))

        ## Do for each distance
        # for min_dist_i in [min_dist_level_low_i, min_dist_level_medium_i, min_dist_level_high_i]:
        for min_dist_i in xrange(enforce_min_distance_space.size):

            # Plot precision
            if False:
                utils.plot_mean_std_area(ratio_space, result_all_precisions_mean[min_dist_i, sigmax_level_i], result_all_precisions_std[min_dist_i, sigmax_level_i]) #, xlabel='Ratio conjunctivity', ylabel='Precision of recall')
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                plt.ylim([0, np.max(result_all_precisions_mean[min_dist_i, sigmax_level_i] + result_all_precisions_std[min_dist_i, sigmax_level_i])])

                if savefigs:
                    dataio.save_current_figure('mindist%.2f_precisionrecall_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

            # Plot kappa fitted
            ax_handle = utils.plot_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 0], result_em_fits_std[min_dist_i, sigmax_level_i, :, 0]) #, xlabel='Ratio conjunctivity', ylabel='Fitted kappa')
            # Add distance between items in kappa units
            dist_items_kappa = utils.stddev_to_kappa(enforce_min_distance_space[min_dist_i])
            ax_handle.plot(ratio_space, dist_items_kappa*np.ones(ratio_space.size), 'k--', linewidth=3)
            plt.ylim([-0.1, np.max((np.max(result_em_fits_mean[min_dist_i, sigmax_level_i, :, 0] + result_em_fits_std[min_dist_i, sigmax_level_i, :, 0]), 1.1*dist_items_kappa))])
            # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
            if savefigs:
                dataio.save_current_figure('mindist%.2f_emkappa_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

            # Plot kappa-stddev fitted. Easier to visualize
            ax_handle = utils.plot_mean_std_area(ratio_space, result_em_kappastddev_mean[min_dist_i, sigmax_level_i], result_em_kappastddev_std[min_dist_i, sigmax_level_i]) #, xlabel='Ratio conjunctivity', ylabel='Fitted kappa_stddev')
            # Add distance between items in std dev units
            dist_items_std = (enforce_min_distance_space[min_dist_i])
            ax_handle.plot(ratio_space, dist_items_std*np.ones(ratio_space.size), 'k--', linewidth=3)
            # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
            plt.ylim([0, 1.1*np.max((np.max(result_em_kappastddev_mean[min_dist_i, sigmax_level_i] + result_em_kappastddev_std[min_dist_i, sigmax_level_i]), dist_items_std))])
            if savefigs:
                dataio.save_current_figure('mindist%.2f_emkappastddev_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])


            if False:
                # Plot LLH
                utils.plot_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, -1], result_em_fits_std[min_dist_i, sigmax_level_i, :, -1]) #, xlabel='Ratio conjunctivity', ylabel='Loglikelihood of Mixture model fit')
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emllh_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

                # Plot mixture parameters, std
                utils.plot_multiple_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 1:4].T, result_em_fits_std[min_dist_i, sigmax_level_i, :, 1:4].T)
                plt.ylim([0.0, 1.1])
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                # plt.legend("Target", "Non-target", "Random")
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emprobs_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

                # Mixture parameters, SEM
                utils.plot_multiple_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 1:4].T, result_em_fits_std[min_dist_i, sigmax_level_i, :, 1:4].T/np.sqrt(nb_repetitions))
                plt.ylim([0.0, 1.1])
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                # plt.legend("Target", "Non-target", "Random")
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emprobs_forpaper_sem_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

    if plots_emfit_allitems:
        # We need to choose 3 levels of min_distances
        target_sigmax = 0.25
        target_mindist_low = 0.15
        target_mindist_medium = 0.36
        target_mindist_high = 1.5

        sigmax_level_i = np.argmin(np.abs(sigmax_space - target_sigmax))
        min_dist_level_low_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_low))
        min_dist_level_medium_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_medium))
        min_dist_level_high_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_high))

        min_dist_i_plotting_space = np.array([min_dist_level_low_i, min_dist_level_medium_i, min_dist_level_high_i])

        if should_fit_allitems_model:

            # kappa, mixt_target, mixt_nontargets (K), mixt_random, LL, bic
            # result_emfitallitems = np.empty((min_dist_i_plotting_space.size, ratio_space.size, 2*K+5))*np.nan
            result_emfitallitems = np.empty((enforce_min_distance_space.size, ratio_space.size, K+5))*np.nan

            ## Do for each distance
            # for min_dist_plotting_i, min_dist_i in enumerate(min_dist_i_plotting_space):
            for min_dist_i in xrange(enforce_min_distance_space.size):
                # Fit the mixture model
                for ratio_i, ratio in enumerate(ratio_space):
                    print "Refitting EM all items. Ratio:", ratio, "Dist:", enforce_min_distance_space[min_dist_i]
                    em_fit = em_circularmixture_allitems_uniquekappa.fit(
                        result_responses_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                        result_target_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                        result_nontargets_all[min_dist_i, sigmax_level_i, ratio_i].transpose((0, 2, 1)).reshape((N*nb_repetitions, K)))

                    result_emfitallitems[min_dist_i, ratio_i] = [em_fit['kappa'], em_fit['mixt_target']] + em_fit['mixt_nontargets'].tolist() + [em_fit[key] for key in ('mixt_random', 'train_LL', 'bic')]

            # Save everything to a file, for faster later plotting
            if caching_emfit_filename is not None:
                try:
                    with open(caching_emfit_filename, 'w') as filecache_out:
                        data_em = dict(result_emfitallitems=result_emfitallitems)
                        pickle.dump(data_em, filecache_out, protocol=2)
                except IOError:
                    print "Error writing out to caching file ", caching_emfit_filename


        ## Plots now, for each distance!
        # for min_dist_plotting_i, min_dist_i in enumerate(min_dist_i_plotting_space):
        for min_dist_i in xrange(enforce_min_distance_space.size):

            # Plot now
            _, ax = plt.subplots()
            ax.plot(ratio_space, result_emfitallitems[min_dist_i, :, 1:5], linewidth=3)
            plt.ylim([0.0, 1.1])
            plt.legend(['Target', 'Nontarget 1', 'Nontarget 2', 'Random'], loc='upper left')

            if savefigs:
                dataio.save_current_figure('mindist%.2f_emprobsfullitems_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

    if plot_min_distance_effect:
        conj_receptive_field_size = 2.*np.pi/((all_args[0]['M']*ratio_space)**0.5)

        target_vs_nontargets_mindist_ratio = result_emfitallitems[..., 1]/np.sum(result_emfitallitems[..., 1:4], axis=-1)
        nontargetsmean_vs_targnontarg_mindist_ratio = np.mean(result_emfitallitems[..., 2:4]/np.sum(result_emfitallitems[..., 1:4], axis=-1)[..., np.newaxis], axis=-1)

        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            # Do one plot per ratio, putting the receptive field size on each
            f, ax = plt.subplots()

            ax.plot(enforce_min_distance_space[1:], target_vs_nontargets_mindist_ratio[1:, ratio_conj_i], linewidth=3, label='target mixture')
            ax.plot(enforce_min_distance_space[1:], nontargetsmean_vs_targnontarg_mindist_ratio[1:, ratio_conj_i], linewidth=3, label='non-target mixture')
            # ax.plot(enforce_min_distance_space[1:], result_emfitallitems[1:, ratio_conj_i, 1:5], linewidth=3)

            ax.axvline(x=conj_receptive_field_size[ratio_conj_i]/2., color='k', linestyle='--', linewidth=2)
            ax.axvline(x=conj_receptive_field_size[ratio_conj_i]*2., color='r', linestyle='--', linewidth=2)

            plt.legend(loc='upper left')
            plt.grid()
            # ax.set_xlabel('Stimuli separation')
            # ax.set_ylabel('Ratio Target to Non-targets')
            plt.axis('tight')
            ax.set_ylim([0.0, 1.0])
            ax.set_xlim([enforce_min_distance_space[1:].min(), enforce_min_distance_space[1:].max()])

            if savefigs:
                dataio.save_current_figure('ratio%.2f_mindistpred_ratiotargetnontarget_{label}_{unique_id}.pdf' % ratio_conj)



    variables_to_save = ['nb_repetitions']

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)
        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='specific_stimuli')

    plt.show()

    return locals()
def plots_specific_stimuli_mixed(data_pbs, generator_module=None):
    '''
        Reload and plot behaviour of mixed population code on specific Stimuli
        of 3 items.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plot_per_min_dist_all = False
    specific_plots_paper = False
    plots_emfit_allitems = False
    plot_min_distance_effect = True

    compute_bootstraps = False

    should_fit_allitems_model = True
    # caching_emfit_filename = None
    mixturemodel_to_use = 'allitems_uniquekappa'
    # caching_emfit_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_emfitallitems_uniquekappa.pickle')
    # mixturemodel_to_use = 'allitems_fikappa'

    caching_emfit_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_emfit%s.pickle' % mixturemodel_to_use)

    compute_fisher_info_perratioconj = True
    caching_fisherinfo_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_fisherinfo.pickle')

    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    all_args = data_pbs.loaded_data['args_list']
    result_all_precisions_mean = utils.nanmean(np.squeeze(data_pbs.dict_arrays['result_all_precisions']['results']), axis=-1)
    result_all_precisions_std = utils.nanstd(np.squeeze(data_pbs.dict_arrays['result_all_precisions']['results']), axis=-1)
    result_em_fits_mean = utils.nanmean(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results']), axis=-1)
    result_em_fits_std = utils.nanstd(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results']), axis=-1)
    result_em_kappastddev_mean = utils.nanmean(utils.kappa_to_stddev(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results'])[..., 0, :]), axis=-1)
    result_em_kappastddev_std = utils.nanstd(utils.kappa_to_stddev(np.squeeze(data_pbs.dict_arrays['result_em_fits']['results'])[..., 0, :]), axis=-1)
    result_responses_all = np.squeeze(data_pbs.dict_arrays['result_responses']['results'])
    result_target_all = np.squeeze(data_pbs.dict_arrays['result_target']['results'])
    result_nontargets_all = np.squeeze(data_pbs.dict_arrays['result_nontargets']['results'])

    nb_repetitions = result_responses_all.shape[-1]
    K = result_nontargets_all.shape[-2]
    N = result_responses_all.shape[-2]

    enforce_min_distance_space = data_pbs.loaded_data['parameters_uniques']['enforce_min_distance']
    sigmax_space = data_pbs.loaded_data['parameters_uniques']['sigmax']
    ratio_space = data_pbs.loaded_data['datasets_list'][0]['ratio_space']

    print enforce_min_distance_space
    print sigmax_space
    print ratio_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape
    print result_responses_all.shape

    dataio = DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    # Reload cached emfitallitems
    if caching_emfit_filename is not None:
        if os.path.exists(caching_emfit_filename):
            # Got file, open it and try to use its contents
            try:
                with open(caching_emfit_filename, 'r') as file_in:
                    # Load and assign values
                    print "Reloader EM fits from cache", caching_emfit_filename
                    cached_data = pickle.load(file_in)
                    result_emfitallitems = cached_data['result_emfitallitems']
                    mixturemodel_used = cached_data.get('mixturemodel_used', '')

                    if mixturemodel_used != mixturemodel_to_use:
                        print "warning, reloaded model used a different mixture model class"
                    should_fit_allitems_model = False

            except IOError:
                print "Error while loading ", caching_emfit_filename, "falling back to computing the EM fits"


    # Load the Fisher Info from cache if exists. If not, compute it.
    if caching_fisherinfo_filename is not None:
        if os.path.exists(caching_fisherinfo_filename):
            # Got file, open it and try to use its contents
            try:
                with open(caching_fisherinfo_filename, 'r') as file_in:
                    # Load and assign values
                    cached_data = pickle.load(file_in)
                    result_fisherinfo_mindist_sigmax_ratio = cached_data['result_fisherinfo_mindist_sigmax_ratio']
                    compute_fisher_info_perratioconj = False

            except IOError:
                print "Error while loading ", caching_fisherinfo_filename, "falling back to computing the Fisher Info"

    if compute_fisher_info_perratioconj:
        # We did not save the Fisher info, but need it if we want to fit the mixture model with fixed kappa. So recompute them using the args_dicts

        result_fisherinfo_mindist_sigmax_ratio = np.empty((enforce_min_distance_space.size, sigmax_space.size, ratio_space.size))

        # Invert the all_args_i -> min_dist, sigmax indexing
        parameters_indirections = data_pbs.loaded_data['parameters_dataset_index']

        # min_dist_i, sigmax_level_i, ratio_i
        for min_dist_i, min_dist in enumerate(enforce_min_distance_space):
            for sigmax_i, sigmax in enumerate(sigmax_space):
                # Get index of first dataset with the current (min_dist, sigmax) (no need for the others, I think)
                arg_index = parameters_indirections[(min_dist, sigmax)][0]

                # Now using this dataset, reconstruct a RandomFactorialNetwork and compute the fisher info
                curr_args = all_args[arg_index]

                for ratio_conj_i, ratio_conj in enumerate(ratio_space):
                    # Update param
                    curr_args['ratio_conj'] = ratio_conj
                    # curr_args['stimuli_generation'] = 'specific_stimuli'

                    (_, _, _, sampler) = launchers.init_everything(curr_args)

                    # Theo Fisher info
                    result_fisherinfo_mindist_sigmax_ratio[min_dist_i, sigmax_i, ratio_conj_i] = sampler.estimate_fisher_info_theocov()

                    print "Min dist: %.2f, Sigmax: %.2f, Ratio: %.2f: %.3f" % (min_dist, sigmax, ratio_conj, result_fisherinfo_mindist_sigmax_ratio[min_dist_i, sigmax_i, ratio_conj_i])


        # Save everything to a file, for faster later plotting
        if caching_fisherinfo_filename is not None:
            try:
                with open(caching_fisherinfo_filename, 'w') as filecache_out:
                    data_cache = dict(result_fisherinfo_mindist_sigmax_ratio=result_fisherinfo_mindist_sigmax_ratio)
                    pickle.dump(data_cache, filecache_out, protocol=2)
            except IOError:
                print "Error writing out to caching file ", caching_fisherinfo_filename


    if plot_per_min_dist_all:
        # Do one plot per min distance.
        for min_dist_i, min_dist in enumerate(enforce_min_distance_space):
            # Show log precision
            utils.pcolor_2d_data(result_all_precisions_mean[min_dist_i].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Precision, min_dist=%.3f' % min_dist)
            if savefigs:
                dataio.save_current_figure('precision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Show log precision
            utils.pcolor_2d_data(result_all_precisions_mean[min_dist_i].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Precision, min_dist=%.3f' % min_dist, log_scale=True)
            if savefigs:
                dataio.save_current_figure('logprecision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)


            # Plot estimated model precision
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 0].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='EM precision, min_dist=%.3f' % min_dist, log_scale=False)
            if savefigs:
                dataio.save_current_figure('logemprecision_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Plot estimated Target, nontarget and random mixture components, in multiple subplots
            _, axes = plt.subplots(1, 3, figsize=(18, 6))
            plt.subplots_adjust(left=0.05, right=0.97, wspace = 0.3, bottom=0.15)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 1].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Target, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[0], ticks_interpolate=5)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 2].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Nontarget, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[1], ticks_interpolate=5)
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., 3].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='Random, min_dist=%.3f' % min_dist, log_scale=False, ax_handle=axes[2], ticks_interpolate=5)

            if savefigs:
                dataio.save_current_figure('em_mixtureprobs_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

            # Plot Log-likelihood of Mixture model, sanity check
            utils.pcolor_2d_data(result_em_fits_mean[min_dist_i, ..., -1].T, x=ratio_space, y=sigmax_space, xlabel='ratio', ylabel='sigma_x', title='EM loglik, min_dist=%.3f' % min_dist, log_scale=False)
            if savefigs:
                dataio.save_current_figure('em_loglik_permindist_mindist%.2f_ratiosigmax_{label}_{unique_id}.pdf' % min_dist)

    if specific_plots_paper:
        # We need to choose 3 levels of min_distances
        target_sigmax = 0.25
        target_mindist_low = 0.15
        target_mindist_medium = 0.36
        target_mindist_high = 1.5

        sigmax_level_i = np.argmin(np.abs(sigmax_space - target_sigmax))
        min_dist_level_low_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_low))
        min_dist_level_medium_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_medium))
        min_dist_level_high_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_high))

        ## Do for each distance
        # for min_dist_i in [min_dist_level_low_i, min_dist_level_medium_i, min_dist_level_high_i]:
        for min_dist_i in xrange(enforce_min_distance_space.size):
            # Plot precision
            if False:
                utils.plot_mean_std_area(ratio_space, result_all_precisions_mean[min_dist_i, sigmax_level_i], result_all_precisions_std[min_dist_i, sigmax_level_i]) #, xlabel='Ratio conjunctivity', ylabel='Precision of recall')
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                plt.ylim([0, np.max(result_all_precisions_mean[min_dist_i, sigmax_level_i] + result_all_precisions_std[min_dist_i, sigmax_level_i])])

                if savefigs:
                    dataio.save_current_figure('mindist%.2f_precisionrecall_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

            # Plot kappa fitted
            ax_handle = utils.plot_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 0], result_em_fits_std[min_dist_i, sigmax_level_i, :, 0]) #, xlabel='Ratio conjunctivity', ylabel='Fitted kappa')
            # Add distance between items in kappa units
            dist_items_kappa = utils.stddev_to_kappa(enforce_min_distance_space[min_dist_i])
            ax_handle.plot(ratio_space, dist_items_kappa*np.ones(ratio_space.size), 'k--', linewidth=3)
            plt.ylim([-0.1, np.max((np.max(result_em_fits_mean[min_dist_i, sigmax_level_i, :, 0] + result_em_fits_std[min_dist_i, sigmax_level_i, :, 0]), 1.1*dist_items_kappa))])
            # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
            if savefigs:
                dataio.save_current_figure('mindist%.2f_emkappa_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

            # Plot kappa-stddev fitted. Easier to visualize
            ax_handle = utils.plot_mean_std_area(ratio_space, result_em_kappastddev_mean[min_dist_i, sigmax_level_i], result_em_kappastddev_std[min_dist_i, sigmax_level_i]) #, xlabel='Ratio conjunctivity', ylabel='Fitted kappa_stddev')
            # Add distance between items in std dev units
            dist_items_std = (enforce_min_distance_space[min_dist_i])
            ax_handle.plot(ratio_space, dist_items_std*np.ones(ratio_space.size), 'k--', linewidth=3)
            # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
            plt.ylim([0, 1.1*np.max((np.max(result_em_kappastddev_mean[min_dist_i, sigmax_level_i] + result_em_kappastddev_std[min_dist_i, sigmax_level_i]), dist_items_std))])
            if savefigs:
                dataio.save_current_figure('mindist%.2f_emkappastddev_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])


            if False:
                # Plot LLH
                utils.plot_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, -1], result_em_fits_std[min_dist_i, sigmax_level_i, :, -1]) #, xlabel='Ratio conjunctivity', ylabel='Loglikelihood of Mixture model fit')
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emllh_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

                # Plot mixture parameters, std
                utils.plot_multiple_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 1:4].T, result_em_fits_std[min_dist_i, sigmax_level_i, :, 1:4].T)
                plt.ylim([0.0, 1.1])
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                # plt.legend("Target", "Non-target", "Random")
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emprobs_forpaper_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

                # Mixture parameters, SEM
                utils.plot_multiple_mean_std_area(ratio_space, result_em_fits_mean[min_dist_i, sigmax_level_i, :, 1:4].T, result_em_fits_std[min_dist_i, sigmax_level_i, :, 1:4].T/np.sqrt(nb_repetitions))
                plt.ylim([0.0, 1.1])
                # plt.title('Min distance %.3f' % enforce_min_distance_space[min_dist_i])
                # plt.legend("Target", "Non-target", "Random")
                if savefigs:
                    dataio.save_current_figure('mindist%.2f_emprobs_forpaper_sem_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

    if plots_emfit_allitems:
        # We need to choose 3 levels of min_distances
        target_sigmax = 0.25
        target_mindist_low = 0.15
        target_mindist_medium = 0.36
        target_mindist_high = 1.5

        sigmax_level_i = np.argmin(np.abs(sigmax_space - target_sigmax))
        min_dist_level_low_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_low))
        min_dist_level_medium_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_medium))
        min_dist_level_high_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_high))

        min_dist_i_plotting_space = np.array([min_dist_level_low_i, min_dist_level_medium_i, min_dist_level_high_i])

        if should_fit_allitems_model:

            # kappa, mixt_target, mixt_nontargets (K), mixt_random, LL, bic
            # result_emfitallitems = np.empty((min_dist_i_plotting_space.size, ratio_space.size, 2*K+5))*np.nan
            result_emfitallitems = np.empty((enforce_min_distance_space.size, ratio_space.size, K+5))*np.nan

            ## Do for each distance
            # for min_dist_plotting_i, min_dist_i in enumerate(min_dist_i_plotting_space):
            for min_dist_i in xrange(enforce_min_distance_space.size):
                # Fit the mixture model
                for ratio_i, ratio in enumerate(ratio_space):
                    print "Refitting EM all items. Ratio:", ratio, "Dist:", enforce_min_distance_space[min_dist_i]

                    if mixturemodel_to_use == 'allitems_uniquekappa':
                        em_fit = em_circularmixture_allitems_uniquekappa.fit(
                            result_responses_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                            result_target_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                            result_nontargets_all[min_dist_i, sigmax_level_i, ratio_i].transpose((0, 2, 1)).reshape((N*nb_repetitions, K)))
                    elif mixturemodel_to_use == 'allitems_fikappa':
                        em_fit = em_circularmixture_allitems_kappafi.fit(result_responses_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                            result_target_all[min_dist_i, sigmax_level_i, ratio_i].flatten(),
                            result_nontargets_all[min_dist_i, sigmax_level_i, ratio_i].transpose((0, 2, 1)).reshape((N*nb_repetitions, K)),
                            kappa=result_fisherinfo_mindist_sigmax_ratio[min_dist_i, sigmax_level_i, ratio_i])
                    else:
                        raise ValueError("Wrong mixturemodel_to_use, %s" % mixturemodel_to_use)

                    result_emfitallitems[min_dist_i, ratio_i] = [em_fit['kappa'], em_fit['mixt_target']] + em_fit['mixt_nontargets'].tolist() + [em_fit[key] for key in ('mixt_random', 'train_LL', 'bic')]

            # Save everything to a file, for faster later plotting
            if caching_emfit_filename is not None:
                try:
                    with open(caching_emfit_filename, 'w') as filecache_out:
                        data_em = dict(result_emfitallitems=result_emfitallitems, target_sigmax=target_sigmax)
                        pickle.dump(data_em, filecache_out, protocol=2)
                except IOError:
                    print "Error writing out to caching file ", caching_emfit_filename


        ## Plots now, for each distance!
        # for min_dist_plotting_i, min_dist_i in enumerate(min_dist_i_plotting_space):
        for min_dist_i in xrange(enforce_min_distance_space.size):

            # Plot now
            _, ax = plt.subplots()
            ax.plot(ratio_space, result_emfitallitems[min_dist_i, :, 1:5], linewidth=3)
            plt.ylim([0.0, 1.1])
            plt.legend(['Target', 'Nontarget 1', 'Nontarget 2', 'Random'], loc='upper left')

            if savefigs:
                dataio.save_current_figure('mindist%.2f_emprobsfullitems_{label}_{unique_id}.pdf' % enforce_min_distance_space[min_dist_i])

    if plot_min_distance_effect:
        conj_receptive_field_size = 2.*np.pi/((all_args[0]['M']*ratio_space)**0.5)

        target_vs_nontargets_mindist_ratio = result_emfitallitems[..., 1]/np.sum(result_emfitallitems[..., 1:4], axis=-1)
        nontargetsmean_vs_targnontarg_mindist_ratio = np.mean(result_emfitallitems[..., 2:4]/np.sum(result_emfitallitems[..., 1:4], axis=-1)[..., np.newaxis], axis=-1)

        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            # Do one plot per ratio, putting the receptive field size on each
            f, ax = plt.subplots()

            ax.plot(enforce_min_distance_space[1:], target_vs_nontargets_mindist_ratio[1:, ratio_conj_i], linewidth=3, label='target mixture')
            ax.plot(enforce_min_distance_space[1:], nontargetsmean_vs_targnontarg_mindist_ratio[1:, ratio_conj_i], linewidth=3, label='non-target mixture')
            # ax.plot(enforce_min_distance_space[1:], result_emfitallitems[1:, ratio_conj_i, 1:5], linewidth=3)

            ax.axvline(x=conj_receptive_field_size[ratio_conj_i]/2., color='k', linestyle='--', linewidth=2)
            ax.axvline(x=conj_receptive_field_size[ratio_conj_i]*2., color='r', linestyle='--', linewidth=2)

            plt.legend(loc='upper left')
            plt.grid()
            # ax.set_xlabel('Stimuli separation')
            # ax.set_ylabel('Ratio Target to Non-targets')
            plt.axis('tight')
            ax.set_ylim([0.0, 1.0])
            ax.set_xlim([enforce_min_distance_space[1:].min(), enforce_min_distance_space[1:].max()])

            if savefigs:
                dataio.save_current_figure('ratio%.2f_mindistpred_ratiotargetnontarget_{label}_{unique_id}.pdf' % ratio_conj)


    if compute_bootstraps:
        ## Bootstrap evaluation

        # We need to choose 3 levels of min_distances
        target_sigmax = 0.25
        target_mindist_low = 0.15
        target_mindist_medium = 0.5
        target_mindist_high = 1.

        sigmax_level_i = np.argmin(np.abs(sigmax_space - target_sigmax))
        min_dist_level_low_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_low))
        min_dist_level_medium_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_medium))
        min_dist_level_high_i = np.argmin(np.abs(enforce_min_distance_space - target_mindist_high))

        # cache_bootstrap_fn = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'outputs', 'cache_bootstrap.pickle')
        cache_bootstrap_fn = '/Users/loicmatthey/Dropbox/UCL/1-phd/Work/Visual_working_memory/code/git-bayesian-visual-working-memory/Experiments/specific_stimuli/specific_stimuli_corrected_mixed_sigmaxmindistance_autoset_repetitions5mult_collectall_281113_outputs/cache_bootstrap.pickle'
        try:
            with open(cache_bootstrap_fn, 'r') as file_in:
                # Load and assign values
                cached_data = pickle.load(file_in)
                bootstrap_ecdf_bays_sigmax_T = cached_data['bootstrap_ecdf_bays_sigmax_T']
                bootstrap_ecdf_allitems_sum_sigmax_T = cached_data['bootstrap_ecdf_allitems_sum_sigmax_T']
                bootstrap_ecdf_allitems_all_sigmax_T = cached_data['bootstrap_ecdf_allitems_all_sigmax_T']
                should_fit_bootstrap = False

        except IOError:
            print "Error while loading ", cache_bootstrap_fn

        ratio_i = 0

        # bootstrap_allitems_nontargets_allitems_uniquekappa = em_circularmixture_allitems_uniquekappa.bootstrap_nontarget_stat(
        # result_responses_all[min_dist_level_low_i, sigmax_level_i, ratio_i].flatten(),
        # result_target_all[min_dist_level_low_i, sigmax_level_i, ratio_i].flatten(),
        # result_nontargets_all[min_dist_level_low_i, sigmax_level_i, ratio_i].transpose((0, 2, 1)).reshape((N*nb_repetitions, K)),
        # sumnontargets_bootstrap_ecdf=bootstrap_ecdf_allitems_sum_sigmax_T[sigmax_level_i][K]['ecdf'],
        # allnontargets_bootstrap_ecdf=bootstrap_ecdf_allitems_all_sigmax_T[sigmax_level_i][K]['ecdf']

        # TODO FINISH HERE

    variables_to_save = ['nb_repetitions']

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)

        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='specific_stimuli')


    plt.show()


    return locals()
def plots_ratioMscaling(data_pbs, generator_module=None):
    '''
        Reload and plot precision/fits of a Mixed code.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plots_pcolor_all = False
    plots_effect_M_target_precision = False
    plots_multiple_precisions = False

    plots_effect_M_target_kappa = False

    plots_subpopulations_effects = False

    plots_subpopulations_effects_kappa_fi = True
    compute_fisher_info_perratioconj = True
    caching_fisherinfo_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_fisherinfo.pickle')

    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_precisions_mean = (utils.nanmean(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_all_precisions_std = (utils.nanstd(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_em_fits_mean = (utils.nanmean(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_em_fits_std = (utils.nanstd(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))

    all_args = data_pbs.loaded_data['args_list']

    result_em_fits_kappa = result_em_fits_mean[..., 0]

    M_space = data_pbs.loaded_data['parameters_uniques']['M'].astype(int)
    ratio_space = data_pbs.loaded_data['parameters_uniques']['ratio_conj']
    num_repetitions = generator_module.num_repetitions

    print M_space
    print ratio_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape

    dataio = DataIO.DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    target_precision = 100.
    dist_to_target_precision = (result_all_precisions_mean - target_precision)**2.
    best_dist_to_target_precision = np.argmin(dist_to_target_precision, axis=1)
    MAX_DISTANCE = 100.

    ratio_target_precision_given_M = np.ma.masked_where(dist_to_target_precision[np.arange(dist_to_target_precision.shape[0]), best_dist_to_target_precision] > MAX_DISTANCE, ratio_space[best_dist_to_target_precision])

    if plots_pcolor_all:
        # Check evolution of precision given M and ratio
        utils.pcolor_2d_data(result_all_precisions_mean, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='precision wrt M / ratio')
        if savefigs:
            dataio.save_current_figure('precision_log_pcolor_{label}_{unique_id}.pdf')

        # See distance to target precision evolution
        utils.pcolor_2d_data(dist_to_target_precision, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='Dist to target precision %d' % target_precision)
        if savefigs:
            dataio.save_current_figure('dist_targetprecision_log_pcolor_{label}_{unique_id}.pdf')


        # Show kappa
        utils.pcolor_2d_data(result_em_fits_kappa, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='kappa wrt M / ratio')
        if savefigs:
            dataio.save_current_figure('kappa_log_pcolor_{label}_{unique_id}.pdf')

        utils.pcolor_2d_data((result_em_fits_kappa - 200)**2., log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='dist to kappa')
        if savefigs:
            dataio.save_current_figure('dist_kappa_log_pcolor_{label}_{unique_id}.pdf')


    if plots_effect_M_target_precision:
        def plot_ratio_target_precision(ratio_target_precision_given_M, target_precision):
            f, ax = plt.subplots()
            ax.plot(M_space, ratio_target_precision_given_M)
            ax.set_xlabel('M')
            ax.set_ylabel('Optimal ratio')
            ax.set_title('Optimal Ratio for precison %d' % target_precision)

            if savefigs:
                dataio.save_current_figure('effect_ratio_M_targetprecision%d_{label}_{unique_id}.pdf' % target_precision)

        plot_ratio_target_precision(ratio_target_precision_given_M, target_precision)

        if plots_multiple_precisions:
            target_precisions = np.array([100, 200, 300, 500, 1000])
            for target_precision in target_precisions:
                dist_to_target_precision = (result_all_precisions_mean - target_precision)**2.
                best_dist_to_target_precision = np.argmin(dist_to_target_precision, axis=1)
                ratio_target_precision_given_M = np.ma.masked_where(dist_to_target_precision[np.arange(dist_to_target_precision.shape[0]), best_dist_to_target_precision] > MAX_DISTANCE, ratio_space[best_dist_to_target_precision])

                # replot
                plot_ratio_target_precision(ratio_target_precision_given_M, target_precision)

    if plots_effect_M_target_kappa:
        def plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa):
            f, ax = plt.subplots()
            ax.plot(M_space, ratio_target_kappa_given_M)
            ax.set_xlabel('M')
            ax.set_ylabel('Optimal ratio')
            ax.set_title('Optimal Ratio for precison %d' % target_kappa)

            if savefigs:
                dataio.save_current_figure('effect_ratio_M_targetkappa%d_{label}_{unique_id}.pdf' % target_kappa)

        target_kappa = np.array([100, 200, 300, 500, 1000, 3000])
        for target_kappa in target_kappa:
            dist_to_target_kappa = (result_em_fits_kappa - target_kappa)**2.
            best_dist_to_target_kappa = np.argmin(dist_to_target_kappa, axis=1)
            ratio_target_kappa_given_M = np.ma.masked_where(dist_to_target_kappa[np.arange(dist_to_target_kappa.shape[0]), best_dist_to_target_kappa] > MAX_DISTANCE, ratio_space[best_dist_to_target_kappa])

            # replot
            plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa)

    if plots_subpopulations_effects:
        # result_all_precisions_mean
        for M_tot_selected_i, M_tot_selected in enumerate(M_space[::2]):

            M_conj_space = ((1.-ratio_space)*M_tot_selected).astype(int)
            M_feat_space = M_tot_selected - M_conj_space

            f, axes = plt.subplots(2, 2)
            axes[0, 0].plot(ratio_space, result_all_precisions_mean[2*M_tot_selected_i])
            axes[0, 0].set_xlabel('ratio')
            axes[0, 0].set_title('Measured precision')

            axes[1, 0].plot(ratio_space, M_conj_space**2*M_feat_space)
            axes[1, 0].set_xlabel('M_feat_size')
            axes[1, 0].set_title('M_c**2*M_f')

            axes[0, 1].plot(ratio_space, M_conj_space**2.)
            axes[0, 1].set_xlabel('M')
            axes[0, 1].set_title('M_c**2')

            axes[1, 1].plot(ratio_space, M_feat_space)
            axes[1, 1].set_xlabel('M')
            axes[1, 1].set_title('M_f')

            f.suptitle('M_tot %d' % M_tot_selected, fontsize=15)
            f.set_tight_layout(True)

            if savefigs:
                dataio.save_current_figure('scaling_precision_subpop_Mtot%d_{label}_{unique_id}.pdf' % M_tot_selected)

            plt.close(f)

    if plots_subpopulations_effects_kappa_fi:
        # From cache
        if caching_fisherinfo_filename is not None:
            if os.path.exists(caching_fisherinfo_filename):
                # Got file, open it and try to use its contents
                try:
                    with open(caching_fisherinfo_filename, 'r') as file_in:
                        # Load and assign values
                        cached_data = pickle.load(file_in)
                        result_fisherinfo_Mratio = cached_data['result_fisherinfo_Mratio']
                        compute_fisher_info_perratioconj = False

                except IOError:
                    print "Error while loading ", caching_fisherinfo_filename, "falling back to computing the Fisher Info"

        if compute_fisher_info_perratioconj:
            # We did not save the Fisher info, but need it if we want to fit the mixture model with fixed kappa. So recompute them using the args_dicts

            result_fisherinfo_Mratio = np.empty((M_space.size, ratio_space.size))

            # Invert the all_args_i -> M, ratio_conj direction
            parameters_indirections = data_pbs.loaded_data['parameters_dataset_index']

            for M_i, M in enumerate(M_space):
                for ratio_conj_i, ratio_conj in enumerate(ratio_space):
                    # Get index of first dataset with the current ratio_conj (no need for the others, I think)
                    try:
                        arg_index = parameters_indirections[(M, ratio_conj)][0]

                        # Now using this dataset, reconstruct a RandomFactorialNetwork and compute the fisher info
                        curr_args = all_args[arg_index]

                        # curr_args['stimuli_generation'] = lambda T: np.linspace(-np.pi*0.6, np.pi*0.6, T)

                        (_, _, _, sampler) = launchers.init_everything(curr_args)

                        # Theo Fisher info
                        result_fisherinfo_Mratio[M_i, ratio_conj_i] = sampler.estimate_fisher_info_theocov()

                        # del curr_args['stimuli_generation']
                    except KeyError:
                        result_fisherinfo_Mratio[M_i, ratio_conj_i] = np.nan


            # Save everything to a file, for faster later plotting
            if caching_fisherinfo_filename is not None:
                try:
                    with open(caching_fisherinfo_filename, 'w') as filecache_out:
                        data_cache = dict(result_fisherinfo_Mratio=result_fisherinfo_Mratio)
                        pickle.dump(data_cache, filecache_out, protocol=2)
                except IOError:
                    print "Error writing out to caching file ", caching_fisherinfo_filename

        # result_em_fits_kappa
        if False:

            for M_tot_selected_i, M_tot_selected in enumerate(M_space[::2]):

                M_conj_space = ((1.-ratio_space)*M_tot_selected).astype(int)
                M_feat_space = M_tot_selected - M_conj_space

                f, axes = plt.subplots(2, 2)
                axes[0, 0].plot(ratio_space, result_em_fits_kappa[2*M_tot_selected_i])
                axes[0, 0].set_xlabel('ratio')
                axes[0, 0].set_title('Fitted kappa')

                axes[1, 0].plot(ratio_space, utils.stddev_to_kappa(1./result_fisherinfo_Mratio[2*M_tot_selected_i]**0.5))
                axes[1, 0].set_xlabel('M_feat_size')
                axes[1, 0].set_title('kappa_FI_mixed')

                f.suptitle('M_tot %d' % M_tot_selected, fontsize=15)
                f.set_tight_layout(True)

                if savefigs:
                    dataio.save_current_figure('scaling_kappa_subpop_Mtot%d_{label}_{unique_id}.pdf' % M_tot_selected)

                plt.close(f)

        utils.pcolor_2d_data((result_fisherinfo_Mratio- 2000)**2., log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='Fisher info')
        if savefigs:
            dataio.save_current_figure('dist2000_fi_log_pcolor_{label}_{unique_id}.pdf')




    all_args = data_pbs.loaded_data['args_list']
    variables_to_save = []

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)

        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='ratio_scaling_M')

    plt.show()

    return locals()
def receptivesize_effect_plots(variables_launcher_running, plotting_parameters):
    '''
        Do some plots (possibly live) with outputs from launcher_do_receptivesize_effect

    '''

    ### Load the experimental data for the plots
    # data_gorgo11 = load_experimental_data.load_data_gorgo11(fit_mixture_model=True)
    # gorgo11_T_space = data_gorgo11['data_to_fit']['n_items']
    # gorgo11_emfits_meanstd = data_gorgo11['em_fits_nitems_arrays']
    # gorgo11_emfits_meanstd['mean'][2, 0] = 0.0
    # gorgo11_emfits_meanstd['std'][2, 0] = 0.0

    # data_bays09 = load_experimental_data.load_data_bays09(fit_mixture_model=True)
    # bays09_T_space = data_bays09['data_to_fit']['n_items']
    # bays09_emfits_meanstd = data_bays09['em_fits_nitems_arrays']

    # Compute the "optimal" rcscale
    if variables_launcher_running['all_parameters']['code_type'] == 'conj':
        optimal_scale = utils.stddev_to_kappa(2.*np.pi/int(variables_launcher_running['all_parameters']['M']**0.5))
        optimal_scale_corrected = utils.stddev_to_kappa(np.pi/(int(variables_launcher_running['all_parameters']['M']**0.5)))
    elif variables_launcher_running['all_parameters']['code_type'] == 'feat':
        optimal_scale = utils.stddev_to_kappa(2.*np.pi/int(variables_launcher_running['all_parameters']['M']/2.))
        optimal_scale_corrected = utils.stddev_to_kappa(np.pi/int(variables_launcher_running['all_parameters']['M']/2.))
    else:
        optimal_scale = 0.0

    ### Now do the plots
    current_axes = plotting_parameters['axes']
    dataio = variables_launcher_running['dataio']
    rcscale_space = variables_launcher_running['rcscale_space']

    result_precision_stats = utils.nanstats(variables_launcher_running['result_all_precisions'], axis=-1)
    result_em_fits_stats = utils.nanstats(variables_launcher_running['result_em_fits'], axis=-1)
    result_marginal_fi_stats = utils.nanstats(variables_launcher_running['result_marginal_inv_fi'][:, 2], axis=-1)

    plt.ion()

    # Precision wrt rcscale_space
    def plot_precision_rcscale(ax=None):
        if ax is not None:
            plt.figure(ax.get_figure().number)
            ax.hold(False)

        # Curve of precision evolution.
        ax = utils.plot_mean_std_area(rcscale_space, result_precision_stats['mean'], result_precision_stats['std'], linewidth=3, fmt='o-', markersize=8, label='Precision', ax_handle=ax)

        ax.hold(True)

        ax.axvline(x=optimal_scale, color='r', linewidth=3)
        ax.axvline(x=optimal_scale_corrected, color='k', linewidth=3)

        ax.legend()
        ax.set_title("Precision {code_type} {M} {sigmax:.3f} {sigmay:.2f}".format(**variables_launcher_running['all_parameters']))


        ax.set_xlim(rcscale_space.min(), rcscale_space.max())
        ax.set_ylim(bottom=0.0)

        ax.get_figure().canvas.draw()

        dataio.save_current_figure('precision_rcscale_{code_type}_M{M}_sigmax{sigmax}_sigmay{sigmay}_{{label}}_{{unique_id}}.pdf'.format(**variables_launcher_running['all_parameters']))

        return ax

    # Precision wrt rcscale_space
    def plot_fisherinfo_rcscale(ax=None):
        if ax is not None:
            plt.figure(ax.get_figure().number)
            ax.hold(False)

        # Curve of precision evolution.
        ax = utils.plot_mean_std_area(rcscale_space, result_marginal_fi_stats['mean'], result_marginal_fi_stats['std'], linewidth=3, fmt='o-', markersize=8, label='Fisher info', ax_handle=ax)

        ax.hold(True)

        ax.axvline(x=optimal_scale, color='r', linewidth=3)
        ax.axvline(x=optimal_scale_corrected, color='k', linewidth=3)

        ax.legend()
        ax.set_title("FI {code_type} {M} {sigmax:.3f} {sigmay:.2f}".format(**variables_launcher_running['all_parameters']))

        ax.set_xlim(rcscale_space.min(), rcscale_space.max())
        ax.set_ylim(bottom=0.0)
        ax.get_figure().canvas.draw()

        dataio.save_current_figure('fi_rcscale_{code_type}_M{M}_sigmax{sigmax}_sigmay{sigmay}_{{label}}_{{unique_id}}.pdf'.format(**variables_launcher_running['all_parameters']))

        return ax

    # Memory curve kappa
    def plot_kappa_rcscale(ax=None):

        if ax is not None:
            plt.figure(ax.get_figure().number)
            ax.hold(False)

        ax = utils.plot_mean_std_area(rcscale_space, result_em_fits_stats['mean'][..., 0], result_em_fits_stats['std'][..., 0], linewidth=3, fmt='o-', markersize=8, label='Memory error $[rad^{-2}]$', ax_handle=ax)

        ax.hold(True)

        ax.axvline(x=optimal_scale, color='r', linewidth=3)
        ax.axvline(x=optimal_scale_corrected, color='k', linewidth=3)

        ax.legend()
        ax.set_xlim(rcscale_space.min(), rcscale_space.max())
        ax.set_ylim(bottom=0.0)

        ax.get_figure().canvas.draw()

        ax.set_title("kappa {code_type} {M} {sigmax:.3f} {sigmay:.2f}".format(**variables_launcher_running['all_parameters']))

        dataio.save_current_figure('kappa_rcscale_{code_type}_M{M}_sigmax{sigmax}_sigmay{sigmay}_{{label}}_{{unique_id}}.pdf'.format(**variables_launcher_running['all_parameters']))

        return ax

    # Plot EM Mixtures proportions
    def plot_mixtures_rcscale(ax=None):

        if ax is None:
            _, ax = plt.subplots()

        if ax is not None:
            plt.figure(ax.get_figure().number)
            ax.hold(False)

        result_em_fits_stats['mean'][np.isnan(result_em_fits_stats['mean'])] = 0.0
        result_em_fits_stats['std'][np.isnan(result_em_fits_stats['std'])] = 0.0

        utils.plot_mean_std_area(rcscale_space, result_em_fits_stats['mean'][..., 1], result_em_fits_stats['std'][..., 1], xlabel='Number of items', ylabel="Mixture probabilities", ax_handle=ax, linewidth=3, fmt='o-', markersize=5, label='Target')
        ax.hold(True)
        utils.plot_mean_std_area(rcscale_space, result_em_fits_stats['mean'][..., 2], result_em_fits_stats['std'][..., 2], xlabel='Number of items', ylabel="Mixture probabilities", ax_handle=ax, linewidth=3, fmt='o-', markersize=5, label='Nontarget')
        utils.plot_mean_std_area(rcscale_space, result_em_fits_stats['mean'][..., 3], result_em_fits_stats['std'][..., 3], xlabel='Number of items', ylabel="Mixture probabilities", ax_handle=ax, linewidth=3, fmt='o-', markersize=5, label='Random')

        ax.axvline(x=optimal_scale, color='r', linewidth=3)
        ax.axvline(x=optimal_scale_corrected, color='k', linewidth=3)

        ax.set_xlim(rcscale_space.min(), rcscale_space.max())
        ax.set_ylim(bottom=0.0, top=1.0)

        ax.legend(prop={'size':15})
        ax.set_title("mixts {code_type} {M} {sigmax:.3f} {sigmay:.2f}".format(**variables_launcher_running['all_parameters']))
        ax.get_figure().canvas.draw()

        dataio.save_current_figure('em_mixts_rcscale_{code_type}_M{M}_sigmax{sigmax}_sigmay{sigmay}_{{label}}_{{unique_id}}.pdf'.format(**variables_launcher_running['all_parameters']))

        return ax

    # Do all plots for all datasets
    current_axes['precision_rcscale'] = plot_precision_rcscale(ax=current_axes.get('precision_rcscale', None))
    current_axes['fisherinfo_rcscale'] = plot_fisherinfo_rcscale(ax=current_axes.get('fisherinfo_rcscale', None))
    current_axes['kappa_rcscale'] = plot_kappa_rcscale(ax=current_axes.get('kappa_rcscale', None))
    current_axes['mixtures_rcscale'] = plot_mixtures_rcscale(ax=current_axes.get('mixtures_rcscale', None))
def plots_misbinding_logposterior(data_pbs, generator_module=None):
    '''
        Reload 3D volume runs from PBS and plot them

    '''


    #### SETUP
    #
    savedata = False
    savefigs = True

    plot_logpost = False
    plot_error = False
    plot_mixtmodel = True
    plot_hist_responses_fisherinfo = True
    compute_plot_bootstrap = False
    compute_fisher_info_perratioconj = True

    # mixturemodel_to_use = 'original'
    mixturemodel_to_use = 'allitems'
    # mixturemodel_to_use = 'allitems_kappafi'

    caching_fisherinfo_filename = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'cache_fisherinfo.pickle')


    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_log_posterior = np.squeeze(data_pbs.dict_arrays['result_all_log_posterior']['results'])
    result_all_thetas = np.squeeze(data_pbs.dict_arrays['result_all_thetas']['results'])

    ratio_space = data_pbs.loaded_data['parameters_uniques']['ratio_conj']

    print ratio_space
    print result_all_log_posterior.shape

    N = result_all_thetas.shape[-1]

    result_prob_wrong = np.zeros((ratio_space.size, N))
    result_em_fits = np.empty((ratio_space.size, 6))*np.nan

    all_args = data_pbs.loaded_data['args_list']

    fixed_means = [-np.pi*0.6, np.pi*0.6]
    all_angles = np.linspace(-np.pi, np.pi, result_all_log_posterior.shape[-1])

    dataio = DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])


    plt.rcParams['font.size'] = 18


    if plot_hist_responses_fisherinfo:

        # From cache
        if caching_fisherinfo_filename is not None:
            if os.path.exists(caching_fisherinfo_filename):
                # Got file, open it and try to use its contents
                try:
                    with open(caching_fisherinfo_filename, 'r') as file_in:
                        # Load and assign values
                        cached_data = pickle.load(file_in)
                        result_fisherinfo_ratio = cached_data['result_fisherinfo_ratio']
                        compute_fisher_info_perratioconj = False

                except IOError:
                    print "Error while loading ", caching_fisherinfo_filename, "falling back to computing the Fisher Info"

        if compute_fisher_info_perratioconj:
            # We did not save the Fisher info, but need it if we want to fit the mixture model with fixed kappa. So recompute them using the args_dicts

            result_fisherinfo_ratio = np.empty(ratio_space.shape)

            # Invert the all_args_i -> ratio_conj direction
            parameters_indirections = data_pbs.loaded_data['parameters_dataset_index']

            for ratio_conj_i, ratio_conj in enumerate(ratio_space):
                # Get index of first dataset with the current ratio_conj (no need for the others, I think)
                arg_index = parameters_indirections[(ratio_conj,)][0]

                # Now using this dataset, reconstruct a RandomFactorialNetwork and compute the fisher info
                curr_args = all_args[arg_index]

                curr_args['stimuli_generation'] = lambda T: np.linspace(-np.pi*0.6, np.pi*0.6, T)

                (random_network, data_gen, stat_meas, sampler) = launchers.init_everything(curr_args)

                # Theo Fisher info
                result_fisherinfo_ratio[ratio_conj_i] = sampler.estimate_fisher_info_theocov()

                del curr_args['stimuli_generation']

            # Save everything to a file, for faster later plotting
            if caching_fisherinfo_filename is not None:
                try:
                    with open(caching_fisherinfo_filename, 'w') as filecache_out:
                        data_cache = dict(result_fisherinfo_ratio=result_fisherinfo_ratio)
                        pickle.dump(data_cache, filecache_out, protocol=2)
                except IOError:
                    print "Error writing out to caching file ", caching_fisherinfo_filename

        # Now plots. Do histograms of responses (around -pi/6 and pi/6), add Von Mises derived from Theo FI on top, and vertical lines for the correct target/nontarget angles.
        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            # Histogram
            ax = utils.hist_angular_data(result_all_thetas[ratio_conj_i], bins=100, title='ratio %.2f, fi %.0f' % (ratio_conj, result_fisherinfo_ratio[ratio_conj_i]))
            bar_heights, _, _ = utils.histogram_binspace(result_all_thetas[ratio_conj_i], bins=100, norm='density')

            # Add Fisher info prediction on top
            x = np.linspace(-np.pi, np.pi, 1000)
            if result_fisherinfo_ratio[ratio_conj_i] < 700:
                # Von Mises PDF
                utils.plot_vonmises_pdf(x, utils.stddev_to_kappa(1./result_fisherinfo_ratio[ratio_conj_i]**0.5), mu=fixed_means[-1], ax_handle=ax, linewidth=3, color='r', scale=np.max(bar_heights), fmt='-')
            else:
                # Switch to Gaussian instead
                utils.plot_normal_pdf(x, mu=fixed_means[-1], std=1./result_fisherinfo_ratio[ratio_conj_i]**0.5, ax_handle=ax, linewidth=3, color='r', scale=np.max(bar_heights), fmt='-')

            # ax.set_xticks([])
            # ax.set_yticks([])

            # Add vertical line to correct target/nontarget
            ax.axvline(x=fixed_means[0], color='g', linewidth=2)
            ax.axvline(x=fixed_means[1], color='r', linewidth=2)

            ax.get_figure().canvas.draw()

            if savefigs:
                # plt.tight_layout()
                dataio.save_current_figure('results_misbinding_histresponses_vonmisespdf_ratioconj%.2f{label}_{unique_id}.pdf' % (ratio_conj))



    if plot_logpost:
        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            # ax = utils.plot_mean_std_area(all_angles, nanmean(result_all_log_posterior[ratio_conj_i], axis=0), nanstd(result_all_log_posterior[ratio_conj_i], axis=0))

            # ax.set_xlim((-np.pi, np.pi))
            # ax.set_xticks((-np.pi, -np.pi / 2, 0, np.pi / 2., np.pi))
            # ax.set_xticklabels((r'$-\pi$', r'$-\frac{\pi}{2}$', r'$0$', r'$\frac{\pi}{2}$', r'$\pi$'))
            # ax.set_yticks(())

            # ax.get_figure().canvas.draw()

            # if savefigs:
            #     dataio.save_current_figure('results_misbinding_logpost_ratioconj%.2f_{label}_global_{unique_id}.pdf' % ratio_conj)


            # Compute the probability of answering wrongly (from fitting mixture distrib onto posterior)
            for n in xrange(result_all_log_posterior.shape[1]):
                result_prob_wrong[ratio_conj_i, n], _, _ = utils.fit_gaussian_mixture_fixedmeans(all_angles, np.exp(result_all_log_posterior[ratio_conj_i, n]), fixed_means=fixed_means, normalise=True, return_fitted_data=False, should_plot=False)

        # ax = utils.plot_mean_std_area(ratio_space, nanmean(result_prob_wrong, axis=-1), nanstd(result_prob_wrong, axis=-1))
        plt.figure()
        plt.plot(ratio_space, utils.nanmean(result_prob_wrong, axis=-1))

        # ax.get_figure().canvas.draw()
        if savefigs:
            dataio.save_current_figure('results_misbinding_probwrongpost_allratioconj_{label}_global_{unique_id}.pdf')

    if plot_error:

        ## Compute Standard deviation/precision from samples and plot it as a function of ratio_conj
        stats = utils.compute_mean_std_circular_data(utils.wrap_angles(result_all_thetas - fixed_means[1]).T)

        f = plt.figure()
        plt.plot(ratio_space, stats['std'])
        plt.ylabel('Standard deviation [rad]')

        if savefigs:
            dataio.save_current_figure('results_misbinding_stddev_allratioconj_{label}_global_{unique_id}.pdf')

        f = plt.figure()
        plt.plot(ratio_space, utils.compute_angle_precision_from_std(stats['std'], square_precision=False), linewidth=2)
        plt.ylabel('Precision [$1/rad$]')
        plt.xlabel('Proportion of conjunctive units')
        plt.grid()

        if savefigs:
            dataio.save_current_figure('results_misbinding_precision_allratioconj_{label}_global_{unique_id}.pdf')

        ## Compute the probability of misbinding
        # 1) Just count samples < 0 / samples tot
        # 2) Fit a mixture model, average over mixture probabilities
        prob_smaller0 = np.sum(result_all_thetas <= 1, axis=1)/float(result_all_thetas.shape[1])

        em_centers = np.zeros((ratio_space.size, 2))
        em_covs = np.zeros((ratio_space.size, 2))
        em_pk = np.zeros((ratio_space.size, 2))
        em_ll = np.zeros(ratio_space.size)
        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            cen_lst, cov_lst, em_pk[ratio_conj_i], em_ll[ratio_conj_i] = pygmm.em(result_all_thetas[ratio_conj_i, np.newaxis].T, K = 2, max_iter = 400, init_kw={'cluster_init':'fixed', 'fixed_means': fixed_means})

            em_centers[ratio_conj_i] = np.array(cen_lst).flatten()
            em_covs[ratio_conj_i] = np.array(cov_lst).flatten()

        # print em_centers
        # print em_covs
        # print em_pk

        f = plt.figure()
        plt.plot(ratio_space, prob_smaller0)
        plt.ylabel('Misbound proportion')
        if savefigs:
            dataio.save_current_figure('results_misbinding_countsmaller0_allratioconj_{label}_global_{unique_id}.pdf')

        f = plt.figure()
        plt.plot(ratio_space, np.max(em_pk, axis=-1), 'g', linewidth=2)
        plt.ylabel('Mixture proportion, correct')
        plt.xlabel('Proportion of conjunctive units')
        plt.grid()
        if savefigs:
            dataio.save_current_figure('results_misbinding_emmixture_allratioconj_{label}_global_{unique_id}.pdf')


        # Put everything on one figure
        f = plt.figure(figsize=(10, 6))
        norm_for_plot = lambda x: (x - np.min(x))/np.max((x - np.min(x)))
        plt.plot(ratio_space, norm_for_plot(stats['std']), ratio_space, norm_for_plot(utils.compute_angle_precision_from_std(stats['std'], square_precision=False)), ratio_space, norm_for_plot(prob_smaller0), ratio_space, norm_for_plot(em_pk[:, 1]), ratio_space, norm_for_plot(em_pk[:, 0]))
        plt.legend(('Std dev', 'Precision', 'Prob smaller 1', 'Mixture proportion correct', 'Mixture proportion misbinding'))
        # plt.plot(ratio_space, norm_for_plot(compute_angle_precision_from_std(stats['std'], square_precision=False)), ratio_space, norm_for_plot(em_pk[:, 1]), linewidth=2)
        # plt.legend(('Precision', 'Mixture proportion correct'), loc='best')
        plt.grid()
        if savefigs:
            dataio.save_current_figure('results_misbinding_allmetrics_allratioconj_{label}_global_{unique_id}.pdf')


    if plot_mixtmodel:
        # Fit Paul's model
        target_angle = np.ones(N)*fixed_means[1]
        nontarget_angles = np.ones((N, 1))*fixed_means[0]

        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            print "Ratio: ", ratio_conj

            responses = result_all_thetas[ratio_conj_i]

            if mixturemodel_to_use == 'allitems_kappafi':
                curr_params_fit = em_circularmixture_allitems_kappafi.fit(responses, target_angle, nontarget_angles, kappa=result_fisherinfo_ratio[ratio_conj_i])
            elif mixturemodel_to_use == 'allitems':
                curr_params_fit = em_circularmixture_allitems_uniquekappa.fit(responses, target_angle, nontarget_angles)
            else:
                curr_params_fit = em_circularmixture.fit(responses, target_angle, nontarget_angles)

            result_em_fits[ratio_conj_i] = [curr_params_fit['kappa'], curr_params_fit['mixt_target']] + utils.arrnum_to_list(curr_params_fit['mixt_nontargets']) + [curr_params_fit[key] for key in ('mixt_random', 'train_LL', 'bic')]

            print curr_params_fit


        if False:
            f, ax = plt.subplots()
            ax2 = ax.twinx()

            # left axis, kappa
            ax = utils.plot_mean_std_area(ratio_space, result_em_fits[:, 0], 0*result_em_fits[:, 0], xlabel='Proportion of conjunctive units', ylabel="Inverse variance $[rad^{-2}]$", ax_handle=ax, linewidth=3, fmt='o-', markersize=8, label='Fitted kappa', color='k')

            # Right axis, mixture probabilities
            utils.plot_mean_std_area(ratio_space, result_em_fits[:, 1], 0*result_em_fits[:, 1], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", ax_handle=ax2, linewidth=3, fmt='o-', markersize=8, label='Target')
            utils.plot_mean_std_area(ratio_space, result_em_fits[:, 2], 0*result_em_fits[:, 2], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", ax_handle=ax2, linewidth=3, fmt='o-', markersize=8, label='Nontarget')
            utils.plot_mean_std_area(ratio_space, result_em_fits[:, 3], 0*result_em_fits[:, 3], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", ax_handle=ax2, linewidth=3, fmt='o-', markersize=8, label='Random')

            lines, labels = ax.get_legend_handles_labels()
            lines2, labels2 = ax2.get_legend_handles_labels()
            ax.legend(lines + lines2, labels + labels2, fontsize=12, loc='right')

            # ax.set_xlim([0.9, 5.1])
            # ax.set_xticks(range(1, 6))
            # ax.set_xticklabels(range(1, 6))
            plt.grid()

            f.canvas.draw()

        if True:
            # Mixture probabilities
            ax = utils.plot_mean_std_area(ratio_space, result_em_fits[:, 1], 0*result_em_fits[:, 1], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", linewidth=3, fmt='-', markersize=8, label='Target')
            utils.plot_mean_std_area(ratio_space, result_em_fits[:, 2], 0*result_em_fits[:, 2], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", ax_handle=ax, linewidth=3, fmt='-', markersize=8, label='Nontarget')
            utils.plot_mean_std_area(ratio_space, result_em_fits[:, 3], 0*result_em_fits[:, 3], xlabel='Proportion of conjunctive units', ylabel="Mixture probabilities", ax_handle=ax, linewidth=3, fmt='-', markersize=8, label='Random')

            ax.legend(loc='right')

            # ax.set_xlim([0.9, 5.1])
            # ax.set_xticks(range(1, 6))
            # ax.set_xticklabels(range(1, 6))
            plt.grid()

            if savefigs:
                dataio.save_current_figure('results_misbinding_emmixture_allratioconj_{label}_global_{unique_id}.pdf')

        if True:
            # Kappa
            # ax = utils.plot_mean_std_area(ratio_space, result_em_fits[:, 0], 0*result_em_fits[:, 0], xlabel='Proportion of conjunctive units', ylabel="$\kappa [rad^{-2}]$", linewidth=3, fmt='-', markersize=8, label='Kappa')
            ax = utils.plot_mean_std_area(ratio_space, utils.kappa_to_stddev(result_em_fits[:, 0]), 0*result_em_fits[:, 2], xlabel='Proportion of conjunctive units', ylabel="Standard deviation [rad]", linewidth=3, fmt='-', markersize=8, label='Mixture model $\kappa$')

            # Add Fisher Info theo
            ax = utils.plot_mean_std_area(ratio_space, utils.kappa_to_stddev(result_fisherinfo_ratio), 0*result_em_fits[:, 2], xlabel='Proportion of conjunctive units', ylabel="Standard deviation [rad]", linewidth=3, fmt='-', markersize=8, label='Fisher Information', ax_handle=ax)

            ax.legend(loc='best')

            # ax.set_xlim([0.9, 5.1])
            # ax.set_xticks(range(1, 6))
            # ax.set_xticklabels(range(1, 6))
            plt.grid()

            if savefigs:
                dataio.save_current_figure('results_misbinding_kappa_allratioconj_{label}_global_{unique_id}.pdf')

    if compute_plot_bootstrap:
        ## Compute the bootstrap pvalue for each ratio
        #       use the bootstrap CDF from mixed runs, not the exact current ones, not sure if good idea.

        bootstrap_to_load = 1
        if bootstrap_to_load == 1:
            cache_bootstrap_fn = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'outputs', 'cache_bootstrap_mixed_from_bootstrapnontargets.pickle')
            bootstrap_ecdf_sum_label = 'bootstrap_ecdf_allitems_sum_sigmax_T'
            bootstrap_ecdf_all_label = 'bootstrap_ecdf_allitems_all_sigmax_T'
        elif bootstrap_to_load == 2:
            cache_bootstrap_fn = os.path.join(generator_module.pbs_submission_infos['simul_out_dir'], 'outputs', 'cache_bootstrap_misbinding_mixed.pickle')
            bootstrap_ecdf_sum_label = 'bootstrap_ecdf_allitems_sum_ratioconj'
            bootstrap_ecdf_all_label = 'bootstrap_ecdf_allitems_all_ratioconj'

        try:
            with open(cache_bootstrap_fn, 'r') as file_in:
                # Load and assign values
                cached_data = pickle.load(file_in)
                assert bootstrap_ecdf_sum_label in cached_data
                assert bootstrap_ecdf_all_label in cached_data
                should_fit_bootstrap = False

        except IOError:
            print "Error while loading ", cache_bootstrap_fn

        # Select the ECDF to use
        if bootstrap_to_load == 1:
            sigmax_i = 3    # corresponds to sigmax = 2, input here.
            T_i = 1         # two possible targets here.
            bootstrap_ecdf_sum_used = cached_data[bootstrap_ecdf_sum_label][sigmax_i][T_i]['ecdf']
            bootstrap_ecdf_all_used = cached_data[bootstrap_ecdf_all_label][sigmax_i][T_i]['ecdf']
        elif bootstrap_to_load == 2:
            ratio_conj_i = 4
            bootstrap_ecdf_sum_used = cached_data[bootstrap_ecdf_sum_label][ratio_conj_i]['ecdf']
            bootstrap_ecdf_all_used = cached_data[bootstrap_ecdf_all_label][ratio_conj_i]['ecdf']


        result_pvalue_bootstrap_sum = np.empty(ratio_space.size)*np.nan
        result_pvalue_bootstrap_all = np.empty((ratio_space.size, nontarget_angles.shape[-1]))*np.nan

        for ratio_conj_i, ratio_conj in enumerate(ratio_space):
            print "Ratio: ", ratio_conj

            responses = result_all_thetas[ratio_conj_i]

            bootstrap_allitems_nontargets_allitems_uniquekappa = em_circularmixture_allitems_uniquekappa.bootstrap_nontarget_stat(responses, target_angle, nontarget_angles,
                sumnontargets_bootstrap_ecdf=bootstrap_ecdf_sum_used,
                allnontargets_bootstrap_ecdf=bootstrap_ecdf_all_used)

            result_pvalue_bootstrap_sum[ratio_conj_i] = bootstrap_allitems_nontargets_allitems_uniquekappa['p_value']
            result_pvalue_bootstrap_all[ratio_conj_i] = bootstrap_allitems_nontargets_allitems_uniquekappa['allnontarget_p_value']

        ## Plots
        # f, ax = plt.subplots()
        # ax.plot(ratio_space, result_pvalue_bootstrap_all, linewidth=2)

        # if savefigs:
        #     dataio.save_current_figure("pvalue_bootstrap_all_ratioconj_{label}_{unique_id}.pdf")

        f, ax = plt.subplots()
        ax.plot(ratio_space, result_pvalue_bootstrap_sum, linewidth=2)
        plt.grid()

        if savefigs:
            dataio.save_current_figure("pvalue_bootstrap_sum_ratioconj_{label}_{unique_id}.pdf")


    # plt.figure()
    # plt.plot(ratio_MMlower, results_filtered_smoothed/np.max(results_filtered_smoothed, axis=0), linewidth=2)
    # plt.plot(ratio_MMlower[np.argmax(results_filtered_smoothed, axis=0)], np.ones(results_filtered_smoothed.shape[-1]), 'ro', markersize=10)
    # plt.grid()
    # plt.ylim((0., 1.1))
    # plt.subplots_adjust(right=0.8)
    # plt.legend(['%d item' % i + 's'*(i>1) for i in xrange(1, T+1)], loc='center right', bbox_to_anchor=(1.3, 0.5))
    # plt.xticks(np.linspace(0, 1.0, 5))

    variables_to_save = ['target_angle', 'nontarget_angles']

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)
        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='misbindings')


    plt.show()

    return locals()
def check_precision_sensitivity_determ():
    ''' Let's construct a situation where we have one Von Mises component and one random component. See how the random component affects the basic precision estimator we use elsewhere.
    '''

    N = 1000
    kappa_space = np.array([3., 10., 20.])
    # kappa_space = np.array([3.])
    nb_repeats = 20
    ratio_to_kappa = False
    savefigs = True
    precision_nb_samples = 101

    N_rnd_space             = np.linspace(0, N/2, precision_nb_samples).astype(int)
    precision_all           = np.zeros((N_rnd_space.size, nb_repeats))
    kappa_estimated_all     = np.zeros((N_rnd_space.size, nb_repeats))
    precision_squared_all   = np.zeros((N_rnd_space.size, nb_repeats))
    kappa_mixtmodel_all     = np.zeros((N_rnd_space.size, nb_repeats))
    mixtmodel_all           = np.zeros((N_rnd_space.size, nb_repeats, 2))

    dataio = DataIO.DataIO()

    target_samples = np.zeros(N)

    for kappa in kappa_space:

        true_kappa = kappa*np.ones(N_rnd_space.size)

        # First sample all as von mises
        samples_all = spst.vonmises.rvs(kappa, size=(N_rnd_space.size, nb_repeats, N))

        for repeat in progress.ProgressDisplay(xrange(nb_repeats)):
            for i, N_rnd in enumerate(N_rnd_space):
                samples = samples_all[i, repeat]

                # Then set K of them to random [-np.pi, np.pi] values.
                samples[np.random.randint(N, size=N_rnd)] = utils.sample_angle(N_rnd)

                # Estimate precision from those samples.
                precision_all[i, repeat] = utils.compute_precision_samples(samples, square_precision=False, remove_chance_level=False)
                precision_squared_all[i, repeat] = utils.compute_precision_samples(samples, square_precision=True)

                # convert circular std dev back to kappa
                kappa_estimated_all[i, repeat] = utils.stddev_to_kappa(1./precision_all[i, repeat])

                # Fit mixture model
                params_fit = em_circularmixture.fit(samples, target_samples)
                kappa_mixtmodel_all[i, repeat] = params_fit['kappa']
                mixtmodel_all[i, repeat] = params_fit['mixt_target'], params_fit['mixt_random']

                print "%d/%d N_rnd: %d, Kappa: %.3f, precision: %.3f, kappa_tilde: %.3f, precision^2: %.3f, kappa_mixtmod: %.3f" % (repeat, nb_repeats, N_rnd, kappa, precision_all[i, repeat], kappa_estimated_all[i, repeat], precision_squared_all[i, repeat], kappa_mixtmodel_all[i, repeat])


        if ratio_to_kappa:
            precision_all /= kappa
            precision_squared_all /= kappa
            kappa_estimated_all /= kappa
            true_kappa /= kappa

        f, ax = plt.subplots()
        ax.plot(N_rnd_space/float(N), true_kappa, 'k-', linewidth=3, label='Kappa_true')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(precision_all, axis=-1), np.std(precision_all, axis=-1), ax_handle=ax, label='precision')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(precision_squared_all, axis=-1), np.std(precision_squared_all, axis=-1), ax_handle=ax, label='precision^2')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(kappa_estimated_all, axis=-1), np.std(kappa_estimated_all, axis=-1), ax_handle=ax, label='kappa_tilde')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(kappa_mixtmodel_all, axis=-1), np.std(kappa_mixtmodel_all, axis=-1), ax_handle=ax, label='kappa mixt model')

        ax.legend()
        ax.set_title('Effect of random samples on precision. kappa: %.2f. ratiokappa %s' % (kappa, ratio_to_kappa))
        ax.set_xlabel('Proportion random samples. N tot %d' % N)
        ax.set_ylabel('Kappa/precision (not same units)')
        f.canvas.draw()

        if savefigs:
            dataio.save_current_figure("precision_sensitivity_kappa%dN%d_{unique_id}.pdf" % (kappa, N))

        # Do another plot, with kappa and mixt_target/mixt_random. Use left/right axis separately
        f, ax = plt.subplots()
        ax2 = ax.twinx()

        # left axis, kappa
        ax.plot(N_rnd_space/float(N), true_kappa, 'k-', linewidth=3, label='kappa true')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(kappa_mixtmodel_all, axis=-1), np.std(kappa_mixtmodel_all, axis=-1), ax_handle=ax, label='kappa')

        # Right axis, mixture probabilities
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(mixtmodel_all[..., 0], axis=-1), np.std(mixtmodel_all[..., 0], axis=-1), ax_handle=ax2, label='mixt target', color='r')
        utils.plot_mean_std_area(N_rnd_space/float(N), np.mean(mixtmodel_all[..., 1], axis=-1), np.std(mixtmodel_all[..., 1], axis=-1), ax_handle=ax2, label='mixt random', color='g')
        ax.set_title('Mixture model parameters evolution. kappa: %.2f, ratiokappa %s' % (kappa, ratio_to_kappa))
        ax.set_xlabel('Proportion random samples. N tot %d' % N)
        ax.set_ylabel('Kappa')
        ax2.set_ylabel('Mixture proportions')

        lines, labels = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax.legend(lines + lines2, labels + labels2)

        if savefigs:
            dataio.save_current_figure("precision_sensitivity_mixtmodel_kappa%dN%d_{unique_id}.pdf" % (kappa, N))



    return locals()
def plots_ratioMscaling(data_pbs, generator_module=None):
    '''
        Reload and plot precision/fits of a Mixed code.
    '''

    #### SETUP
    #
    savefigs = True
    savedata = True

    plots_pcolor_all = True
    plots_effect_M_target_kappa = False

    plots_kappa_fi_comparison = False
    plots_multiple_fisherinfo = False
    specific_plot_effect_R = False

    convert_M_realsizes = True

    plots_pcolor_realsizes_Msubs = True
    plots_pcolor_realsizes_Mtot = True

    colormap = None  # or 'cubehelix'
    plt.rcParams['font.size'] = 16
    #
    #### /SETUP

    print "Order parameters: ", generator_module.dict_parameters_range.keys()

    result_all_precisions_mean = (utils.nanmean(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_all_precisions_std = (utils.nanstd(data_pbs.dict_arrays['result_all_precisions']['results'], axis=-1))
    result_em_fits_mean = (utils.nanmean(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_em_fits_std = (utils.nanstd(data_pbs.dict_arrays['result_em_fits']['results'], axis=-1))
    result_fisherinfo_mean = (utils.nanmean(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))
    result_fisherinfo_std = (utils.nanstd(data_pbs.dict_arrays['result_fisher_info']['results'], axis=-1))

    result_em_fits_kappa = result_em_fits_mean[..., 0]
    result_em_fits_target = result_em_fits_mean[..., 1]
    result_em_fits_kappa_valid = np.ma.masked_where(result_em_fits_target < 0.8, result_em_fits_kappa)


    # flat versions
    result_parameters_flat = np.array(data_pbs.dict_arrays['result_all_precisions']['parameters_flat'])
    result_all_precisions_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_all_precisions']['results_flat']), axis=-1)
    result_em_fits_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_em_fits']['results_flat']), axis=-1)
    result_fisherinfor_mean_flat = np.mean(np.array(data_pbs.dict_arrays['result_fisher_info']['results_flat']), axis=-1)
    result_em_fits_kappa_flat = result_em_fits_mean_flat[..., 0]
    result_em_fits_target_flat = result_em_fits_mean_flat[..., 1]
    result_em_fits_kappa_valid_flat = np.ma.masked_where(result_em_fits_target_flat < 0.8, result_em_fits_kappa_flat)

    all_args = data_pbs.loaded_data['args_list']


    M_space = data_pbs.loaded_data['parameters_uniques']['M'].astype(int)
    ratio_space = data_pbs.loaded_data['parameters_uniques']['ratio_conj']
    R_space = data_pbs.loaded_data['parameters_uniques']['R'].astype(int)
    num_repetitions = generator_module.num_repetitions

    print M_space
    print ratio_space
    print R_space
    print result_all_precisions_mean.shape, result_em_fits_mean.shape

    dataio = DataIO.DataIO(output_folder=generator_module.pbs_submission_infos['simul_out_dir'] + '/outputs/', label='global_' + dataset_infos['save_output_filename'])

    MAX_DISTANCE = 100.

    if convert_M_realsizes:
        # alright, currently M*ratio_conj gives the conjunctive subpopulation,
        # but only floor(M_conj**1/R) neurons are really used. So we should
        # convert to M_conj_real and M_feat_real instead of M and ratio
        result_parameters_flat_subM_converted = []
        result_parameters_flat_Mtot_converted = []

        for params in result_parameters_flat:
            M = params[0]; ratio_conj = params[1]; R = int(params[2])

            M_conj_prior = int(M*ratio_conj)
            M_conj_true = int(np.floor(M_conj_prior**(1./R))**R)
            M_feat_true = int(np.floor((M-M_conj_prior)/R)*R)

            # result_parameters_flat_subM_converted contains (M_conj, M_feat, R)
            result_parameters_flat_subM_converted.append(np.array([M_conj_true, M_feat_true, R]))
            # result_parameters_flat_Mtot_converted contains (M_tot, ratio_conj, R)
            result_parameters_flat_Mtot_converted.append(np.array([float(M_conj_true+M_feat_true), M_conj_true/float(M_conj_true+M_feat_true), R]))

        result_parameters_flat_subM_converted = np.array(result_parameters_flat_subM_converted)
        result_parameters_flat_Mtot_converted = np.array(result_parameters_flat_Mtot_converted)

    if plots_pcolor_all:
        if convert_M_realsizes:
            def plot_interp(points, data, currR_indices, title='', points_label='', xlabel='', ylabel=''):
                utils.contourf_interpolate_data_interactive_maxvalue(points[currR_indices][..., :2], data[currR_indices], xlabel=xlabel, ylabel=ylabel, title='%s, R=%d' % (title, R), interpolation_numpoints=200, interpolation_method='nearest', log_scale=False)

                if savefigs:
                    dataio.save_current_figure('pcolortrueM%s_%s_R%d_log_{label}_{unique_id}.pdf' % (points_label, title, R))

            all_datas = [dict(name='precision', data=result_all_precisions_mean_flat), dict(name='kappa', data=result_em_fits_kappa_flat), dict(name='kappavalid', data=result_em_fits_kappa_valid_flat), dict(name='target', data=result_em_fits_target_flat), dict(name='fisherinfo', data=result_fisherinfor_mean_flat)]
            all_points = []
            if plots_pcolor_realsizes_Msubs:
                all_points.append(dict(name='sub', data=result_parameters_flat_subM_converted, xlabel='M_conj', ylabel='M_feat'))
            if plots_pcolor_realsizes_Mtot:
                all_points.append(dict(name='tot', data=result_parameters_flat_Mtot_converted, xlabel='Mtot', ylabel='ratio_conj'))

            for curr_points in all_points:
                for curr_data in all_datas:
                    for R_i, R in enumerate(R_space):
                        currR_indices = curr_points['data'][:, 2] == R

                        plot_interp(curr_points['data'], curr_data['data'], currR_indices, title=curr_data['name'], points_label=curr_points['name'], xlabel=curr_points['xlabel'], ylabel=curr_points['ylabel'])

                # # show precision
                # plot_interp(result_parameters_flat_subM_converted, result_all_precisions_mean_flat, currR_indices, title='precision')

                # # show kappa
                # plot_interp(result_parameters_flat_subM_converted, result_em_fits_kappa_flat, currR_indices, title='kappa')

                # plot_interp(result_parameters_flat_subM_converted, result_em_fits_kappa_valid_flat, currR_indices, title='kappavalid')

                # # show probability on target
                # plot_interp(result_parameters_flat_subM_converted, result_em_fits_target_flat, currR_indices, title='target')

                # # show fisher info
                # plot_interp(result_parameters_flat_subM_converted, result_fisherinfor_mean_flat, currR_indices, title='fisherinfo')

        else:
            # Do one pcolor for M and ratio per R
            for R_i, R in enumerate(R_space):
                # Check evolution of precision given M and ratio
                utils.pcolor_2d_data(result_all_precisions_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='precision, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_precision_R%d_log_{label}_{unique_id}.pdf' % R)

                # Show kappa
                try:
                    utils.pcolor_2d_data(result_em_fits_kappa_valid[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='kappa, R=%d' % R)
                    if savefigs:
                        dataio.save_current_figure('pcolor_kappa_R%d_log_{label}_{unique_id}.pdf' % R)
                except ValueError:
                    pass

                # Show probability on target
                utils.pcolor_2d_data(result_em_fits_target[..., R_i], log_scale=False, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='target, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_target_R%d_{label}_{unique_id}.pdf' % R)

                # # Show Fisher info
                utils.pcolor_2d_data(result_fisherinfo_mean[..., R_i], log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='fisher info, R=%d' % R)
                if savefigs:
                    dataio.save_current_figure('pcolor_fisherinfo_R%d_log_{label}_{unique_id}.pdf' % R)

                plt.close('all')

    if plots_effect_M_target_kappa:
        def plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R):
            f, ax = plt.subplots()
            ax.plot(M_space, ratio_target_kappa_given_M)
            ax.set_xlabel('M')
            ax.set_ylabel('Optimal ratio')
            ax.set_title('Optimal Ratio for kappa %d, R=%d' % (target_kappa, R))

            if savefigs:
                dataio.save_current_figure('optratio_M_targetkappa%d_R%d_{label}_{unique_id}.pdf' % (target_kappa, R))

        target_kappas = np.array([100, 200, 300, 500, 1000, 3000])
        for R_i, R in enumerate(R_space):
            for target_kappa in target_kappas:
                dist_to_target_kappa = (result_em_fits_kappa[..., R_i] - target_kappa)**2.
                best_dist_to_target_kappa = np.argmin(dist_to_target_kappa, axis=1)
                ratio_target_kappa_given_M = np.ma.masked_where(dist_to_target_kappa[np.arange(dist_to_target_kappa.shape[0]), best_dist_to_target_kappa] > MAX_DISTANCE, ratio_space[best_dist_to_target_kappa])

                # replot
                plot_ratio_target_kappa(ratio_target_kappa_given_M, target_kappa, R)

            plt.close('all')


    if plots_kappa_fi_comparison:

        # result_em_fits_kappa and fisher info
        if True:
            for R_i, R in enumerate(R_space):
                for M_tot_selected_i, M_tot_selected in enumerate(M_space[::2]):

                    # M_conj_space = ((1.-ratio_space)*M_tot_selected).astype(int)
                    # M_feat_space = M_tot_selected - M_conj_space

                    f, axes = plt.subplots(2, 1)
                    axes[0].plot(ratio_space, result_em_fits_kappa[2*M_tot_selected_i, ..., R_i])
                    axes[0].set_xlabel('ratio')
                    axes[0].set_title('Fitted kappa')

                    axes[1].plot(ratio_space, utils.stddev_to_kappa(1./result_fisherinfo_mean[2*M_tot_selected_i, ..., R_i]**0.5))
                    axes[1].set_xlabel('ratio')
                    axes[1].set_title('kappa_FI')

                    f.suptitle('M_tot %d' % M_tot_selected, fontsize=15)
                    f.set_tight_layout(True)

                    if savefigs:
                        dataio.save_current_figure('comparison_kappa_fisher_R%d_M%d_{label}_{unique_id}.pdf' % (R, M_tot_selected))

                    plt.close(f)

        if plots_multiple_fisherinfo:
            target_fisherinfos = np.array([100, 200, 300, 500, 1000])
            for R_i, R in enumerate(R_space):
                for target_fisherinfo in target_fisherinfos:
                    dist_to_target_fisherinfo = (result_fisherinfo_mean[..., R_i] - target_fisherinfo)**2.

                    utils.pcolor_2d_data(dist_to_target_fisherinfo, log_scale=True, x=M_space, y=ratio_space, xlabel='M', ylabel='ratio', xlabel_format="%d", title='Fisher info, R=%d' % R)
                    if savefigs:
                        dataio.save_current_figure('pcolor_distfi%d_R%d_log_{label}_{unique_id}.pdf' % (target_fisherinfo, R))

                plt.close('all')

    if specific_plot_effect_R:
        # Choose a M, find which ratio gives best fit to a given kappa
        M_target = 356
        M_target_i = np.argmin(np.abs(M_space - M_target))

        utils.pcolor_2d_data(result_em_fits_kappa_valid[M_target_i], log_scale=True, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='Kappa, M %d' % (M_target))
        if savefigs:
            dataio.save_current_figure('specific_pcolor_kappa_M%d_log_{label}_{unique_id}.pdf' % (M_target))
        # target_kappa = np.ma.mean(result_em_fits_kappa_valid[M_target_i])
        # target_kappa = 5*1e3
        target_kappa = 1.2e3

        dist_target_kappa = np.abs(result_em_fits_kappa_valid[M_target_i] - target_kappa)

        utils.pcolor_2d_data(dist_target_kappa, log_scale=True, x=ratio_space, y=R_space, xlabel='ratio', ylabel='R', ylabel_format="%d", title='Kappa dist %.2f, M %d' % (target_kappa, M_target))
        if savefigs:
            dataio.save_current_figure('specific_pcolor_distkappa%d_M%d_log_{label}_{unique_id}.pdf' % (target_kappa, M_target))




    all_args = data_pbs.loaded_data['args_list']
    variables_to_save = []

    if savedata:
        dataio.save_variables_default(locals(), variables_to_save)

        dataio.make_link_output_to_dropbox(dropbox_current_experiment_folder='higher_dimensions_R')

    plt.show()

    return locals()