예제 #1
0
def _update_light_curve_data_for_next_epoch(
        light_curve_data: DataBase, next_day_data: DataBase,
        canonical_data: DataBase, is_queryable: bool, strategy: str,
        is_separate_files: bool) -> DataBase:
    """
    Updates samples for next epoch

    Parameters
    ----------
    light_curve_data
        light curve learning data
    next_day_data
        next day light curve data
    canonical_data
        canonical strategy light curve data
    is_queryable
        If True, allow queries only on objects flagged as queryable.
        Default is True.
    strategy
        Query strategy. Options are (all can be run with budget):
        "UncSampling", "UncSamplingEntropy", "UncSamplingLeastConfident",
        "UncSamplingMargin", "QBDMI", "QBDEntropy", "RandomSampling",
    is_separate_files
        If True, consider samples separately read
        from independent files. Default is False.
    """
    light_curve_data.pool_metadata = next_day_data.pool_metadata
    light_curve_data.pool_features = next_day_data.pool_features
    light_curve_data.pool_labels = next_day_data.pool_labels

    if not is_separate_files:
        light_curve_data.test_metadata = next_day_data.test_metadata
        light_curve_data.test_features = next_day_data.test_features
        light_curve_data.test_labels = next_day_data.test_labels

        light_curve_data.validation_metadata = next_day_data.validation_metadata
        light_curve_data.validation_features = next_day_data.validation_features
        light_curve_data.validation_labels = next_day_data.validation_labels

    if strategy == 'canonical':
        light_curve_data.queryable_ids = canonical_data.queryable_ids

    if is_queryable:
        queryable_flag = light_curve_data.pool_metadata['queryable'].values
        light_curve_data.queryable_ids = light_curve_data.pool_metadata[
            'id'].values[queryable_flag]
    else:
        light_curve_data.queryable_ids = light_curve_data.pool_metadata[
            'id'].values
    return light_curve_data
예제 #2
0
def _update_light_curve_data_val_and_test_data(
        light_curve_data: DataBase, first_loop_data: DataBase,
        is_separate_files: bool = False,
        initial_training: Union[str, int] = 'original',
        is_queryable: bool = False, number_of_classes: int = 2) -> DataBase:
    """
    Updates initial light curve validation and test data

    Parameters
    ----------
    light_curve_data
        initial light curve training data
    first_loop_data
        first loop light curve data
    is_queryable
        If True, allow queries only on objects flagged as queryable.
        Default is True.
    is_separate_files
        If True, consider samples separately read
        from independent files. Default is False.
    initial_training
        Choice of initial training sample.
        If 'original': begin from the train sample flagged in the file
        eilf 'previous': read training and queried from previous run.
        If int: choose the required number of samples at random,
        ensuring that at least half are SN Ia
        Default is 'original'.
    number_of_classes
        Number of classes to consider in the classification
        Currently only number_of_classes == 2 is implemented.
    """
    if is_separate_files:
        light_curve_data.build_samples(
            nclass=number_of_classes, queryable=is_queryable,
            sep_files=is_separate_files, initial_training=initial_training)
    else:
        light_curve_data.test_features = first_loop_data.pool_features
        light_curve_data.test_metadata = first_loop_data.pool_metadata
        light_curve_data.test_labels = first_loop_data.pool_labels

        light_curve_data.validation_features = first_loop_data.pool_features
        light_curve_data.validation_metadata = first_loop_data.pool_metadata
        light_curve_data.validation_labels = first_loop_data.pool_labels
    return light_curve_data
예제 #3
0
def _update_next_day_val_and_test_data(
        next_day_data: DataBase, metadata_value: int,
        id_key_name: str) -> DataBase:
    """
    Removes metadata value data from next day validation and test samples

    Parameters
    ----------
    next_day_data
        next day light curve data
    metadata_value
        metadata object value
    id_key_name
        object identification key name
    """
    if (len(next_day_data.validation_metadata) > 0 and metadata_value
            in next_day_data.validation_metadata[id_key_name].values):
        val_indices = list(next_day_data.validation_metadata[
                               id_key_name].values).index(metadata_value)
        next_day_data.validation_metadata = (
            next_day_data.validation_metadata.drop(
                next_day_data.validation_metadata.index[val_indices]))
        next_day_data.validation_labels = np.delete(
            next_day_data.validation_labels, val_indices, axis=0)
        next_day_data.validation_features = np.delete(
            next_day_data.validation_features, val_indices, axis=0)

    if (len(next_day_data.test_metadata) > 0 and metadata_value
            in next_day_data.test_metadata[id_key_name].values):
        test_indices = list(next_day_data.test_metadata[
                                id_key_name].values).index(metadata_value)

        next_day_data.test_metadata = (
            next_day_data.test_metadata.drop(
                next_day_data.test_metadata.index[test_indices]))
        next_day_data.test_labels = np.delete(
            next_day_data.test_labels, test_indices, axis=0)
        next_day_data.test_features = np.delete(
            next_day_data.test_features, test_indices, axis=0)
    return next_day_data
