def get_classes(rootdir):
    classes = []
    filedirs = fcfiles.get_filedirs(rootdir)
    for filedir in filedirs:
        fdict = fcfiles.load_pickle(filedir)
        c = fdict['c']
        if not c in classes:
            classes.append(c)
    return classes
示例#2
0
def plot_mse(rootdir, model_names,
	figsize=C_.PLOT_FIGZISE_RECT,
	fext='metrics',
	set_name='???',
	):
	fig, ax = plt.subplots(1, 1, figsize=figsize)
	color_dict = utils.get_color_dict(model_names)
	for kmn,model_name in enumerate(model_names):
		new_rootdir = f'{rootdir}/{mode}/{model_name}'
		new_rootdir = new_rootdir.replace('mode=pre-training', f'mode={mode}') # patch
		new_rootdir = new_rootdir.replace('mode=fine-tuning', f'mode={mode}') # patch
		filedirs = search_for_filedirs(new_rootdir, fext=fext, verbose=0)
		print(f'[{kmn}] {model_name} (iters: {len(filedirs)})')
		mn_dict = strings.get_dict_from_string(model_name)
		rsc = mn_dict['rsc']
		mdl = mn_dict['mdl']
		is_parallel = 'Parallel' in mdl

		metric_curve = []
		for filedir in filedirs:
			rdict = load_pickle(filedir, verbose=0)
			#model_name = rdict['model_name']
			days = rdict['days']
			survey = rdict['survey']
			band_names = ''.join(rdict['band_names'])
			class_names = rdict['class_names']
			metric_curve += [rdict['days_rec_metrics_df']['mse'].values[:][None,:]]

		metric_curve = np.concatenate(metric_curve, axis=0)
		xe_metric_curve = XError(np.log(metric_curve), 0)
		label = f'{mdl} {rsc}'
		color = color_dict[utils.get_cmodel_name(model_name)]
		ax.plot(days, xe_metric_curve.median, '--' if is_parallel else '-', label=label, c=color)
		ax.fill_between(days, xe_metric_curve.p15, xe_metric_curve.p85, alpha=0.25, fc=color)

	title = 'log-reconstruction-wmse v/s days\n'
	title += f'survey: {survey} - bands: {band_names}\n'
	ax.set_title(title[:-1])
	ax.set_xlabel('days')
	ax.set_ylabel('mse')
	ax.set_xlim([days.min(), days.max()])
	ax.grid(alpha=0.5)
	ax.legend(loc='upper right')
	plt.show()
示例#3
0
    import matplotlib.pyplot as plt
    from lchandler.plots.lc import plot_lightcurve
    from fuzzytools.files import save_time_stamp

    methods = [
        'linear-fstw', 'bspline-fstw', 'spm-mle-fstw', 'spm-mle-estw',
        'spm-mcmc-fstw', 'spm-mcmc-estw'
    ] if main_args.method == '.' else main_args.method
    methods = [methods] if isinstance(methods, str) else methods

    for method in methods:
        filedir = f'../../surveys-save/survey=alerceZTFv7.1~bands=gr~mode=onlySNe~method={method}.splcds'
        filedict = get_dict_from_filedir(filedir)
        rootdir = filedict['_rootdir']
        cfilename = filedict['_cfilename']
        lcdataset = load_pickle(filedir)
        lcset_info = lcdataset['raw'].get_info()
        print(lcdataset)

        lcset_names = lcdataset.get_lcset_names()
        for lcset_name in lcset_names:
            lcset = lcdataset[lcset_name]
            for lcobj_name in lcset.get_lcobj_names():
                print(
                    f'method={method} - lcset_name={lcset_name} - lcobj_name={lcobj_name}'
                )
                figsize = (12, 5)
                fig, ax = plt.subplots(1, 1, figsize=figsize)
                lcobj = lcset[lcobj_name]
                c = lcset.class_names[lcobj.y]
                for kb, b in enumerate(lcset.band_names):
