def get_xerror_k(self, k):
		assert k>=0 and k<len(self)
		sne_model = self.sne_models[k]
		if not sne_model is None and len(self)>0:
			return XError([self.fit_errors[k]])
		else:
			return XError(None)
Beispiel #2
0
def get_ps_times_df(
    rootdir,
    cfilename,
    kf,
    method,
    model_names,
    train_mode='pre-training',
):
    info_df = DFBuilder()
    new_model_names = utils.get_sorted_model_names(model_names)
    for kmn, model_name in enumerate(new_model_names):
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/model_info/{cfilename}'
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            set_name,
            fext='d',
            disbalanced_kf_mode='oversampling',  # error oversampling
            random_state=RANDOM_STATE,
        )
        print(
            f'{model_name} {files_ids}({len(files_ids)}#); model={model_name}')
        if len(files) == 0:
            continue

        survey = files[0]()['survey']
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        is_parallel = 'Parallel' in model_name

        loss_name = 'wmse-xentropy'
        print(files[0]()['monitors'][loss_name].keys())

        #th = 1 # bug?
        d = {}
        parameters = [f()['parameters'] for f in files][0]
        d['params'] = parameters
        #d['best_epoch'] = XError([f()['monitors']['wmse-xentropy']['best_epoch'] for f in files])
        d['time-per-iteration [segs]'] = sum(
            [f()['monitors'][loss_name]['time_per_iteration'] for f in files])
        #print(d['time-per-iteration [segs]'].max())
        #d['time-per-iteration/params $1e6\\cdot$[segs]'] = sum([f()['monitors'][loss_name]['time_per_iteration']/parameters*1e6 for f in files])
        #d['time_per_epoch'] = sum([f()['monitors']['wmse-xentropy']['time_per_epoch'] for f in files])

        print(files[0]()['monitors'][loss_name]['time_per_epoch'])

        #d['time_per_epoch [segs]'] = sum([f()['monitors'][loss_name]['time_per_epoch'] for f in files])
        d['time-per-epoch [segs]'] = XError(
            [f()['monitors'][loss_name]['total_time'] / 1500 for f in files])

        d['total-time [mins]'] = XError(
            [f()['monitors'][loss_name]['total_time'] / 60 for f in files])

        index = f'model={utils.get_fmodel_name(model_name)}'
        info_df.append(index, d)

    return info_df
def get_perf_times(rootdir, kf, lcset_name):
    files, files_ids = fcfiles.gather_files_by_kfold(rootdir, kf, lcset_name)
    times = []
    for f in files:
        if all([new_lcobj.all_synthetic() for new_lcobj in f()['new_lcobjs']]):
            times.append(f()['segs'])
    return XError(times)
def get_info_dict(
    rootdir,
    methods,
    cfilename,
    kf,
    lcset_name,
    band_names=['g', 'r'],
):
    info_df = DFBuilder()

    ### all info
    d = {}
    for method in methods:
        _rootdir = f'{rootdir}/{method}/{cfilename}'
        files, files_ids = fcfiles.gather_files_by_kfold(
            _rootdir, kf, lcset_name)
        trace_time = [f()['segs'] for f in files]
        d[method] = XError(trace_time)

    info_df.append(f'metric=trace-time [segs]~band=.', d)

    ### per band info
    for kb, b in enumerate(band_names):
        d = nested_dict()
        for method in methods:
            _rootdir = f'{rootdir}/{method}/{cfilename}'
            files, files_ids = fcfiles.gather_files_by_kfold(
                _rootdir, kf, lcset_name)
            traces = [f()['trace_bdict'][b] for f in files]
            trace_errors = flat_list([t.get_valid_errors() for t in traces])
            trace_errors_xe = XError(np.log(np.array(trace_errors) + _C.EPS))
            d['error'][method] = trace_errors_xe
            d['success'][method] = len(trace_errors) / sum(
                [len(t) for t in traces]) * 100

        d = d.to_dict()
        info_df.append(f'metric=fit-log-error~band={b}', d['error'])
        info_df.append(f'metric=fits-success [%]~band={b}', d['success'])

    return info_df.get_df()
