Example #1
0
def sample_matrix_store(sample_df, sample_metadata):
    with tempfile.TemporaryDirectory() as tempdir:
        project_storage = ProjectStorage(tempdir)
        store = project_storage.matrix_storage_engine().get_store("1234")
        store.matrix = sample_df
        store.metadata = sample_metadata
        return store
Example #2
0
def basic_integration_test(
    cohort_names,
    feature_group_create_rules,
    feature_group_mix_rules,
    expected_matrix_multiplier,
    expected_group_lists,
):
    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        Base.metadata.create_all(db_engine)
        populate_source_data(db_engine)

        with TemporaryDirectory() as temp_dir:
            chopper = Timechop(
                feature_start_time=datetime(2010, 1, 1),
                feature_end_time=datetime(2014, 1, 1),
                label_start_time=datetime(2011, 1, 1),
                label_end_time=datetime(2014, 1, 1),
                model_update_frequency="1year",
                training_label_timespans=["6months"],
                test_label_timespans=["6months"],
                training_as_of_date_frequencies="1day",
                test_as_of_date_frequencies="3months",
                max_training_histories=["1months"],
                test_durations=["1months"],
            )

            entity_date_table_generator = EntityDateTableGenerator(
                db_engine=db_engine,
                entity_date_table_name="cohort_abcd",
                query="select distinct(entity_id) from events")

            label_generator = LabelGenerator(
                db_engine=db_engine,
                query=sample_config()["label_config"]["query"])

            feature_generator = FeatureGenerator(
                db_engine=db_engine,
                features_schema_name="features",
                replace=True)

            feature_dictionary_creator = FeatureDictionaryCreator(
                db_engine=db_engine, features_schema_name="features")

            feature_group_creator = FeatureGroupCreator(
                feature_group_create_rules)

            feature_group_mixer = FeatureGroupMixer(feature_group_mix_rules)
            project_storage = ProjectStorage(temp_dir)
            planner = Planner(
                feature_start_time=datetime(2010, 1, 1),
                label_names=["outcome"],
                label_types=["binary"],
                cohort_names=cohort_names,
                user_metadata={},
            )

            builder = MatrixBuilder(
                engine=db_engine,
                db_config={
                    "features_schema_name": "features",
                    "labels_schema_name": "public",
                    "labels_table_name": "labels",
                    "cohort_table_name": "cohort_abcd",
                },
                experiment_hash=None,
                matrix_storage_engine=project_storage.matrix_storage_engine(),
                replace=True,
            )

            # chop time
            split_definitions = chopper.chop_time()
            num_split_matrices = sum(1 + len(split["test_matrices"])
                                     for split in split_definitions)

            # generate as_of_times for feature/label/state generation
            all_as_of_times = []
            for split in split_definitions:
                all_as_of_times.extend(split["train_matrix"]["as_of_times"])
                for test_matrix in split["test_matrices"]:
                    all_as_of_times.extend(test_matrix["as_of_times"])
            all_as_of_times = list(set(all_as_of_times))

            # generate entity_date state table
            entity_date_table_generator.generate_entity_date_table(
                as_of_dates=all_as_of_times)

            # create labels table
            label_generator.generate_all_labels(
                labels_table="labels",
                as_of_dates=all_as_of_times,
                label_timespans=["6months"],
            )

            # create feature table tasks
            # we would use FeatureGenerator#create_all_tables but want to use
            # the tasks dict directly to create a feature dict
            aggregations = feature_generator.aggregations(
                feature_aggregation_config=[
                    {
                        "prefix":
                        "cat",
                        "from_obj":
                        "cat_complaints",
                        "knowledge_date_column":
                        "as_of_date",
                        "aggregates": [{
                            "quantity": "cat_sightings",
                            "metrics": ["count", "avg"],
                            "imputation": {
                                "all": {
                                    "type": "mean"
                                }
                            },
                        }],
                        "intervals": ["1y"],
                        "groups": ["entity_id"],
                    },
                    {
                        "prefix":
                        "dog",
                        "from_obj":
                        "dog_complaints",
                        "knowledge_date_column":
                        "as_of_date",
                        "aggregates_imputation": {
                            "count": {
                                "type": "constant",
                                "value": 7
                            },
                            "sum": {
                                "type": "mean"
                            },
                            "avg": {
                                "type": "zero"
                            },
                        },
                        "aggregates": [{
                            "quantity": "dog_sightings",
                            "metrics": ["count", "avg"]
                        }],
                        "intervals": ["1y"],
                        "groups": ["entity_id"],
                    },
                ],
                feature_dates=all_as_of_times,
                state_table=entity_date_table_generator.entity_date_table_name,
            )
            feature_table_agg_tasks = feature_generator.generate_all_table_tasks(
                aggregations, task_type="aggregation")

            # create feature aggregation tables
            feature_generator.process_table_tasks(feature_table_agg_tasks)

            feature_table_imp_tasks = feature_generator.generate_all_table_tasks(
                aggregations, task_type="imputation")

            # create feature imputation tables
            feature_generator.process_table_tasks(feature_table_imp_tasks)

            # build feature dictionaries from feature tables and
            # subsetting config
            master_feature_dict = feature_dictionary_creator.feature_dictionary(
                feature_table_names=feature_table_imp_tasks.keys(),
                index_column_lookup=feature_generator.index_column_lookup(
                    aggregations),
            )

            feature_dicts = feature_group_mixer.generate(
                feature_group_creator.subsets(master_feature_dict))

            # figure out what matrices need to be built
            _, matrix_build_tasks = planner.generate_plans(
                split_definitions, feature_dicts)

            # go and build the matrices
            builder.build_all_matrices(matrix_build_tasks)

            # super basic assertion: did matrices we expect get created?
            matrices_records = list(
                db_engine.execute(
                    """select matrix_uuid, num_observations, matrix_type
                    from triage_metadata.matrices
                    """))
            matrix_directory = os.path.join(temp_dir, "matrices")
            matrices = [
                path for path in os.listdir(matrix_directory) if ".csv" in path
            ]
            metadatas = [
                path for path in os.listdir(matrix_directory)
                if ".yaml" in path
            ]
            assert len(matrices) == num_split_matrices * \
                expected_matrix_multiplier
            assert len(metadatas) == num_split_matrices * \
                expected_matrix_multiplier
            assert len(matrices) == len(matrices_records)
            feature_group_name_lists = []
            for metadata_path in metadatas:
                with open(os.path.join(matrix_directory, metadata_path)) as f:
                    metadata = yaml.full_load(f)
                    feature_group_name_lists.append(metadata["feature_groups"])

            for matrix_uuid, num_observations, matrix_type in matrices_records:
                assert matrix_uuid in matrix_build_tasks  # the hashes of the matrices
                assert type(num_observations) is int
                assert matrix_type == matrix_build_tasks[matrix_uuid][
                    "matrix_type"]

            def deep_unique_tuple(l):
                return set([tuple(i) for i in l])

            assert deep_unique_tuple(
                feature_group_name_lists) == deep_unique_tuple(
                    expected_group_lists)
