def create_sample_distance_table(engine): ensure_db(engine) init_engine(engine) model_groups = { 'stable': ModelGroupFactory(model_type='myStableClassifier'), 'spiky': ModelGroupFactory(model_type='mySpikeClassifier'), } class StableModelFactory(ModelFactory): model_group_rel = model_groups['stable'] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups['spiky'] models = { 'stable_3y_ago': StableModelFactory(train_end_time='2014-01-01'), 'stable_2y_ago': StableModelFactory(train_end_time='2015-01-01'), 'stable_1y_ago': StableModelFactory(train_end_time='2016-01-01'), 'spiky_3y_ago': SpikyModelFactory(train_end_time='2014-01-01'), 'spiky_2y_ago': SpikyModelFactory(train_end_time='2015-01-01'), 'spiky_1y_ago': SpikyModelFactory(train_end_time='2016-01-01'), } session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') distance_table._create() stable_grp = model_groups['stable'].model_group_id spiky_grp = model_groups['spiky'].model_group_id stable_3y_id = models['stable_3y_ago'].model_id stable_3y_end = models['stable_3y_ago'].train_end_time stable_2y_id = models['stable_2y_ago'].model_id stable_2y_end = models['stable_2y_ago'].train_end_time stable_1y_id = models['stable_1y_ago'].model_id stable_1y_end = models['stable_1y_ago'].train_end_time spiky_3y_id = models['spiky_3y_ago'].model_id spiky_3y_end = models['spiky_3y_ago'].train_end_time spiky_2y_id = models['spiky_2y_ago'].model_id spiky_2y_end = models['spiky_2y_ago'].train_end_time spiky_1y_id = models['spiky_1y_ago'].model_id spiky_1y_end = models['spiky_1y_ago'].train_end_time distance_rows = [ (stable_grp, stable_3y_id, stable_3y_end, 'precision@', '100_abs', 0.5, 0.6, 0.1, 0.5, 0.15), (stable_grp, stable_2y_id, stable_2y_end, 'precision@', '100_abs', 0.5, 0.84, 0.34, 0.5, 0.18), (stable_grp, stable_1y_id, stable_1y_end, 'precision@', '100_abs', 0.46, 0.67, 0.21, 0.5, 0.11), (spiky_grp, spiky_3y_id, spiky_3y_end, 'precision@', '100_abs', 0.45, 0.6, 0.15, 0.5, 0.19), (spiky_grp, spiky_2y_id, spiky_2y_end, 'precision@', '100_abs', 0.84, 0.84, 0.0, 0.5, 0.3), (spiky_grp, spiky_1y_id, spiky_1y_end, 'precision@', '100_abs', 0.45, 0.67, 0.22, 0.5, 0.12), (stable_grp, stable_3y_id, stable_3y_end, 'recall@', '100_abs', 0.4, 0.4, 0.0, 0.4, 0.0), (stable_grp, stable_2y_id, stable_2y_end, 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.5, 0.0), (stable_grp, stable_1y_id, stable_1y_end, 'recall@', '100_abs', 0.6, 0.6, 0.0, 0.6, 0.0), (spiky_grp, spiky_3y_id, spiky_3y_end, 'recall@', '100_abs', 0.65, 0.65, 0.0, 0.65, 0.0), (spiky_grp, spiky_2y_id, spiky_2y_end, 'recall@', '100_abs', 0.55, 0.55, 0.0, 0.55, 0.0), (spiky_grp, spiky_1y_id, spiky_1y_end, 'recall@', '100_abs', 0.45, 0.45, 0.0, 0.45, 0.0), ] for dist_row in distance_rows: engine.execute( 'insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', dist_row) return distance_table, model_groups
def test_DistanceFromBestTable(): with testing.postgresql.Postgresql() as postgresql: engine = create_engine(postgresql.url()) ensure_db(engine) init_engine(engine) model_groups = { "stable": ModelGroupFactory(model_type="myStableClassifier"), "bad": ModelGroupFactory(model_type="myBadClassifier"), "spiky": ModelGroupFactory(model_type="mySpikeClassifier"), } class StableModelFactory(ModelFactory): model_group_rel = model_groups["stable"] class BadModelFactory(ModelFactory): model_group_rel = model_groups["bad"] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups["spiky"] models = { "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"), "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"), "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"), "bad_3y_ago": BadModelFactory(train_end_time="2014-01-01"), "bad_2y_ago": BadModelFactory(train_end_time="2015-01-01"), "bad_1y_ago": BadModelFactory(train_end_time="2016-01-01"), "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"), "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"), "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"), } class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 1)) class MonthOutEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 31)) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 32)) class Precision100Factory(ImmediateEvalFactory): metric = "precision@" parameter = "100_abs" class Precision100FactoryMonthOut(MonthOutEvalFactory): metric = "precision@" parameter = "100_abs" class Recall100Factory(ImmediateEvalFactory): metric = "recall@" parameter = "100_abs" class Recall100FactoryMonthOut(MonthOutEvalFactory): metric = "recall@" parameter = "100_abs" for (add_val, PrecFac, RecFac) in ( (0, Precision100Factory, Recall100Factory), (-0.15, Precision100FactoryMonthOut, Recall100FactoryMonthOut), ): PrecFac(model_rel=models["stable_3y_ago"], value=0.6 + add_val) PrecFac(model_rel=models["stable_2y_ago"], value=0.57 + add_val) PrecFac(model_rel=models["stable_1y_ago"], value=0.59 + add_val) PrecFac(model_rel=models["bad_3y_ago"], value=0.4 + add_val) PrecFac(model_rel=models["bad_2y_ago"], value=0.39 + add_val) PrecFac(model_rel=models["bad_1y_ago"], value=0.43 + add_val) PrecFac(model_rel=models["spiky_3y_ago"], value=0.8 + add_val) PrecFac(model_rel=models["spiky_2y_ago"], value=0.4 + add_val) PrecFac(model_rel=models["spiky_1y_ago"], value=0.4 + add_val) RecFac(model_rel=models["stable_3y_ago"], value=0.55 + add_val) RecFac(model_rel=models["stable_2y_ago"], value=0.56 + add_val) RecFac(model_rel=models["stable_1y_ago"], value=0.55 + add_val) RecFac(model_rel=models["bad_3y_ago"], value=0.35 + add_val) RecFac(model_rel=models["bad_2y_ago"], value=0.34 + add_val) RecFac(model_rel=models["bad_1y_ago"], value=0.36 + add_val) RecFac(model_rel=models["spiky_3y_ago"], value=0.35 + add_val) RecFac(model_rel=models["spiky_2y_ago"], value=0.8 + add_val) RecFac(model_rel=models["spiky_1y_ago"], value=0.36 + add_val) session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table") metrics = [ { "metric": "precision@", "parameter": "100_abs" }, { "metric": "recall@", "parameter": "100_abs" }, ] model_group_ids = [mg.model_group_id for mg in model_groups.values()] distance_table.create_and_populate( model_group_ids, ["2014-01-01", "2015-01-01", "2016-01-01"], metrics) # get an ordered list of the models/groups for a particular metric/time query = """ select model_id, raw_value, dist_from_best_case, dist_from_best_case_next_time from dist_table where metric = %s and parameter = %s and train_end_time = %s order by dist_from_best_case """ prec_3y_ago = engine.execute(query, ("precision@", "100_abs", "2014-01-01")) assert [row for row in prec_3y_ago] == [ (models["spiky_3y_ago"].model_id, 0.8, 0, 0.17), (models["stable_3y_ago"].model_id, 0.6, 0.2, 0), (models["bad_3y_ago"].model_id, 0.4, 0.4, 0.18), ] recall_2y_ago = engine.execute(query, ("recall@", "100_abs", "2015-01-01")) assert [row for row in recall_2y_ago] == [ (models["spiky_2y_ago"].model_id, 0.8, 0, 0.19), (models["stable_2y_ago"].model_id, 0.56, 0.24, 0), (models["bad_2y_ago"].model_id, 0.34, 0.46, 0.19), ] assert distance_table.observed_bounds == { ("precision@", "100_abs"): (0.39, 0.8), ("recall@", "100_abs"): (0.34, 0.8), }
def setup_data(self, engine): ensure_db(engine) init_engine(engine) ModelGroupFactory(model_group_id=1, model_type='modelType1') ModelGroupFactory(model_group_id=2, model_type='modelType2') ModelGroupFactory(model_group_id=3, model_type='modelType3') ModelGroupFactory(model_group_id=4, model_type='modelType4') ModelGroupFactory(model_group_id=5, model_type='modelType5') session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') distance_table._create() distance_rows = [ # 2014: model group 1 should pass both close and min checks (1, 1, '2014-01-01', 'precision@', '100_abs', 0.5, 0.5, 0.0, 0.38), (1, 1, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (1, 1, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 1 should not pass close check (1, 2, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (1, 2, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (1, 2, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (1, 3, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (1, 3, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (1, 3, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 2 should not pass min check (2, 4, '2014-01-01', 'precision@', '100_abs', 0.39, 0.5, 0.11, 0.5 ), (2, 4, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (2, 4, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 2 should pass both checks (2, 5, '2015-01-01', 'precision@', '100_abs', 0.69, 0.88, 0.19, 0.12), (2, 5, '2015-01-01', 'recall@', '100_abs', 0.69, 0.88, 0.19, 0.0), (2, 5, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (2, 6, '2016-01-01', 'precision@', '100_abs', 0.34, 0.46, 0.12, 0.11), (2, 6, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (2, 6, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # model group 3 not included in this round (3, 7, '2014-01-01', 'precision@', '100_abs', 0.28, 0.5, 0.22, 0.0 ), (3, 7, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (3, 7, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (3, 8, '2015-01-01', 'precision@', '100_abs', 0.88, 0.88, 0.0, 0.02), (3, 8, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (3, 8, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (3, 9, '2016-01-01', 'precision@', '100_abs', 0.44, 0.46, 0.02, 0.11), (3, 9, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (3, 9, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 4 should not pass any checks (4, 10, '2014-01-01', 'precision@', '100_abs', 0.29, 0.5, 0.21, 0.21), (4, 10, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (4, 10, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 4 should not pass close check (4, 11, '2015-01-01', 'precision@', '100_abs', 0.67, 0.88, 0.21, 0.21), (4, 11, '2015-01-01', 'recall@', '100_abs', 0.5, 0.88, 0.38, 0.0), (4, 11, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (4, 12, '2016-01-01', 'precision@', '100_abs', 0.25, 0.46, 0.21, 0.21), (4, 12, '2016-01-01', 'recall@', '100_abs', 0.46, 0.46, 0.0, 0.11), (4, 12, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 5 should not pass because precision is good but not recall (5, 13, '2014-01-01', 'precision@', '100_abs', 0.5, 0.38, 0.0, 0.38 ), (5, 13, '2014-01-01', 'recall@', '100_abs', 0.3, 0.5, 0.2, 0.38), (5, 13, '2014-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2015: model group 5 should not pass because precision is good but not recall (5, 14, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (5, 14, '2015-01-01', 'recall@', '100_abs', 0.3, 0.88, 0.58, 0.0), (5, 14, '2015-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), (5, 15, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (5, 15, '2016-01-01', 'recall@', '100_abs', 0.3, 0.46, 0.16, 0.11), (5, 15, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), # 2014: model group 6 is failed by false positives (6, 16, '2014-01-01', 'precision@', '100_abs', 0.5, 0.5, 0.0, 0.38 ), (6, 16, '2014-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (6, 16, '2014-01-01', 'false positives@', '100_abs', 60, 30, 30, 10), # 2015: model group 6 is failed by false positives (6, 17, '2015-01-01', 'precision@', '100_abs', 0.5, 0.88, 0.38, 0.0 ), (6, 17, '2015-01-01', 'recall@', '100_abs', 0.5, 0.38, 0.0, 0.38), (6, 17, '2015-01-01', 'false positives@', '100_abs', 60, 30, 30, 10), (6, 18, '2016-01-01', 'precision@', '100_abs', 0.46, 0.46, 0.0, 0.11), (6, 18, '2016-01-01', 'recall@', '100_abs', 0.5, 0.5, 0.0, 0.38), (6, 18, '2016-01-01', 'false positives@', '100_abs', 40, 30, 10, 10), ] for dist_row in distance_rows: engine.execute( 'insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s)', dist_row) thresholder = ModelGroupThresholder( distance_from_best_table=distance_table, train_end_times=['2014-01-01', '2015-01-01'], initial_model_group_ids=[1, 2, 4, 5, 6], initial_metric_filters=self.metric_filters) return thresholder
def test_DistanceFromBestTable(): with testing.postgresql.Postgresql() as postgresql: engine = create_engine(postgresql.url()) ensure_db(engine) init_engine(engine) model_groups = { 'stable': ModelGroupFactory(model_type='myStableClassifier'), 'bad': ModelGroupFactory(model_type='myBadClassifier'), 'spiky': ModelGroupFactory(model_type='mySpikeClassifier'), } class StableModelFactory(ModelFactory): model_group_rel = model_groups['stable'] class BadModelFactory(ModelFactory): model_group_rel = model_groups['bad'] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups['spiky'] models = { 'stable_3y_ago': StableModelFactory(train_end_time='2014-01-01'), 'stable_2y_ago': StableModelFactory(train_end_time='2015-01-01'), 'stable_1y_ago': StableModelFactory(train_end_time='2016-01-01'), 'bad_3y_ago': BadModelFactory(train_end_time='2014-01-01'), 'bad_2y_ago': BadModelFactory(train_end_time='2015-01-01'), 'bad_1y_ago': BadModelFactory(train_end_time='2016-01-01'), 'spiky_3y_ago': SpikyModelFactory(train_end_time='2014-01-01'), 'spiky_2y_ago': SpikyModelFactory(train_end_time='2015-01-01'), 'spiky_1y_ago': SpikyModelFactory(train_end_time='2016-01-01'), } class ImmediateEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: o.model_rel.train_end_time) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 1)) class MonthOutEvalFactory(EvaluationFactory): evaluation_start_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 31)) evaluation_end_time = factory.LazyAttribute( lambda o: _sql_add_days(o.model_rel.train_end_time, 32)) class Precision100Factory(ImmediateEvalFactory): metric = 'precision@' parameter = '100_abs' class Precision100FactoryMonthOut(MonthOutEvalFactory): metric = 'precision@' parameter = '100_abs' class Recall100Factory(ImmediateEvalFactory): metric = 'recall@' parameter = '100_abs' class Recall100FactoryMonthOut(MonthOutEvalFactory): metric = 'recall@' parameter = '100_abs' for (add_val, PrecFac, RecFac) in ((0, Precision100Factory, Recall100Factory), (-0.15, Precision100FactoryMonthOut, Recall100FactoryMonthOut)): PrecFac(model_rel=models['stable_3y_ago'], value=0.6 + add_val) PrecFac(model_rel=models['stable_2y_ago'], value=0.57 + add_val) PrecFac(model_rel=models['stable_1y_ago'], value=0.59 + add_val) PrecFac(model_rel=models['bad_3y_ago'], value=0.4 + add_val) PrecFac(model_rel=models['bad_2y_ago'], value=0.39 + add_val) PrecFac(model_rel=models['bad_1y_ago'], value=0.43 + add_val) PrecFac(model_rel=models['spiky_3y_ago'], value=0.8 + add_val) PrecFac(model_rel=models['spiky_2y_ago'], value=0.4 + add_val) PrecFac(model_rel=models['spiky_1y_ago'], value=0.4 + add_val) RecFac(model_rel=models['stable_3y_ago'], value=0.55 + add_val) RecFac(model_rel=models['stable_2y_ago'], value=0.56 + add_val) RecFac(model_rel=models['stable_1y_ago'], value=0.55 + add_val) RecFac(model_rel=models['bad_3y_ago'], value=0.35 + add_val) RecFac(model_rel=models['bad_2y_ago'], value=0.34 + add_val) RecFac(model_rel=models['bad_1y_ago'], value=0.36 + add_val) RecFac(model_rel=models['spiky_3y_ago'], value=0.35 + add_val) RecFac(model_rel=models['spiky_2y_ago'], value=0.8 + add_val) RecFac(model_rel=models['spiky_1y_ago'], value=0.36 + add_val) session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table='models', distance_table='dist_table') metrics = [{ 'metric': 'precision@', 'parameter': '100_abs' }, { 'metric': 'recall@', 'parameter': '100_abs' }] model_group_ids = [mg.model_group_id for mg in model_groups.values()] distance_table.create_and_populate( model_group_ids, ['2014-01-01', '2015-01-01', '2016-01-01'], metrics) # get an ordered list of the models/groups for a particular metric/time query = ''' select model_id, raw_value, dist_from_best_case, dist_from_best_case_next_time from dist_table where metric = %s and parameter = %s and train_end_time = %s order by dist_from_best_case ''' prec_3y_ago = engine.execute(query, ('precision@', '100_abs', '2014-01-01')) assert [row for row in prec_3y_ago] == [ (models['spiky_3y_ago'].model_id, 0.8, 0, 0.17), (models['stable_3y_ago'].model_id, 0.6, 0.2, 0), (models['bad_3y_ago'].model_id, 0.4, 0.4, 0.18), ] recall_2y_ago = engine.execute(query, ('recall@', '100_abs', '2015-01-01')) assert [row for row in recall_2y_ago] == [ (models['spiky_2y_ago'].model_id, 0.8, 0, 0.19), (models['stable_2y_ago'].model_id, 0.56, 0.24, 0), (models['bad_2y_ago'].model_id, 0.34, 0.46, 0.19), ] assert distance_table.observed_bounds == { ('precision@', '100_abs'): (0.39, 0.8), ('recall@', '100_abs'): (0.34, 0.8), }
def setup_data(self, engine): ensure_db(engine) init_engine(engine) ModelGroupFactory(model_group_id=1, model_type="modelType1") ModelGroupFactory(model_group_id=2, model_type="modelType2") ModelGroupFactory(model_group_id=3, model_type="modelType3") ModelGroupFactory(model_group_id=4, model_type="modelType4") ModelGroupFactory(model_group_id=5, model_type="modelType5") session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table", agg_type="worst") distance_table._create() distance_rows = [ # 2014: model group 1 should pass both close and min checks (1, "2014-01-01", "precision@", "100_abs", 0.5, 0.5, 0.0, 0.38), (1, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (1, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 1 should not pass close check (1, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (1, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (1, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (1, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (1, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (1, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 2 should not pass min check (2, "2014-01-01", "precision@", "100_abs", 0.39, 0.5, 0.11, 0.5), (2, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (2, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 2 should pass both checks (2, "2015-01-01", "precision@", "100_abs", 0.69, 0.88, 0.19, 0.12), (2, "2015-01-01", "recall@", "100_abs", 0.69, 0.88, 0.19, 0.0), (2, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (2, "2016-01-01", "precision@", "100_abs", 0.34, 0.46, 0.12, 0.11), (2, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (2, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # model group 3 not included in this round (3, "2014-01-01", "precision@", "100_abs", 0.28, 0.5, 0.22, 0.0), (3, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (3, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (3, "2015-01-01", "precision@", "100_abs", 0.88, 0.88, 0.0, 0.02), (3, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (3, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (3, "2016-01-01", "precision@", "100_abs", 0.44, 0.46, 0.02, 0.11), (3, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (3, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 4 should not pass any checks (4, "2014-01-01", "precision@", "100_abs", 0.29, 0.5, 0.21, 0.21), (4, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (4, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 4 should not pass close check (4, "2015-01-01", "precision@", "100_abs", 0.67, 0.88, 0.21, 0.21), (4, "2015-01-01", "recall@", "100_abs", 0.5, 0.88, 0.38, 0.0), (4, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (4, "2016-01-01", "precision@", "100_abs", 0.25, 0.46, 0.21, 0.21), (4, "2016-01-01", "recall@", "100_abs", 0.46, 0.46, 0.0, 0.11), (4, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 5 should not pass because precision is good but not recall (5, "2014-01-01", "precision@", "100_abs", 0.5, 0.38, 0.0, 0.38), (5, "2014-01-01", "recall@", "100_abs", 0.3, 0.5, 0.2, 0.38), (5, "2014-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2015: model group 5 should not pass because precision is good but not recall (5, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (5, "2015-01-01", "recall@", "100_abs", 0.3, 0.88, 0.58, 0.0), (5, "2015-01-01", "false positives@", "100_abs", 40, 30, 10, 10), (5, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (5, "2016-01-01", "recall@", "100_abs", 0.3, 0.46, 0.16, 0.11), (5, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), # 2014: model group 6 is failed by false positives (6, "2014-01-01", "precision@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2014-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2014-01-01", "false positives@", "100_abs", 60, 30, 30, 10), # 2015: model group 6 is failed by false positives (6, "2015-01-01", "precision@", "100_abs", 0.5, 0.88, 0.38, 0.0), (6, "2015-01-01", "recall@", "100_abs", 0.5, 0.38, 0.0, 0.38), (6, "2015-01-01", "false positives@", "100_abs", 60, 30, 30, 10), (6, "2016-01-01", "precision@", "100_abs", 0.46, 0.46, 0.0, 0.11), (6, "2016-01-01", "recall@", "100_abs", 0.5, 0.5, 0.0, 0.38), (6, "2016-01-01", "false positives@", "100_abs", 40, 30, 10, 10), ] for dist_row in distance_rows: engine.execute( "insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s)", dist_row, ) thresholder = ModelGroupThresholder( distance_from_best_table=distance_table, train_end_times=["2014-01-01", "2015-01-01"], initial_model_group_ids=[1, 2, 4, 5, 6], initial_metric_filters=self.metric_filters, ) return thresholder
def create_sample_distance_table(engine): ensure_db(engine) init_engine(engine) model_groups = { "stable": ModelGroupFactory(model_type="myStableClassifier"), "spiky": ModelGroupFactory(model_type="mySpikeClassifier"), } class StableModelFactory(ModelFactory): model_group_rel = model_groups["stable"] class SpikyModelFactory(ModelFactory): model_group_rel = model_groups["spiky"] models = { "stable_3y_ago": StableModelFactory(train_end_time="2014-01-01"), "stable_2y_ago": StableModelFactory(train_end_time="2015-01-01"), "stable_1y_ago": StableModelFactory(train_end_time="2016-01-01"), "spiky_3y_ago": SpikyModelFactory(train_end_time="2014-01-01"), "spiky_2y_ago": SpikyModelFactory(train_end_time="2015-01-01"), "spiky_1y_ago": SpikyModelFactory(train_end_time="2016-01-01"), } session.commit() distance_table = DistanceFromBestTable(db_engine=engine, models_table="models", distance_table="dist_table") distance_table._create() stable_grp = model_groups["stable"].model_group_id spiky_grp = model_groups["spiky"].model_group_id stable_3y_id = models["stable_3y_ago"].model_id stable_3y_end = models["stable_3y_ago"].train_end_time stable_2y_id = models["stable_2y_ago"].model_id stable_2y_end = models["stable_2y_ago"].train_end_time stable_1y_id = models["stable_1y_ago"].model_id stable_1y_end = models["stable_1y_ago"].train_end_time spiky_3y_id = models["spiky_3y_ago"].model_id spiky_3y_end = models["spiky_3y_ago"].train_end_time spiky_2y_id = models["spiky_2y_ago"].model_id spiky_2y_end = models["spiky_2y_ago"].train_end_time spiky_1y_id = models["spiky_1y_ago"].model_id spiky_1y_end = models["spiky_1y_ago"].train_end_time distance_rows = [ ( stable_grp, stable_3y_id, stable_3y_end, "precision@", "100_abs", 0.5, 0.6, 0.1, 0.5, 0.15, ), ( stable_grp, stable_2y_id, stable_2y_end, "precision@", "100_abs", 0.5, 0.84, 0.34, 0.5, 0.18, ), ( stable_grp, stable_1y_id, stable_1y_end, "precision@", "100_abs", 0.46, 0.67, 0.21, 0.5, 0.11, ), ( spiky_grp, spiky_3y_id, spiky_3y_end, "precision@", "100_abs", 0.45, 0.6, 0.15, 0.5, 0.19, ), ( spiky_grp, spiky_2y_id, spiky_2y_end, "precision@", "100_abs", 0.84, 0.84, 0.0, 0.5, 0.3, ), ( spiky_grp, spiky_1y_id, spiky_1y_end, "precision@", "100_abs", 0.45, 0.67, 0.22, 0.5, 0.12, ), ( stable_grp, stable_3y_id, stable_3y_end, "recall@", "100_abs", 0.4, 0.4, 0.0, 0.4, 0.0, ), ( stable_grp, stable_2y_id, stable_2y_end, "recall@", "100_abs", 0.5, 0.5, 0.0, 0.5, 0.0, ), ( stable_grp, stable_1y_id, stable_1y_end, "recall@", "100_abs", 0.6, 0.6, 0.0, 0.6, 0.0, ), ( spiky_grp, spiky_3y_id, spiky_3y_end, "recall@", "100_abs", 0.65, 0.65, 0.0, 0.65, 0.0, ), ( spiky_grp, spiky_2y_id, spiky_2y_end, "recall@", "100_abs", 0.55, 0.55, 0.0, 0.55, 0.0, ), ( spiky_grp, spiky_1y_id, spiky_1y_end, "recall@", "100_abs", 0.45, 0.45, 0.0, 0.45, 0.0, ), ] for dist_row in distance_rows: engine.execute( "insert into dist_table values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", dist_row, ) return distance_table, model_groups