def get_ranks(
    rootdir,
    kf,
    lcset_name,
    band_names=['g', 'r'],
):
    files, files_ids = fcfiles.gather_files_by_kfold(rootdir, kf, lcset_name)
    rank_bdict = {b: TopRank(f'band={b}') for b in band_names}
    for f, fid in zip(files, files_ids):
        lcobj_name = f()['lcobj_name']
        for b in band_names:
            errors = f()['trace_bdict'][b].get_valid_errors()
            if len(errors) == 0:
                continue

            xe = XError(errors)
            rank_bdict[b].append(fid, xe.mean)

    for b in band_names:
        rank_bdict[b].calcule()
    return rank_bdict
    def get_bstats_idf_c(
        self,
        c,
        b,
        index=None,
    ):
        lcobjs = self.get_lcobjs(c)
        if len(lcobjs) > 0:
            info_dict = {
                f'{c}-$x$':
                XError(np.concatenate([x.get_b(b).obs for x in lcobjs])),
                f'{c}-$L$':
                XError([len(x.get_b(b)) for x in lcobjs]),
                f'{c}-$\Delta T$':
                XError([
                    x.get_b(b).get_days_duration() for x in lcobjs
                    if len(x.get_b(b)) >= 1
                ]),
                f'{c}-$\Delta t$':
                XError(
                    np.concatenate(
                        [x.get_b(b).get_diff('days') for x in lcobjs])),
            }
        else:
            info_dict = {
                f'{c}-$x$': XError([]),
                f'{c}-$L$': XError([]),
                f'{c}-$\Delta T$': XError([]),
                f'{c}-$\Delta t$': XError([]),
            }

        info_dict = {id(self) if index is None else index: info_dict}
        df = pd.DataFrame.from_dict(info_dict, orient='index').reindex(
            list(info_dict.keys()))
        df.index.rename(C_.SET_NAME_STR, inplace=True)
        return df
 def get_serial_stats_idf_c(self, c):
     lcobjs = self.get_lcobjs(c)
     if len(lcobjs) > 0:
         xs = [lcobj.get_x_serial() for lcobj in lcobjs]
         info_dict = {
             f'{c}-$x$':
             XError(np.concatenate([x[:, C_.OBS_INDEX] for x in xs])),
             f'{c}-$L$':
             XError([len(lcobj) for lcobj in lcobjs]),
             f'{c}-$\Delta T$':
             XError([lcobj.get_days_serial_duration() for lcobj in lcobjs]),
             f'{c}-$\Delta t$':
             XError(
                 np.concatenate(
                     [diff_vector(x[:, C_.DAYS_INDEX]) for x in xs])),
         }
     else:
         info_dict = {
             f'{c}-$x$': XError([]),
             f'{c}-$L$': XError([]),
             f'{c}-$\Delta T$': XError([]),
             f'{c}-$\Delta t$': XError([]),
         }
     return info_dict
def plot_ocurve_classes(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_classes,
    thday,
    baselines_dict={},
    figsize=FIGSIZE_1X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE,
    shadow_alpha=SHADOW_ALPHA,
    ocurve_name='rocc',
    baseline_roodir=None,
):
    for kmn, model_name in enumerate(model_names):
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        for target_class in target_classes:
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = ftfiles.gather_files_by_kfold(
                load_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(
                f'{model_name} {files_ids}({len(files_ids)}#); model={model_name}'
            )
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            thdays = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday]['aucroc'].item()
                for f in files
            ])
            label = f'{target_class}; AUC={xe_aucroc}'
            color = CLASSES_STYLES[target_class]['c']

            ocurves = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday][ocurve_name].item()
                for f in files
            ]
            fill_beetween(
                ax,
                [ocurve[XLABEL_DICT[ocurve_name]] for ocurve in ocurves],
                [ocurve[YLABEL_DICT[ocurve_name]] for ocurve in ocurves],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        ax.plot([0, 1],
                GUIDE_CURVE_DICT[ocurve_name],
                '--',
                color='k',
                alpha=1,
                lw=1)
        ax.set_xlabel(XLABEL_DICT[ocurve_name])
        ax.set_ylabel(YLABEL_DICT[ocurve_name])
        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=.5)
        ax.set_axisbelow(True)
        ax.legend(loc='lower right')

        title = ''
        title += f'{ocurve_name.upper()[:-1]} operative curves for SNe classes' + '\n'
        title += f'set={survey} [{lcset_name.replace(".@", "")}]' + '\n'
        title += f'th-day={thday:.3f} [days]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

    fig.tight_layout()
    plt.show()