示例#4
0
def plot_metric(rootdir, metric_name, model_names, baselines_dict,
	label_keys=[],
	figsize=C_.PLOT_FIGZISE_RECT,
	fext='metrics',
	mode='fine-tuning',
	set_name='???',
	p=C_.P_PLOT,
	alpha=0.2,
	):
	fig, axs = plt.subplots(1, 2, figsize=figsize)
	color_dict = utils.get_color_dict(model_names)

	#for kax,mode in enumerate(['fine-tuning']):
	#for kax,mode in enumerate(['pre-training', 'fine-tuning']):
	#ax = axs[kax]
	for kmn,model_name in enumerate(model_names):
		new_rootdir = f'{rootdir}/{mode}/{model_name}'
		new_rootdir = new_rootdir.replace('mode=pre-training', f'mode={mode}') # patch
		new_rootdir = new_rootdir.replace('mode=fine-tuning', f'mode={mode}') # patch
		filedirs = search_for_filedirs(new_rootdir, fext=fext, verbose=0)
		model_ids = sorted([int(strings.get_dict_from_string(f.split('/')[-1])['id']) for f in filedirs])
		print(f'[{kmn}][{"-".join([str(m) for m in model_ids])}]{len(model_ids)}#')
		print(f'\t{model_name}')
		mn_dict = strings.get_dict_from_string(model_name)
		rsc = mn_dict['rsc']
		mdl = mn_dict['mdl']
		is_parallel = 'Parallel' in mdl
		ax = axs[int(not is_parallel)]

		metric_curve = []
		for filedir in filedirs:
			rdict = load_pickle(filedir, verbose=0)
			#model_name = rdict['model_name']
			days = rdict['days']
			survey = rdict['survey']
			band_names = ''.join(rdict['band_names'])
			class_names = rdict['class_names']
			_, vs, interp_days = utils.get_metric_along_day(days, rdict, metric_name, days[-1])
			metric_curve += [vs[None,:]]

		metric_curve = np.concatenate(metric_curve, axis=0)
		xe_metric_curve = XError(metric_curve, 0)
		xe_curve_avg = XError(np.mean(metric_curve, axis=-1), 0)
		label = f'{mdl}'
		for label_key in label_keys:
			if label_key in mn_dict.keys():
				label += f' - {label_key}={mn_dict[label_key]}'
		#label += f' ({utils.get_mday_avg_str(metric_name, days[-1])}={xe_curve_avg})'
		label += f' ({xe_curve_avg}*)'
		color = color_dict[utils.get_cmodel_name(model_name)] if rsc=='0' else 'k'
		ax.plot(interp_days, xe_metric_curve.median, '--' if is_parallel else '-', label=label, c=color)
		ax.fill_between(interp_days, getattr(xe_metric_curve, f'p{p}'), getattr(xe_metric_curve, f'p{100-p}'), alpha=alpha, fc=color)

	title = f'{metric_name} v/s days\n'
	title += f'survey={survey} - mode={mode} - eval={set_name} - bands={band_names}\n'
	#ax.set_title(title)
	fig.suptitle(title[:-1], va='bottom')

	for kax,ax in enumerate(axs):
		is_accuracy = 'accuracy' in metric_name
		random_guess = 100./len(class_names)
		if is_accuracy:
			ax.plot(days, np.full_like(days, random_guess), ':', c='k', label=f'random guess accuracy ($100/N_c$)', alpha=.5)

		if not baselines_dict is None:
			ax.plot(days, np.full_like(days, baselines_dict[metric_name]), ':', c='k', label='FATS+b-RF baseline (complete light curves)')

		ax.set_xlabel('days')
		if kax==1:
			ax.set_ylabel(None)
			ax.set_yticklabels([])
			ax.set_title('Serial Models')
		else:
			ax.set_ylabel(metric_name)
			ax.set_title('Parallel Models')

		ax.set_xlim([days.min(), days.max()])
		ax.set_ylim([random_guess*.95, 100] if is_accuracy else [0, 1])
		ax.grid(alpha=0.5)
		ax.legend(loc='lower right')

	fig.tight_layout()
	plt.show()
