def compare_posteriors_with_different_data(cfg_name, model, t, replace_indices,
                                           params) -> None:
    plt.clf()
    plt.close()
    fig, axarr = plt.subplots(nrows=1, ncols=len(params))
    colours = cm.viridis(np.linspace(0.2, 0.8, len(replace_indices)))

    for j, replace_index in enumerate(replace_indices):
        for i, p in enumerate(params):
            samples = results_utils.get_posterior_samples(
                cfg_name,
                iter_range=(t, t + 1),
                model=model,
                replace_index=replace_index,
                params=[p])
            sns.distplot(samples,
                         ax=axarr[i],
                         color=to_hex(colours[j]),
                         label=str(replace_index),
                         kde=False)
    # save

    for i, p in enumerate(params):
        axarr[i].set_xlabel('parameter ' + p)

    axarr[0].set_title('iteration ' + str(t))
    axarr[-1].legend()
    vis_utils.beautify_axes(axarr)

    return
def weight_posterior(cfg_name,
                     model,
                     replace_indices='random',
                     t=500,
                     param='#0',
                     n_bins=25):
    """
    """
    iter_range = (t, t + 1)
    nolegend = False

    if replace_indices == 'random':
        print('Picking two *random* replace indices for this setting...')
        df = results_utils.get_available_results(cfg_name, model)
        replace_counts = df['replace'].value_counts()
        replaces = replace_counts[replace_counts > 2].index.values
        replace_indices = np.random.choice(replaces, 2, replace=False).tolist()
    elif type(replace_indices) == int:
        replace_indices = [replace_indices]
        nolegend = True

    assert type(replace_indices) == list
    # Set up the plot
    fig, axarr = plt.subplots(nrows=1, ncols=1, figsize=(4, 2.5))
    # now load the data!
    for replace_index in replace_indices:
        df = results_utils.get_posterior_samples(cfg_name,
                                                 iter_range,
                                                 model,
                                                 replace_index=replace_index,
                                                 params=[param],
                                                 seeds='all')
        sns.distplot(df[param],
                     ax=axarr,
                     label=f'D\{replace_index}',
                     kde=True,
                     bins=n_bins,
                     norm_hist=True)

    axarr.set_xlabel('weight ' + param)
    if not nolegend:
        axarr.legend()
    axarr.set_ylabel('# runs')
    vis_utils.beautify_axes(np.array([axarr]))
    plt.tight_layout()

    plot_identifier = f'weight_posterior_{cfg_name}_{param}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def fit_pval_histogram(what,
                       cfg_name,
                       model,
                       t,
                       n_experiments=3,
                       diffinit=False,
                       xlim=None,
                       seed=1) -> None:
    """
    histogram of p-values (across parameters-?) for a given model etc.
    """
    assert what in ['weights', 'gradients']
    # set some stuff up
    iter_range = (t, t + 1)
    fig, axarr = plt.subplots(nrows=1, ncols=1, figsize=(3.5, 2.1))
    pval_colour = '#b237c4'
    # sample experiments
    df = results_utils.get_available_results(cfg_name,
                                             model,
                                             diffinit=diffinit)
    replace_indices = df['replace'].unique()
    replace_indices = np.random.choice(replace_indices,
                                       n_experiments,
                                       replace=False)
    print('Looking at replace indices...', replace_indices)
    all_pvals = []

    for i, replace_index in enumerate(replace_indices):
        experiment = results_utils.ExperimentIdentifier(
            cfg_name, model, replace_index, seed, diffinit)

        if what == 'gradients':
            print('Loading gradients...')
            df = experiment.load_gradients(noise=True,
                                           iter_range=iter_range,
                                           params=None)
            second_col = df.columns[1]
        elif what == 'weights':
            df = results_utils.get_posterior_samples(
                cfg_name,
                iter_range=iter_range,
                model=model,
                replace_index=replace_index,
                params=None,
                seeds='all')
            second_col = df.columns[1]
        params = df.columns[2:]
        n_params = len(params)
        print(n_params)

        if n_params < 50:
            print(
                'ERROR: Insufficient parameters for this kind of visualisation, please try something else'
            )

            return False
        print('Identified', n_params, 'parameters, proceeding with analysis')
        p_vals = np.zeros(shape=(n_params))

        for j, p in enumerate(params):
            print('getting fit for parameter', p)
            df_fit = dr.estimate_statistics_through_training(
                what=what,
                cfg_name=None,
                model=None,
                replace_index=None,
                seed=None,
                df=df.loc[:, ['t', second_col, p]],
                params=None,
                iter_range=None)
            p_vals[j] = df_fit.loc[t, 'norm_p']
            del df_fit
        log_pvals = np.log(p_vals)
        all_pvals.append(log_pvals)
    log_pvals = np.concatenate(all_pvals)

    if xlim is not None:
        # remove values below the limit
        number_below = (log_pvals < xlim[0]).sum()
        print('There are', number_below, 'p-values below the limit of',
              xlim[0])
        log_pvals = log_pvals[log_pvals > xlim[0]]
        print('Remaining pvals:', len(log_pvals))
    sns.distplot(log_pvals,
                 kde=True,
                 bins=min(100, int(len(log_pvals) * 0.25)),
                 ax=axarr,
                 color=pval_colour,
                 norm_hist=True)
    axarr.axvline(x=np.log(0.05),
                  ls=':',
                  label='p = 0.05',
                  color='black',
                  alpha=0.75)
    axarr.axvline(x=np.log(0.05 / n_params),
                  ls='--',
                  label='p = 0.05/' + str(n_params),
                  color='black',
                  alpha=0.75)
    axarr.legend()
    axarr.set_xlabel(r'$\log(p)$')
    axarr.set_ylabel('density')

    if xlim is not None:
        axarr.set_xlim(xlim)
    else:
        axarr.set_xlim((None, 0.01))


