def get_perf_times(rootdir, kf, lcset_name):
    files, files_ids = fcfiles.gather_files_by_kfold(rootdir, kf, lcset_name)
    times = []
    for f in files:
        if all([new_lcobj.all_synthetic() for new_lcobj in f()['new_lcobjs']]):
            times.append(f()['segs'])
    return XError(times)
def get_all_incorrects_fittings(rootdir, kf, lcset_name):
    files, files_ids = fcfiles.gather_files_by_kfold(rootdir, kf, lcset_name)
    lcobj_names = []
    for f in files:
        if all([new_lcobj.all_real() for new_lcobj in f()['new_lcobjs']]):
            lcobj_names.append(f()['lcobj_name'])
    return lcobj_names
示例#3
0
def get_ps_times_df(
    rootdir,
    cfilename,
    kf,
    method,
    model_names,
    train_mode='pre-training',
):
    info_df = DFBuilder()
    new_model_names = utils.get_sorted_model_names(model_names)
    for kmn, model_name in enumerate(new_model_names):
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/model_info/{cfilename}'
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            set_name,
            fext='d',
            disbalanced_kf_mode='oversampling',  # error oversampling
            random_state=RANDOM_STATE,
        )
        print(
            f'{model_name} {files_ids}({len(files_ids)}#); model={model_name}')
        if len(files) == 0:
            continue

        survey = files[0]()['survey']
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        is_parallel = 'Parallel' in model_name

        loss_name = 'wmse-xentropy'
        print(files[0]()['monitors'][loss_name].keys())

        #th = 1 # bug?
        d = {}
        parameters = [f()['parameters'] for f in files][0]
        d['params'] = parameters
        #d['best_epoch'] = XError([f()['monitors']['wmse-xentropy']['best_epoch'] for f in files])
        d['time-per-iteration [segs]'] = sum(
            [f()['monitors'][loss_name]['time_per_iteration'] for f in files])
        #print(d['time-per-iteration [segs]'].max())
        #d['time-per-iteration/params $1e6\\cdot$[segs]'] = sum([f()['monitors'][loss_name]['time_per_iteration']/parameters*1e6 for f in files])
        #d['time_per_epoch'] = sum([f()['monitors']['wmse-xentropy']['time_per_epoch'] for f in files])

        print(files[0]()['monitors'][loss_name]['time_per_epoch'])

        #d['time_per_epoch [segs]'] = sum([f()['monitors'][loss_name]['time_per_epoch'] for f in files])
        d['time-per-epoch [segs]'] = XError(
            [f()['monitors'][loss_name]['total_time'] / 1500 for f in files])

        d['total-time [mins]'] = XError(
            [f()['monitors'][loss_name]['total_time'] / 60 for f in files])

        index = f'model={utils.get_fmodel_name(model_name)}'
        info_df.append(index, d)

    return info_df
def get_info_dict(
    rootdir,
    methods,
    cfilename,
    kf,
    lcset_name,
    band_names=['g', 'r'],
):
    info_df = DFBuilder()

    ### all info
    d = {}
    for method in methods:
        _rootdir = f'{rootdir}/{method}/{cfilename}'
        files, files_ids = fcfiles.gather_files_by_kfold(
            _rootdir, kf, lcset_name)
        trace_time = [f()['segs'] for f in files]
        d[method] = XError(trace_time)

    info_df.append(f'metric=trace-time [segs]~band=.', d)

    ### per band info
    for kb, b in enumerate(band_names):
        d = nested_dict()
        for method in methods:
            _rootdir = f'{rootdir}/{method}/{cfilename}'
            files, files_ids = fcfiles.gather_files_by_kfold(
                _rootdir, kf, lcset_name)
            traces = [f()['trace_bdict'][b] for f in files]
            trace_errors = flat_list([t.get_valid_errors() for t in traces])
            trace_errors_xe = XError(np.log(np.array(trace_errors) + _C.EPS))
            d['error'][method] = trace_errors_xe
            d['success'][method] = len(trace_errors) / sum(
                [len(t) for t in traces]) * 100

        d = d.to_dict()
        info_df.append(f'metric=fit-log-error~band={b}', d['error'])
        info_df.append(f'metric=fits-success [%]~band={b}', d['success'])

    return info_df.get_df()
def get_ranks(
    rootdir,
    kf,
    lcset_name,
    band_names=['g', 'r'],
):
    files, files_ids = fcfiles.gather_files_by_kfold(rootdir, kf, lcset_name)
    rank_bdict = {b: TopRank(f'band={b}') for b in band_names}
    for f, fid in zip(files, files_ids):
        lcobj_name = f()['lcobj_name']
        for b in band_names:
            errors = f()['trace_bdict'][b].get_valid_errors()
            if len(errors) == 0:
                continue

            xe = XError(errors)
            rank_bdict[b].append(fid, xe.mean)

    for b in band_names:
        rank_bdict[b].calcule()
    return rank_bdict
