def __init__(self, cfg_name, model, data_privacy):
     self.cfg_name = cfg_name
     self.model = model
     self.data_privacy = data_privacy
     sample_experiment = results_utils.ExperimentIdentifier(
         cfg_name=cfg_name, model=model, data_privacy=data_privacy)
     self.derived_directory = sample_experiment.derived_path_stub()
     self.suffix = None
    def generate(self, diffinit):
        path_string = self.path_string(diffinit)

        if path_string.exists():
            print(
                f'[AggregatedLoss] {path_string} already exists, not recomputing!'
            )

            return

        df = results_utils.get_available_results(self.cfg_name, self.model)
        train_list = []
        vali_list = []

        for i, row in df.iterrows():
            experiment = results_utils.ExperimentIdentifier(
                self.cfg_name,
                self.model,
                replace_index=row['replace'],
                seed=row['seed'],
                diffinit=diffinit,
                data_privacy=self.data_privacy)
            try:
                loss = experiment.load_loss(iter_range=self.iter_range,
                                            verbose=False)
            except FileNotFoundError:
                print(
                    f'WARNING: Could not find loss for {experiment.path_stub()}'
                )
                continue
            loss_train = loss.loc[loss['minibatch_id'] == 'ALL', :].set_index(
                't')
            loss_vali = loss.loc[loss['minibatch_id'] == 'VALI', :].set_index(
                't')
            train_list.append(loss_train)
            vali_list.append(loss_vali)
        print('All traces collected')
        # dataframe
        train = pd.concat(train_list)
        vali = pd.concat(vali_list)
        # aggregate: mean and std
        train_mean = train.groupby('t').mean()
        train_std = train.groupby('t').std()
        vali_mean = vali.groupby('t').mean()
        vali_std = vali.groupby('t').std()
        # recombine
        train = train_mean.join(train_std, rsuffix='_std', lsuffix='_mean')
        vali = vali_mean.join(vali_std, rsuffix='_std', lsuffix='_mean')
        df = train.join(vali, lsuffix='_train', rsuffix='_vali')

        self.suffix = '.csv'
        df.to_csv(path_string, header=True, index=True)
        print(f'[AggregatedLoss] Saved to {path_string}')

        return
def find_convergence_point_for_single_experiment(cfg_name,
                                                 model,
                                                 replace_index,
                                                 seed,
                                                 diffinit=False,
                                                 tolerance=3,
                                                 metric='ce',
                                                 verbose=False,
                                                 data_privacy='all'):
    experiment = results_utils.ExperimentIdentifier(cfg_name,
                                                    model,
                                                    replace_index,
                                                    seed,
                                                    diffinit=diffinit,
                                                    data_privacy=data_privacy)
    loss = experiment.load_loss(iter_range=(None, None))
    try:
        assert metric in loss.columns
    except AssertionError:
        print('ERROR:', metric, 'is not in columns...', loss.columns)

        return np.nan
    loss = loss.loc[:, ['t', 'minibatch_id', metric]]
    loss = loss.pivot(index='t', columns='minibatch_id', values=metric)
    vali_loss = loss['VALI']
    delta_vali = vali_loss - vali_loss.shift()
    # was there a decrease at that time point? (1 if yes --> good)
    decrease = (delta_vali < 0)
    counter = 0

    for t, dec in decrease.items():
        if not dec:
            counter += 1
        else:
            counter = 0

        if counter >= tolerance:
            convergence_point = t

            break
    else:
        if verbose:
            print(
                f'Did not find instance of validation loss failing to decrease for {tolerance} steps - returning nan'
            )
        convergence_point = np.nan

    return convergence_point