def plot_ocurve_models(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_class,
    thday,
    baselines_dict={},
    figsize=FIGSIZE_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE,
    shadow_alpha=SHADOW_ALPHA,
    ocurve_name='rocc',
    baseline_roodir=None,
):
    fig, axs = plt.subplots(1, 2, figsize=figsize)
    ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
    for kax, ax in enumerate(axs):
        if len(ps_model_names[kax]) == 0:
            continue
        color_dict = utils.get_color_dict(ps_model_names[kax])
        for kmn, model_name in enumerate(ps_model_names[kax]):
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = ftfiles.gather_files_by_kfold(
                load_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(f'{model_name} {files_ids}({len(files_ids)}#)')
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            days = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf'][target_class].loc[
                    f()['days_class_metrics_cdf'][target_class]['_day'] ==
                    thday]['auc' + ocurve_name[:-1]].item() for f in files
            ])
            roccs = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday][ocurve_name].item()
                for f in files
            ]

            label = f'{utils.get_fmodel_name(model_name)}; AUC={xe_aucroc}'
            color = color_dict[utils.get_fmodel_name(model_name)]
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        title = ''
        title += f'{target_class}-ROC curve' + '\n'
        title += f'set={survey} [{lcset_name.replace(".@", "")}]' + '\n'
        title += f'th-day={thday:.3f} [days]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

        if not baseline_roodir is None:
            files, files_ids = ftfiles.gather_files_by_kfold(
                baseline_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(f'{files_ids}({len(files_ids)}#); model={model_name}')

            xe_aucroc = XError([
                f()['thdays_class_metrics_cdf'][target_class].loc[
                    f()['thdays_class_metrics_cdf'][target_class]['_thday'] ==
                    thday]['auc' + ocurve_name[:-1]].item() for f in files
            ])
            roccs = [
                f()['thdays_class_metrics_cdf'][target_class].loc[
                    f()['thdays_class_metrics_cdf'][target_class]['_thday'] ==
                    thday][ocurve_name].item() for f in files
            ]

            label = f'{utils.get_fmodel_name(BASELINE_MODEL_NAME)}; AUC={xe_aucroc}'
            color = 'k'
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                    'linestyle': '--',
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

    for kax, ax in enumerate(axs):
        ax.plot([0, 1], [0, 1], '--', color='k', alpha=1, lw=1)
        ax.set_xlabel('FPR')
        if kax == 0:
            ax.set_ylabel('TPR')
            ax.set_title(f'{bf_alphabet_count(0)} Parallel models')
        else:
            ax.set_yticklabels([])
            ax.set_title(f'{bf_alphabet_count(1)} Serial models')

        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=0.5)
        ax.legend(loc='lower right')

    fig.tight_layout()
    plt.show()
 def get_time_per_epoch_set(self, set_name):
     loss_df_epoch = self.loss_df_epoch.get_df()
     return XError([
         v for v in loss_df_epoch['_dt'][loss_df_epoch['_set'].isin(
             [set_name])].values
     ])
	def get_xerror(self):
		errors = self.get_valid_errors()
		return XError(errors)