示例#6
0
def get_ps_performance_df(
    rootdir,
    cfilename,
    kf,
    set_name,
    model_names,
    dmetrics,
    target_class=None,
    thday=None,
    train_mode='fine-tuning',
    n=1e3,
    uses_avg=False,
    baseline_roodir=None,
):
    info_df = DFBuilder()
    new_model_names = utils.get_sorted_model_names(model_names)
    new_model_names = [
        BASELINE_MODEL_NAME
    ] + new_model_names if not baseline_roodir is None else new_model_names
    for kmn, model_name in enumerate(new_model_names):
        is_baseline = 'BRF' in model_name
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}' if not is_baseline else baseline_roodir
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            set_name,
            fext='d',
            disbalanced_kf_mode='oversampling',  # error oversampling
            random_state=RANDOM_STATE,
        )
        print(f'{files_ids}({len(files_ids)}#); model={model_name}')
        if len(files) == 0:
            continue

        fixme = 'th' if kmn == 0 else ''
        # survey = files[0]()['survey'] # fixme
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        thdays = files[0]()[fixme + 'days']

        d = {}
        for km, metric_name in enumerate(dmetrics.keys()):
            new_metric_name = f'{"b" if target_class is None else target_class}-{metric_name if dmetrics[metric_name]["mn"] is None else dmetrics[metric_name]["mn"]}'
            if not uses_avg:
                if target_class is None:
                    xe_metric = XError([
                        f()[fixme + 'days_class_metrics_df'].loc[f()[
                            fixme + 'days_class_metrics_df']['_' + fixme +
                                                             'day'] == thday]
                        [f'b-{metric_name}'].item() for f in files
                    ])
                else:
                    xe_metric = XError([
                        f()[fixme +
                            'days_class_metrics_cdf'][target_class].loc[f()[
                                fixme + 'days_class_metrics_df'][
                                    '_' + fixme +
                                    'day'] == thday][f'{metric_name}'].item()
                        for f in files
                    ])
                d[new_metric_name] = xe_metric
            else:
                if is_baseline:
                    d[new_metric_name] = XError([-999])
                else:
                    if target_class is None:
                        metric_curves = [
                            f()[fixme + 'days_class_metrics_df']
                            [f'b-{metric_name}'].values for f in files
                        ]
                    else:
                        metric_curves = [
                            f()[fixme + 'days_class_metrics_cdf'][target_class]
                            [f'{metric_name}'].values for f in files
                        ]
                    # print(np.concatenate([metric_curve[None] for metric_curve in metric_curves], axis=0).shape)
                    xe_metric_curve_auc = XError(
                        np.mean(np.concatenate([
                            metric_curve[None]
                            for metric_curve in metric_curves
                        ],
                                               axis=0),
                                axis=-1))  # (b,t)>(b)
                    # interp_metric_curve = interp1d(thdays, metric_curve)(np.linspace(thdays.min(), thday, int(n)))
                    # xe_metric_curve_avg = XError(np.mean(interp_metric_curve, axis=-1))
                    d[new_metric_name] = xe_metric_curve_auc

        index = f'model={utils.get_fmodel_name(model_name)}'
        info_df.append(index, d)

    return info_df
def plot_ocurve_models(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_class,
    thday,
    baselines_dict={},
    figsize=FIGSIZE_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE,
    shadow_alpha=SHADOW_ALPHA,
    ocurve_name='rocc',
    baseline_roodir=None,
):
    fig, axs = plt.subplots(1, 2, figsize=figsize)
    ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
    for kax, ax in enumerate(axs):
        if len(ps_model_names[kax]) == 0:
            continue
        color_dict = utils.get_color_dict(ps_model_names[kax])
        for kmn, model_name in enumerate(ps_model_names[kax]):
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = ftfiles.gather_files_by_kfold(
                load_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(f'{model_name} {files_ids}({len(files_ids)}#)')
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            days = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf'][target_class].loc[
                    f()['days_class_metrics_cdf'][target_class]['_day'] ==
                    thday]['auc' + ocurve_name[:-1]].item() for f in files
            ])
            roccs = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday][ocurve_name].item()
                for f in files
            ]

            label = f'{utils.get_fmodel_name(model_name)}; AUC={xe_aucroc}'
            color = color_dict[utils.get_fmodel_name(model_name)]
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        title = ''
        title += f'{target_class}-ROC curve' + '\n'
        title += f'set={survey} [{lcset_name.replace(".@", "")}]' + '\n'
        title += f'th-day={thday:.3f} [days]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

        if not baseline_roodir is None:
            files, files_ids = ftfiles.gather_files_by_kfold(
                baseline_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(f'{files_ids}({len(files_ids)}#); model={model_name}')

            xe_aucroc = XError([
                f()['thdays_class_metrics_cdf'][target_class].loc[
                    f()['thdays_class_metrics_cdf'][target_class]['_thday'] ==
                    thday]['auc' + ocurve_name[:-1]].item() for f in files
            ])
            roccs = [
                f()['thdays_class_metrics_cdf'][target_class].loc[
                    f()['thdays_class_metrics_cdf'][target_class]['_thday'] ==
                    thday][ocurve_name].item() for f in files
            ]

            label = f'{utils.get_fmodel_name(BASELINE_MODEL_NAME)}; AUC={xe_aucroc}'
            color = 'k'
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                    'linestyle': '--',
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

    for kax, ax in enumerate(axs):
        ax.plot([0, 1], [0, 1], '--', color='k', alpha=1, lw=1)
        ax.set_xlabel('FPR')
        if kax == 0:
            ax.set_ylabel('TPR')
            ax.set_title(f'{bf_alphabet_count(0)} Parallel models')
        else:
            ax.set_yticklabels([])
            ax.set_title(f'{bf_alphabet_count(1)} Serial models')

        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=0.5)
        ax.legend(loc='lower right')

    fig.tight_layout()
    plt.show()
