def test_save_classifier(db, datarun, model, metrics): worker = Worker(db, datarun, models_dir=MODEL_DIR, metrics_dir=METRIC_DIR) hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0] classifier = worker.db.start_classifier(hyperpartition_id=hp.id, datarun_id=worker.datarun.id, host='localhost', hyperparameter_values=DT_PARAMS) worker.db.complete_classifier = Mock() worker.save_classifier(classifier.id, model, metrics) worker.db.complete_classifier.assert_called() with DBSession(worker.db): clf = db.get_classifier(classifier.id) loaded = load_model(clf, MODEL_DIR) assert isinstance(loaded, Model) assert loaded.method == model.method assert loaded.random_state == model.random_state assert load_metrics(clf, METRIC_DIR) == metrics
def test_save_classifier(db, datarun, model, metrics): log_conf = LogConfig(model_dir=MODEL_DIR, metric_dir=METRIC_DIR) worker = Worker(db, datarun, log_config=log_conf) hp = db.get_hyperpartitions(datarun_id=worker.datarun.id)[0] classifier = worker.db.start_classifier(hyperpartition_id=hp.id, datarun_id=worker.datarun.id, host='localhost', hyperparameter_values=DT_PARAMS) worker.db.complete_classifier = Mock() worker.save_classifier(classifier.id, model, metrics) worker.db.complete_classifier.assert_called() with db_session(worker.db): clf = db.get_classifier(classifier.id) loaded = load_model(clf, MODEL_DIR) assert type(loaded) == Model assert loaded.method == model.method assert loaded.random_state == model.random_state assert load_metrics(clf, METRIC_DIR) == metrics
help='Only train on dataruns with these ids', nargs='+') parser.add_argument('--time', help='Number of seconds to run worker', type=int) parser.add_argument( '--choose-randomly', action='store_true', help='Choose dataruns to work on randomly (default = sequential order)') parser.add_argument('--no-save', dest='save_files', default=True, action='store_const', const=False, help="don't save models and metrics at all") # parse arguments and load configuration args = parser.parse_args() sql_config, _, aws_config, log_config = load_config(**vars(args)) db = Database(**vars(sql_config)) with db_session(db): # keep a database session open to access the dataruns ## get all the classifier in the dataset classifiers = db.get_classifiers() ## or ## get one classifier by the classifier ID # classifier = db.get_classifier(classifier_id) print("total {} classifiers".format(len(classifiers))) for classifier in classifiers: metrics = load_metrics(classifier, metric_dir="./metrics") print(metrics)