Example #3
0
def add_predictions(db_engine,
                    model_groups,
                    project_path,
                    experiment_hashes=None,
                    train_end_times_range=None,
                    rank_order='worst',
                    replace=True):
    """ For a set of modl_groups generate test predictions and write to DB
        Args:
            db_engine: Sqlalchemy engine
            model_groups (list): The list of model group ids we are interested in (ideally, chosen through audition)
            project_path (str): Path where the created matrices and trained model objects are stored for the experiment
            experiment_hashes (List[str]): Optional. hash(es) of the experiments we are interested in. Can be used to narrow down the model_ids in the model groups specified
            range_train_end_times (Dict): Optional. If provided, only the models with train_end_times that fall in the range are scored. 
                                        This too, helps narrow down model_ids in the model groups specified.
                                        A dictionary with two possible keys 'range_start_date' and 'range_end_date'. Either or both could be set
            rank_order (str) : How to deal with ties in the scores. 
            replace (bool) : Whether to overwrite the preditctions for a model_id, if already found in the DB.

        Returns: None
            This directly writes to the test_results.predictions table
    """

    model_matrix_info = _fetch_relevant_model_matrix_info(
        db_engine=db_engine,
        model_groups=model_groups,
        experiment_hashes=experiment_hashes)

    # If we are only generating predictions for a specific time range
    if train_end_times_range is not None:
        if 'range_start_date' in train_end_times_range:
            range_start = train_end_times_range['range_start_date']
            msk = (model_matrix_info['train_end_time'] >= range_start)
            logging.info(
                'Filtering out models with a train_end_time before {}'.format(
                    range_start))

            model_matrix_info = model_matrix_info[msk]

        if 'range_end_date' in train_end_times_range:
            range_end = train_end_times_range['range_end_date']
            msk = (model_matrix_info['train_end_time'] <= range_end)
            logging.info(
                'Filtering out models with a train_end_time after {}'.format(
                    range_end))

            model_matrix_info = model_matrix_info[msk]

    if len(model_matrix_info) == 0:
        raise ValueError('Configis not valid. No models were found!')

    # Al the model groups specified in the config file should valid (even if the experiment_hashes and train_end_times are specified)
    not_fetched_model_grps = [
        x for x in model_groups
        if not x in model_matrix_info['model_group_id'].unique()
    ]

    if len(not_fetched_model_grps) > 0:
        raise ValueError(
            'The config is not valid. No models were found for the model group(s) {}. All specified model groups should be present'
            .format(not_fetched_model_grps))

    logging.info('Scoring {} model ids'.format(len(model_matrix_info)))

    # summary of the models that we are scoring. To check any special things worth noting
    _summary_of_models(model_matrix_info)

    logging.info('Instantiating storage engines and the predictor')

    # Storage objects to handle already stored models and matrices
    project_storage = ProjectStorage(project_path)
    model_storage_engine = project_storage.model_storage_engine()
    matrix_storage_engine = project_storage.matrix_storage_engine()

    # Prediction generation is handled by the Predictor class in catwalk
    predictor = Predictor(model_storage_engine=model_storage_engine,
                          db_engine=db_engine,
                          rank_order=rank_order,
                          replace=replace,
                          save_predictions=True)

    # Organizing prediction run over unique (train_mat, test_mat) pairs
    # This is to reduce no. the times the matrices get loaded to memory
    groupby_obj = model_matrix_info.groupby(
        ['train_matrix_uuid', 'test_matrix_uuid'])

    for group, _ in groupby_obj:
        train_uuid = group[0]
        test_uuid = group[1]

        df_grp = groupby_obj.get_group(group)

        logging.info(
            'Processing {} model_ids for train matrix {} and test matrix {}'.
            format(len(df_grp), train_uuid, test_uuid))

        train_matrix_store = matrix_storage_engine.get_store(
            matrix_uuid=train_uuid)

        # To ensure that the column order we use for predictions match the order we used in model training
        train_matrix_columns = list(train_matrix_store.design_matrix.columns)

        test_matrix_store = matrix_storage_engine.get_store(
            matrix_uuid=test_uuid)

        for model_id in df_grp['model_id'].tolist():
            logging.info(
                'Writing predictions for model_id {}'.format(model_id))
            predictor.predict(model_id=model_id,
                              matrix_store=test_matrix_store,
                              train_matrix_columns=train_matrix_columns,
                              misc_db_parameters={})

    logging.info('Successfully generated predictions for {} models!'.format(
        len(model_matrix_info)))
