lu.print_blue("Finished rf training, validating and testing") ################ # VALIDATION ################ if settings.validate_rnn: if settings.model_files is None: validate_rnn.get_predictions(settings) # Compute metrics metrics.get_metrics_singlemodel(settings, model_type="rnn") else: for model_file in settings.model_files: # Restore model settings model_settings = conf.get_settings_from_dump( model_file, override_source_data=settings.override_source_data) # Get predictions prediction_file = validate_rnn.get_predictions( model_settings, model_file=model_file) # Compute metrics metrics.get_metrics_singlemodel( model_settings, prediction_file=prediction_file, model_type="rnn") if settings.validate_rf: if settings.model_files is None: validate_randomforest.get_predictions(settings) # Compute metrics
def towards_cosmo(df, df_delta, df_delta_ood, settings, plots, debug): """ Towards cosmology """ # 1. Calibration print(lu.str_to_bluestr("Calibration")) df_sel = df.copy() df_sel = df_sel.round(4) # rf can't be done with photometry df_sel = df_sel[ ( df_sel["model_name_noseed"].isin( [l.replace("photometry", "saltfit") for l in list_models] ) ) & (df_sel["source_data"] == "saltfit") ] print("saltfit") if not debug: sm.nice_df_print( df_sel, keys=[ "model_name_noseed", "calibration_dispersion_mean", "calibration_dispersion_std", ], ) # without rf print("photometry") df_sel = df.copy() df_sel = df_sel[ (df_sel["model_name_noseed"].isin([Base, Var, BBB])) & (df_sel["source_data"] == "photometry") ] if not debug: sm.nice_df_print( df_sel, keys=[ "model_name_noseed", "calibration_dispersion_mean", "calibration_dispersion_std", ], ) # Calibration vs. training set size # using salt print(lu.str_to_bluestr("Calibration vs. data set size")) print("Baseline") sel_criteria = Base_salt.split("DF_1.0") if debug: print(sel_criteria) else: sm.get_metric_ranges( df, sel_criteria, metric="calibration_dispersion", round_output=5 ) # Calibration vs. dataset nature sel_criteria = [ "DES_vanilla_CLF_2_R_None_saltfit_DF_1.0_N_global_lstm_32x2_0.05_128_True_mean_C" ] if debug: print(sel_criteria) else: sm.get_metric_ranges( df, sel_criteria, metric="calibration_dispersion", round_output=5 ) # Calibration figure if plots: print(lu.str_to_yellowstr("Plotting reliability diagram (Figure 6)")) tmp_pred_files = settings.prediction_files settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_") .replace("CLF_2", "S_0_CLF_2") .replace("photometry", "saltfit") + "/" + model.replace("DES_", "PRED_DES_") .replace("CLF_2", "S_0_CLF_2") .replace("photometry", "saltfit") + ".pickle" for model in [RF, Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.plot_calibration(settings) settings.prediction_files = tmp_pred_files # 2. Representativeness print(lu.str_to_bluestr("Representativeness")) for model in [Base, Var, BBB]: m_left = model.replace("photometry", "saltfit") m_right = model.replace("DF_1.0", "DF_0.43") print(m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) # 3. OOD print(lu.str_to_bluestr("Out-of-distribution light-curves")) # OOD type assignement figure if plots: print(lu.str_to_yellowstr("Plotting OOD classification percentages (Figure 8)")) if not debug: sp.create_OOD_classification_plots(df, list_models_rnn, settings) # Get entropy print("binary") df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_rnn)] if not debug: sm.nice_df_print(df_sel) print("ternary") list_models_sel = [l.replace("CLF_2", "CLF_3") for l in list_models_rnn] df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_sel)] if not debug: sm.nice_df_print(df_sel) print("seven-way") list_models_sel = [l.replace("CLF_2", "CLF_7") for l in list_models_rnn] df_sel = df_delta_ood[df_delta_ood["model_name_noseed"].isin(list_models_sel)] if not debug: sm.nice_df_print(df_sel) if plots: print( lu.str_to_yellowstr( "Plotting OOD candidates with seven-way classification (Figure 9)" ) ) model_file = ( f"{settings.models_dir}/" + Var.replace("DES_variational_", "variational_S_0_").replace( "CLF_2", "CLF_7" ) + "/" + Var.replace("DES_variational_", "variational_S_0_").replace( "CLF_2", "CLF_7" ) + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction(model_settings, nb_lcs=20) else: print(lu.str_to_redstr(f"File not found {model_file}")) # 4. Cosmology print(lu.str_to_bluestr("SNe Ia for cosmology")) # Plotting Hubble residuals adn other science plots for just one seed if plots: print(lu.str_to_yellowstr("Plotting Hubble residuals (Figures 10 and 11)")) tmp_pred_files = settings.prediction_files settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + "/" + model.replace("DES_", "PRED_DES_").replace("CLF_2", "S_0_CLF_2") + ".pickle" for model in [Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.science_plots(settings) settings.prediction_files = tmp_pred_files
def baseline(df, settings, plots, debug): """ Baseline RNN """ # 0. Figure example if plots: print( lu.str_to_yellowstr("Plotting candidates for Baseline binary (Figure 2.)") ) model_file = f"{settings.models_dir}/{Base.replace('DES_vanilla_','vanilla_S_0_')}/{Base.replace('DES_vanilla_','vanilla_S_0_')}.pt" if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=20, do_gifs=True ) else: print(lu.str_to_redstr(f"File not found {model_file}")) # 1. Hyper-parameters # saltfit, DF 0.2 sel_criteria = ["DES_vanilla_CLF_2_R_None_saltfit_DF_0.2"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 2. Normalization # saltfit, DF 0.5 sel_criteria = Base_salt.replace("DF_1.0", "DF_0.5").split("global") print(lu.str_to_bluestr(f"Normalization {sel_criteria}")) if not debug: df_sel = sm.get_metric_ranges(df, sel_criteria) # 3. Comparing with other methods print(lu.str_to_bluestr(f"Other methods:")) # Figure: accuracy vs. number of SNe if plots: print(lu.str_to_yellowstr("Plotting accuracy vs. SNe (Figure 3.)")) if not debug: sp.performance_plots(settings) # baseline best, saltfit if debug: print(Base_salt) print(Base_salt.replace("DF_1.0", "DF_0.05")) print(Base_salt.replace("DF_1.0", "DF_0.05").replace("None", "zpho")) else: sm.acc_auc_df(df, [Base_salt], data="saltfit") # RF and baseline comparisson with Charnock Moss sm.acc_auc_df( df, [ RF, Base_salt.replace("DF_1.0", "DF_0.05"), Base_salt.replace("DF_1.0", "DF_0.05").replace("None", "zpho"), ], ) # 4. Redshift, contamination # baseline saltfit, 1.0, all redshifts sel_criteria = Base_salt.split("None") print("salt") if debug: print(sel_criteria, Base.split("None")) if not debug: sm.print_contamination(df, sel_criteria, settings, data="saltfit") print("photometry") if not debug: sm.print_contamination(df, Base.split("None"), settings, data="photometry") # Multiclass # Plotting Confusion Matrix for just one seed if plots: print( lu.str_to_yellowstr( "Plotting confusion matrix for multiclass classification (Figure 4.)" ) ) for target in [3, 7]: settings.prediction_files = [ settings.models_dir + "/" + model.strip("DES_").replace("CLF_2", f"S_0_CLF_{target}") + "/" + model.replace("DES_", "PRED_DES_").replace( "CLF_2", f"S_0_CLF_{target}" ) + ".pickle" for model in [Base, Var, BBB] ] if debug: print(settings.prediction_files) else: sp.science_plots(settings, onlycnf=True)
def bayesian(df, df_delta, df_delta_ood, settings, plots, debug): """ Bayesian RNNs: BBB and Variational """ # 2. Variational hyper-parameters sel_criteria = ["DES_variational_CLF_2_R_None_saltfit_DF_0.2_N_global_lstm_32x2_"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 3. BBB hyper-parameters sel_criteria = ["DES_bayesian_CLF_2_R_None_saltfit_DF_0.2_N_global_lstm_32x2_"] print(lu.str_to_bluestr(f"Hyperparameters {sel_criteria}")) if not debug: sm.get_metric_ranges(df, sel_criteria) # 2 and 3. Best models print(lu.str_to_bluestr("Best performing Bayesian accuracies")) if debug: print( [ Var, BBB, Var.replace("None", "zpho"), BBB.replace("None", "zpho"), Var.replace("None", "zspe"), BBB.replace("None", "zspe"), ] ) else: sm.acc_auc_df( df, [ Var, BBB, Var.replace("None", "zpho"), BBB.replace("None", "zpho"), Var.replace("None", "zspe"), BBB.replace("None", "zspe"), ], ) # contamination # baseline saltfit, 1.0, all redshifts for model in [Var, BBB]: sel_criteria = model.split("None") if debug: print("contamination") print(sel_criteria) else: sm.print_contamination(df, sel_criteria, settings, data="photometry") # 4. Uncertainties print(lu.str_to_bluestr("Best performing Bayesian uncertainties")) print("Epistemic behaviour") for model in [Var, BBB]: m_right = model.replace("photometry", "saltfit") m_left = model.replace("photometry", "saltfit").replace("DF_1.0", "DF_0.5") print("salt", m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) m_right = model m_left = model.replace("DF_1.0", "DF_0.43") print("complete", m_left, m_right) df_sel = df_delta[ (df_delta["model_name_left"] == m_left) & (df_delta["model_name_right"] == m_right) ] if not debug: sm.nice_df_print( df_sel, keys=[ "all_accuracy_mean_delta", "mean_all_class0_std_dev_mean_delta", "all_entropy_mean_delta", ], ) print("Uncertainty size") df_sel = df[df["model_name_noseed"].isin([Var, BBB])] df_sel = df_sel.round(4) if not debug: sm.nice_df_print( df_sel, keys=["mean_all_class0_std_dev_mean", "mean_all_class0_std_dev_std"] ) if plots: print( lu.str_to_yellowstr( "Plotting candidates for multiclass classification (Fig. 5)" ) ) for model in [Var, BBB]: model_file = ( f"{settings.models_dir}/" + model.replace("DES_", "").replace("CLF_2", "S_0_CLF_7") + "/" + model.replace("DES_", "").replace("CLF_2", "S_0_CLF_7") + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=20, do_gifs=True ) else: print(lu.str_to_redstr(f"File not found {model_file}")) print(lu.str_to_yellowstr("Adding gifs for binary classification")) for model in [Var, BBB]: model_file = ( f"{settings.models_dir}/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + "/" + model.strip("DES_").replace("CLF_2", "S_0_CLF_2") + ".pt" ) if os.path.exists(model_file): if debug: print(model_file) else: model_settings = conf.get_settings_from_dump(model_file) early_prediction.make_early_prediction( model_settings, nb_lcs=10, do_gifs=True )