def plot_ocurve_classes(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_classes,
    thday,
    baselines_dict={},
    figsize=FIGSIZE_1X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE,
    shadow_alpha=SHADOW_ALPHA,
    ocurve_name='rocc',
    baseline_roodir=None,
):
    for kmn, model_name in enumerate(model_names):
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        for target_class in target_classes:
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = ftfiles.gather_files_by_kfold(
                load_roodir,
                kf,
                lcset_name,
                fext='d',
                disbalanced_kf_mode='oversampling',  # error oversampling
                random_state=RANDOM_STATE,
            )
            print(
                f'{model_name} {files_ids}({len(files_ids)}#); model={model_name}'
            )
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            thdays = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday]['aucroc'].item()
                for f in files
            ])
            label = f'{target_class}; AUC={xe_aucroc}'
            color = CLASSES_STYLES[target_class]['c']

            ocurves = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == thday][ocurve_name].item()
                for f in files
            ]
            fill_beetween(
                ax,
                [ocurve[XLABEL_DICT[ocurve_name]] for ocurve in ocurves],
                [ocurve[YLABEL_DICT[ocurve_name]] for ocurve in ocurves],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        ax.plot([0, 1],
                GUIDE_CURVE_DICT[ocurve_name],
                '--',
                color='k',
                alpha=1,
                lw=1)
        ax.set_xlabel(XLABEL_DICT[ocurve_name])
        ax.set_ylabel(YLABEL_DICT[ocurve_name])
        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=.5)
        ax.set_axisbelow(True)
        ax.legend(loc='lower right')

        title = ''
        title += f'{ocurve_name.upper()[:-1]} operative curves for SNe classes' + '\n'
        title += f'set={survey} [{lcset_name.replace(".@", "")}]' + '\n'
        title += f'th-day={thday:.3f} [days]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

    fig.tight_layout()
    plt.show()
示例#9
0
def plot_metric(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    dmetrics,
    target_class=None,
    baselines_dict={},
    figsize=RECT_PLOT_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE_PLOT,
    shadow_alpha=SHADOW_ALPHA,
):
    for metric_name in dmetrics.keys():
        fig, axs = plt.subplots(1, 2, figsize=figsize)
        ps_model_names = utils.get_sorted_model_names(model_names,
                                                      merged=False)
        for kax, ax in enumerate(axs):
            if len(ps_model_names[kax]) == 0:
                continue
            color_dict = utils.get_color_dict(ps_model_names[kax])
            ylims = [[], []]
            for kmn, model_name in enumerate(ps_model_names[kax]):
                load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
                files, files_ids = fcfiles.gather_files_by_kfold(load_roodir,
                                                                 kf,
                                                                 lcset_name,
                                                                 fext='d')
                print(f'{model_name} {files_ids}({len(files_ids)}#)')
                if len(files) == 0:
                    continue

                survey = files[0]()['survey']
                band_names = files[0]()['band_names']
                class_names = files[0]()['class_names']
                days = files[0]()['days']

                if target_class is None:
                    metric_curves = [
                        f()['days_class_metrics_df'][metric_name].values
                        for f in files
                    ]
                else:
                    metric_curves = [
                        f()['days_class_metrics_cdf'][target_class]
                        [metric_name.replace('b-', '')].values for f in files
                    ]
                xe_metric_curve_avg = XError(
                    np.mean(np.concatenate(
                        [metric_curve[None] for metric_curve in metric_curves],
                        axis=0),
                            axis=-1))

                label = f'{utils.get_fmodel_name(model_name)} | AUC={xe_metric_curve_avg}'
                color = color_dict[utils.get_fmodel_name(model_name)]
                fill_beetween(
                    ax,
                    [days for metric_curve in metric_curves],
                    [metric_curve for metric_curve in metric_curves],
                    fill_kwargs={
                        'color': color,
                        'alpha': shadow_alpha,
                        'lw': 0,
                    },
                    median_kwargs={
                        'color': color,
                        'alpha': 1,
                    },
                    percentile=percentile,
                )
                ax.plot([None], [None], color=color, label=label)
                ylims[0] += [ax.get_ylim()[0]]
                ylims[1] += [ax.get_ylim()[1]]

            mn = metric_name if dmetrics[metric_name][
                'mn'] is None else dmetrics[metric_name]['mn']
            mn = mn if target_class is None else mn.replace(
                'b-', f'{target_class}-')
            title = ''
            title += f'{mn} v/s days' + '\n'
            title += f'train-mode={train_mode} - survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]' + '\n'
            fig.suptitle(title[:-1], va='bottom')

        for kax, ax in enumerate(axs):
            if f'{kf}@{lcset_name}' in baselines_dict.keys():
                # ax.plot(days, np.full_like(days, baselines_dict[f'{kf}@{lcset_name}'][metric_name]), ':', c='k', label=f'FATS & b-RF Baseline (day={days[-1]:.3f})')
                pass

            ax.set_xlabel('time [days]')
            if kax == 1:
                ax.set_yticklabels([])
                ax.set_title('serial models')
            else:
                ax.set_ylabel(mn)
                ax.set_title('parallel models')

            ax.set_xlim([days.min(), days.max()])
            ax.set_ylim(min(ylims[0]), max(ylims[1]) * 1.05)
            ax.grid(alpha=0.5)
            ax.legend(loc='lower right')

        fig.tight_layout()
        plt.show()