Example #4
0
def basic_integration_test(state_filters, feature_group_create_rules,
                           feature_group_mix_rules, expected_matrix_multiplier,
                           expected_group_lists):
    with testing.postgresql.Postgresql() as postgresql:
        db_engine = create_engine(postgresql.url())
        Base.metadata.create_all(db_engine)
        populate_source_data(db_engine)

        with TemporaryDirectory() as temp_dir:
            chopper = Timechop(
                feature_start_time=datetime(2010, 1, 1),
                feature_end_time=datetime(2014, 1, 1),
                label_start_time=datetime(2011, 1, 1),
                label_end_time=datetime(2014, 1, 1),
                model_update_frequency='1year',
                training_label_timespans=['6months'],
                test_label_timespans=['6months'],
                training_as_of_date_frequencies='1day',
                test_as_of_date_frequencies='3months',
                max_training_histories=['1months'],
                test_durations=['1months'],
            )

            state_table_generator = StateTableGeneratorFromDense(
                db_engine=db_engine,
                experiment_hash='abcd',
                dense_state_table='states',
            )

            label_generator = LabelGenerator(
                db_engine=db_engine,
                query=sample_config()['label_config']['query'])

            feature_generator = FeatureGenerator(
                db_engine=db_engine,
                features_schema_name='features',
                replace=True,
            )

            feature_dictionary_creator = FeatureDictionaryCreator(
                db_engine=db_engine, features_schema_name='features')

            feature_group_creator = FeatureGroupCreator(
                feature_group_create_rules)

            feature_group_mixer = FeatureGroupMixer(feature_group_mix_rules)
            project_storage = ProjectStorage(temp_dir)
            planner = Planner(
                feature_start_time=datetime(2010, 1, 1),
                label_names=['outcome'],
                label_types=['binary'],
                states=state_filters,
                user_metadata={},
            )

            builder = MatrixBuilder(
                engine=db_engine,
                db_config={
                    'features_schema_name': 'features',
                    'labels_schema_name': 'public',
                    'labels_table_name': 'labels',
                    'sparse_state_table_name': 'tmp_sparse_states_abcd',
                },
                matrix_storage_engine=project_storage.matrix_storage_engine(),
                replace=True)

            # chop time
            split_definitions = chopper.chop_time()
            num_split_matrices = sum(1 + len(split['test_matrices'])
                                     for split in split_definitions)

            # generate as_of_times for feature/label/state generation
            all_as_of_times = []
            for split in split_definitions:
                all_as_of_times.extend(split['train_matrix']['as_of_times'])
                for test_matrix in split['test_matrices']:
                    all_as_of_times.extend(test_matrix['as_of_times'])
            all_as_of_times = list(set(all_as_of_times))

            # generate sparse state table
            state_table_generator.generate_sparse_table(
                as_of_dates=all_as_of_times)

            # create labels table
            label_generator.generate_all_labels(labels_table='labels',
                                                as_of_dates=all_as_of_times,
                                                label_timespans=['6months'])

            # create feature table tasks
            # we would use FeatureGenerator#create_all_tables but want to use
            # the tasks dict directly to create a feature dict
            aggregations = feature_generator.aggregations(
                feature_aggregation_config=[{
                    'prefix':
                    'cat',
                    'from_obj':
                    'cat_complaints',
                    'knowledge_date_column':
                    'as_of_date',
                    'aggregates': [{
                        'quantity': 'cat_sightings',
                        'metrics': ['count', 'avg'],
                        'imputation': {
                            'all': {
                                'type': 'mean'
                            }
                        }
                    }],
                    'intervals': ['1y'],
                    'groups': ['entity_id']
                }, {
                    'prefix':
                    'dog',
                    'from_obj':
                    'dog_complaints',
                    'knowledge_date_column':
                    'as_of_date',
                    'aggregates_imputation': {
                        'count': {
                            'type': 'constant',
                            'value': 7
                        },
                        'sum': {
                            'type': 'mean'
                        },
                        'avg': {
                            'type': 'zero'
                        }
                    },
                    'aggregates': [{
                        'quantity': 'dog_sightings',
                        'metrics': ['count', 'avg'],
                    }],
                    'intervals': ['1y'],
                    'groups': ['entity_id']
                }],
                feature_dates=all_as_of_times,
                state_table=state_table_generator.sparse_table_name)
            feature_table_agg_tasks = feature_generator.generate_all_table_tasks(
                aggregations, task_type='aggregation')

            # create feature aggregation tables
            feature_generator.process_table_tasks(feature_table_agg_tasks)

            feature_table_imp_tasks = feature_generator.generate_all_table_tasks(
                aggregations, task_type='imputation')

            # create feature imputation tables
            feature_generator.process_table_tasks(feature_table_imp_tasks)

            # build feature dictionaries from feature tables and
            # subsetting config
            master_feature_dict = feature_dictionary_creator.feature_dictionary(
                feature_table_names=feature_table_imp_tasks.keys(),
                index_column_lookup=feature_generator.index_column_lookup(
                    aggregations))

            feature_dicts = feature_group_mixer.generate(
                feature_group_creator.subsets(master_feature_dict))

            # figure out what matrices need to be built
            _, matrix_build_tasks =\
                planner.generate_plans(
                    split_definitions,
                    feature_dicts
                )

            # go and build the matrices
            builder.build_all_matrices(matrix_build_tasks)

            # super basic assertion: did matrices we expect get created?
            matrices_records = list(
                db_engine.execute(
                    '''select matrix_uuid, num_observations, matrix_type
                    from model_metadata.matrices
                    '''))
            matrix_directory = os.path.join(temp_dir, 'matrices')
            matrices = [
                path for path in os.listdir(matrix_directory) if '.csv' in path
            ]
            metadatas = [
                path for path in os.listdir(matrix_directory)
                if '.yaml' in path
            ]
            assert len(matrices) == num_split_matrices * \
                expected_matrix_multiplier
            assert len(metadatas) == num_split_matrices * \
                expected_matrix_multiplier
            assert len(matrices) == len(matrices_records)
            feature_group_name_lists = []
            for metadata_path in metadatas:
                with open(os.path.join(matrix_directory, metadata_path)) as f:
                    metadata = yaml.load(f)
                    feature_group_name_lists.append(metadata['feature_groups'])

            for matrix_uuid, num_observations, matrix_type in matrices_records:
                assert matrix_uuid in matrix_build_tasks  #the hashes of the matrices
                assert type(num_observations) is int
                assert matrix_type == matrix_build_tasks[matrix_uuid][
                    'matrix_type']

            def deep_unique_tuple(l):
                return set([tuple(i) for i in l])

            assert deep_unique_tuple(
                feature_group_name_lists) == deep_unique_tuple(
                    expected_group_lists)