#    axarr.set_xscale('log')
    vis_utils.beautify_axes(np.array([axarr]))
    plt.tight_layout()
    plot_identifier = f'pval_histogram_{cfg_name}_{model}_{what}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def weight_evolution(cfg_name,
                     model,
                     n_seeds=50,
                     replace_indices=None,
                     iter_range=(None, None),
                     params=['#4', '#2'],
                     diffinit=False,
                     aggregate=False):
    plt.clf()
    plt.close()
    fig, axarr = plt.subplots(nrows=len(params),
                              ncols=1,
                              sharex=True,
                              figsize=(4, 3))

    if aggregate:
        colours = cm.get_cmap('Set1')(np.linspace(0.2, 0.8,
                                                  len(replace_indices)))
        assert n_seeds > 1

        for i, replace_index in enumerate(replace_indices):
            vary_S = results_utils.get_posterior_samples(
                cfg_name,
                iter_range,
                model,
                replace_index=replace_index,
                params=params,
                seeds='all',
                n_seeds=n_seeds,
                diffinit=diffinit)
            vary_S_min = vary_S.groupby('t').min()
            vary_S_std = vary_S.groupby('t').std()
            vary_S_max = vary_S.groupby('t').max()
            vary_S_mean = vary_S.groupby('t').mean()

            for j, p in enumerate(params):
                axarr[j].fill_between(vary_S_min.index,
                                      vary_S_min[p],
                                      vary_S_max[p],
                                      alpha=0.1,
                                      color=colours[i],
                                      label='_legend_')
                axarr[j].fill_between(vary_S_mean.index,
                                      vary_S_mean[p] - vary_S_std[p],
                                      vary_S_mean[p] + vary_S_std[p],
                                      alpha=0.1,
                                      color=colours[i],
                                      label='_nolegend_',
                                      linestyle='--')
                axarr[j].plot(vary_S_min.index,
                              vary_S_mean[p],
                              color=colours[i],
                              alpha=0.7,
                              label='D -' + str(replace_index))
                axarr[j].set_ylabel('weight ' + p)
    else:
        colours = cm.get_cmap('plasma')(np.linspace(0.2, 0.8, n_seeds))
        assert len(replace_indices) == 1
        replace_index = replace_indices[0]
        vary_S = results_utils.get_posterior_samples(
            cfg_name,
            iter_range,
            model,
            replace_index=replace_index,
            params=params,
            seeds='all',
            n_seeds=n_seeds,
            diffinit=diffinit)
        seeds = vary_S['seed'].unique()

        for i, s in enumerate(seeds):
            vary_Ss = vary_S.loc[vary_S['seed'] == s, :]

            for j, p in enumerate(params):
                axarr[j].plot(vary_Ss['t'],
                              vary_Ss[p],
                              color=colours[i],
                              label='seed ' + str(s),
                              alpha=0.8)

                if i == 0:
                    axarr[j].set_ylabel(r'$\mathbf{w}^{' + p[1:] + '}$')

    axarr[-1].set_xlabel('training steps')
    vis_utils.beautify_axes(np.array([axarr]))
    plt.tight_layout()
    plot_identifier = f'weight_trajectory_{cfg_name}.{model}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def qq_plot(what: str,
            cfg_name: str,
            model: str,
            replace_index: int,
            seed: int,
            times=[50],
            params='random') -> None:
    """
    grab trace file, do qq plot for gradient noise at specified time-point
    """
    plt.clf()
    plt.close()
    assert what in ['gradients', 'weights']

    if what == 'weights':
        print('Looking at weights, this means we consider all seeds!')
    colours = cm.viridis(np.linspace(0.2, 0.8, len(times)))

    experiment = results_utils.ExperimentIdentifier(cfg_name, model,
                                                    replace_index, seed)

    if params == 'random':
        if what == 'gradients':
            df = experiment.load_gradients(noise=True,
                                           params=None,
                                           iter_range=(min(times),
                                                       max(times) + 1))
        else:
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=(min(times), max(times) + 1),
                params=None)
        params = np.random.choice(df.columns[2:], 1)
        print('picking random parameter', params)
        first_two_cols = df.columns[:2].tolist()
        df = df.loc[:, first_two_cols + list(params)]
    else:
        if what == 'gradients':
            df = experiment.load_gradients(noise=True,
                                           params=params,
                                           iter_range=(min(times),
                                                       max(times) + 1))
        else:
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=(min(times), max(times) + 1),
                params=params)

    if df is False:
        print('ERROR: No data available')

        return False
    fig, axarr = plt.subplots(nrows=1, ncols=2, figsize=(7, 3.5))

    for i, t in enumerate(times):
        df_t = df.loc[df['t'] == t, :]
        X = df_t.iloc[:, 2:].values.flatten()
        print('number of samples:', X.shape[0])
        sns.distplot(X,
                     ax=axarr[0],
                     kde=False,
                     color=to_hex(colours[i]),
                     label=str(t))
        sm.qqplot(X,
                  line='45',
                  fit=True,
                  ax=axarr[1],
                  c=colours[i],
                  alpha=0.5,
                  label=str(t))
    plt.suptitle('cfg_name: ' + cfg_name + ', model:' + model + ',' + what)
    axarr[0].legend()
    axarr[1].legend()
    axarr[0].set_xlabel('parameter:' + '.'.join(params))
    vis_utils.beautify_axes(axarr)
    plt.tight_layout()

    plot_identifier = f'qq_{what}_{cfg_name}_{model}_{"_".join(params)}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def estimate_statistics_through_training(what,
                                         cfg_name,
                                         model,
                                         replace_index,
                                         seed,
                                         df=None,
                                         params=None,
                                         iter_range=(None, None),
                                         diffinit=True):
    """
    Grab a trace file for a model, estimate the alpha value for gradient noise throughout training
    NOTE: All weights taken together as IID (in the list of params supplied)
    """
    assert what in ['gradients', 'weights']

    if replace_index is None:
        replace_index = results_utils.get_replace_index_with_most_seeds(
            cfg_name, model, diffinit=diffinit)

    if df is None:
        if what == 'gradients':
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=iter_range,
                params=params,
                diffinit=diffinit,
                what='gradients')
        else:
            print('Getting posterior for weights, seed is irrelevant')
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=iter_range,
                params=params,
                diffinit=diffinit)

        if df is False:
            print('ERROR: No data found')

            return False

    # now go through the iterations
    iterations = df['t'].unique()
    # store the results in this dataframe
    df_fits = pd.DataFrame(index=iterations)
    df_fits.index.name = 't'
    df_fits['N'] = np.nan
    df_fits['alpha'] = np.nan
    df_fits['alpha_fit'] = np.nan

    for t in iterations:
        df_t = df.loc[df['t'] == t, :]
        # zero it out by seed
        if what == 'gradients':
            seed_means = df_t.groupby('seed').transform('mean')
            df_t = (df_t - seed_means).drop(columns=['seed', 't'])
            X = df_t.values
        else:
            X = df_t.iloc[:, 2:].values
            X = X - X.mean(axis=0)
        df_fits['N'] = X.shape[0]
        # fit alpha_stable
        alpha, fit = stats_utils.fit_alpha_stable(X)
        df_fits.loc[t, 'alpha'] = alpha
        df_fits.loc[t, 'alpha_fit'] = fit
        # fit multivariate gaussian - dont record the params since they don't fit...
        _, _, _, p = stats_utils.fit_multivariate_normal(X)
        df_fits.loc[t, 'mvnorm_mu'] = np.nan
        df_fits.loc[t, 'mvnorm_sigma'] = np.nan
        df_fits.loc[t, 'mvnorm_W'] = np.nan
        df_fits.loc[t, 'mvnorm_p'] = p
        # Now flatten and look at univariate distributions
        X_flat = X.reshape(-1, 1)
        df_fits['N_flat'] = X_flat.shape[0]
        # fit univariate gaussian
        mu, sigma, W, p = stats_utils.fit_normal(X_flat)
        df_fits.loc[t, 'norm_mu'] = mu
        df_fits.loc[t, 'norm_sigma'] = sigma
        df_fits.loc[t, 'norm_W'] = W
        df_fits.loc[t, 'norm_p'] = p
        # fit laplace
        loc, scale, D, p = stats_utils.fit_laplace(X_flat)
        df_fits.loc[t, 'lap_loc'] = loc
        df_fits.loc[t, 'lap_scale'] = scale
        df_fits.loc[t, 'lap_D'] = D
        df_fits.loc[t, 'lap_p'] = p

    # Attach what the fit was on
    df_fits.columns = [f'{what}_{x}' for x in df_fits.columns]
    return df_fits
    def generate(self, diffinit, verbose=True, ephemeral=False):
        """ ephemeral allows us generate it and return without saving """

        if not ephemeral:
            path_string = self.path_string(diffinit)

            if path_string.exists():
                print(
                    f'[Sigmas] File {path_string} already exists, not computing again!'
                )

                return
        # now compute
        df = results_utils.get_available_results(
            self.cfg_name,
            self.model,
            data_privacy=self.data_privacy,
            diffinit=diffinit)
        replace_counts = df['replace'].value_counts()
        replaces = replace_counts[replace_counts > 2].index.values

        if verbose:
            print(
                f'[Sigmas] Estimating variability across {len(replaces)} datasets!'
            )
            print('Warning: this can be slow...')
        sigmas = []

        if not self.num_replaces == 'max' and self.num_replaces < len(
                replaces):
            replaces = np.random.choice(replaces,
                                        self.num_replaces,
                                        replace=False)

        for replace_index in replaces:
            if verbose:
                print('replace index:', replace_index)
            samples = results_utils.get_posterior_samples(
                self.cfg_name, (self.t, self.t + 1),
                self.model,
                replace_index=replace_index,
                params=None,
                seeds='all',
                verbose=verbose,
                diffinit=diffinit,
                data_privacy=self.data_privacy,
                num_seeds=self.num_seeds)
            try:
                params = samples.columns[2:]

                if self.multivariate:
                    this_sigma = samples.std(axis=0)
                    this_sigma = this_sigma[params]
                else:
                    params_vals = samples[params].values
                    params_norm = params_vals - params_vals.mean(axis=0)
                    params_flat = params_norm.flatten()
                    this_sigma = np.std(params_flat)
            except AttributeError:
                print(f'WARNING: data from {replace_index} is bad - skipping')
                assert samples is False
                this_sigma = np.nan
            sigmas.append(this_sigma)
        sigmas = np.array(sigmas)
        sigmas_data = {'sigmas': sigmas, 'replaces': replaces}

        if not ephemeral:
            np.save(path_string, sigmas_data)
        else:
            return sigmas_data