def visualise_weight_trajectory(cfg_name,
                                identifiers,
                                df=None,
                                save=True,
                                iter_range=(None, None),
                                params=['#4', '#2'],
                                include_optimum=False,
                                include_autocorrelation=False,
                                diffinit=False) -> None:
    """
    """
    df_list = []

    for identifier in identifiers:
        model = identifier['model']
        replace_index = identifier['replace']
        seed = identifier['seed']
        experiment = results_utils.ExperimentIdentifier(
            cfg_name, model, replace_index, seed, diffinit)
        df = experiment.load_weights(iter_range=iter_range, params=params)
        df_list.append(df)
    colors = cm.viridis(np.linspace(0.2, 0.8, len(df_list)))
    labels = [':'.join(x) for x in identifiers]

    if params is None:
        if len(df.columns) > 6:
            print('WARNING: No parameters indicated, choosing randomly...')
            params = np.random.choice(df_list[0].columns[1:], 4, replace=False)
        else:
            print('WARNING: No parameters indicated, selecting all')
            params = df_list[0].columns[1:]

    for p in params:
        for df in df_list:
            assert p in df.columns

    if include_optimum:
        # hack!
        optimum, hessian = data_utils.solve_with_linear_regression(cfg_name)

    if include_autocorrelation:
        ncols = 2
    else:
        ncols = 1
    fig, axarr = plt.subplots(nrows=len(params),
                              ncols=ncols,
                              sharex='col',
                              figsize=(4 * ncols, 1.5 * len(params) + 1))

    firstcol = axarr[:, 0] if include_autocorrelation else axarr

    for k, df in enumerate(df_list):
        color = to_hex(colors[k])

        for i, p in enumerate(params):
            firstcol[i].scatter(df['t'],
                                df[p],
                                c=color,
                                alpha=1,
                                s=4,
                                label=labels[k])
            firstcol[i].plot(df['t'],
                             df[p],
                             c=color,
                             alpha=0.75,
                             label='_nolegend_')
            firstcol[i].set_ylabel('param: ' + str(p))

            if include_optimum:
                firstcol[i].axhline(y=optimum[int(p[1:])],
                                    ls='--',
                                    color='red',
                                    alpha=0.5)
        firstcol[0].set_title('weight trajectory')
        firstcol[-1].set_xlabel('training steps')
        firstcol[0].legend()

        if include_autocorrelation:
            n_lags = 100
            autocorr = np.zeros(n_lags)
            axarr[0, 1].set_title('autocorrelation of weight trajectory')

            for i, p in enumerate(params):
                for lag in range(n_lags):
                    autocorr[lag] = df[p].autocorr(lag=lag)

                axarr[i, 1].plot(range(n_lags),
                                 autocorr,
                                 alpha=0.5,
                                 color=color)
                axarr[i, 1].scatter(range(n_lags),
                                    autocorr,
                                    s=4,
                                    zorder=2,
                                    color=color)
                axarr[i, 1].set_ylabel(p)
                axarr[i, 1].axhline(y=0, ls='--', alpha=0.5, color='black')
            axarr[-1, 1].set_xlabel('lag')

    vis_utils.beautify_axes(axarr)
    plt.tight_layout()

    if save:
        plot_identifier = f'weights_{cfg_name}_{"_".join(labels)}'
        plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
        plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))
    plt.clf()
    plt.close()

    return