Example #5
0
def predict_forward_with_existed_model(db_engine, project_path, model_id,
                                       as_of_date):
    """Predict forward given model_id and as_of_date and store the prediction in database

    Args:
            db_engine (sqlalchemy.db.engine)
            project_storage (catwalk.storage.ProjectStorage)
            model_id (int) The id of a given model in the database
            as_of_date (string) a date string like "YYYY-MM-DD"
    """
    logger.spam("In PREDICT LIST................")
    upgrade_db(db_engine=db_engine)
    project_storage = ProjectStorage(project_path)
    matrix_storage_engine = project_storage.matrix_storage_engine()
    # 1. Get feature and cohort config from database
    (train_matrix_uuid,
     matrix_metadata) = train_matrix_info_from_model_id(db_engine, model_id)
    experiment_config = experiment_config_from_model_id(db_engine, model_id)

    # 2. Generate cohort
    cohort_table_name = f"triage_production.cohort_{experiment_config['cohort_config']['name']}"
    cohort_table_generator = EntityDateTableGenerator(
        db_engine=db_engine,
        query=experiment_config['cohort_config']['query'],
        entity_date_table_name=cohort_table_name)
    cohort_table_generator.generate_entity_date_table(
        as_of_dates=[dt_from_str(as_of_date)])

    # 3. Generate feature aggregations
    feature_generator = FeatureGenerator(
        db_engine=db_engine,
        features_schema_name="triage_production",
        feature_start_time=experiment_config['temporal_config']
        ['feature_start_time'],
    )
    collate_aggregations = feature_generator.aggregations(
        feature_aggregation_config=experiment_config['feature_aggregations'],
        feature_dates=[as_of_date],
        state_table=cohort_table_name)
    feature_generator.process_table_tasks(
        feature_generator.generate_all_table_tasks(collate_aggregations,
                                                   task_type='aggregation'))

    # 4. Reconstruct feature disctionary from feature_names and generate imputation

    reconstructed_feature_dict = FeatureGroup()
    imputation_table_tasks = OrderedDict()

    for aggregation in collate_aggregations:
        feature_group, feature_names = get_feature_names(
            aggregation, matrix_metadata)
        reconstructed_feature_dict[feature_group] = feature_names

        # Make sure that the features imputed in training should also be imputed in production

        features_imputed_in_train = get_feature_needs_imputation_in_train(
            aggregation, feature_names)

        features_imputed_in_production = get_feature_needs_imputation_in_production(
            aggregation, db_engine)

        total_impute_cols = set(features_imputed_in_production) | set(
            features_imputed_in_train)
        total_nonimpute_cols = set(f for f in set(feature_names)
                                   if '_imp' not in f) - total_impute_cols

        task_generator = feature_generator._generate_imp_table_tasks_for

        imputation_table_tasks.update(
            task_generator(aggregation,
                           impute_cols=list(total_impute_cols),
                           nonimpute_cols=list(total_nonimpute_cols)))
    feature_generator.process_table_tasks(imputation_table_tasks)

    # 5. Build matrix
    db_config = {
        "features_schema_name": "triage_production",
        "labels_schema_name": "public",
        "cohort_table_name": cohort_table_name,
    }

    matrix_builder = MatrixBuilder(
        db_config=db_config,
        matrix_storage_engine=matrix_storage_engine,
        engine=db_engine,
        experiment_hash=None,
        replace=True,
    )

    feature_start_time = experiment_config['temporal_config'][
        'feature_start_time']
    label_name = experiment_config['label_config']['name']
    label_type = 'binary'
    cohort_name = experiment_config['cohort_config']['name']
    user_metadata = experiment_config['user_metadata']

    # Use timechop to get the time definition for production
    temporal_config = experiment_config["temporal_config"]
    temporal_config.update(
        temporal_params_from_matrix_metadata(db_engine, model_id))
    timechopper = Timechop(**temporal_config)
    prod_definitions = timechopper.define_test_matrices(
        train_test_split_time=dt_from_str(as_of_date),
        test_duration=temporal_config['test_durations'][0],
        test_label_timespan=temporal_config['test_label_timespans'][0])

    matrix_metadata = Planner.make_metadata(
        prod_definitions[-1],
        reconstructed_feature_dict,
        label_name,
        label_type,
        cohort_name,
        'production',
        feature_start_time,
        user_metadata,
    )

    matrix_metadata['matrix_id'] = str(
        as_of_date) + f'_model_id_{model_id}' + '_risklist'

    matrix_uuid = filename_friendly_hash(matrix_metadata)

    matrix_builder.build_matrix(
        as_of_times=[as_of_date],
        label_name=label_name,
        label_type=label_type,
        feature_dictionary=reconstructed_feature_dict,
        matrix_metadata=matrix_metadata,
        matrix_uuid=matrix_uuid,
        matrix_type="production",
    )

    # 6. Predict the risk score for production
    predictor = Predictor(
        model_storage_engine=project_storage.model_storage_engine(),
        db_engine=db_engine,
        rank_order='best')

    predictor.predict(
        model_id=model_id,
        matrix_store=matrix_storage_engine.get_store(matrix_uuid),
        misc_db_parameters={},
        train_matrix_columns=matrix_storage_engine.get_store(
            train_matrix_uuid).columns())