Beispiel #12
0
def plot_metric(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    dmetrics,
    target_class=None,
    baselines_dict={},
    figsize=RECT_PLOT_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE_PLOT,
    shadow_alpha=SHADOW_ALPHA,
):
    for metric_name in dmetrics.keys():
        fig, axs = plt.subplots(1, 2, figsize=figsize)
        ps_model_names = utils.get_sorted_model_names(model_names,
                                                      merged=False)
        for kax, ax in enumerate(axs):
            if len(ps_model_names[kax]) == 0:
                continue
            color_dict = utils.get_color_dict(ps_model_names[kax])
            ylims = [[], []]
            for kmn, model_name in enumerate(ps_model_names[kax]):
                load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
                files, files_ids = fcfiles.gather_files_by_kfold(load_roodir,
                                                                 kf,
                                                                 lcset_name,
                                                                 fext='d')
                print(f'{model_name} {files_ids}({len(files_ids)}#)')
                if len(files) == 0:
                    continue

                survey = files[0]()['survey']
                band_names = files[0]()['band_names']
                class_names = files[0]()['class_names']
                days = files[0]()['days']

                if target_class is None:
                    metric_curves = [
                        f()['days_class_metrics_df'][metric_name].values
                        for f in files
                    ]
                else:
                    metric_curves = [
                        f()['days_class_metrics_cdf'][target_class]
                        [metric_name.replace('b-', '')].values for f in files
                    ]
                xe_metric_curve_avg = XError(
                    np.mean(np.concatenate(
                        [metric_curve[None] for metric_curve in metric_curves],
                        axis=0),
                            axis=-1))

                label = f'{utils.get_fmodel_name(model_name)} | AUC={xe_metric_curve_avg}'
                color = color_dict[utils.get_fmodel_name(model_name)]
                fill_beetween(
                    ax,
                    [days for metric_curve in metric_curves],
                    [metric_curve for metric_curve in metric_curves],
                    fill_kwargs={
                        'color': color,
                        'alpha': shadow_alpha,
                        'lw': 0,
                    },
                    median_kwargs={
                        'color': color,
                        'alpha': 1,
                    },
                    percentile=percentile,
                )
                ax.plot([None], [None], color=color, label=label)
                ylims[0] += [ax.get_ylim()[0]]
                ylims[1] += [ax.get_ylim()[1]]

            mn = metric_name if dmetrics[metric_name][
                'mn'] is None else dmetrics[metric_name]['mn']
            mn = mn if target_class is None else mn.replace(
                'b-', f'{target_class}-')
            title = ''
            title += f'{mn} v/s days' + '\n'
            title += f'train-mode={train_mode} - survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]' + '\n'
            fig.suptitle(title[:-1], va='bottom')

        for kax, ax in enumerate(axs):
            if f'{kf}@{lcset_name}' in baselines_dict.keys():
                # ax.plot(days, np.full_like(days, baselines_dict[f'{kf}@{lcset_name}'][metric_name]), ':', c='k', label=f'FATS & b-RF Baseline (day={days[-1]:.3f})')
                pass

            ax.set_xlabel('time [days]')
            if kax == 1:
                ax.set_yticklabels([])
                ax.set_title('serial models')
            else:
                ax.set_ylabel(mn)
                ax.set_title('parallel models')

            ax.set_xlim([days.min(), days.max()])
            ax.set_ylim(min(ylims[0]), max(ylims[1]) * 1.05)
            ax.grid(alpha=0.5)
            ax.legend(loc='lower right')

        fig.tight_layout()
        plt.show()