def visualise_trace(cfg_names,
                    models,
                    replaces,
                    seeds,
                    privacys,
                    save=True,
                    include_batches=False,
                    iter_range=(None, None),
                    include_convergence=True,
                    diffinit=False,
                    convergence_tolerance=3,
                    include_vali=True,
                    labels=None) -> None:
    """
    Show the full training set loss as well as the gradient (at our element) over training
    """
    identifiers = vis_utils.process_identifiers(cfg_names, models, replaces,
                                                seeds, privacys)
    print(identifiers)

    if len(identifiers) > 1:
        print(
            'WARNING: When more than one experiment is included, we turn off visualisation of batches to avoid cluttering the plot'
        )
        include_batches = False

    if labels is None:
        labels = [
            f'{x["cfg_name"]}-{x["model"]}-{x["replace"]}-{x["seed"]}'
            for x in identifiers
        ]
    else:
        assert len(labels) == len(identifiers)
    loss_list = []

    for identifier in identifiers:
        cfg_name = identifier['cfg_name']
        model = identifier['model']
        replace_index = identifier['replace']
        seed = identifier['seed']
        data_privacy = identifier['data_privacy']
        experiment = results_utils.ExperimentIdentifier(
            cfg_name,
            model,
            replace_index,
            seed,
            data_privacy=data_privacy,
            diffinit=diffinit)
        df_loss = experiment.load_loss(iter_range=iter_range)

        if df_loss is False:
            print('No fit data available for identifier:', identifier)
            df_loss = []
        loss_list.append(df_loss)

    if len(loss_list) == 0:
        print('Error: no valid data')

        return False

    if include_batches:
        minibatch_ids = loss_list[0]['minibatch_id'].unique()
        colormap = dict(
            zip(minibatch_ids, cm.viridis(np.linspace(0, 1,
                                                      len(minibatch_ids)))))
    colours = cm.viridis(np.linspace(0.2, 0.8, len(loss_list)))

    # what metrics were recorded for this run?
    metrics = loss_list[0].columns[2:]
    print('Visualising trace of', identifiers, 'with metrics', metrics)

    nrows = len(metrics)
    fig, axarr = plt.subplots(nrows=nrows,
                              ncols=1,
                              sharex='col',
                              figsize=(4, 3.2))

    if nrows == 1:
        axarr = np.array([axarr])

    for j, df in enumerate(loss_list):
        # this is just for the purpose of plotting the overall, not batches
        df_train = df.loc[df['minibatch_id'] == 'ALL', :]
        df_vali = df.loc[df['minibatch_id'] == 'VALI', :]

        # plot all

        for i, metric in enumerate(metrics):
            axarr[i].scatter(df_train['t'],
                             df_train[metric],
                             s=4,
                             color=colours[j],
                             zorder=2,
                             label='_nolegend_',
                             alpha=0.5)
            axarr[i].plot(df_train['t'],
                          df_train[metric],
                          alpha=0.25,
                          color=colours[j],
                          zorder=2,
                          label=labels[j])

            if include_vali:
                axarr[i].plot(df_vali['t'],
                              df_vali[metric],
                              ls='--',
                              color=colours[j],
                              zorder=2,
                              label='_nolegend_',
                              alpha=0.5)
            axarr[i].legend()

            if metric in ['mse']:
                axarr[i].set_yscale('log')
            axarr[i].set_ylabel(re.sub('_', '\n', metric))

            if include_batches:
                axarr[i].scatter(df['t'],
                                 df[metric],
                                 c=[colormap[x] for x in df['minibatch_id']],
                                 s=4,
                                 alpha=0.2,
                                 zorder=0)

                for minibatch_idx in df['minibatch_id'].unique():
                    df_temp = df.loc[df['minibatch_id'] == minibatch_idx, :]
                    axarr[i].plot(df_temp['t'],
                                  df_temp[metric],
                                  c=colormap[minibatch_idx],
                                  alpha=0.1,
                                  zorder=0)

    if include_convergence:
        for j, identifier in enumerate(identifiers):
            cfg_name = identifier['cfg_name']
            model = identifier['model']
            replace_index = identifier['replace']
            seed = identifier['seed']
            data_privacy = identifier['data_privacy']
            convergence_point = dr.find_convergence_point_for_single_experiment(
                cfg_name,
                model,
                replace_index,
                seed,
                diffinit,
                tolerance=convergence_tolerance,
                metric=metrics[0],
                data_privacy=data_privacy)
            print('Convergence point:', convergence_point)

            for ax in axarr:
                ax.axvline(x=convergence_point, ls='--', color=colours[j])
    axarr[-1].set_xlabel('training steps')

    vis_utils.beautify_axes(axarr)
    plt.tight_layout()

    if save:
        plot_label = '__'.join(
            [f'r{x["replace"]}-s{x["seed"]}' for x in identifiers])
        plot_identifier = f'trace_{cfg_name}_{plot_label}'
        plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
        plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))
    plt.clf()
    plt.close()

    return