def compute_mvn_laplace_fit_and_alpha(cfg_name,
                                      model,
                                      t,
                                      diffinit=True,
                                      just_on_normal_marginals=False,
                                      replace_index=None) -> dict:
    if replace_index is None:
        replace_index = results_utils.get_replace_index_with_most_seeds(
            cfg_name, model, diffinit=diffinit)

    iter_range = (t, t + 1)
    params = None
    df = results_utils.get_posterior_samples(cfg_name,
                                             model=model,
                                             replace_index=replace_index,
                                             iter_range=iter_range,
                                             params=params,
                                             diffinit=diffinit,
                                             what='weights')

    df_t = df.loc[df['t'] == t, :]
    X = df_t.iloc[:, 2:].values
    X = X - X.mean(axis=0)
    d = X.shape[1]
    if just_on_normal_marginals:
        print('Selecting just those parameters with normal marginals')
        normal_marginals = []
        for di in range(d):
            Xd = X[:, di]
            _, _, _, pval = stats_utils.fit_normal(Xd)
            if pval > 0.05:
                normal_marginals.append(di)
        print(
            f'Found {len(normal_marginals)} parameters with normally-distributed marginals!'
        )
        X = X[:, normal_marginals]
        d = X.shape[1]
    if d > 55:
        print(f'More than 55 features (d = {d}), selecting a random subset')
        n_replicates = 2 * (d // 55 + 1)
        print(f'Using {n_replicates} replicates')
        p_array = []
        for _ in range(n_replicates):
            idx_subset = np.random.choice(d, 55, replace=False)
            X_sub = X[:, idx_subset]
            _, _, _, p = stats_utils.fit_multivariate_normal(X_sub)
            p_array.append(p)
        print(p_array)
        p = np.min(p_array)
        print(np.mean(p_array), np.std(p_array))
    else:
        _, _, _, p = stats_utils.fit_multivariate_normal(X)

    alpha, _ = stats_utils.fit_alpha_stable(X)

    # now for laplace
    laplace_ps = []
    for di in range(d):
        Xd = X[:, di]
        _, _, _, pval = stats_utils.fit_laplace(Xd)
        laplace_ps.append(pval)
    laplace_ps = np.array(laplace_ps)
    print('without bonferroni...')
    fraction_of_laplace_vars = np.mean(laplace_ps > 0.05)
    print(f'\tfraction of laplace vars: {fraction_of_laplace_vars}')
    sum_of_laplace_vars = np.sum(laplace_ps > 0.05)
    print(f'\tsum of laplace vars: {sum_of_laplace_vars}')
    print('with bonferroni...')
    fraction_of_laplace_vars = np.mean(laplace_ps > 0.05 / d)
    print(f'\tfraction of laplace vars: {fraction_of_laplace_vars}')
    sum_of_laplace_vars = np.sum(laplace_ps > 0.05 / d)
    print(f'\tsum of laplace vars: {sum_of_laplace_vars}')
    mean_of_laplace_ps = np.mean(laplace_ps)
    print(f'mean of laplace ps: {mean_of_laplace_ps}')
    max_of_laplace_ps = np.max(laplace_ps)
    print(f'max of laplace ps: {max_of_laplace_ps}')

    return {'mvn p': p, 'alpha': alpha}
def compute_pairwise_sens_and_var(cfg_name,
                                  model,
                                  t,
                                  replace_indices,
                                  multivariate=False,
                                  verbose=True,
                                  diffinit=False):
    """
    for a pair of experiments...
    estimate sensitivity (distance between means)
    estimate variability (variance about means .. both?)
    given delta ... return this epsilon!
    optionally, by parameter (returns an array!)
    """

    if multivariate:
        raise NotImplementedError
    samples_1 = results_utils.get_posterior_samples(
        cfg_name, (t, t + 1),
        model,
        replace_index=replace_indices[0],
        params=None,
        seeds='all',
        verbose=verbose,
        diffinit=diffinit)
    samples_2 = results_utils.get_posterior_samples(
        cfg_name, (t, t + 1),
        model,
        replace_index=replace_indices[1],
        params=None,
        seeds='all',
        verbose=verbose,
        diffinit=diffinit)
    try:
        samples_1.set_index('seed', inplace=True)
        samples_2.set_index('seed', inplace=True)
    except AttributeError:
        print('ERROR: Issue loading samples from', replace_indices)

        return np.nan, np.nan, np.nan
    params = [x for x in samples_1.columns if not x == 't']
    samples_1 = samples_1[params]
    samples_2 = samples_2[params]
    # get intersection of seeds
    intersection = list(
        set(samples_1.index).intersection(set(samples_2.index)))
    num_seeds = len(intersection)

    if len(intersection) < 10:
        print(
            f'WARNING: Experiments with replace indices {replace_indices} only have {num_seeds} overlapping seeds: {intersection}'
        )

        return np.nan, np.nan, num_seeds
    samples_1_intersection = samples_1.loc[intersection, :]
    samples_2_intersection = samples_2.loc[intersection, :]
    # compute the distances on the same seed
    distances = np.linalg.norm(samples_1_intersection - samples_2_intersection,
                               axis=1)
    sensitivity = np.max(distances)

    if verbose:
        print('Max sensitivity from same seed diff data:', sensitivity)
    # compute distance by getting average value and comparing
    mean_1 = samples_1.mean(axis=0)
    mean_2 = samples_2.mean(axis=0)
    sensitivity_bymean = np.linalg.norm(mean_1 - mean_2)

    if verbose:
        print('Sensitivity from averaging posteriors and comparing:',
              sensitivity_bymean)
    variability_1 = (samples_1 - mean_1).values.std()
    variability_2 = (samples_2 - mean_2).values.std()
    variability = 0.5 * (variability_1 + variability_2)

    if verbose:
        print('Variability:', variability)

    return sensitivity, variability, num_seeds