def get_glm_fit_data(n=50): """Compute GLM fit across empirical datasets.""" col_dict = collections.defaultdict(list) for dataset_name, dataset_dict in list(Datasets.items()): dataset = dataset_dict[Folds.test] for glm_name, glm_model in list(GLMModels.items()): print((dataset_name, glm_name)) row_idx = pd.MultiIndex.from_tuples([(dataset_name, glm_name)], names=('dataset_name', 'glm_name')) curr_glm_fit_data = dataset.fit_glm_bootstrap(glm_model, n=n) for metric_name, metric_dict in curr_glm_fit_data.items(): for stat_name, stat_dict in metric_dict.items(): col_idx = (metric_name, stat_name, 'value') col_dict[col_idx].append( pd.Series([stat_dict['statistic']], index=row_idx)) col_idx = (metric_name, stat_name, 'lower') col_dict[col_idx].append( pd.Series([stat_dict['minmax'][0]], index=row_idx)) col_idx = (metric_name, stat_name, 'upper') col_dict[col_idx].append( pd.Series([stat_dict['minmax'][1]], index=row_idx)) for key, val in col_dict.items(): col_dict[key] = functools.reduce(lambda x, y: x.append(y), val) df_glm_fit = pd.DataFrame(col_dict) df_glm_fit.columns.names = ['parameter', 'statistic', 'estimate'] def f(): return collections.defaultdict(f) glm_fit_data_dict = collections.defaultdict(f) for curr_ds, glm, parameter, statistic, estimate in itertools.product( Datasets, GLMModels, ['AIC', 'nll', 'b0', 'b1'], ['mean', 'std'], ['value', 'lower', 'upper']): try: datum = df_glm_fit.loc[curr_ds.name, glm.name].loc[parameter, statistic, estimate] except KeyError: continue glm_fit_data_dict[curr_ds.name][ glm.name][parameter][statistic][estimate] = datum glm_fit_data_dict = json.loads(json.dumps(glm_fit_data_dict)) return { 'data': glm_fit_data_dict, 'dataframe': df_glm_fit, 'metadata': { 'N': n } }
def get_beta_fit_data(): """Perform MLE of Beta distribution of best fit.""" data_dict_beta_fit = collections.defaultdict(list) for dataset_name, dataset_dict in list(Datasets.items()): dataset = dataset_dict[Folds.test] beta_fit_best_param_dict = {'nll': float('inf')} for shift in [1e-16]: print((dataset_name, shift)) beta_fit_p1_dict = recursive_beta_shift_fit(dataset, arange=(0, 200), brange=(0, 50), n_s=11, tol=1e-5, cf=.5, shift=shift) if beta_fit_p1_dict['nll'] < beta_fit_best_param_dict['nll']: beta_fit_best_param_dict = beta_fit_p1_dict data_dict_beta_fit['dataset_name'].append(dataset_name) for key, val in beta_fit_best_param_dict.items(): data_dict_beta_fit[key].append(val) df_beta_fit = pd.DataFrame(data_dict_beta_fit).set_index(['dataset_name']) print(df_beta_fit) def f(): return collections.defaultdict(f) beta_fit_data_dict = collections.defaultdict(f) for curr_ds, parameter in itertools.product( Datasets, ['a', 'b', 'loc', 'scale', 'p1']): datum = df_beta_fit.loc[curr_ds.name, parameter] if isinstance(datum, np.int64): datum = int(datum) beta_fit_data_dict[curr_ds.name][parameter] = datum beta_fit_data_dict = json.loads(json.dumps(beta_fit_data_dict)) return {'data': beta_fit_data_dict, 'dataframe': df_beta_fit}
# Write glm_fit summary plot: for curr_Dataset, curr_GLMModel in itertools.product(Datasets, GLMModels): gm = curr_GLMModel.value ds = curr_Dataset.value[Folds.test] save_file_path = os.path.join('glm_fit_figs', curr_Dataset.name) if not os.path.exists(save_file_path): os.mkdir(save_file_path) save_file_name = os.path.join(save_file_path, '{}.png'.format(curr_GLMModel.name)) fig = gm.plot_fit_sequence(ds, figsize_single=3, fontsize=10) fig.savefig(save_file_name) # Write calibration curve plot: fig, ax = plt.subplots(figsize=(5.1, 3.1)) fontsize = 8 clrs = sns.color_palette('husl', n_colors=len(list(Datasets.items()))) LINE_STYLES = ['solid', 'dashed', 'dashdot', 'dotted'] NUM_STYLES = len(LINE_STYLES) for ii, (ds_name, ds_dict) in enumerate(Datasets.items()): ds = ds_dict[Folds.test] gm_name_AIC_dict = { gm_name: glm_fit_data['data'][ds_name][gm_name]['AIC']['mean']['value'] for gm_name, gm in GLMModels.items() } gm_best_name = min(gm_name_AIC_dict, key=gm_name_AIC_dict.get) gm_best = {key: val for key, val in GLMModels.items()}[gm_best_name] gm_best.plot_calibration(ax, ds, plot_yx=ii == 0, color=clrs[ii],