def fit_pval_histogram(what,
                       cfg_name,
                       model,
                       t,
                       n_experiments=3,
                       diffinit=False,
                       xlim=None,
                       seed=1) -> None:
    """
    histogram of p-values (across parameters-?) for a given model etc.
    """
    assert what in ['weights', 'gradients']
    # set some stuff up
    iter_range = (t, t + 1)
    fig, axarr = plt.subplots(nrows=1, ncols=1, figsize=(3.5, 2.1))
    pval_colour = '#b237c4'
    # sample experiments
    df = results_utils.get_available_results(cfg_name,
                                             model,
                                             diffinit=diffinit)
    replace_indices = df['replace'].unique()
    replace_indices = np.random.choice(replace_indices,
                                       n_experiments,
                                       replace=False)
    print('Looking at replace indices...', replace_indices)
    all_pvals = []

    for i, replace_index in enumerate(replace_indices):
        experiment = results_utils.ExperimentIdentifier(
            cfg_name, model, replace_index, seed, diffinit)

        if what == 'gradients':
            print('Loading gradients...')
            df = experiment.load_gradients(noise=True,
                                           iter_range=iter_range,
                                           params=None)
            second_col = df.columns[1]
        elif what == 'weights':
            df = results_utils.get_posterior_samples(
                cfg_name,
                iter_range=iter_range,
                model=model,
                replace_index=replace_index,
                params=None,
                seeds='all')
            second_col = df.columns[1]
        params = df.columns[2:]
        n_params = len(params)
        print(n_params)

        if n_params < 50:
            print(
                'ERROR: Insufficient parameters for this kind of visualisation, please try something else'
            )

            return False
        print('Identified', n_params, 'parameters, proceeding with analysis')
        p_vals = np.zeros(shape=(n_params))

        for j, p in enumerate(params):
            print('getting fit for parameter', p)
            df_fit = dr.estimate_statistics_through_training(
                what=what,
                cfg_name=None,
                model=None,
                replace_index=None,
                seed=None,
                df=df.loc[:, ['t', second_col, p]],
                params=None,
                iter_range=None)
            p_vals[j] = df_fit.loc[t, 'norm_p']
            del df_fit
        log_pvals = np.log(p_vals)
        all_pvals.append(log_pvals)
    log_pvals = np.concatenate(all_pvals)

    if xlim is not None:
        # remove values below the limit
        number_below = (log_pvals < xlim[0]).sum()
        print('There are', number_below, 'p-values below the limit of',
              xlim[0])
        log_pvals = log_pvals[log_pvals > xlim[0]]
        print('Remaining pvals:', len(log_pvals))
    sns.distplot(log_pvals,
                 kde=True,
                 bins=min(100, int(len(log_pvals) * 0.25)),
                 ax=axarr,
                 color=pval_colour,
                 norm_hist=True)
    axarr.axvline(x=np.log(0.05),
                  ls=':',
                  label='p = 0.05',
                  color='black',
                  alpha=0.75)
    axarr.axvline(x=np.log(0.05 / n_params),
                  ls='--',
                  label='p = 0.05/' + str(n_params),
                  color='black',
                  alpha=0.75)
    axarr.legend()
    axarr.set_xlabel(r'$\log(p)$')
    axarr.set_ylabel('density')

    if xlim is not None:
        axarr.set_xlim(xlim)
    else:
        axarr.set_xlim((None, 0.01))