示例#10
0
def plot_temporal_encoding(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    train_mode='pre-training',
    layers=1,
    figsize=FIGSIZE,
    n=1e3,
    percentile=PERCENTILE_PLOT,
    shadow_alpha=SHADOW_ALPHA,
):
    for kmn, model_name in enumerate(model_names):
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/temporal_encoding/{cfilename}'
        if not ftfiles.path_exists(load_roodir):
            continue
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            lcset_name,
            fext='d',
            disbalanced_kf_mode='ignore',  # error oversampling ignore
            random_state=RANDOM_STATE,
        )
        print(f'{model_name} {files_ids}({len(files_ids)}#)')
        if len(files) == 0:
            continue

        survey = files[0]()['survey']
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        mn_dict = strings.get_dict_from_string(model_name)
        mdl = mn_dict['mdl']
        is_parallel = 'Parallel' in mdl
        if not is_parallel:
            continue

        days = files[0]()['days']
        days = np.linspace(days[0], days[-1], int(n))

        global_median_curves_d = {}
        fig, axs = plt.subplots(2, len(band_names), figsize=figsize)
        for kfile, file in enumerate(files):
            for kb, b in enumerate(band_names):
                d = file()['temporal_encoding_info']['encoder'][
                    f'ml_attn.{b}']['te_film']
                weight = d['weight']  # (f,2m)
                alpha_weights, beta_weights = np.split(
                    weight.T, 2, axis=-1)  # (f,2m)>(2m,f/2),(2m,f/2)
                scales = []
                biases = []
                for kfu in range(0, alpha_weights.shape[-1]):
                    te_ws = d['te_ws']
                    te_periods = d['te_periods']
                    te_phases = d['te_phases']
                    alpha = get_fourier(days, alpha_weights[:, kfu],
                                        te_periods, te_phases)
                    dalpha = (get_diff(alpha, 1)**2) * 1e6
                    beta = get_fourier(days, beta_weights[:, kfu], te_periods,
                                       te_phases)
                    dbeta = (get_diff(beta, 1)**2) * 1e6
                    scales += [dalpha]
                    biases += [dbeta]

                d = {
                    'scale': {
                        'curve': scales,
                        'c': 'r'
                    },
                    'bias': {
                        'curve': biases,
                        'c': 'g'
                    },
                }
                for kax, curve_name in enumerate(['scale', 'bias']):
                    ax = axs[kax, kb]
                    curves = d[curve_name]['curve']
                    c = 'k'
                    median_curve = np.median(np.concatenate(
                        [curve[None] for curve in curves], axis=0),
                                             axis=0)
                    if not f'{kax}/{kb}/{b}' in global_median_curves_d.keys():
                        global_median_curves_d[f'{kax}/{kb}/{b}'] = []
                    global_median_curves_d[f'{kax}/{kb}/{b}'] += [median_curve]
                    ax.plot(
                        days,
                        median_curve,
                        c=c,
                        alpha=1,
                        lw=.5,
                    )
                    ax.plot([None], [None],
                            c=c,
                            label=f'variation power continuous-time function'
                            if kfile == 0 else None)
                    ax.legend(loc='upper right')
                    ax.grid(alpha=0.5)
                    ax.set_xlim((days[0], days[-1]))
                    ax.set_title(
                        f'{bf_alphabet_count(kb, kax)} variation power for {curve_name}; band={b}'
                    )
                    ax_styles.set_color_borders(ax, C_.COLOR_DICT[b])
                    if kb == 0:
                        ax.set_ylabel(f'variation power [M]')
                    else:
                        pass
                    if kax == 0:
                        ax.set_xticklabels([])
                    else:
                        ax.set_xlabel(f'time [days]')
            model_label = utils.get_fmodel_name(model_name)
            suptitle = ''
            suptitle = f'{model_label}' + '\n'
            # suptitle += f'set={survey} [{lcset_name.replace(".@", "")}]'+'\n'
            fig.suptitle(suptitle[:-1], va='bottom')

        for k in global_median_curves_d.keys():
            kax, kb, b = k.split('/')
            median_curves = global_median_curves_d[k]
            ax = axs[int(kax), int(kb)]
            ax.plot(
                days,
                np.median(np.concatenate(
                    [median_curve[None] for median_curve in median_curves],
                    axis=0),
                          axis=0),
                '-',
                # c=['r', 'g'][int(kax)],
                c='r',
                label=f'median variation power continuous-time function',
            )

            ax.axvline(EMPIRICAL_TMAXS[b],
                       linestyle='--',
                       c='k',
                       label='empirical median SNe peak-time')
            ax.legend(loc='upper right')

        fig.tight_layout()
        plt.show()