示例#5
0
def plot_cm(rootdir, model_names, day_to_metric,
	figsize=C_.PLOT_FIGZISE_RECT,
	fext='metrics',
	mode='fine-tuning',
	lcset_name='???',
	export_animation=False,
	fps=15,
	):
	for kmn,model_name in enumerate(model_names):
		#fig, axs = plt.subplots(1, 2, figsize=figsize)
		#ax = axs[kax]
		new_rootdir = f'{rootdir}/{mode}/{model_name}'
		new_rootdir = new_rootdir.replace('mode=pre-training', f'mode={mode}') # patch
		new_rootdir = new_rootdir.replace('mode=fine-tuning', f'mode={mode}') # patch
		filedirs = search_for_filedirs(new_rootdir, fext=fext, verbose=0)
		model_ids = sorted([int(strings.get_dict_from_string(f.split('/')[-1])['id']) for f in filedirs])
		print(f'[{kmn}][{"-".join([str(m) for m in model_ids])}]{len(model_ids)}#')
		print(f'\t{model_name}')
		mn_dict = strings.get_dict_from_string(model_name)
		rsc = mn_dict['rsc']
		mdl = mn_dict['mdl']
		is_parallel = 'Parallel' in mdl

		target_days = [d for d in load_pickle(filedirs[0], verbose=0)['days'] if d<=day_to_metric]
		plot_animation = PlotAnimation(len(target_days), 10, dummy=not export_animation)
		for kd,target_day in enumerate(target_days):
			cms = []
			accuracy = []
			f1score = []
			for filedir in filedirs:
				rdict = load_pickle(filedir, verbose=0)
				#model_name = rdict['model_name']
				days = rdict['days']
				survey = rdict['survey']
				band_names = ''.join(rdict['band_names'])
				class_names = rdict['class_names']
				cms += [rdict['days_cm'][target_day][None,...]]
				v, _, _ = utils.get_metric_along_day(days, rdict, 'b-accuracy', target_day)
				accuracy += [v]
				v, _, _ = utils.get_metric_along_day(days, rdict, 'b-f1score', target_day)
				f1score += [v]

			accuracy_xe = XError(accuracy)
			f1score_xe = XError(f1score)
			title = ''
			title += f'{mn_dict["mdl"]}\n'
			title += f'eval={lcset_name} - day={target_day:.2f}/{day_to_metric:.2f}\n'
			title += f'b-f1score={f1score_xe}\n'
			title += f'b-accuracy={accuracy_xe}\n'
			cm_kwargs = {
				#'fig':fig,
				#'ax':ax,
				'title':title[:-1],
				'figsize':(6,5),
				'new_order_classes':['SNIa', 'SNIbc', 'allSNII', 'SLSN'],
			}
			fig, ax = plot_custom_confusion_matrix(np.concatenate(cms, axis=0), class_names, **cm_kwargs)
			plot_animation.add_frame(fig)
			if kd<len(target_days)-1:
				plt.close(fig)
			else:
				plt.show()

		plot_animation.save(f'../temp/{model_name}.gif')
def get_band_names(rootdir):
    filedirs = fcfiles.get_filedirs(rootdir)
    filedir = filedirs[0]
    return fcfiles.load_pickle(filedir)['band_names']
parser.add_argument('--method', type=str, default='.')
parser.add_argument('--kf', type=str, default='.')
parser.add_argument('--setn', type=str, default='train')
main_args = parser.parse_args()
print_big_bar()

###################################################################################################################################################
import numpy as np
from fuzzytools.files import load_pickle, save_pickle, get_dict_from_filedir

filedir = f'../../surveys-save/survey=alerceZTFv7.1~bands=gr~mode=onlySNe.splcds'
filedict = get_dict_from_filedir(filedir)
rootdir = filedict['_rootdir']
cfilename = filedict['_cfilename']
survey = filedict['survey']
lcdataset = load_pickle(filedir)
print(lcdataset)

