def _update_light_curve_data_for_next_epoch( light_curve_data: DataBase, next_day_data: DataBase, canonical_data: DataBase, is_queryable: bool, strategy: str, is_separate_files: bool) -> DataBase: """ Updates samples for next epoch Parameters ---------- light_curve_data light curve learning data next_day_data next day light curve data canonical_data canonical strategy light curve data is_queryable If True, allow queries only on objects flagged as queryable. Default is True. strategy Query strategy. Options are (all can be run with budget): "UncSampling", "UncSamplingEntropy", "UncSamplingLeastConfident", "UncSamplingMargin", "QBDMI", "QBDEntropy", "RandomSampling", is_separate_files If True, consider samples separately read from independent files. Default is False. """ light_curve_data.pool_metadata = next_day_data.pool_metadata light_curve_data.pool_features = next_day_data.pool_features light_curve_data.pool_labels = next_day_data.pool_labels if not is_separate_files: light_curve_data.test_metadata = next_day_data.test_metadata light_curve_data.test_features = next_day_data.test_features light_curve_data.test_labels = next_day_data.test_labels light_curve_data.validation_metadata = next_day_data.validation_metadata light_curve_data.validation_features = next_day_data.validation_features light_curve_data.validation_labels = next_day_data.validation_labels if strategy == 'canonical': light_curve_data.queryable_ids = canonical_data.queryable_ids if is_queryable: queryable_flag = light_curve_data.pool_metadata['queryable'].values light_curve_data.queryable_ids = light_curve_data.pool_metadata[ 'id'].values[queryable_flag] else: light_curve_data.queryable_ids = light_curve_data.pool_metadata[ 'id'].values return light_curve_data
def _update_queryable_ids(light_curve_data: DataBase, id_key_name: str, is_queryable: bool) -> DataBase: """ Updates queryable ids Parameters ---------- light_curve_data initial light curve training data id_key_name object identification key name is_queryable If True, allow queries only on objects flagged as queryable. Default is True. """ if is_queryable: queryable_flags = light_curve_data.pool_metadata['queryable'].values light_curve_data.queryable_ids = light_curve_data.pool_metadata[ id_key_name].values[queryable_flags] else: light_curve_data.queryable_ids = light_curve_data.pool_metadata[ id_key_name].values return light_curve_data
def _update_canonical_ids(light_curve_data: DataBase, canonical_file_name: str, is_restrict_canonical: bool) -> Tuple[ DataBase, DataBase]: """ Updates canonical ids Parameters ---------- light_curve_data initial light curve training data canonical_file_name Path to canonical sample features files. It is only used if "strategy==canonical". is_restrict_canonical If True, restrict the search to the canonical sample. """ database_class = None if is_restrict_canonical: database_class = DataBase() database_class.load_features(path_to_file=canonical_file_name) light_curve_data.queryable_ids = database_class.queryable_ids return light_curve_data, database_class
def time_domain_loop(days: list, output_diag_file: str, output_queried_file: str, path_to_features_dir: str, strategy: str, batch=1, canonical = False, classifier='RandomForest', features_method='Bazin', path_to_canonical="", path_to_full_lc_features="", screen=True, training='original'): """Perform the active learning loop. All results are saved to file. Parameters ---------- days: list List of 2 elements. First and last day of observations since the beginning of the survey. output_diag_file: str Full path to output file to store diagnostics of each loop. output_queried_file: str Full path to output file to store the queried sample. path_to_features_dir: str Complete path to directory holding features files for all days. strategy: str Query strategy. Options are 'UncSampling' and 'RandomSampling'. batch: int (optional) Size of batch to be queried in each loop. Default is 1. canonical: bool (optional) If True, restrict the search to the canonical sample. classifier: str (optional) Machine Learning algorithm. Currently only 'RandomForest' is implemented. features_method: str (optional) Feature extraction method. Currently only 'Bazin' is implemented. path_to_canonical: str (optional) Path to canonical sample features files. It is only used if "strategy==canonical". path_to_full_lc_features: str (optional) Path to full light curve features file. Only used if training is a number. screen: bool (optional) If True, print on screen number of light curves processed. training: str or int (optional) Choice of initial training sample. If 'original': begin from the train sample flagged in the file If int: choose the required number of samples at random, ensuring that at least half are SN Ia Default is 'original'. """ ## This will need to change for RESSPECT # initiate object data = DataBase() # load features for the first day path_to_features = path_to_features_dir + 'day_' + str(int(days[0])) + '.dat' data.load_features(path_to_features, method=features_method, screen=screen) # change training if training == 'original': data.build_samples(initial_training='original') full_lc_features = get_original_training(path_to_features=path_to_full_lc_features) data.train_metadata = full_lc_features.train_metadata data.train_labels = full_lc_features.train_labels data.train_features = full_lc_features.train_features else: data.build_samples(initial_training=int(training)) # get list of canonical ids if canonical: canonical = DataBase() canonical.load_features(path_to_file=path_to_canonical) data.queryable_ids = canonical.queryable_ids for night in range(int(days[0]), int(days[-1]) - 1): if screen: print('Processing night: ', night) # cont loop loop = night - int(days[0]) # classify data.classify(method=classifier) # calculate metrics data.evaluate_classification() # choose object to query indx = data.make_query(strategy=strategy, batch=batch) # update training and test samples data.update_samples(indx, loop=loop) # save diagnostics for current state data.save_metrics(loop=loop, output_metrics_file=output_diag_file, batch=batch, epoch=night) # save query sample to file data.save_queried_sample(output_queried_file, loop=loop, full_sample=False) # load features for next day path_to_features2 = path_to_features_dir + 'day_' + str(night + 1) + '.dat' data_tomorrow = DataBase() data_tomorrow.load_features(path_to_features2, method=features_method, screen=False) # identify objects in the new day which must be in training train_flag = np.array([item in data.train_metadata['id'].values for item in data_tomorrow.metadata['id'].values]) # use new data data.train_metadata = data_tomorrow.metadata[train_flag] data.train_features = data_tomorrow.features.values[train_flag] data.test_metadata = data_tomorrow.metadata[~train_flag] data.test_features = data_tomorrow.features.values[~train_flag] # new labels data.train_labels = np.array([int(item == 'Ia') for item in data.train_metadata['type'].values]) data.test_labels = np.array([int(item == 'Ia') for item in data.test_metadata['type'].values]) if strategy == 'canonical': data.queryable_ids = canonical.queryable_ids if queryable: queryable_flag = data_tomorrow.metadata['queryable'].values queryable_test_flag = np.logical_and(~train_flag, queryable_flag) data.queryable_ids = data_tomorrow.metadata['id'].values[queryable_test_flag] else: data.queryable_ids = data_tomorrow.metadata['id'].values[~train_flag] if screen: print('Training set size: ', data.train_metadata.shape[0]) print('Test set size: ', data.test_metadata.shape[0])