示例#11
0
def plot_rocc(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    target_class,
    target_day,
    baselines_dict={},
    figsize=RECT_PLOT_2X1,
    train_mode='fine-tuning',
    percentile=PERCENTILE_PLOT,
    shadow_alpha=SHADOW_ALPHA,
):
    fig, axs = plt.subplots(1, 2, figsize=figsize)
    ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
    for kax, ax in enumerate(axs):
        if len(ps_model_names[kax]) == 0:
            continue
        color_dict = utils.get_color_dict(ps_model_names[kax])
        for kmn, model_name in enumerate(ps_model_names[kax]):
            load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
            files, files_ids = fcfiles.gather_files_by_kfold(load_roodir,
                                                             kf,
                                                             lcset_name,
                                                             fext='d')
            print(f'{model_name} {files_ids}({len(files_ids)}#)')
            if len(files) == 0:
                continue

            survey = files[0]()['survey']
            band_names = files[0]()['band_names']
            class_names = files[0]()['class_names']
            days = files[0]()['days']

            xe_aucroc = XError([
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == target_day]['aucroc'].item()
                for f in files
            ])
            label = f'{utils.get_fmodel_name(model_name)} | AUC={xe_aucroc}'
            color = color_dict[utils.get_fmodel_name(model_name)]

            roccs = [
                f()['days_class_metrics_cdf']
                [target_class].loc[f()['days_class_metrics_cdf'][target_class]
                                   ['_day'] == target_day]['rocc'].item()
                for f in files
            ]
            fill_beetween(
                ax,
                [rocc['fpr'] for rocc in roccs],
                [rocc['tpr'] for rocc in roccs],
                fill_kwargs={
                    'color': color,
                    'alpha': shadow_alpha,
                    'lw': 0,
                },
                median_kwargs={
                    'color': color,
                    'alpha': 1,
                },
                percentile=percentile,
            )
            ax.plot([None], [None], color=color, label=label)

        title = ''
        title += f'{target_class}-ROC curve ({target_day:.3f} [days])' + '\n'
        title += f'train-mode={train_mode} - survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]' + '\n'
        fig.suptitle(title[:-1], va='bottom')

    for kax, ax in enumerate(axs):
        ax.plot([0, 1], [0, 1], '--', color='k', alpha=1, lw=1)
        ax.set_xlabel('FPR')
        if kax == 0:
            ax.set_ylabel('TPR')
            ax.set_title('parallel models')
        else:
            ax.set_yticklabels([])
            ax.set_title('serial models')

        ax.set_xlim(0.0, 1.0)
        ax.set_ylim(0.0, 1.0)
        ax.grid(alpha=0.5)
        ax.legend(loc='lower right')

    fig.tight_layout()
    plt.show()