###################################################################################################################################################
from synthsne.synthetic_datasets import generate_synthetic_dataset
import pandas as pd
import numpy as np
from synthsne import _C
import fuzzytools.files as ff
from fuzzytools.progress_bars import ProgressBar
from fuzzytools.files import load_pickle, save_pickle
from synthsne.distr_fittings import ObsErrorConditionalSampler
from synthsne.plots.samplers import plot_obse_samplers
from synthsne.plots.mcmc import plot_mcmc_prior
from fuzzytools.dicts import along_dict_obj_method
from nested_dict import nested_dict
def get_df_table(
    rootdir,
    metric_names,
    model_names,
    day_to_metric,
    format_f,
    fext='metrics',
    mode='fine-tuning',
    arch_modes=['Parallel', 'Serial'],
):
    index_df = []
    info_df = {}
    for arch_mode in arch_modes:
        for model_name in model_names:
            info_df[f'{format_f(model_name)} [{arch_mode}]'] = []

    for kmn, model_name in enumerate(model_names):
        new_rootdir = f'{rootdir}/{mode}/{model_name}'
        new_rootdir = new_rootdir.replace('mode=pre-training',
                                          f'mode={mode}')  # patch
        new_rootdir = new_rootdir.replace('mode=fine-tuning',
                                          f'mode={mode}')  # patch
        filedirs = search_for_filedirs(new_rootdir, fext=fext, verbose=0)
        print(f'[{kmn}][{len(filedirs)}#] {model_name}')
        mn_dict = strings.get_dict_from_string(model_name)
        rsc = mn_dict['rsc']
        mdl = mn_dict['mdl']
        is_parallel = 'Parallel' in mdl
        arch_mode = 'Parallel' if is_parallel else 'Serial'

        if arch_mode in arch_modes:
            for km, metric_name in enumerate(metric_names):
                day_metric = []
                day_metric_avg = []
                for filedir in filedirs:
                    rdict = load_pickle(filedir, verbose=0)
                    #model_name = rdict['model_name']
                    days = rdict['days']
                    survey = rdict['survey']
                    band_names = ''.join(rdict['band_names'])
                    class_names = rdict['class_names']
                    v, vs, _ = utils.get_metric_along_day(
                        days, rdict, metric_name, day_to_metric)
                    day_metric += [v]
                    day_metric_avg += [vs.mean()]

                xe_day_metric = dstats.XError(day_metric, 0)
                xe_day_metric_avg = dstats.XError(day_metric_avg, 0)
                key = f'{format_f(model_name)} [{arch_mode}]'
                info_df[key] += [xe_day_metric]
                info_df[key] += [xe_day_metric_avg]

                key = f'metric={utils.get_mday_str(metric_name, day_to_metric)}'
                if not key in index_df:
                    index_df += [key]
                    index_df += [
                        f'metric={utils.get_mday_avg_str(metric_name, day_to_metric)}'
                    ]

    info_df = pd.DataFrame.from_dict(info_df)
    info_df.index = index_df
    return info_df
示例#9
0
parser.add_argument('--method', type=str, default='.')
parser.add_argument('--kf', type=str, default='.')
parser.add_argument('--setn', type=str, default='train')
main_args = parser.parse_args()
print_big_bar()

###################################################################################################################################################
import numpy as np
from fuzzytools.files import load_pickle, save_pickle, get_dict_from_filedir

filedir = f'../../surveys-save/survey=alerceZTFv7.1~bands=gr~mode=onlySNe.splcds'
filedict = get_dict_from_filedir(filedir)
rootdir = filedict['_rootdir']
cfilename = filedict['_cfilename']
survey = filedict['survey']
lcdataset = load_pickle(filedir)
print(lcdataset)

###################################################################################################################################################
import numpy as np
import fuzzytools.files as fcfiles
from fuzzytools.progress_bars import ProgressBar
from fuzzytools.files import load_pickle, save_pickle
from synthsne import _C

kfs = lcdataset.kfolds if main_args.kf == '.' else main_args.kf
kfs = [kfs] if isinstance(kfs, str) else kfs
setns = [str(setn) for setn in ['train', 'val']
         ] if main_args.setn == '.' else main_args.setn
setns = [setns] if isinstance(setns, str) else setns