예제 #4
0
def time_domain_loop(days: list,  output_diag_file: str,
                     output_queried_file: str,
                     path_to_features_dir: str, strategy: str,
                     batch=1, canonical = False,  classifier='RandomForest',
                     features_method='Bazin', path_to_canonical="",
                     path_to_full_lc_features="",
                     screen=True, training='original'):
    """Perform the active learning loop. All results are saved to file.

    Parameters
    ----------
    days: list
        List of 2 elements. First and last day of observations since the
        beginning of the survey.
    output_diag_file: str
        Full path to output file to store diagnostics of each loop.
    output_queried_file: str
        Full path to output file to store the queried sample.
    path_to_features_dir: str
        Complete path to directory holding features files for all days.
    strategy: str
        Query strategy. Options are 'UncSampling' and 'RandomSampling'.
    batch: int (optional)
        Size of batch to be queried in each loop. Default is 1.
    canonical: bool (optional)
        If True, restrict the search to the canonical sample.
    classifier: str (optional)
        Machine Learning algorithm.
        Currently only 'RandomForest' is implemented.
    features_method: str (optional)
        Feature extraction method. Currently only 'Bazin' is implemented.
    path_to_canonical: str (optional)
        Path to canonical sample features files.
        It is only used if "strategy==canonical".
    path_to_full_lc_features: str (optional)
        Path to full light curve features file.
        Only used if training is a number.
    screen: bool (optional)
        If True, print on screen number of light curves processed.
    training: str or int (optional)
        Choice of initial training sample.
        If 'original': begin from the train sample flagged in the file
        If int: choose the required number of samples at random,
        ensuring that at least half are SN Ia
        Default is 'original'.

    """

    ## This will need to change for RESSPECT

    # initiate object
    data = DataBase()

    # load features for the first day
    path_to_features = path_to_features_dir + 'day_' + str(int(days[0])) + '.dat'
    data.load_features(path_to_features, method=features_method,
                       screen=screen)

    # change training
    if training == 'original':
        data.build_samples(initial_training='original')
        full_lc_features = get_original_training(path_to_features=path_to_full_lc_features)
        data.train_metadata = full_lc_features.train_metadata
        data.train_labels = full_lc_features.train_labels
        data.train_features = full_lc_features.train_features

    else:
        data.build_samples(initial_training=int(training))

    # get list of canonical ids
    if canonical:
        canonical = DataBase()
        canonical.load_features(path_to_file=path_to_canonical)
        data.queryable_ids = canonical.queryable_ids


    for night in range(int(days[0]), int(days[-1]) - 1):

        if screen:
            print('Processing night: ', night)

        # cont loop
        loop = night - int(days[0])

        # classify
        data.classify(method=classifier)

        # calculate metrics
        data.evaluate_classification()

        # choose object to query
        indx = data.make_query(strategy=strategy, batch=batch)

        # update training and test samples
        data.update_samples(indx, loop=loop)

        # save diagnostics for current state
        data.save_metrics(loop=loop, output_metrics_file=output_diag_file,
                          batch=batch, epoch=night)

        # save query sample to file
        data.save_queried_sample(output_queried_file, loop=loop,
                                 full_sample=False)

        # load features for next day
        path_to_features2 = path_to_features_dir + 'day_' + str(night + 1) + '.dat'

        data_tomorrow = DataBase()
        data_tomorrow.load_features(path_to_features2, method=features_method,
                                    screen=False)

        # identify objects in the new day which must be in training
        train_flag = np.array([item in data.train_metadata['id'].values 
                              for item in data_tomorrow.metadata['id'].values])
   
        # use new data        
        data.train_metadata = data_tomorrow.metadata[train_flag]
        data.train_features = data_tomorrow.features.values[train_flag]
        data.test_metadata = data_tomorrow.metadata[~train_flag]
        data.test_features = data_tomorrow.features.values[~train_flag]

        # new labels
        data.train_labels = np.array([int(item  == 'Ia') for item in 
                                     data.train_metadata['type'].values])
        data.test_labels = np.array([int(item == 'Ia') for item in 
                                    data.test_metadata['type'].values])

        if strategy == 'canonical':
            data.queryable_ids = canonical.queryable_ids

        if  queryable:
            queryable_flag = data_tomorrow.metadata['queryable'].values
            queryable_test_flag = np.logical_and(~train_flag, queryable_flag)
            data.queryable_ids = data_tomorrow.metadata['id'].values[queryable_test_flag]
        else:
            data.queryable_ids = data_tomorrow.metadata['id'].values[~train_flag]

        if screen:
            print('Training set size: ', data.train_metadata.shape[0])
            print('Test set size: ', data.test_metadata.shape[0])