def plot_metric(rootdir, cfilename, kf, lcset_name, model_names, dmetrics,
	target_class=None,
	figsize=FIGSIZE_2X1,
	train_mode='fine-tuning',
	percentile=PERCENTILE,
	shadow_alpha=SHADOW_ALPHA,
	baseline_roodir=None,
	):
	for metric_name in dmetrics.keys():
		fig, axs = plt.subplots(1, 2, figsize=figsize)
		axis_lims = AxisLims({'x':(None, None), 'y':(0, 1)}, {'x':.0, 'y':.1})
		ps_model_names = utils.get_sorted_model_names(model_names, merged=False)
		for kax,ax in enumerate(axs):
			if len(ps_model_names[kax])==0:
				continue
			color_dict = utils.get_color_dict(ps_model_names[kax])
			for kmn,model_name in enumerate(ps_model_names[kax]):
				load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
				files, files_ids = ftfiles.gather_files_by_kfold(load_roodir, kf, lcset_name,
					fext='d',
					disbalanced_kf_mode='oversampling', # error oversampling
					random_state=RANDOM_STATE,
					)
				print(f'{files_ids}({len(files_ids)}#); model={model_name}')
				if len(files)==0:
					continue

				survey = files[0]()['survey']
				band_names = files[0]()['band_names']
				class_names = files[0]()['class_names']
				thdays = files[0]()['days']

				if target_class is None:
					metric_curves = [f()['days_class_metrics_df'][f'b-{metric_name}'].values for f in files]
				else:
					metric_curves = [f()['days_class_metrics_cdf'][target_class][f'{metric_name}'].values for f in files] 
				xe_metric_curve_auc = XError(np.mean(np.concatenate([metric_curve[None] for metric_curve in metric_curves], axis=0), axis=-1)) # (b,t)

				model_label = utils.get_fmodel_name(model_name)
				label = f'{model_label}; AUC={xe_metric_curve_auc}'
				color = color_dict[utils.get_fmodel_name(model_name)]
				lines.fill_beetween(ax, [thdays for _ in metric_curves], metric_curves,
					fill_kwargs={'color':color, 'alpha':shadow_alpha, 'lw':0,},
					median_kwargs={'color':color, 'alpha':1,},
					percentile=percentile,
					)
				ax.plot([None], [None], color=color, label=label)
				axis_lims.append('x', thdays)
				axis_lims.append('y', np.concatenate([metric_curve for metric_curve in metric_curves], axis=0))

			new_metric_name = f'{"b" if target_class is None else target_class}-{metric_name if dmetrics[metric_name]["mn"] is None else dmetrics[metric_name]["mn"]}'
			suptitle = ''
			suptitle += f'{new_metric_name} v/s days using moving th-day'+'\n'
			suptitle += f'set={survey} [{lcset_name.replace(".@", "")}]'+'\n'
			fig.suptitle(suptitle[:-1], va='bottom')

		for kax,ax in enumerate(axs):
			if not baseline_roodir is None:
				files, files_ids = ftfiles.gather_files_by_kfold(baseline_roodir, kf, lcset_name,
					fext='d',
					disbalanced_kf_mode='oversampling', # error oversampling
					random_state=RANDOM_STATE,
					)
				print(f'{files_ids}({len(files_ids)}#); model={model_name}')
				thdays = files[0]()['thdays']
				for kthday,thday in enumerate(thdays):
					if target_class is None:
						xe_metric = XError([f()['thdays_class_metrics_df'].loc[f()['thdays_class_metrics_df']['_thday']==thday][f'b-{metric_name}'].item() for f in files])
					else:
						xe_metric = XError([f()['thdays_class_metrics_cdf'][target_class].loc[f()['thdays_class_metrics_df']['_thday']==thday][f'{metric_name}'].item() for f in files])

					model_label = utils.get_fmodel_name(BASELINE_MODEL_NAME)
					ax.plot(thday, xe_metric.p50, 'D', c='k', label=f'{model_label}' if kthday==0 else None)

				ax.axhline(xe_metric.p50, linestyle='--', c='k')

			ax.set_xlabel('time [days]')
			if kax==0:
				ax.set_ylabel(new_metric_name)
				ax.set_title(f'{bf_alphabet_count(0)} Parallel models')
			else:
				ax.set_yticklabels([])
				ax.set_title(f'{bf_alphabet_count(1)} Serial models')

			axis_lims.set_ax_axis_lims(ax)
			ax.grid(alpha=0.5)
			ax.legend(loc='lower right')

		fig.tight_layout()
		plt.show()