Beispiel #13
0
def plot_rocc(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_class,
    target_day,
    baselines_dict={},
    figsize=RECT_PLOT_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE_PLOT,
    shadow_alpha=SHADOW_ALPHA,
):
    fig, axs = plt.subplots(1, 2, figsize=figsize)
    ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
    for kax, ax in enumerate(axs):
        if len(ps_model_names[kax]) == 0:
            continue
        color_dict = utils.get_color_dict(ps_model_names[kax])
        for kmn, model_name in enumerate(ps_model_names[kax]):
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = fcfiles.gather_files_by_kfold(load_roodir,
                                                             kf,
                                                             lcset_name,
                                                             fext='d')
            print(f'{model_name} {files_ids}({len(files_ids)}#)')
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            days = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == target_day]['aucroc'].item()
                for f in files
            ])
            label = f'{utils.get_fmodel_name(model_name)} | AUC={xe_aucroc}'
            color = color_dict[utils.get_fmodel_name(model_name)]

            roccs = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == target_day]['rocc'].item()
                for f in files
            ]
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        title = ''
        title += f'{target_class}-ROC curve ({target_day:.3f} [days])' + '\n'
        title += f'train-mode={train_mode} - survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

    for kax, ax in enumerate(axs):
        ax.plot([0, 1], [0, 1], '--', color='k', alpha=1, lw=1)
        ax.set_xlabel('FPR')
        if kax == 0:
            ax.set_ylabel('TPR')
            ax.set_title('parallel models')
        else:
            ax.set_yticklabels([])
            ax.set_title('serial models')

        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=0.5)
        ax.legend(loc='lower right')

    fig.tight_layout()
    plt.show()
def plot_metric(rootdir, cfilename, kf, lcset_name, model_names, dmetrics,
	target_class=None,
	figsize=FIGSIZE_2X1,
	train_mode='fine-tuning',
	percentile=PERCENTILE,
	shadow_alpha=SHADOW_ALPHA,
	baseline_roodir=None,
	):
	for metric_name in dmetrics.keys():
		fig, axs = plt.subplots(1, 2, figsize=figsize)
		axis_lims = AxisLims({'x':(None, None), 'y':(0, 1)}, {'x':.0, 'y':.1})
		ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
		for kax,ax in enumerate(axs):
			if len(ps_model_names[kax])==0:
				continue
			color_dict = utils.get_color_dict(ps_model_names[kax])
			for kmn,model_name in enumerate(ps_model_names[kax]):
				load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
				files, files_ids = ftfiles.gather_files_by_kfold(load_roodir, kf, lcset_name,
					fext='d',
					disbalanced_kf_mode='oversampling', # error oversampling
					random_state=RANDOM_STATE,
					)
				print(f'{files_ids}({len(files_ids)}#); model={model_name}')
				if len(files)==0:
					continue

				survey = files[0]()['survey']
				band_names = files[0]()['band_names']
				class_names = files[0]()['class_names']
				thdays = files[0]()['days']

				if target_class is None:
					metric_curves = [f()['days_class_metrics_df'][f'b-{metric_name}'].values for f in files]
				else:
					metric_curves = [f()['days_class_metrics_cdf'][target_class][f'{metric_name}'].values for f in files] 
				xe_metric_curve_auc = XError(np.mean(np.concatenate([metric_curve[None] for metric_curve in metric_curves], axis=0), axis=-1)) # (b,t)

				model_label = utils.get_fmodel_name(model_name)
				label = f'{model_label}; AUC={xe_metric_curve_auc}'
				color = color_dict[utils.get_fmodel_name(model_name)]
				lines.fill_beetween(ax, [thdays for _ in metric_curves], metric_curves,
					fill_kwargs={'color':color, 'alpha':shadow_alpha, 'lw':0,},
					median_kwargs={'color':color, 'alpha':1,},
					percentile=percentile,
					)
				ax.plot([None], [None], color=color, label=label)
				axis_lims.append('x', thdays)
				axis_lims.append('y', np.concatenate([metric_curve for metric_curve in metric_curves], axis=0))

			new_metric_name = f'{"b" if target_class is None else target_class}-{metric_name if dmetrics[metric_name]["mn"] is None else dmetrics[metric_name]["mn"]}'
			suptitle = ''
			suptitle += f'{new_metric_name} v/s days using moving th-day'+'\n'
			suptitle += f'set={survey} [{lcset_name.replace(".@", "")}]'+'\n'
			fig.suptitle(suptitle[:-1], va='bottom')

		for kax,ax in enumerate(axs):
			if not baseline_roodir is None:
				files, files_ids = ftfiles.gather_files_by_kfold(baseline_roodir, kf, lcset_name,
					fext='d',
					disbalanced_kf_mode='oversampling', # error oversampling
					random_state=RANDOM_STATE,
					)
				print(f'{files_ids}({len(files_ids)}#); model={model_name}')
				thdays = files[0]()['thdays']
				for kthday,thday in enumerate(thdays):
					if target_class is None:
						xe_metric = XError([f()['thdays_class_metrics_df'].loc[f()['thdays_class_metrics_df']['_thday']==thday][f'b-{metric_name}'].item() for f in files])
					else:
						xe_metric = XError([f()['thdays_class_metrics_cdf'][target_class].loc[f()['thdays_class_metrics_df']['_thday']==thday][f'{metric_name}'].item() for f in files])

					model_label = utils.get_fmodel_name(BASELINE_MODEL_NAME)
					ax.plot(thday, xe_metric.p50, 'D', c='k', label=f'{model_label}' if kthday==0 else None)

				ax.axhline(xe_metric.p50, linestyle='--', c='k')

			ax.set_xlabel('time [days]')
			if kax==0:
				ax.set_ylabel(new_metric_name)
				ax.set_title(f'{bf_alphabet_count(0)} Parallel models')
			else:
				ax.set_yticklabels([])
				ax.set_title(f'{bf_alphabet_count(1)} Serial models')

			axis_lims.set_ax_axis_lims(ax)
			ax.grid(alpha=0.5)
			ax.legend(loc='lower right')

		fig.tight_layout()
		plt.show()