#    axarr.set_xscale('log')
    vis_utils.beautify_axes(np.array([axarr]))
    plt.tight_layout()
    plot_identifier = f'pval_histogram_{cfg_name}_{model}_{what}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def qq_plot(what: str,
            cfg_name: str,
            model: str,
            replace_index: int,
            seed: int,
            times=[50],
            params='random') -> None:
    """
    grab trace file, do qq plot for gradient noise at specified time-point
    """
    plt.clf()
    plt.close()
    assert what in ['gradients', 'weights']

    if what == 'weights':
        print('Looking at weights, this means we consider all seeds!')
    colours = cm.viridis(np.linspace(0.2, 0.8, len(times)))

    experiment = results_utils.ExperimentIdentifier(cfg_name, model,
                                                    replace_index, seed)

    if params == 'random':
        if what == 'gradients':
            df = experiment.load_gradients(noise=True,
                                           params=None,
                                           iter_range=(min(times),
                                                       max(times) + 1))
        else:
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=(min(times), max(times) + 1),
                params=None)
        params = np.random.choice(df.columns[2:], 1)
        print('picking random parameter', params)
        first_two_cols = df.columns[:2].tolist()
        df = df.loc[:, first_two_cols + list(params)]
    else:
        if what == 'gradients':
            df = experiment.load_gradients(noise=True,
                                           params=params,
                                           iter_range=(min(times),
                                                       max(times) + 1))
        else:
            df = results_utils.get_posterior_samples(
                cfg_name,
                model=model,
                replace_index=replace_index,
                iter_range=(min(times), max(times) + 1),
                params=params)

    if df is False:
        print('ERROR: No data available')

        return False
    fig, axarr = plt.subplots(nrows=1, ncols=2, figsize=(7, 3.5))

    for i, t in enumerate(times):
        df_t = df.loc[df['t'] == t, :]
        X = df_t.iloc[:, 2:].values.flatten()
        print('number of samples:', X.shape[0])
        sns.distplot(X,
                     ax=axarr[0],
                     kde=False,
                     color=to_hex(colours[i]),
                     label=str(t))
        sm.qqplot(X,
                  line='45',
                  fit=True,
                  ax=axarr[1],
                  c=colours[i],
                  alpha=0.5,
                  label=str(t))
    plt.suptitle('cfg_name: ' + cfg_name + ', model:' + model + ',' + what)
    axarr[0].legend()
    axarr[1].legend()
    axarr[0].set_xlabel('parameter:' + '.'.join(params))
    vis_utils.beautify_axes(axarr)
    plt.tight_layout()

    plot_identifier = f'qq_{what}_{cfg_name}_{model}_{"_".join(params)}'
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.png'))
    plt.savefig((PLOTS_DIR / plot_identifier).with_suffix('.pdf'))

    return