示例#13
0
def plot_cm(rootdir, cfilename, kf, lcset_name, model_names,
	figsize=FIGSIZE,
	train_mode='fine-tuning',
	export_animation=False,
	animation_duration=12,
	new_order_classes=['SNIa', 'SNIbc', 'SNII-b-n', 'SLSN'],
	percentile=PERCENTILE,
	):
	for kmn,model_name in enumerate(model_names):
		load_roodir = f'{rootdir}/{model_name}/{train_mode}/performance/{cfilename}'
		files, files_ids = ftfiles.gather_files_by_kfold(load_roodir, kf, lcset_name,
			fext='d',
			disbalanced_kf_mode='oversampling', # error oversampling
			random_state=RANDOM_STATE,
			)
		print(f'ids={files_ids}(n={len(files_ids)}#); model={model_name}')
		if len(files)==0:
			continue

		survey = files[0]()['survey']
		band_names = files[0]()['band_names']
		class_names = files[0]()['class_names']
		is_parallel = 'Parallel' in model_name
		days = files[0]()['days']

		plot_animation = PlotAnimator(animation_duration,
			is_dummy=not export_animation,
			#save_init_frame=True,
			save_end_frame=True,
			)

		thdays = days if export_animation else [days[-1]]
		bar = ProgressBar(len(thdays), bar_format='{l_bar}{bar}{postfix}')
		for kd,thday in enumerate(thdays):
			bar(f'thday={thday:.3f} [days]')
			xe_dict = {}
			for metric_name in ['b-precision', 'b-recall', 'b-f1score']:
				xe_metric = XError([f()['days_class_metrics_df'].loc[f()['days_class_metrics_df']['_day']==thday][metric_name].item() for f in files])
				xe_dict[metric_name] = xe_metric

			bprecision_xe = xe_dict['b-precision']
			brecall_xe = xe_dict['b-recall']
			bf1score_xe = xe_dict['b-f1score']

			title = ''
			title += f'{utils.get_fmodel_name(model_name)}'+'\n'
			#title += f'survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]'+'\n'
			#title += f'train-mode={train_mode}; eval-set={kf}@{lcset_name}'+'\n'
			title += f'b-recall={brecall_xe}; b-f1score={bf1score_xe}'+'\n'
			title += f'th-day={thday:.3f} [days]'+'\n'
			#title += f'b-p/r={bprecision_xe} / {brecall_xe}'+'\n'
			#title += f'b-f1score={bf1score_xe}'+'\n'
			cms = np.concatenate([f()['days_cm'][thday][None] for f in files], axis=0)
			fig, ax, cm_norm = plot_custom_confusion_matrix(cms, class_names,
				#fig=fig,
				#ax=ax,
				title=title[:-1],
				figsize=figsize,
				new_order_classes=new_order_classes,
				percentile=percentile,
				)
			uses_close_fig = kd<len(days)-1
			plot_animation.append(fig, uses_close_fig)

		bar.done()
		plt.show()
		plot_animation.save(f'../temp/{model_name}.gif') # gif mp4
def plot_slope_distance_attnstats(
        rootdir,
        cfilename,
        kf,
        lcset_name,
        model_names,
        train_mode='pre-training',
        figsize=FIGSIZE,
        attn_th=0.5,
        len_th=LEN_TH,
        n_bins=50,
        bins_xrange=[None, None],
        bins_yrange=[None, None],
        cmap_name='inferno',  # plasma viridis inferno
        dj=3,
        distance_mode='mean',  # local mean median 
):
    for kmn, model_name in enumerate(model_names):
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/attnstats/{cfilename}'
        if not ftfiles.path_exists(load_roodir):
            continue
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            lcset_name,
            fext='d',
            disbalanced_kf_mode='ignore',  # error oversampling ignore
            random_state=RANDOM_STATE,
        )
        print(f'{model_name} {files_ids}({len(files_ids)}#)')
        assert len(files) > 0

        survey = files[0]()['survey']
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        #days = files[0]()['days']

        target_class_names = class_names
        x_key = f'peak_distance.j~dj={dj}~mode={distance_mode}'
        y_key = f'local_slope_m.j~dj={dj}'
        label_dict = {
            x_key: f'peak-distance [days]',
            y_key: f'local-slope using $\\Delta j={dj}$',
        }

        fig, axs = plt.subplots(2, len(band_names), figsize=figsize)
        for kb, b in enumerate(band_names):
            xy_marginal = []
            xy_attn = []
            attn_scores_collection = flat_list(
                [f()['attn_scores_collection'][b] for f in files])
            for d in attn_scores_collection:
                xy_marginal += [[d[x_key], d[y_key]]]
                if d['attn_scores_min_max_k.j'] >= attn_th and d[
                        'b_len'] >= len_th and d['c'] in target_class_names:
                    xy_attn += [[d[x_key], d[y_key]]]

            xy_marginal = np.array(xy_marginal)
            xy_attn = np.array(xy_attn)
            print('xy_marginal', xy_marginal.shape, 'xy_attn', xy_attn.shape)

            xrange0 = xy_attn[:, 0].min(
            ) if bins_xrange[0] is None else bins_xrange[0]
            xrange1 = xy_attn[:, 0].max(
            ) if bins_xrange[1] is None else bins_xrange[1]
            yrange0 = xy_attn[:, 1].min(
            ) if bins_yrange[0] is None else bins_yrange[0]
            yrange1 = xy_attn[:, 1].max(
            ) if bins_yrange[1] is None else bins_yrange[1]

            d = {
                'xy_marginal': {
                    'xy': xy_marginal,
                    'title': 'joint distribution'
                },
                'xy_attn': {
                    'xy':
                    xy_attn,
                    'title':
                    f'conditional joint distribution using ' +
                    '$\\bar{s}_{th}=' + str(attn_th) + '$'
                },
            }
            for kax, xy_name in enumerate(['xy_marginal', 'xy_attn']):
                ax = axs[kax, kb]
                xy = d[xy_name]['xy']
                H, xedges, yedges = np.histogram2d(
                    xy[:, 0],
                    xy[:, 1],
                    bins=(np.linspace(xrange0, xrange1, n_bins),
                          np.linspace(yrange0, yrange1, n_bins)))
                H = H.T  # Let each row list bins with common y range.
                extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                ax.imshow(
                    H,
                    interpolation='nearest',
                    origin='lower',
                    aspect='auto',
                    cmap=cmap_name,
                    extent=extent,
                )
                ax.axvline(0, linewidth=.5, color='w')
                ax.axhline(0, linewidth=.5, color='w')
                title = ''
                title += f'{bf_alphabet_count(kb, kax)} {d[xy_name]["title"]}; band={b}' + '\n'
                ax.set_title(title[:-1])

                txt_y = yedges[0]
                ax.text(0,
                        txt_y,
                        'pre SNe-peak < ',
                        fontsize=12,
                        c='w',
                        ha='right',
                        va='bottom')
                ax.text(0,
                        txt_y,
                        ' > post SNe-peak',
                        fontsize=12,
                        c='w',
                        ha='left',
                        va='bottom')
                ax_styles.set_color_borders(ax, C_.COLOR_DICT[b])

                xlabel = label_dict[x_key]
                ylabel = label_dict[y_key]
                if kb == 0:
                    ax.set_ylabel(ylabel)
                else:
                    ax.set_yticklabels([])
                if kax == 0:
                    ax.set_xticklabels([])
                else:
                    ax.set_xlabel(xlabel)

        model_label = utils.get_fmodel_name(model_name)
        suptitle = ''
        suptitle += f'local-slope v/s peak-distance' + '\n'
        # suptitle += f'survey={survey}-{"".join(band_names)} [{kf}@{lcset_name}]'+'\n'
        suptitle += f'{model_label}' + '\n'
        fig.suptitle(suptitle[:-1], va='bottom')

        fig.tight_layout()
        plt.show()