Beispiel #15
0
def plot_cm(rootdir, cfilename, kf, lcset_name, model_names,
	figsize=FIGSIZE,
	train_mode='fine-tuning',
	export_animation=False,
	animation_duration=12,
	new_order_classes=['SNIa', 'SNIbc', 'SNII-b-n', 'SLSN'],
	percentile=PERCENTILE,
	):
	for kmn,model_name in enumerate(model_names):
		load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
		files, files_ids = ftfiles.gather_files_by_kfold(load_roodir, kf, lcset_name,
			fext='d',
			disbalanced_kf_mode='oversampling', # error oversampling
			random_state=RANDOM_STATE,
			)
		print(f'ids={files_ids}(n={len(files_ids)}#); model={model_name}')
		if len(files)==0:
			continue

		survey = files[0]()['survey']
		band_names = files[0]()['band_names']
		class_names = files[0]()['class_names']
		is_parallel = 'Parallel' in model_name
		days = files[0]()['days']

		plot_animation = PlotAnimator(animation_duration,
			is_dummy=not export_animation,
			#save_init_frame=True,
			save_end_frame=True,
			)

		thdays = days if export_animation else [days[-1]]
		bar = ProgressBar(len(thdays), bar_format='{l_bar}{bar}{postfix}')
		for kd,thday in enumerate(thdays):
			bar(f'thday={thday:.3f} [days]')
			xe_dict = {}
			for metric_name in ['b-precision', 'b-recall', 'b-f1score']:
				xe_metric = XError([f()['days_class_metrics_df'].loc[f()['days_class_metrics_df']['_day']==thday][metric_name].item() for f in files])
				xe_dict[metric_name] = xe_metric

			bprecision_xe = xe_dict['b-precision']
			brecall_xe = xe_dict['b-recall']
			bf1score_xe = xe_dict['b-f1score']

			title = ''
			title += f'{utils.get_fmodel_name(model_name)}'+'\n'
			#title += f'survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]'+'\n'
			#title += f'train-mode={train_mode}; eval-set={kf}@{lcset_name}'+'\n'
			title += f'b-recall={brecall_xe}; b-f1score={bf1score_xe}'+'\n'
			title += f'th-day={thday:.3f} [days]'+'\n'
			#title += f'b-p/r={bprecision_xe} / {brecall_xe}'+'\n'
			#title += f'b-f1score={bf1score_xe}'+'\n'
			cms = np.concatenate([f()['days_cm'][thday][None] for f in files], axis=0)
			fig, ax, cm_norm = plot_custom_confusion_matrix(cms, class_names,
				#fig=fig,
				#ax=ax,
				title=title[:-1],
				figsize=figsize,
				new_order_classes=new_order_classes,
				percentile=percentile,
				)
			uses_close_fig = kd<len(days)-1
			plot_animation.append(fig, uses_close_fig)

		bar.done()
		plt.show()
		plot_animation.save(f'../temp/{model_name}.gif') # gif mp4