def get_deltas(cfg_name,
               iter_range,
               model,
               vary_seed=True,
               vary_data=True,
               params=None,
               num_deltas=100,
               include_identifiers=False,
               diffinit=False,
               data_privacy='all',
               multivariate=False,
               verbose=False):
    """
    collect samples of weights from experiments on cfg_name+model, varying:
    - seed (vary_seed)
    - data (vary_data)

    to clarify, we want to estimate |w(S, r) - w(S', r')|,
    with potentially S' = S (vary_data = False), or r' = r (vary_seed = False)

    we need to make sure that we only compare like-with-like!

    we want to get num_deltas values of delta in the end
    """
    df = results_utils.get_available_results(cfg_name,
                                             model,
                                             diffinit=diffinit,
                                             data_privacy=data_privacy)
    # filter out replaces with only a small number of seeds
    seeds_per_replace = df['replace'].value_counts()
    good_replaces = seeds_per_replace[seeds_per_replace > 2].index
    replace_per_seed = df['seed'].value_counts()
    good_seeds = replace_per_seed[replace_per_seed > 2].index
    df = df[df['replace'].isin(good_replaces)]
    df = df[df['seed'].isin(good_seeds)]

    if num_deltas == 'max':
        num_deltas = int(df.shape[0] / 2)
        print('Using num_deltas:', num_deltas)

    if df.shape[0] < 2 * num_deltas:
        print('ERROR: Run more experiments, or set num_deltas to be at most',
              int(df.shape[0] / 2))

        return None, None
    w_rows = np.random.choice(df.shape[0], num_deltas, replace=False)
    remaining_rows = [x for x in range(df.shape[0]) if x not in w_rows]
    df_remaining = df.iloc[remaining_rows]
    seed_options = df_remaining['seed'].unique()

    if len(seed_options) < 2:
        print('ERROR: Insufficient seeds!')

        return None, None
    data_options = df_remaining['replace'].unique()

    if len(data_options) == 1:
        print('ERROR: Insufficient data!')

        return None, None

    w = df.iloc[w_rows]
    w.reset_index(inplace=True)
    # now let's get comparators for each row of w!
    wp_data_vals = [np.nan] * w.shape[0]
    wp_seed_vals = [np.nan] * w.shape[0]

    for i, row in w.iterrows():
        row_data = row['replace']
        row_seed = row['seed']

        if not vary_seed:
            wp_seed = row_seed
        else:
            # get a new seed
            new_seed = np.random.choice(seed_options)

            while new_seed == row_seed:
                new_seed = np.random.choice(seed_options)
            wp_seed = new_seed

        if not vary_data:
            wp_data = row_data
        else:
            # get a new data
            new_data = np.random.choice(data_options)

            while new_data == row_data:
                new_data = np.random.choice(data_options)
            wp_data = new_data
        wp_data_vals[i] = wp_data
        wp_seed_vals[i] = wp_seed
    wp = pd.DataFrame({'replace': wp_data_vals, 'seed': wp_seed_vals})

    if vary_seed:
        # make sure the seed is always different
        assert ((wp['seed'].astype(int).values -
                 w['seed'].astype(int).values) == 0).sum() == 0
    else:
        # make sure it's alwys the same
        assert ((wp['seed'].astype(int).values -
                 w['seed'].astype(int).values) == 0).mean() == 1

    if vary_data:
        # make sure the data is always different
        assert ((wp['replace'].astype(int).values -
                 w['replace'].astype(int).values) == 0).sum() == 0
    else:
        assert ((wp['replace'].astype(int).values -
                 w['replace'].astype(int).values) == 0).mean() == 1

    deltas = [0] * num_deltas

    for i in range(num_deltas):
        replace_index = w.iloc[i]['replace']
        seed = w.iloc[i]['seed']

        exp = results_utils.ExperimentIdentifier(cfg_name, model,
                                                 replace_index, seed, diffinit,
                                                 data_privacy)

        if exp.exists():
            w_weights = exp.load_weights(iter_range=iter_range,
                                         params=params,
                                         verbose=False).values[:, 1:]
            # the first column is the time-step
        else:
            print('WARNING: Missing data for (seed, replace) = (', seed,
                  replace_index, ')')
            w_weights = np.array([np.nan])
        replace_index_p = wp.iloc[i]['replace']
        seed_p = wp.iloc[i]['seed']

        exp_p = results_utils.ExperimentIdentifier(cfg_name, model,
                                                   replace_index_p, seed_p,
                                                   diffinit, data_privacy)

        if exp_p.exists():
            wp_weights = exp_p.load_weights(iter_range=iter_range,
                                            params=params,
                                            verbose=False).values[:, 1:]
        else:
            print('WARNING: Missing data for (seed, replace) = (', seed_p,
                  replace_index_p, ')')
            wp_weights = np.array([np.nan])

        if multivariate:
            delta = np.abs(w_weights - wp_weights)
        else:
            delta = np.linalg.norm(w_weights - wp_weights)
        deltas[i] = delta
    w_identifiers = list(zip(w['replace'], w['seed']))
    wp_identifiers = list(zip(wp['replace'], wp['seed']))
    identifiers = np.array(list(zip(w_identifiers, wp_identifiers)))

    deltas = np.array(deltas)
    return deltas, identifiers