Example #6
0
class Retrainer:
    """Given a model_group_id and prediction_date, retrain a model using the all the data till prediction_date 
    Args:
        db_engine (sqlalchemy.engine)
        project_path (string)
        model_group_id (string)
    """
    def __init__(self, db_engine, project_path, model_group_id):
        self.retrain_hash = None
        self.db_engine = db_engine
        upgrade_db(db_engine=self.db_engine)
        self.project_storage = ProjectStorage(project_path)
        self.model_group_id = model_group_id
        self.model_group_info = get_model_group_info(self.db_engine,
                                                     self.model_group_id)
        self.matrix_storage_engine = self.project_storage.matrix_storage_engine(
        )
        self.triage_run_id, self.experiment_config = experiment_config_from_model_group_id(
            self.db_engine, self.model_group_id)

        # This feels like it needs some refactoring since in some edge cases at least the test matrix temporal parameters
        # might differ across models in the mdoel group (the training ones shouldn't), but this should probably work for
        # the vast majorty of use cases...
        self.experiment_config['temporal_config'].update(
            temporal_params_from_matrix_metadata(
                self.db_engine, self.model_group_info['model_id_last_split']))

        # Since "testing" here is predicting forward to a single new date, the test_duration should always be '0day'
        # (regardless of what it may have been before)
        self.experiment_config['temporal_config']['test_durations'] = ['0day']

        # These lists should now only contain one item (the value actually used for the last model in this group)
        self.training_label_timespan = self.experiment_config[
            'temporal_config']['training_label_timespans'][0]
        self.test_label_timespan = self.experiment_config['temporal_config'][
            'test_label_timespans'][0]
        self.test_duration = self.experiment_config['temporal_config'][
            'test_durations'][0]
        self.feature_start_time = self.experiment_config['temporal_config'][
            'feature_start_time']

        self.label_name = self.experiment_config['label_config']['name']
        self.cohort_name = self.experiment_config['cohort_config']['name']
        self.user_metadata = self.experiment_config['user_metadata']

        self.feature_dictionary_creator = FeatureDictionaryCreator(
            features_schema_name='triage_production', db_engine=self.db_engine)
        self.label_generator = LabelGenerator(
            label_name=self.experiment_config['label_config'].get(
                "name", None),
            query=self.experiment_config['label_config']["query"],
            replace=True,
            db_engine=self.db_engine,
        )

        self.labels_table_name = "labels_{}_{}_production".format(
            self.experiment_config['label_config'].get('name', 'default'),
            filename_friendly_hash(
                self.experiment_config['label_config']['query']))

        self.feature_generator = FeatureGenerator(
            db_engine=self.db_engine,
            features_schema_name="triage_production",
            feature_start_time=self.feature_start_time,
        )

        self.model_trainer = ModelTrainer(
            experiment_hash=None,
            model_storage_engine=ModelStorageEngine(self.project_storage),
            db_engine=self.db_engine,
            replace=True,
            run_id=self.triage_run_id,
        )

    def get_temporal_config_for_retrain(self, prediction_date):
        temporal_config = self.experiment_config['temporal_config'].copy()
        temporal_config['feature_end_time'] = datetime.strftime(
            prediction_date, "%Y-%m-%d")
        temporal_config['label_end_time'] = datetime.strftime(
            prediction_date +
            convert_str_to_relativedelta(self.test_label_timespan), "%Y-%m-%d")
        # just needs to be bigger than the gap between the label start and end times
        # to ensure we only get one time split for the retraining
        temporal_config['model_update_frequency'] = '%syears' % (
            dt_from_str(temporal_config['label_end_time']).year -
            dt_from_str(temporal_config['label_start_time']).year + 10)

        return temporal_config

    def generate_all_labels(self, as_of_date):
        self.label_generator.generate_all_labels(
            labels_table=self.labels_table_name,
            as_of_dates=[as_of_date],
            label_timespans=[self.training_label_timespan])

    def generate_entity_date_table(self, as_of_date, entity_date_table_name):
        cohort_table_generator = EntityDateTableGenerator(
            db_engine=self.db_engine,
            query=self.experiment_config['cohort_config']['query'],
            entity_date_table_name=entity_date_table_name)
        cohort_table_generator.generate_entity_date_table(
            as_of_dates=[dt_from_str(as_of_date)])

    def get_collate_aggregations(self, as_of_date, state_table):
        collate_aggregations = self.feature_generator.aggregations(
            feature_aggregation_config=self.
            experiment_config['feature_aggregations'],
            feature_dates=[as_of_date],
            state_table=state_table)
        return collate_aggregations

    def get_feature_dict_and_imputation_task(self, collate_aggregations,
                                             model_id):
        (train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(
            self.db_engine, model_id)
        reconstructed_feature_dict = FeatureGroup()
        imputation_table_tasks = OrderedDict()
        for aggregation in collate_aggregations:
            feature_group, feature_names = get_feature_names(
                aggregation, matrix_metadata)
            reconstructed_feature_dict[feature_group] = feature_names
            # Make sure that the features imputed in training should also be imputed in production

            features_imputed_in_train = get_feature_needs_imputation_in_train(
                aggregation, feature_names)

            features_imputed_in_production = get_feature_needs_imputation_in_production(
                aggregation, self.db_engine)

            total_impute_cols = set(features_imputed_in_production) | set(
                features_imputed_in_train)
            total_nonimpute_cols = set(f for f in set(feature_names)
                                       if '_imp' not in f) - total_impute_cols

            task_generator = self.feature_generator._generate_imp_table_tasks_for

            imputation_table_tasks.update(
                task_generator(aggregation,
                               impute_cols=list(total_impute_cols),
                               nonimpute_cols=list(total_nonimpute_cols)))
        return reconstructed_feature_dict, imputation_table_tasks

    def retrain(self, prediction_date):
        """Retrain a model by going back one split from prediction_date, so the as_of_date for training would be (prediction_date - training_label_timespan)
        
        Args:
            prediction_date(str) 
        """
        # Retrain config and hash
        retrain_config = {
            "model_group_id": self.model_group_id,
            "prediction_date": prediction_date,
            "test_label_timespan": self.test_label_timespan,
            "test_duration": self.test_duration,
        }
        self.retrain_hash = save_retrain_and_get_hash(retrain_config,
                                                      self.db_engine)

        with get_for_update(self.db_engine, Retrain,
                            self.retrain_hash) as retrain:
            retrain.prediction_date = prediction_date

        # Timechop
        prediction_date = dt_from_str(prediction_date)
        temporal_config = self.get_temporal_config_for_retrain(prediction_date)
        timechopper = Timechop(**temporal_config)
        chops = timechopper.chop_time()
        assert len(chops) == 1
        chops_train_matrix = chops[0]['train_matrix']
        as_of_date = datetime.strftime(chops_train_matrix['last_as_of_time'],
                                       "%Y-%m-%d")
        retrain_definition = {
            'first_as_of_time':
            chops_train_matrix['first_as_of_time'],
            'last_as_of_time':
            chops_train_matrix['last_as_of_time'],
            'matrix_info_end_time':
            chops_train_matrix['matrix_info_end_time'],
            'as_of_times': [as_of_date],
            'training_label_timespan':
            chops_train_matrix['training_label_timespan'],
            'max_training_history':
            chops_train_matrix['max_training_history'],
            'training_as_of_date_frequency':
            chops_train_matrix['training_as_of_date_frequency'],
        }

        # Set ExperimentRun
        run = TriageRun(
            start_time=datetime.now(),
            git_hash=infer_git_hash(),
            triage_version=infer_triage_version(),
            python_version=infer_python_version(),
            run_type="retrain",
            run_hash=self.retrain_hash,
            last_updated_time=datetime.now(),
            current_status=TriageRunStatus.started,
            installed_libraries=infer_installed_libraries(),
            platform=platform.platform(),
            os_user=getpass.getuser(),
            working_directory=os.getcwd(),
            ec2_instance_type=infer_ec2_instance_type(),
            log_location=infer_log_location(),
            experiment_class_path=classpath(self.__class__),
            random_seed=retrieve_experiment_seed_from_run_id(
                self.db_engine, self.triage_run_id),
        )
        run_id = None
        with scoped_session(self.db_engine) as session:
            session.add(run)
            session.commit()
            run_id = run.run_id
        if not run_id:
            raise ValueError("Failed to retrieve run_id from saved row")

        # set ModelTrainer's run_id and experiment_hash for Retrain run
        self.model_trainer.run_id = run_id
        self.model_trainer.experiment_hash = self.retrain_hash

        # 1. Generate all labels
        self.generate_all_labels(as_of_date)
        record_labels_table_name(run_id, self.db_engine,
                                 self.labels_table_name)

        # 2. Generate cohort
        cohort_table_name = f"triage_production.cohort_{self.experiment_config['cohort_config']['name']}_retrain"
        self.generate_entity_date_table(as_of_date, cohort_table_name)
        record_cohort_table_name(run_id, self.db_engine, cohort_table_name)

        # 3. Generate feature aggregations
        collate_aggregations = self.get_collate_aggregations(
            as_of_date, cohort_table_name)
        feature_aggregation_table_tasks = self.feature_generator.generate_all_table_tasks(
            collate_aggregations, task_type='aggregation')
        self.feature_generator.process_table_tasks(
            feature_aggregation_table_tasks)

        # 4. Reconstruct feature disctionary from feature_names and generate imputation
        reconstructed_feature_dict, imputation_table_tasks = self.get_feature_dict_and_imputation_task(
            collate_aggregations,
            self.model_group_info['model_id_last_split'],
        )
        feature_group_creator = FeatureGroupCreator(
            self.experiment_config['feature_group_definition'])
        feature_group_mixer = FeatureGroupMixer(["all"])
        feature_group_dict = feature_group_mixer.generate(
            feature_group_creator.subsets(reconstructed_feature_dict))[0]
        self.feature_generator.process_table_tasks(imputation_table_tasks)
        # 5. Build new matrix
        db_config = {
            "features_schema_name": "triage_production",
            "labels_schema_name": "public",
            "cohort_table_name": cohort_table_name,
            "labels_table_name": self.labels_table_name,
        }

        record_matrix_building_started(run_id, self.db_engine)
        matrix_builder = MatrixBuilder(
            db_config=db_config,
            matrix_storage_engine=self.matrix_storage_engine,
            engine=self.db_engine,
            experiment_hash=None,
            replace=True,
        )
        new_matrix_metadata = Planner.make_metadata(
            matrix_definition=retrain_definition,
            feature_dictionary=feature_group_dict,
            label_name=self.label_name,
            label_type='binary',
            cohort_name=self.cohort_name,
            matrix_type='train',
            feature_start_time=dt_from_str(self.feature_start_time),
            user_metadata=self.user_metadata,
        )

        new_matrix_metadata['matrix_id'] = "_".join([
            self.label_name,
            'binary',
            str(as_of_date),
            'retrain',
        ])

        matrix_uuid = filename_friendly_hash(new_matrix_metadata)
        matrix_builder.build_matrix(
            as_of_times=[as_of_date],
            label_name=self.label_name,
            label_type='binary',
            feature_dictionary=feature_group_dict,
            matrix_metadata=new_matrix_metadata,
            matrix_uuid=matrix_uuid,
            matrix_type="train",
        )
        retrain_model_comment = 'retrain_' + str(datetime.now())

        misc_db_parameters = {
            'train_end_time': dt_from_str(as_of_date),
            'test': False,
            'train_matrix_uuid': matrix_uuid,
            'training_label_timespan': self.training_label_timespan,
            'model_comment': retrain_model_comment,
        }

        # get the random seed from the last split
        last_split_train_matrix_uuid, last_split_matrix_metadata = train_matrix_info_from_model_id(
            self.db_engine,
            model_id=self.model_group_info['model_id_last_split'])

        random_seed = self.model_trainer.get_or_generate_random_seed(
            model_group_id=self.model_group_id,
            matrix_metadata=last_split_matrix_metadata,
            train_matrix_uuid=last_split_train_matrix_uuid)

        # create retrain model hash
        retrain_model_hash = self.model_trainer._model_hash(
            self.matrix_storage_engine.get_store(matrix_uuid).metadata,
            class_path=self.model_group_info['model_type'],
            parameters=self.model_group_info['hyperparameters'],
            random_seed=random_seed,
        )

        associate_models_with_retrain(self.retrain_hash,
                                      (retrain_model_hash, ), self.db_engine)

        record_model_building_started(run_id, self.db_engine)
        retrain_model_id = self.model_trainer.process_train_task(
            matrix_store=self.matrix_storage_engine.get_store(matrix_uuid),
            class_path=self.model_group_info['model_type'],
            parameters=self.model_group_info['hyperparameters'],
            model_hash=retrain_model_hash,
            misc_db_parameters=misc_db_parameters,
            random_seed=random_seed,
            retrain=True,
            model_group_id=self.model_group_id)

        self.retrain_model_hash = retrieve_model_hash_from_id(
            self.db_engine, retrain_model_id)
        self.retrain_matrix_uuid = matrix_uuid
        self.retrain_model_id = retrain_model_id
        return {
            'retrain_model_comment': retrain_model_comment,
            'retrain_model_id': retrain_model_id
        }

    def predict(self, prediction_date):
        """Predict forward by creating a matrix using as_of_date = prediction_date and applying the retrain model on it

        Args:
            prediction_date(str)
        """
        cohort_table_name = f"triage_production.cohort_{self.experiment_config['cohort_config']['name']}_predict"

        # 1. Generate cohort
        self.generate_entity_date_table(prediction_date, cohort_table_name)

        # 2. Generate feature aggregations
        collate_aggregations = self.get_collate_aggregations(
            prediction_date, cohort_table_name)
        self.feature_generator.process_table_tasks(
            self.feature_generator.generate_all_table_tasks(
                collate_aggregations, task_type='aggregation'))
        # 3. Reconstruct feature disctionary from feature_names and generate imputation
        reconstructed_feature_dict, imputation_table_tasks = self.get_feature_dict_and_imputation_task(
            collate_aggregations, self.retrain_model_id)
        self.feature_generator.process_table_tasks(imputation_table_tasks)

        # 4. Build matrix
        db_config = {
            "features_schema_name": "triage_production",
            "labels_schema_name": "public",
            "cohort_table_name": cohort_table_name,
        }

        matrix_builder = MatrixBuilder(
            db_config=db_config,
            matrix_storage_engine=self.matrix_storage_engine,
            engine=self.db_engine,
            experiment_hash=None,
            replace=True,
        )
        # Use timechop to get the time definition for production
        temporal_config = self.get_temporal_config_for_retrain(
            dt_from_str(prediction_date))
        timechopper = Timechop(**temporal_config)

        retrain_config = get_retrain_config_from_model_id(
            self.db_engine, self.retrain_model_id)

        prod_definitions = timechopper.define_test_matrices(
            train_test_split_time=dt_from_str(prediction_date),
            test_duration=retrain_config['test_duration'],
            test_label_timespan=retrain_config['test_label_timespan'])
        last_split_definition = prod_definitions[-1]
        matrix_metadata = Planner.make_metadata(
            matrix_definition=last_split_definition,
            feature_dictionary=reconstructed_feature_dict,
            label_name=self.label_name,
            label_type='binary',
            cohort_name=self.cohort_name,
            matrix_type='production',
            feature_start_time=self.feature_start_time,
            user_metadata=self.user_metadata,
        )

        matrix_metadata['matrix_id'] = str(
            prediction_date
        ) + f'_model_id_{self.retrain_model_id}' + '_risklist'

        matrix_uuid = filename_friendly_hash(matrix_metadata)

        matrix_builder.build_matrix(
            as_of_times=[prediction_date],
            label_name=self.label_name,
            label_type='binary',
            feature_dictionary=reconstructed_feature_dict,
            matrix_metadata=matrix_metadata,
            matrix_uuid=matrix_uuid,
            matrix_type="production",
        )

        # 5. Predict the risk score for production
        predictor = Predictor(
            model_storage_engine=self.project_storage.model_storage_engine(),
            db_engine=self.db_engine,
            rank_order='best')

        predictor.predict(
            model_id=self.retrain_model_id,
            matrix_store=self.matrix_storage_engine.get_store(matrix_uuid),
            misc_db_parameters={},
            train_matrix_columns=self.matrix_storage_engine.get_store(
                self.retrain_matrix_uuid).columns(),
        )
        self.predict_matrix_uuid = matrix_uuid