Beispiel #16
0
def get_ps_performance_df(
    rootdir,
    cfilename,
    kf,
    set_name,
    model_names,
    dmetrics,
    target_class=None,
    thday=None,
    train_mode='fine-tuning',
    n=1e3,
    uses_avg=False,
    baseline_roodir=None,
):
    info_df = DFBuilder()
    new_model_names = utils.get_sorted_model_names(model_names)
    new_model_names = [
        BASELINE_MODEL_NAME
    ] + new_model_names if not baseline_roodir is None else new_model_names
    for kmn, model_name in enumerate(new_model_names):
        is_baseline = 'BRF' in model_name
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}' if not is_baseline else baseline_roodir
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            set_name,
            fext='d',
            disbalanced_kf_mode='oversampling',  # error oversampling
            random_state=RANDOM_STATE,
        )
        print(f'{files_ids}({len(files_ids)}#); model={model_name}')
        if len(files) == 0:
            continue

        fixme = 'th' if kmn == 0 else ''
        # survey = files[0]()['survey'] # fixme
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        thdays = files[0]()[fixme + 'days']

        d = {}
        for km, metric_name in enumerate(dmetrics.keys()):
            new_metric_name = f'{"b" if target_class is None else target_class}-{metric_name if dmetrics[metric_name]["mn"] is None else dmetrics[metric_name]["mn"]}'
            if not uses_avg:
                if target_class is None:
                    xe_metric = XError([
                        f()[fixme + 'days_class_metrics_df'].loc[f()[
                            fixme + 'days_class_metrics_df']['_' + fixme +
                                                             'day'] == thday]
                        [f'b-{metric_name}'].item() for f in files
                    ])
                else:
                    xe_metric = XError([
                        f()[fixme +
                            'days_class_metrics_cdf'][target_class].loc[f()[
                                fixme + 'days_class_metrics_df'][
                                    '_' + fixme +
                                    'day'] == thday][f'{metric_name}'].item()
                        for f in files
                    ])
                d[new_metric_name] = xe_metric
            else:
                if is_baseline:
                    d[new_metric_name] = XError([-999])
                else:
                    if target_class is None:
                        metric_curves = [
                            f()[fixme + 'days_class_metrics_df']
                            [f'b-{metric_name}'].values for f in files
                        ]
                    else:
                        metric_curves = [
                            f()[fixme + 'days_class_metrics_cdf'][target_class]
                            [f'{metric_name}'].values for f in files
                        ]
                    # print(np.concatenate([metric_curve[None] for metric_curve in metric_curves], axis=0).shape)
                    xe_metric_curve_auc = XError(
                        np.mean(np.concatenate([
                            metric_curve[None]
                            for metric_curve in metric_curves
                        ],
                                               axis=0),
                                axis=-1))  # (b,t)>(b)
                    # interp_metric_curve = interp1d(thdays, metric_curve)(np.linspace(thdays.min(), thday, int(n)))
                    # xe_metric_curve_avg = XError(np.mean(interp_metric_curve, axis=-1))
                    d[new_metric_name] = xe_metric_curve_auc

        index = f'model={utils.get_fmodel_name(model_name)}'
        info_df.append(index, d)

    return info_df
 def get_time_per_iteration(self):
     loss_df = self.loss_df.get_df()
     return XError([v for v in loss_df['_dt'].values])