def filter_same_train_end_times(self, engine): ensure_db(engine) init_engine(engine) mg1 = ModelGroupFactory(model_group_id=1, model_type='modelType1') mg2 = ModelGroupFactory(model_group_id=2, model_type='modelType2') mg3 = ModelGroupFactory(model_group_id=3, model_type='modelType3') mg4 = ModelGroupFactory(model_group_id=4, model_type='modelType4') # model group 1 ModelFactory(model_group_rel=mg1, train_end_time=datetime(2014, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2016, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2017, 1, 1)) # model group 2 only has three timestamps, should not pass ModelFactory(model_group_rel=mg2, train_end_time=datetime(2014, 1, 1)) # model group 3 ModelFactory(model_group_rel=mg3, train_end_time=datetime(2014, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2016, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2017, 1, 1)) # model group 4 only has three timestamps, should not pass ModelFactory(model_group_rel=mg4, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg4, train_end_time=datetime(2016, 1, 1)) session.commit() train_end_times = [ '2014-01-01', '2015-01-01', '2016-01-01', '2017-01-01' ] model_groups = [1, 2, 3, 4] model_group_ids = model_groups_filter( train_end_times=train_end_times, initial_model_group_ids=model_groups, models_table='models', db_engine=engine) return model_group_ids
def test_Audition(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) ensure_db(db_engine) init_engine(db_engine) num_model_groups = 10 model_types = [ "classifier type {}".format(i) for i in range(0, num_model_groups) ] model_groups = [ ModelGroupFactory(model_type=model_type) for model_type in model_types ] train_end_times = [ datetime(2013, 1, 1), datetime(2014, 1, 1), datetime(2015, 1, 1), datetime(2016, 1, 1), ] models = [ ModelFactory(model_group_rel=model_group, train_end_time=train_end_time) for model_group in model_groups for train_end_time in train_end_times ] metrics = [ ("precision@", "100_abs"), ("recall@", "100_abs"), ("precision@", "50_abs"), ("recall@", "50_abs"), ("fpr@", "10_pct"), ] class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) for model in models: for (metric, parameter) in metrics: ImmediateEvalFactory(model_rel=model, metric=metric, parameter=parameter) session.commit() with tempfile.TemporaryDirectory() as td: with mock.patch('os.getcwd') as mock_getcwd: mock_getcwd.return_value = td AuditionRunner(config_dict=config, db_engine=db_engine, directory=td).run() assert len(os.listdir(os.getcwd())) == 6
def filter_train_end_times(self, engine, train_end_times): ensure_db(engine) init_engine(engine) mg1 = ModelGroupFactory(model_group_id=1, model_type="modelType1") mg2 = ModelGroupFactory(model_group_id=2, model_type="modelType2") mg3 = ModelGroupFactory(model_group_id=3, model_type="modelType3") mg4 = ModelGroupFactory(model_group_id=4, model_type="modelType4") mg5 = ModelGroupFactory(model_group_id=5, model_type="modelType5") # model group 1 ModelFactory(model_group_rel=mg1, train_end_time=datetime(2014, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2016, 1, 1)) ModelFactory(model_group_rel=mg1, train_end_time=datetime(2017, 1, 1)) # model group 2 only has one timestamps ModelFactory(model_group_rel=mg2, train_end_time=datetime(2014, 1, 1)) # model group 3 ModelFactory(model_group_rel=mg3, train_end_time=datetime(2014, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2016, 1, 1)) ModelFactory(model_group_rel=mg3, train_end_time=datetime(2017, 1, 1)) # model group 4 only has two timestamps ModelFactory(model_group_rel=mg4, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg4, train_end_time=datetime(2016, 1, 1)) # model group 5 only has three timestamps ModelFactory(model_group_rel=mg5, train_end_time=datetime(2014, 1, 1)) ModelFactory(model_group_rel=mg5, train_end_time=datetime(2015, 1, 1)) ModelFactory(model_group_rel=mg5, train_end_time=datetime(2016, 1, 1)) session.commit() model_groups = [1, 2, 3, 4, 5] model_group_ids = model_groups_filter( train_end_times=train_end_times, initial_model_group_ids=model_groups, models_table="models", db_engine=engine, ) return model_group_ids
def create_sample_distance_table(engine): ensure_db(engine) init_engine(engine) model_groups = { 'stable': ModelGroupFactory(model_type='myStableClassifier'), 'spiky': ModelGroupFactory(model_type='mySpikeClassifier'), } class StableModelFactory(ModelFactory): model_group_rel = model_groups['stable'] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups['spiky'] models = { 'stable_3y_ago': StableModelFactory(train_end_time='2014-01-01'), 'stable_2y_ago': StableModelFactory(train_end_time='2015-01-01'), 'stable_1y_ago': StableModelFactory(train_end_time='2016-01-01'), 'spiky_3y_ago': SpikyModelFactory(train_end_time='2014-01-01'), 'spiky_2y_ago': SpikyModelFactory(train_end_time='2015-01-01'), 'spiky_1y_ago': SpikyModelFactory(train_end_time='2016-01-01'), } session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') distance_table._create() stable_grp = model_groups['stable'].model_group_id spiky_grp = model_groups['spiky'].model_group_id stable_3y_id = models['stable_3y_ago'].model_id stable_3y_end = models['stable_3y_ago'].train_end_time stable_2y_id = models['stable_2y_ago'].model_id stable_2y_end = models['stable_2y_ago'].train_end_time stable_1y_id = models['stable_1y_ago'].model_id stable_1y_end = models['stable_1y_ago'].train_end_time spiky_3y_id = models['spiky_3y_ago'].model_id spiky_3y_end = models['spiky_3y_ago'].train_end_time spiky_2y_id = models['spiky_2y_ago'].model_id spiky_2y_end = models['spiky_2y_ago'].train_end_time spiky_1y_id = models['spiky_1y_ago'].model_id spiky_1y_end = models['spiky_1y_ago'].train_end_time distance_rows = [ (stable_grp, stable_3y_id, stable_3y_end, 'precision@', '100_abs', 0.5, 0.6, 0.1, 0.5, 0.15), (stable_grp, stable_2y_id, stable_2y_end, 'precision@', '100_abs', 0.5, 0.84, 0.34, 0.5, 0.18), (stable_grp, stable_1y_id, stable_1y_end, 'precision@', '100_abs', 0.46, 0.67, 0.21, 0.5, 0.11), (spiky_grp, spiky_3y_id, spiky_3y_end, 'precision@', '100_abs', 0.45, 0.6, 0.15, 0.5, 0.19), (spiky_grp, spiky_2y_id, spiky_2y_end, 'precision@', '100_abs', 0.84, 0.84, 0.0, 0.5, 0.3), (spiky_grp, spiky_1y_id, spiky_1y_end, 'precision@', '100_abs', 0.45, 0.67, 0.22, 0.5, 0.12), (stable_grp, stable_3y_id, stable_3y_end, 'recall@', '100_abs', 0.4, 0.4, 0.0, 0.4, 0.0), (stable_grp, stable_2y_id, stable_2y_end, 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.5, 0.0), (stable_grp, stable_1y_id, stable_1y_end, 'recall@', '100_abs', 0.6, 0.6, 0.0, 0.6, 0.0), (spiky_grp, spiky_3y_id, spiky_3y_end, 'recall@', '100_abs', 0.65, 0.65, 0.0, 0.65, 0.0), (spiky_grp, spiky_2y_id, spiky_2y_end, 'recall@', '100_abs', 0.55, 0.55, 0.0, 0.55, 0.0), (spiky_grp, spiky_1y_id, spiky_1y_end, 'recall@', '100_abs', 0.45, 0.45, 0.0, 0.45, 0.0), ] for dist_row in distance_rows: engine.execute( 'insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', dist_row) return distance_table, model_groups
def test_Auditioner(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) ensure_db(db_engine) init_engine(db_engine) # set up data, randomly generated by the factories but conforming # generally to what we expect triage_metadata schema data to look like num_model_groups = 10 model_types = [ "classifier type {}".format(i) for i in range(0, num_model_groups) ] model_groups = [ ModelGroupFactory(model_type=model_type) for model_type in model_types ] train_end_times = [ datetime(2013, 1, 1), datetime(2014, 1, 1), datetime(2015, 1, 1), datetime(2016, 1, 1), ] models = [ ModelFactory(model_group_rel=model_group, train_end_time=train_end_time) for model_group in model_groups for train_end_time in train_end_times ] metrics = [ ("precision@", "100_abs"), ("recall@", "100_abs"), ("precision@", "50_abs"), ("recall@", "50_abs"), ("fpr@", "10_pct"), ] class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) for model in models: for (metric, parameter) in metrics: ImmediateEvalFactory(model_rel=model, metric=metric, parameter=parameter) session.commit() # define a very loose filtering that should admit all model groups no_filtering = [ { "metric": "precision@", "parameter": "100_abs", "max_from_best": 1.0, "threshold_value": 0.0, }, { "metric": "recall@", "parameter": "100_abs", "max_from_best": 1.0, "threshold_value": 0.0, }, ] model_group_ids = [mg.model_group_id for mg in model_groups] auditioner = Auditioner(db_engine, model_group_ids, train_end_times, no_filtering) assert len(auditioner.thresholded_model_group_ids) == num_model_groups auditioner.plot_model_groups() # here, we pick thresholding rules that should definitely remove # all model groups from contention because they are too strict. remove_all = [ { "metric": "precision@", "parameter": "100_abs", "max_from_best": 0.0, "threshold_value": 1.1, }, { "metric": "recall@", "parameter": "100_abs", "max_from_best": 0.0, "threshold_value": 1.1, }, ] auditioner.update_metric_filters(new_filters=remove_all) assert len(auditioner.thresholded_model_group_ids) == 0 # pass the argument instead and remove all model groups auditioner.set_one_metric_filter( metric="precision@", parameter="100_abs", max_from_best=0.0, threshold_value=1.1, ) assert len(auditioner.thresholded_model_group_ids) == 0 # one potential place for bugs would be when we pull back the rules # for being too restrictive. we want to make sure that the original list is # always used for thresholding, or else such a move would be impossible auditioner.update_metric_filters(new_filters=no_filtering) assert len(auditioner.thresholded_model_group_ids) == num_model_groups # pass the argument instead and let all model groups pass auditioner.set_one_metric_filter( metric="precision@", parameter="100_abs", max_from_best=1.0, threshold_value=0.0, ) assert len(auditioner.thresholded_model_group_ids) == num_model_groups # now, we want to take this partially thresholded list and run it through # a grid of selection rules, meant to pick winners by a variety of user-defined # criteria rule_grid = [ { "shared_parameters": [ { "metric": "precision@", "parameter": "100_abs" }, { "metric": "recall@", "parameter": "100_abs" }, ], "selection_rules": [ { "name": "most_frequent_best_dist", "dist_from_best_case": [0.1, 0.2, 0.3], "n": 1, }, { "name": "best_current_value", "n": 1 }, ], }, { "shared_parameters": [{ "metric1": "precision@", "parameter1": "100_abs" }], "selection_rules": [{ "name": "best_average_two_metrics", "metric2": ["recall@"], "parameter2": ["100_abs"], "metric1_weight": [0.4, 0.5, 0.6], "n": 1, }], }, ] auditioner.register_selection_rule_grid(rule_grid, plot=False) final_model_group_ids = auditioner.selection_rule_model_group_ids # we expect the result to be a mapping of selection rule name to model group id assert isinstance(final_model_group_ids, dict) # we expect that there is one winner for each selection rule assert sorted(final_model_group_ids.keys()) == sorted( [rule.descriptive_name for rule in auditioner.selection_rules])
def test_PreAudition(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) ensure_db(db_engine) init_engine(db_engine) # set up data, randomly generated by the factories but conforming # generally to what we expect triage_metadata schema data to look like num_model_groups = 10 model_types = [ "classifier type {}".format(i) for i in range(0, num_model_groups) ] model_configs = [ {"label_definition": "label_1"} if i % 2 == 0 else {"label_definition": "label_2"} for i in range(0, num_model_groups) ] model_groups = [ ModelGroupFactory(model_type=model_type, model_config=model_config) for model_type, model_config in zip(model_types, model_configs) ] train_end_times = [ datetime(2013, 1, 1), datetime(2013, 7, 1), datetime(2014, 1, 1), datetime(2014, 7, 1), datetime(2015, 1, 1), datetime(2015, 7, 1), datetime(2016, 7, 1), datetime(2016, 1, 1), ] models = [ ModelFactory(model_group_rel=model_group, train_end_time=train_end_time) for model_group in model_groups for train_end_time in train_end_times ] metrics = [ ("precision@", "100_abs"), ("recall@", "100_abs"), ("precision@", "50_abs"), ("recall@", "50_abs"), ("fpr@", "10_pct"), ] class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time ) for model in models: for (metric, parameter) in metrics: ImmediateEvalFactory( model_rel=model, metric=metric, parameter=parameter ) session.commit() pre_aud = PreAudition(db_engine) # Expect the number of model groups with label_1 assert len(pre_aud.get_model_groups_from_label("label_1")['model_groups']) == sum( [x["label_definition"] == "label_1" for x in model_configs] ) # Expect no baseline model groups assert len(pre_aud.get_model_groups_from_label("label_1")['baseline_model_groups']) == 0 # Expect the number of model groups with certain experiment_hash experiment_hash = list( pd.read_sql( """SELECT experiment_hash FROM triage_metadata.models JOIN triage_metadata.experiment_models using (model_hash) limit 1""", con=db_engine, )["experiment_hash"] )[0] assert len(pre_aud.get_model_groups_from_experiment(experiment_hash)['model_groups']) == 1 # Expect the number of model groups for customs SQL query = """ SELECT DISTINCT(model_group_id) FROM triage_metadata.models JOIN triage_metadata.experiment_models using (model_hash) WHERE train_end_time >= '2013-01-01' AND experiment_hash = '{}' """.format( experiment_hash ) assert len(pre_aud.get_model_groups(query)) == 1 # Expect the number of train_end_times after 2014-01-01 assert len(pre_aud.get_train_end_times(after="2014-01-01")) == 6 query = """ SELECT DISTINCT train_end_time FROM triage_metadata.models WHERE model_group_id IN ({}) AND train_end_time >= '2014-01-01' ORDER BY train_end_time """.format( ", ".join(map(str, pre_aud.model_groups)) ) assert len(pre_aud.get_train_end_times(query=query)) == 6
def test_PreAudition(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) ensure_db(db_engine) init_engine(db_engine) # set up data, randomly generated by the factories but conforming # generally to what we expect results schema data to look like num_model_groups = 10 model_types = [ 'classifier type {}'.format(i) for i in range(0, num_model_groups) ] model_configs = [{ 'label_definition': 'label_1' } if i % 2 == 0 else { 'label_definition': 'label_2' } for i in range(0, num_model_groups)] model_groups = [ ModelGroupFactory(model_type=model_type, model_config=model_config) for model_type, model_config in zip(model_types, model_configs) ] train_end_times = [ datetime(2013, 1, 1), datetime(2013, 7, 1), datetime(2014, 1, 1), datetime(2014, 7, 1), datetime(2015, 1, 1), datetime(2015, 7, 1), datetime(2016, 7, 1), datetime(2016, 1, 1), ] models = [ ModelFactory(model_group_rel=model_group, train_end_time=train_end_time) for model_group in model_groups for train_end_time in train_end_times ] metrics = [ ('precision@', '100_abs'), ('recall@', '100_abs'), ('precision@', '50_abs'), ('recall@', '50_abs'), ('fpr@', '10_pct'), ] class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) for model in models: for (metric, parameter) in metrics: ImmediateEvalFactory(model_rel=model, metric=metric, parameter=parameter) session.commit() pre_aud = PreAudition(db_engine) # Expect the number of model groups with label_1 assert len(pre_aud.get_model_groups_from_label("label_1")) == \ sum([x['label_definition']=='label_1' for x in model_configs]) # Expect the number of model groups with certain experiment_hash experiment_hash = list( pd.read_sql("SELECT experiment_hash FROM results.models limit 1", con=db_engine)['experiment_hash'])[0] assert len( pre_aud.get_model_groups_from_experiment(experiment_hash)) == 1 # Expect the number of model groups for customs SQL query = """ SELECT DISTINCT(model_group_id) FROM results.models WHERE train_end_time >= '2013-01-01' AND experiment_hash = '{}' """.format(experiment_hash) assert len(pre_aud.get_model_groups(query)) == 1 # Expect the number of train_end_times after 2014-01-01 assert len(pre_aud.get_train_end_times(after='2014-01-01')) == 6 query = """ SELECT DISTINCT train_end_time FROM results.models WHERE model_group_id IN ({}) AND train_end_time >= '2014-01-01' ORDER BY train_end_time """.format(', '.join(map(str, pre_aud.model_groups))) assert len(pre_aud.get_train_end_times(query=query)) == 6
def test_DistanceFromBestTable(): with testing.postgresql.Postgresql() as postgresql: engine = create_engine(postgresql.url()) ensure_db(engine) init_engine(engine) model_groups = { "stable": ModelGroupFactory(model_type="myStableClassifier"), "bad": ModelGroupFactory(model_type="myBadClassifier"), "spiky": ModelGroupFactory(model_type="mySpikeClassifier"), } class StableModelFactory(ModelFactory): model_group_rel = model_groups["stable"] class BadModelFactory(ModelFactory): model_group_rel = model_groups["bad"] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups["spiky"] models = { "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"), "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"), "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"), "bad_3y_ago": BadModelFactory(train_end_time="2014-01-01"), "bad_2y_ago": BadModelFactory(train_end_time="2015-01-01"), "bad_1y_ago": BadModelFactory(train_end_time="2016-01-01"), "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"), "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"), "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"), } class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 1)) class MonthOutEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 31)) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 32)) class Precision100Factory(ImmediateEvalFactory): metric = "precision@" parameter = "100_abs" class Precision100FactoryMonthOut(MonthOutEvalFactory): metric = "precision@" parameter = "100_abs" class Recall100Factory(ImmediateEvalFactory): metric = "recall@" parameter = "100_abs" class Recall100FactoryMonthOut(MonthOutEvalFactory): metric = "recall@" parameter = "100_abs" for (add_val, PrecFac, RecFac) in ( (0, Precision100Factory, Recall100Factory), (-0.15, Precision100FactoryMonthOut, Recall100FactoryMonthOut), ): PrecFac(model_rel=models["stable_3y_ago"], value=0.6 + add_val) PrecFac(model_rel=models["stable_2y_ago"], value=0.57 + add_val) PrecFac(model_rel=models["stable_1y_ago"], value=0.59 + add_val) PrecFac(model_rel=models["bad_3y_ago"], value=0.4 + add_val) PrecFac(model_rel=models["bad_2y_ago"], value=0.39 + add_val) PrecFac(model_rel=models["bad_1y_ago"], value=0.43 + add_val) PrecFac(model_rel=models["spiky_3y_ago"], value=0.8 + add_val) PrecFac(model_rel=models["spiky_2y_ago"], value=0.4 + add_val) PrecFac(model_rel=models["spiky_1y_ago"], value=0.4 + add_val) RecFac(model_rel=models["stable_3y_ago"], value=0.55 + add_val) RecFac(model_rel=models["stable_2y_ago"], value=0.56 + add_val) RecFac(model_rel=models["stable_1y_ago"], value=0.55 + add_val) RecFac(model_rel=models["bad_3y_ago"], value=0.35 + add_val) RecFac(model_rel=models["bad_2y_ago"], value=0.34 + add_val) RecFac(model_rel=models["bad_1y_ago"], value=0.36 + add_val) RecFac(model_rel=models["spiky_3y_ago"], value=0.35 + add_val) RecFac(model_rel=models["spiky_2y_ago"], value=0.8 + add_val) RecFac(model_rel=models["spiky_1y_ago"], value=0.36 + add_val) session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table") metrics = [ { "metric": "precision@", "parameter": "100_abs" }, { "metric": "recall@", "parameter": "100_abs" }, ] model_group_ids = [mg.model_group_id for mg in model_groups.values()] distance_table.create_and_populate( model_group_ids, ["2014-01-01", "2015-01-01", "2016-01-01"], metrics) # get an ordered list of the models/groups for a particular metric/time query = """ select model_id, raw_value, dist_from_best_case, dist_from_best_case_next_time from dist_table where metric = %s and parameter = %s and train_end_time = %s order by dist_from_best_case """ prec_3y_ago = engine.execute(query, ("precision@", "100_abs", "2014-01-01")) assert [row for row in prec_3y_ago] == [ (models["spiky_3y_ago"].model_id, 0.8, 0, 0.17), (models["stable_3y_ago"].model_id, 0.6, 0.2, 0), (models["bad_3y_ago"].model_id, 0.4, 0.4, 0.18), ] recall_2y_ago = engine.execute(query, ("recall@", "100_abs", "2015-01-01")) assert [row for row in recall_2y_ago] == [ (models["spiky_2y_ago"].model_id, 0.8, 0, 0.19), (models["stable_2y_ago"].model_id, 0.56, 0.24, 0), (models["bad_2y_ago"].model_id, 0.34, 0.46, 0.19), ] assert distance_table.observed_bounds == { ("precision@", "100_abs"): (0.39, 0.8), ("recall@", "100_abs"): (0.34, 0.8), }
def setup_data(self, engine): ensure_db(engine) init_engine(engine) ModelGroupFactory(model_group_id=1, model_type='modelType1') ModelGroupFactory(model_group_id=2, model_type='modelType2') ModelGroupFactory(model_group_id=3, model_type='modelType3') ModelGroupFactory(model_group_id=4, model_type='modelType4') ModelGroupFactory(model_group_id=5, model_type='modelType5') session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') distance_table._create() distance_rows = [ # 2014: model group 1 should pass both close and min checks (1, 1, '2014-01-01', 'precision@', '100_abs', 0.5, 0.5, 0.0, 0.38), (1, 1, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (1, 1, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 1 should not pass close check (1, 2, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (1, 2, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (1, 2, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (1, 3, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (1, 3, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (1, 3, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 2 should not pass min check (2, 4, '2014-01-01', 'precision@', '100_abs', 0.39, 0.5, 0.11, 0.5 ), (2, 4, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (2, 4, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 2 should pass both checks (2, 5, '2015-01-01', 'precision@', '100_abs', 0.69, 0.88, 0.19, 0.12), (2, 5, '2015-01-01', 'recall@', '100_abs', 0.69, 0.88, 0.19, 0.0), (2, 5, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (2, 6, '2016-01-01', 'precision@', '100_abs', 0.34, 0.46, 0.12, 0.11), (2, 6, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (2, 6, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # model group 3 not included in this round (3, 7, '2014-01-01', 'precision@', '100_abs', 0.28, 0.5, 0.22, 0.0 ), (3, 7, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (3, 7, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (3, 8, '2015-01-01', 'precision@', '100_abs', 0.88, 0.88, 0.0, 0.02), (3, 8, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (3, 8, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (3, 9, '2016-01-01', 'precision@', '100_abs', 0.44, 0.46, 0.02, 0.11), (3, 9, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (3, 9, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 4 should not pass any checks (4, 10, '2014-01-01', 'precision@', '100_abs', 0.29, 0.5, 0.21, 0.21), (4, 10, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (4, 10, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 4 should not pass close check (4, 11, '2015-01-01', 'precision@', '100_abs', 0.67, 0.88, 0.21, 0.21), (4, 11, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (4, 11, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (4, 12, '2016-01-01', 'precision@', '100_abs', 0.25, 0.46, 0.21, 0.21), (4, 12, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (4, 12, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 5 should not pass because precision is good but not recall (5, 13, '2014-01-01', 'precision@', '100_abs', 0.5, 0.38, 0.0, 0.38 ), (5, 13, '2014-01-01', 'recall@', '100_abs', 0.3, 0.5, 0.2, 0.38), (5, 13, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 5 should not pass because precision is good but not recall (5, 14, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (5, 14, '2015-01-01', 'recall@', '100_abs', 0.3, 0.88, 0.58, 0.0), (5, 14, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (5, 15, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (5, 15, '2016-01-01', 'recall@', '100_abs', 0.3, 0.46, 0.16, 0.11), (5, 15, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 6 is failed by false positives (6, 16, '2014-01-01', 'precision@', '100_abs', 0.5, 0.5, 0.0, 0.38 ), (6, 16, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (6, 16, '2014-01-01', 'false positives@', '100_abs', 60, 30, 30, 10), # 2015: model group 6 is failed by false positives (6, 17, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (6, 17, '2015-01-01', 'recall@', '100_abs', 0.5, 0.38, 0.0, 0.38), (6, 17, '2015-01-01', 'false positives@', '100_abs', 60, 30, 30, 10), (6, 18, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (6, 18, '2016-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (6, 18, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), ] for dist_row in distance_rows: engine.execute( 'insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s)', dist_row) thresholder = ModelGroupThresholder( distance_from_best_table=distance_table, train_end_times=['2014-01-01', '2015-01-01'], initial_model_group_ids=[1, 2, 4, 5, 6], initial_metric_filters=self.metric_filters) return thresholder
def test_DistanceFromBestTable(): with testing.postgresql.Postgresql() as postgresql: engine = create_engine(postgresql.url()) ensure_db(engine) init_engine(engine) model_groups = { 'stable': ModelGroupFactory(model_type='myStableClassifier'), 'bad': ModelGroupFactory(model_type='myBadClassifier'), 'spiky': ModelGroupFactory(model_type='mySpikeClassifier'), } class StableModelFactory(ModelFactory): model_group_rel = model_groups['stable'] class BadModelFactory(ModelFactory): model_group_rel = model_groups['bad'] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups['spiky'] models = { 'stable_3y_ago': StableModelFactory(train_end_time='2014-01-01'), 'stable_2y_ago': StableModelFactory(train_end_time='2015-01-01'), 'stable_1y_ago': StableModelFactory(train_end_time='2016-01-01'), 'bad_3y_ago': BadModelFactory(train_end_time='2014-01-01'), 'bad_2y_ago': BadModelFactory(train_end_time='2015-01-01'), 'bad_1y_ago': BadModelFactory(train_end_time='2016-01-01'), 'spiky_3y_ago': SpikyModelFactory(train_end_time='2014-01-01'), 'spiky_2y_ago': SpikyModelFactory(train_end_time='2015-01-01'), 'spiky_1y_ago': SpikyModelFactory(train_end_time='2016-01-01'), } class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 1)) class MonthOutEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 31)) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 32)) class Precision100Factory(ImmediateEvalFactory): metric = 'precision@' parameter = '100_abs' class Precision100FactoryMonthOut(MonthOutEvalFactory): metric = 'precision@' parameter = '100_abs' class Recall100Factory(ImmediateEvalFactory): metric = 'recall@' parameter = '100_abs' class Recall100FactoryMonthOut(MonthOutEvalFactory): metric = 'recall@' parameter = '100_abs' for (add_val, PrecFac, RecFac) in ((0, Precision100Factory, Recall100Factory), (-0.15, Precision100FactoryMonthOut, Recall100FactoryMonthOut)): PrecFac(model_rel=models['stable_3y_ago'], value=0.6 + add_val) PrecFac(model_rel=models['stable_2y_ago'], value=0.57 + add_val) PrecFac(model_rel=models['stable_1y_ago'], value=0.59 + add_val) PrecFac(model_rel=models['bad_3y_ago'], value=0.4 + add_val) PrecFac(model_rel=models['bad_2y_ago'], value=0.39 + add_val) PrecFac(model_rel=models['bad_1y_ago'], value=0.43 + add_val) PrecFac(model_rel=models['spiky_3y_ago'], value=0.8 + add_val) PrecFac(model_rel=models['spiky_2y_ago'], value=0.4 + add_val) PrecFac(model_rel=models['spiky_1y_ago'], value=0.4 + add_val) RecFac(model_rel=models['stable_3y_ago'], value=0.55 + add_val) RecFac(model_rel=models['stable_2y_ago'], value=0.56 + add_val) RecFac(model_rel=models['stable_1y_ago'], value=0.55 + add_val) RecFac(model_rel=models['bad_3y_ago'], value=0.35 + add_val) RecFac(model_rel=models['bad_2y_ago'], value=0.34 + add_val) RecFac(model_rel=models['bad_1y_ago'], value=0.36 + add_val) RecFac(model_rel=models['spiky_3y_ago'], value=0.35 + add_val) RecFac(model_rel=models['spiky_2y_ago'], value=0.8 + add_val) RecFac(model_rel=models['spiky_1y_ago'], value=0.36 + add_val) session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') metrics = [{ 'metric': 'precision@', 'parameter': '100_abs' }, { 'metric': 'recall@', 'parameter': '100_abs' }] model_group_ids = [mg.model_group_id for mg in model_groups.values()] distance_table.create_and_populate( model_group_ids, ['2014-01-01', '2015-01-01', '2016-01-01'], metrics) # get an ordered list of the models/groups for a particular metric/time query = ''' select model_id, raw_value, dist_from_best_case, dist_from_best_case_next_time from dist_table where metric = %s and parameter = %s and train_end_time = %s order by dist_from_best_case ''' prec_3y_ago = engine.execute(query, ('precision@', '100_abs', '2014-01-01')) assert [row for row in prec_3y_ago] == [ (models['spiky_3y_ago'].model_id, 0.8, 0, 0.17), (models['stable_3y_ago'].model_id, 0.6, 0.2, 0), (models['bad_3y_ago'].model_id, 0.4, 0.4, 0.18), ] recall_2y_ago = engine.execute(query, ('recall@', '100_abs', '2015-01-01')) assert [row for row in recall_2y_ago] == [ (models['spiky_2y_ago'].model_id, 0.8, 0, 0.19), (models['stable_2y_ago'].model_id, 0.56, 0.24, 0), (models['bad_2y_ago'].model_id, 0.34, 0.46, 0.19), ] assert distance_table.observed_bounds == { ('precision@', '100_abs'): (0.39, 0.8), ('recall@', '100_abs'): (0.34, 0.8), }
def setup_data(self, engine): ensure_db(engine) init_engine(engine) ModelGroupFactory(model_group_id=1, model_type="modelType1") ModelGroupFactory(model_group_id=2, model_type="modelType2") ModelGroupFactory(model_group_id=3, model_type="modelType3") ModelGroupFactory(model_group_id=4, model_type="modelType4") ModelGroupFactory(model_group_id=5, model_type="modelType5") session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table", agg_type="worst") distance_table._create() distance_rows = [ # 2014: model group 1 should pass both close and min checks (1, "2014-01-01", "precision@", "100_abs", 0.5, 0.5, 0.0, 0.38), (1, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (1, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 1 should not pass close check (1, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (1, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (1, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (1, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (1, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (1, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 2 should not pass min check (2, "2014-01-01", "precision@", "100_abs", 0.39, 0.5, 0.11, 0.5), (2, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (2, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 2 should pass both checks (2, "2015-01-01", "precision@", "100_abs", 0.69, 0.88, 0.19, 0.12), (2, "2015-01-01", "recall@", "100_abs", 0.69, 0.88, 0.19, 0.0), (2, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (2, "2016-01-01", "precision@", "100_abs", 0.34, 0.46, 0.12, 0.11), (2, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (2, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # model group 3 not included in this round (3, "2014-01-01", "precision@", "100_abs", 0.28, 0.5, 0.22, 0.0), (3, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (3, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (3, "2015-01-01", "precision@", "100_abs", 0.88, 0.88, 0.0, 0.02), (3, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (3, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (3, "2016-01-01", "precision@", "100_abs", 0.44, 0.46, 0.02, 0.11), (3, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (3, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 4 should not pass any checks (4, "2014-01-01", "precision@", "100_abs", 0.29, 0.5, 0.21, 0.21), (4, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (4, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 4 should not pass close check (4, "2015-01-01", "precision@", "100_abs", 0.67, 0.88, 0.21, 0.21), (4, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (4, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (4, "2016-01-01", "precision@", "100_abs", 0.25, 0.46, 0.21, 0.21), (4, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (4, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 5 should not pass because precision is good but not recall (5, "2014-01-01", "precision@", "100_abs", 0.5, 0.38, 0.0, 0.38), (5, "2014-01-01", "recall@", "100_abs", 0.3, 0.5, 0.2, 0.38), (5, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 5 should not pass because precision is good but not recall (5, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (5, "2015-01-01", "recall@", "100_abs", 0.3, 0.88, 0.58, 0.0), (5, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (5, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (5, "2016-01-01", "recall@", "100_abs", 0.3, 0.46, 0.16, 0.11), (5, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 6 is failed by false positives (6, "2014-01-01", "precision@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2014-01-01", "false positives@", "100_abs", 60, 30, 30, 10), # 2015: model group 6 is failed by false positives (6, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (6, "2015-01-01", "recall@", "100_abs", 0.5, 0.38, 0.0, 0.38), (6, "2015-01-01", "false positives@", "100_abs", 60, 30, 30, 10), (6, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (6, "2016-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), ] for dist_row in distance_rows: engine.execute( "insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s)", dist_row, ) thresholder = ModelGroupThresholder( distance_from_best_table=distance_table, train_end_times=["2014-01-01", "2015-01-01"], initial_model_group_ids=[1, 2, 4, 5, 6], initial_metric_filters=self.metric_filters, ) return thresholder
def test_Auditioner(): with testing.postgresql.Postgresql() as postgresql: db_engine = create_engine(postgresql.url()) ensure_db(db_engine) init_engine(db_engine) # set up data, randomly generated by the factories but conforming # generally to what we expect model_metadata schema data to look like num_model_groups = 10 model_types = [ 'classifier type {}'.format(i) for i in range(0, num_model_groups) ] model_groups = [ ModelGroupFactory(model_type=model_type) for model_type in model_types ] train_end_times = [ datetime(2013, 1, 1), datetime(2014, 1, 1), datetime(2015, 1, 1), datetime(2016, 1, 1), ] models = [ ModelFactory(model_group_rel=model_group, train_end_time=train_end_time) for model_group in model_groups for train_end_time in train_end_times ] metrics = [ ('precision@', '100_abs'), ('recall@', '100_abs'), ('precision@', '50_abs'), ('recall@', '50_abs'), ('fpr@', '10_pct'), ] class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) for model in models: for (metric, parameter) in metrics: ImmediateEvalFactory(model_rel=model, metric=metric, parameter=parameter) session.commit() # define a very loose filtering that should admit all model groups no_filtering = [{ 'metric': 'precision@', 'parameter': '100_abs', 'max_from_best': 1.0, 'threshold_value': 0.0 }, { 'metric': 'recall@', 'parameter': '100_abs', 'max_from_best': 1.0, 'threshold_value': 0.0 }] model_group_ids = [mg.model_group_id for mg in model_groups] auditioner = Auditioner( db_engine, model_group_ids, train_end_times, no_filtering, ) assert len(auditioner.thresholded_model_group_ids) == num_model_groups auditioner.plot_model_groups() # here, we pick thresholding rules that should definitely remove # all model groups from contention because they are too strict. remove_all = [{ 'metric': 'precision@', 'parameter': '100_abs', 'max_from_best': 0.0, 'threshold_value': 1.1 }, { 'metric': 'recall@', 'parameter': '100_abs', 'max_from_best': 0.0, 'threshold_value': 1.1 }] auditioner.update_metric_filters(new_filters=remove_all) assert len(auditioner.thresholded_model_group_ids) == 0 # pass the argument instead and remove all model groups auditioner.set_one_metric_filter(metric='precision@', parameter='100_abs', max_from_best=0.0, threshold_value=1.1) assert len(auditioner.thresholded_model_group_ids) == 0 # one potential place for bugs would be when we pull back the rules # for being too restrictive. we want to make sure that the original list is # always used for thresholding, or else such a move would be impossible auditioner.update_metric_filters(new_filters=no_filtering) assert len(auditioner.thresholded_model_group_ids) == num_model_groups # pass the argument instead and let all model groups pass auditioner.set_one_metric_filter(metric='precision@', parameter='100_abs', max_from_best=1.0, threshold_value=0.0) assert len(auditioner.thresholded_model_group_ids) == num_model_groups # now, we want to take this partially thresholded list and run it through # a grid of selection rules, meant to pick winners by a variety of user-defined # criteria rule_grid = [{ 'shared_parameters': [ { 'metric': 'precision@', 'parameter': '100_abs' }, { 'metric': 'recall@', 'parameter': '100_abs' }, ], 'selection_rules': [{ 'name': 'most_frequent_best_dist', 'dist_from_best_case': [0.1, 0.2, 0.3], 'n': 1 }, { 'name': 'best_current_value', 'n': 1 }] }, { 'shared_parameters': [ { 'metric1': 'precision@', 'parameter1': '100_abs' }, ], 'selection_rules': [ { 'name': 'best_average_two_metrics', 'metric2': ['recall@'], 'parameter2': ['100_abs'], 'metric1_weight': [0.4, 0.5, 0.6], 'n': 1 }, ] }] auditioner.register_selection_rule_grid(rule_grid, plot=False) final_model_group_ids = auditioner.selection_rule_model_group_ids # we expect the result to be a mapping of selection rule name to model group id assert isinstance(final_model_group_ids, dict) # we expect that there is one winner for each selection rule assert sorted(final_model_group_ids.keys()) == \ sorted([rule.descriptive_name for rule in auditioner.selection_rules]) # we expect that the results written to the yaml file are the # chosen model groups and their rules # however because the source data is randomly generated we could have a # different list on consecutive runs # and don't want to introduce non-determinism to the test with tempfile.NamedTemporaryFile() as tf: auditioner.write_tyra_config(tf.name) assert sorted(yaml.load(tf)['selection_rule_model_groups'].keys()) == \ sorted(final_model_group_ids.keys())
def create_sample_distance_table(engine): ensure_db(engine) init_engine(engine) model_groups = { "stable": ModelGroupFactory(model_type="myStableClassifier"), "spiky": ModelGroupFactory(model_type="mySpikeClassifier"), } class StableModelFactory(ModelFactory): model_group_rel = model_groups["stable"] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups["spiky"] models = { "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"), "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"), "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"), "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"), "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"), "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"), } session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table") distance_table._create() stable_grp = model_groups["stable"].model_group_id spiky_grp = model_groups["spiky"].model_group_id stable_3y_id = models["stable_3y_ago"].model_id stable_3y_end = models["stable_3y_ago"].train_end_time stable_2y_id = models["stable_2y_ago"].model_id stable_2y_end = models["stable_2y_ago"].train_end_time stable_1y_id = models["stable_1y_ago"].model_id stable_1y_end = models["stable_1y_ago"].train_end_time spiky_3y_id = models["spiky_3y_ago"].model_id spiky_3y_end = models["spiky_3y_ago"].train_end_time spiky_2y_id = models["spiky_2y_ago"].model_id spiky_2y_end = models["spiky_2y_ago"].train_end_time spiky_1y_id = models["spiky_1y_ago"].model_id spiky_1y_end = models["spiky_1y_ago"].train_end_time distance_rows = [ ( stable_grp, stable_3y_id, stable_3y_end, "precision@", "100_abs", 0.5, 0.6, 0.1, 0.5, 0.15, ), ( stable_grp, stable_2y_id, stable_2y_end, "precision@", "100_abs", 0.5, 0.84, 0.34, 0.5, 0.18, ), ( stable_grp, stable_1y_id, stable_1y_end, "precision@", "100_abs", 0.46, 0.67, 0.21, 0.5, 0.11, ), ( spiky_grp, spiky_3y_id, spiky_3y_end, "precision@", "100_abs", 0.45, 0.6, 0.15, 0.5, 0.19, ), ( spiky_grp, spiky_2y_id, spiky_2y_end, "precision@", "100_abs", 0.84, 0.84, 0.0, 0.5, 0.3, ), ( spiky_grp, spiky_1y_id, spiky_1y_end, "precision@", "100_abs", 0.45, 0.67, 0.22, 0.5, 0.12, ), ( stable_grp, stable_3y_id, stable_3y_end, "recall@", "100_abs", 0.4, 0.4, 0.0, 0.4, 0.0, ), ( stable_grp, stable_2y_id, stable_2y_end, "recall@", "100_abs", 0.5, 0.5, 0.0, 0.5, 0.0, ), ( stable_grp, stable_1y_id, stable_1y_end, "recall@", "100_abs", 0.6, 0.6, 0.0, 0.6, 0.0, ), ( spiky_grp, spiky_3y_id, spiky_3y_end, "recall@", "100_abs", 0.65, 0.65, 0.0, 0.65, 0.0, ), ( spiky_grp, spiky_2y_id, spiky_2y_end, "recall@", "100_abs", 0.55, 0.55, 0.0, 0.55, 0.0, ), ( spiky_grp, spiky_1y_id, spiky_1y_end, "recall@", "100_abs", 0.45, 0.45, 0.0, 0.45, 0.0, ), ] for dist_row in distance_rows: engine.execute( "insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", dist_row, ) return distance_table, model_groups