def _update_samples_with_object_indices( database_class: DataBase, object_indices: list, is_queryable: bool, epoch: int) -> DataBase: """ Runs database class update_samples methods with object indices Parameters ---------- database_class An instance of DataBase class object_indices List of indexes identifying objects to be moved. is_queryable If True, consider queryable flag. Default is False. epoch Day since beginning of survey. Default is 20. """ database_class.update_samples( object_indices, queryable=is_queryable, epoch=epoch) return database_class
def learn_loop(nloops: int, strategy: str, path_to_features: str, output_metrics_file: str, output_queried_file: str, features_method: str = 'Bazin', classifier: str = 'RandomForest', training: str = 'original', batch: int = 1, survey: str = 'DES', nclass: int = 2, photo_class_thr: float = 0.5, photo_ids_to_file: bool = False, photo_ids_froot: str = ' ', classifier_bootstrap: bool = False, save_predictions: bool = False, sep_files=False, pred_dir: str = None, queryable: bool = False, metric_label: str = 'snpcc', save_alt_class: bool = False, SNANA_types: bool = False, metadata_fname: str = None, bar: bool = True, **kwargs): """ Perform the active learning loop. All results are saved to file. Parameters ---------- nloops: int Number of active learning loops to run. strategy: str Query strategy. Options are 'UncSampling', 'RandomSampling', 'UncSamplingEntropy', 'UncSamplingLeastConfident', 'UncSamplingMargin', 'QBDMI' and 'QBDEntropy'. path_to_features: str or dict Complete path to input features file. if dict, keywords should be 'train' and 'test', and values must contain the path for separate train and test sample files. output_metrics_file: str Full path to output file to store metric values of each loop. output_queried_file: str Full path to output file to store the queried sample. features_method: str (optional) Feature extraction method. Currently only 'Bazin' is implemented. classifier: str (optional) Machine Learning algorithm. Currently implemented options are 'RandomForest', 'GradientBoostedTrees', 'K-NNclassifier','MLPclassifier','SVMclassifier' and 'NBclassifier'. Default is 'RandomForest'. sep_files: bool (optional) If True, consider train and test samples separately read from independent files. Default is False. batch: int (optional) Size of batch to be queried in each loop. Default is 1. classifier_bootstrap: bool (optional) Flag for bootstrapping on the classifier Must be true if using disagreement based strategy. metadata_fname: str (optional) Complete path to PLAsTiCC zenodo test metadata. Only used it SNANA_types == True. Default is None. metric_label: str (optional) Choice of metric. Currently only "snpcc", "cosmo" or "snpcc_cosmo" are accepted. Default is "snpcc". nclass: int (optional) Number of classes to consider in the classification Currently only nclass == 2 is implemented. photo_class_thr: float (optional) Threshold for photometric classification. Default is 0.5. Only used if photo_ids is True. photo_ids_to_file: bool (optional) If True, save photometric ids to file. Default is False. photo_ids_froot: str (optional) Output root of file name to store photo ids. Only used if photo_ids is True. pred_dir: str (optional) Output diretory to store prediction file for each loop. Only used if `save_predictions==True`. queryable: bool (optional) If True, check if randomly chosen object is queryable. Default is False. save_alt_class: bool (optional) If True, train the model and save classifications for alternative query label (this is necessary to calculate impact on cosmology). Default is False. save_predictions: bool (optional) If True, save classification predictions to file in each loop. Default is False. SNANA_types: bool (optional) If True, translate zenodo types to SNANA codes. Default is False. survey: str (optional) 'DES' or 'LSST'. Default is 'DES'. Name of the survey which characterizes filter set. training: str or int (optional) Choice of initial training sample. If 'original': begin from the train sample flagged in the file If int: choose the required number of samples at random, ensuring that at least half are SN Ia Default is 'original'. bar: bool (optional) If True, display progress bar. kwargs: extra parameters All keywords required by the classifier function. """ if 'QBD' in strategy and not classifier_bootstrap: raise ValueError( 'Bootstrap must be true when using disagreement strategy') # initiate object database_class = DataBase() logging.info('Loading features') database_class = load_features(database_class, path_to_features, survey, features_method, nclass, training, queryable, sep_files) logging.info('Running active learning loop') if bar: ensemble = progressbar.progressbar(range(nloops)) else: ensemble = range(nloops) for iteration_step in ensemble: if not bar: print(iteration_step) database_class = run_classification(database_class, classifier, classifier_bootstrap, pred_dir, save_predictions, iteration_step, **kwargs) run_evaluation(database_class, metric_label) save_photo_ids(database_class, photo_ids_to_file, SNANA_types, metadata_fname, photo_class_thr, iteration_step, photo_ids_froot, '.dat') indices_to_query = run_make_query(database_class, strategy, batch, queryable) if save_alt_class and batch == 1: database_class_alternative = copy.deepcopy(database_class) database_class_alternative = update_alternative_label( database_class_alternative, indices_to_query, iteration_step, classifier, pred_dir, save_predictions, metric_label, SNANA_types, photo_ids_to_file, metadata_fname, photo_class_thr, photo_ids_froot, **kwargs) _save_metrics_and_queried_samples(database_class_alternative, output_metrics_file, output_queried_file, iteration_step, batch, False, '_alt_label.dat') elif save_alt_class and batch > 1: raise ValueError('Alternative label only works with batch=1!') database_class.update_samples(indices_to_query, epoch=iteration_step, queryable=queryable, alternative_label=False) _save_metrics_and_queried_samples(database_class, output_metrics_file, output_queried_file, iteration_step, batch, False) return database_class
def update_alternative_label( database_class_alternative: DataBase, indices_to_query: list, iteration_step: int, classifier: str, pred_dir: str, is_save_prediction: bool, metric_label: str, is_save_snana_types: bool, is_save_photoids_to_file: bool, meta_data_fname: str, photo_class_threshold: float, photo_ids_froot: str, **kwargs: dict): """ Function to update active learning training with alternative label Parameters ---------- database_class_alternative An instance of DataBase class for alternative label indices_to_query List of indexes identifying objects to be moved. iteration_step active learning iteration number classifier Machine Learning algorithm. Currently implemented options are 'RandomForest', 'GradientBoostedTrees', 'K-NNclassifier','MLPclassifier','SVMclassifier' and 'NBclassifier'. Default is 'RandomForest'. pred_dir Output diretory to store prediction file for each loop. Only used if `save_predictions==True`. is_save_prediction if predictions should be saved metric_label Choice of metric. Currently only "snpcc", "cosmo" or "snpcc_cosmo" are accepted. Default is "snpcc". is_save_snana_types if True, translate type to SNANA codes and add column with original values. Default is False. is_save_photoids_to_file If true, populate the photo_Ia_list attribute. Otherwise write to file. Default is False. meta_data_fname Full path to PLAsTiCC zenodo test metadata file. photo_class_threshold Probability threshold above which an object is considered Ia. photo_ids_froot Output root of file name to store photo ids. Only used if photo_ids is True. kwargs additional arguments """ database_class_alternative.update_samples(indices_to_query, epoch=iteration_step, alternative_label=True) database_class_alternative = run_classification(database_class_alternative, classifier, False, pred_dir, is_save_prediction, iteration_step, **kwargs) run_evaluation(database_class_alternative, metric_label) save_photo_ids(database_class_alternative, is_save_photoids_to_file, is_save_snana_types, meta_data_fname, photo_class_threshold, iteration_step, photo_ids_froot, '_alt_label.dat') return database_class_alternative
def learn_loop(nloops: int, strategy: str, path_to_features: str, output_diag_file: str, output_queried_file: str, features_method='Bazin', classifier='RandomForest', training='original', batch=1, screen=True, survey='DES', perc=0.1, nclass=2): """Perform the active learning loop. All results are saved to file. Parameters ---------- nloops: int Number of active learning loops to run. strategy: str Query strategy. Options are 'UncSampling' and 'RandomSampling'. path_to_features: str or dict Complete path to input features file. if dict, keywords should be 'train' and 'test', and values must contain the path for separate train and test sample files. output_diag_file: str Full path to output file to store diagnostics of each loop. output_queried_file: str Full path to output file to store the queried sample. features_method: str (optional) Feature extraction method. Currently only 'Bazin' is implemented. classifier: str (optional) Machine Learning algorithm. Currently only 'RandomForest' is implemented. training: str or int (optional) Choice of initial training sample. If 'original': begin from the train sample flagged in the file If int: choose the required number of samples at random, ensuring that at least half are SN Ia Default is 'original'. batch: int (optional) Size of batch to be queried in each loop. Default is 1. screen: bool (optional) If True, print on screen number of light curves processed. survey: str (optional) 'DES' or 'LSST'. Default is 'DES'. Name of the survey which characterizes filter set. perc: float in [0,1] (optioal) Percentile chosen to identify the new query. Only used for PercentileSampling. Default is 0.1. nclass: int (optional) Number of classes to consider in the classification Currently only nclass == 2 is implemented. """ ## This module will need to be expanded for RESSPECT # initiate object data = DataBase() # load features if isinstance(path_to_features, str): data.load_features(path_to_features, method=features_method, screen=screen, survey=survey) # separate training and test samples data.build_samples(initial_training=training, nclass=nclass) else: data.load_features(path_to_features['train'], method=features_method, screen=screen, survey=survey, sample='train') data.load_features(path_to_features['test'], method=features_method, screen=screen, survey=survey, sample='test') data.build_samples(initial_training=training, nclass=nclass, screen=screen, sep_files=True) for loop in range(nloops): if screen: print('Processing... ', loop) # classify data.classify(method=classifier) # calculate metrics data.evaluate_classification() # choose object to query indx = data.make_query(strategy=strategy, batch=batch, perc=perc) # update training and test samples data.update_samples(indx, loop=loop) # save diagnostics for current state data.save_metrics(loop=loop, output_metrics_file=output_diag_file, batch=batch, epoch=loop) # save query sample to file data.save_queried_sample(output_queried_file, loop=loop, full_sample=False)
def time_domain_loop(days: list, output_diag_file: str, output_queried_file: str, path_to_features_dir: str, strategy: str, batch=1, canonical = False, classifier='RandomForest', features_method='Bazin', path_to_canonical="", path_to_full_lc_features="", screen=True, training='original'): """Perform the active learning loop. All results are saved to file. Parameters ---------- days: list List of 2 elements. First and last day of observations since the beginning of the survey. output_diag_file: str Full path to output file to store diagnostics of each loop. output_queried_file: str Full path to output file to store the queried sample. path_to_features_dir: str Complete path to directory holding features files for all days. strategy: str Query strategy. Options are 'UncSampling' and 'RandomSampling'. batch: int (optional) Size of batch to be queried in each loop. Default is 1. canonical: bool (optional) If True, restrict the search to the canonical sample. classifier: str (optional) Machine Learning algorithm. Currently only 'RandomForest' is implemented. features_method: str (optional) Feature extraction method. Currently only 'Bazin' is implemented. path_to_canonical: str (optional) Path to canonical sample features files. It is only used if "strategy==canonical". path_to_full_lc_features: str (optional) Path to full light curve features file. Only used if training is a number. screen: bool (optional) If True, print on screen number of light curves processed. training: str or int (optional) Choice of initial training sample. If 'original': begin from the train sample flagged in the file If int: choose the required number of samples at random, ensuring that at least half are SN Ia Default is 'original'. """ ## This will need to change for RESSPECT # initiate object data = DataBase() # load features for the first day path_to_features = path_to_features_dir + 'day_' + str(int(days[0])) + '.dat' data.load_features(path_to_features, method=features_method, screen=screen) # change training if training == 'original': data.build_samples(initial_training='original') full_lc_features = get_original_training(path_to_features=path_to_full_lc_features) data.train_metadata = full_lc_features.train_metadata data.train_labels = full_lc_features.train_labels data.train_features = full_lc_features.train_features else: data.build_samples(initial_training=int(training)) # get list of canonical ids if canonical: canonical = DataBase() canonical.load_features(path_to_file=path_to_canonical) data.queryable_ids = canonical.queryable_ids for night in range(int(days[0]), int(days[-1]) - 1): if screen: print('Processing night: ', night) # cont loop loop = night - int(days[0]) # classify data.classify(method=classifier) # calculate metrics data.evaluate_classification() # choose object to query indx = data.make_query(strategy=strategy, batch=batch) # update training and test samples data.update_samples(indx, loop=loop) # save diagnostics for current state data.save_metrics(loop=loop, output_metrics_file=output_diag_file, batch=batch, epoch=night) # save query sample to file data.save_queried_sample(output_queried_file, loop=loop, full_sample=False) # load features for next day path_to_features2 = path_to_features_dir + 'day_' + str(night + 1) + '.dat' data_tomorrow = DataBase() data_tomorrow.load_features(path_to_features2, method=features_method, screen=False) # identify objects in the new day which must be in training train_flag = np.array([item in data.train_metadata['id'].values for item in data_tomorrow.metadata['id'].values]) # use new data data.train_metadata = data_tomorrow.metadata[train_flag] data.train_features = data_tomorrow.features.values[train_flag] data.test_metadata = data_tomorrow.metadata[~train_flag] data.test_features = data_tomorrow.features.values[~train_flag] # new labels data.train_labels = np.array([int(item == 'Ia') for item in data.train_metadata['type'].values]) data.test_labels = np.array([int(item == 'Ia') for item in data.test_metadata['type'].values]) if strategy == 'canonical': data.queryable_ids = canonical.queryable_ids if queryable: queryable_flag = data_tomorrow.metadata['queryable'].values queryable_test_flag = np.logical_and(~train_flag, queryable_flag) data.queryable_ids = data_tomorrow.metadata['id'].values[queryable_test_flag] else: data.queryable_ids = data_tomorrow.metadata['id'].values[~train_flag] if screen: print('Training set size: ', data.train_metadata.shape[0]) print('Test set size: ', data.test_metadata.shape[0])