def plot_attnentropy(
    rootdir,
    cfilename,
    kf,
    lcset_name,
    model_names,
    train_mode='pre-training',
    figsize=FIGSIZE,
    len_th=LEN_TH,
):
    for kmn, model_name in enumerate(model_names):
        load_roodir = f'{rootdir}/{model_name}/{train_mode}/attnstats/{cfilename}'
        if not ftfiles.path_exists(load_roodir):
            continue
        files, files_ids = ftfiles.gather_files_by_kfold(
            load_roodir,
            kf,
            lcset_name,
            fext='d',
            disbalanced_kf_mode='ignore',  # error oversampling ignore
            random_state=RANDOM_STATE,
        )
        print(f'{model_name} {files_ids}({len(files_ids)}#)')
        assert len(files) > 0

        survey = files[0]()['survey']
        band_names = files[0]()['band_names']
        class_names = files[0]()['class_names']
        #days = files[0]()['days']

        fig, axs = plt.subplots(1, len(band_names), figsize=figsize)
        entropy_d = {b: {c: [] for c in class_names} for b in band_names}
        for kb, b in enumerate(band_names):
            ax = axs[kb]
            norm_attnentropys = []
            snrs = []
            max_obs = []
            attn_scores_collection = flat_list(
                [f()['attn_scores_collection'][b] for f in files])
            for kfile, file in enumerate(files):
                attn_scores_collection = file()['attn_scores_collection'][b]
                for k, d in enumerate(attn_scores_collection):
                    c = d['c']
                    b_len = d['b_len']
                    if b_len < len_th:
                        continue
                    attnentropy_h = d['attn_entropy_h']
                    snr = d['snr']
                    max_obs = d['max_obs']
                    peak_day = d['peak_day']
                    entropy_d[b][c] += [
                        h / np.log(b_len) for h in attnentropy_h
                    ]
                    ax.plot(
                        [np.mean(attnentropy_h / np.log(b_len))],
                        [snr],
                        c='k',
                        #c=CLASSES_STYLES[c]['c'],
                        marker=CLASSES_STYLES[c]['marker'],
                        markersize=2,
                        alpha=.5,
                    )
                    # ax.plot([np.log(norm_attnentropy+1e-10)], [snr], 'o', c='k', markersize=1, alpha=.25)
                    ax.set_xlabel('$H(X)/\\log(L)$')
                    ax.set_ylabel('???')
        fig.tight_layout()
        plt.show()
        return entropy_d