예제 #1
0
class PastExperiment(object):
    def __init__(self, database_name, experiment_id, uri=None):
        self.database_name = database_name
        self.experiment_id = experiment_id
        self.uri = uri
        self.collection_name = 'runs'

        self.generic_dao = PyMongoDataAccess(
            uri=self.uri,
            database_name=self.database_name,
            collection_name=self.collection_name)
        self.generic_dao.connect()

        self.metrics_dao = self.generic_dao.get_metrics_dao()
        self.run_dao = self.generic_dao.get_run_dao()

        self.run = self.run_dao.get(experiment_id)
        self.info = self.run.get('info')
        self.metrics_info = self.info.get('metrics')

    def get_metric(self, metric_name):
        metric_id = None
        # self.metrics_info is a list of dicts
        for metric_info in self.metrics_info:
            if metric_info.get('name') == metric_name:
                metric_id = metric_info.get('id')
                break
        else:
            raise AttributeError(f"Can't find metric {metric_name}")

        metric = self.metrics_dao.get(self.experiment_id, metric_id)
        return metric

    def get_config(self):
        return self.run['config']
예제 #2
0
def test_get_runs_order(db_gateway: PyMongoDataAccess):
    runs = list(
        db_gateway.get_run_dao().get_runs(sort_by="host.python_version"))
    assert len(runs) == 2
    assert runs[0]["host"]["python_version"] == "3.4.3"
    assert runs[1]["host"]["python_version"] == "3.5.2"

    runs = list(db_gateway.get_run_dao().get_runs(
        sort_by="host.python_version", sort_direction="desc"))
    assert len(runs) == 2
    assert runs[0]["host"]["python_version"] == "3.5.2"
    assert runs[1]["host"]["python_version"] == "3.4.3"
예제 #3
0
    def __init__(self, database_name, experiment_id, uri=None):
        self.database_name = database_name
        self.experiment_id = experiment_id
        self.uri = uri
        self.collection_name = 'runs'

        self.generic_dao = PyMongoDataAccess(
            uri=self.uri,
            database_name=self.database_name,
            collection_name=self.collection_name)
        self.generic_dao.connect()

        self.metrics_dao = self.generic_dao.get_metrics_dao()
        self.run_dao = self.generic_dao.get_run_dao()

        self.run = self.run_dao.get(experiment_id)
        self.info = self.run.get('info')
        self.metrics_info = self.info.get('metrics')
예제 #4
0
def test_get_metrics_dao(db_gateway: PyMongoDataAccess):
    dao = db_gateway.get_metrics_dao()
    assert dao is not None
    assert isinstance(dao, MetricsDAO)
    assert isinstance(dao, MongoMetricsDAO)
    assert dao.generic_dao == db_gateway._generic_dao
예제 #5
0
def test_get_run(db_gateway: PyMongoDataAccess):
    run = dict(db_gateway.get_run_dao().get("57f9efb2e4b8490d19d7c30e"))
    assert run["host"]["hostname"] == "ntbacer"
예제 #6
0
def test_get_runs_filter_or(db_gateway: PyMongoDataAccess):
    filter = {
        "type":
        "and",
        "filters": [{
            "field": "host.hostname",
            "operator": "==",
            "value": "ntbacer"
        }, {
            "type":
            "or",
            "filters": [{
                "field": "result",
                "operator": "==",
                "value": 2403.52
            }, {
                "field": "host.python_version",
                "operator": "==",
                "value": "3.5.2"
            }]
        }]
    }
    runs = list(db_gateway.get_run_dao().get_runs(query=filter))
    assert len(runs) == 1
    assert runs[0]["host"]["hostname"] == "ntbacer"
    assert runs[0]["result"] == 2403.52

    filter = {
        "type":
        "and",
        "filters": [{
            "field": "host.hostname",
            "operator": "==",
            "value": "martin-virtual-machine"
        }, {
            "type":
            "or",
            "filters": [{
                "field": "result",
                "operator": "==",
                "value": 2403.52
            }, {
                "field": "host.python_version",
                "operator": "==",
                "value": "3.5.2"
            }]
        }]
    }
    runs = list(db_gateway.get_run_dao().get_runs(query=filter))
    assert len(runs) == 1
    assert runs[0]["host"]["hostname"] == "martin-virtual-machine"
    assert runs[0]["host"]["python_version"] == "3.5.2"

    filter = {
        "type":
        "and",
        "filters": [{
            "type":
            "or",
            "filters": [{
                "field": "result",
                "operator": "==",
                "value": 2403.52
            }, {
                "field": "host.python_version",
                "operator": "==",
                "value": "3.5.2"
            }]
        }]
    }
    runs = list(db_gateway.get_run_dao().get_runs(query=filter))
    assert len(runs) == 2

    assert runs[0]["host"]["hostname"] == "ntbacer"
    assert runs[0]["host"]["python_version"] == "3.4.3"

    assert runs[1]["host"]["hostname"] == "martin-virtual-machine"
    assert runs[1]["host"]["python_version"] == "3.5.2"
예제 #7
0
def test_get_runs_filter(db_gateway: PyMongoDataAccess, query_filter):
    runs = list(db_gateway.get_run_dao().get_runs(query=query_filter))
    assert len(runs) == 1
    assert runs[0]["host"]["hostname"] == "ntbacer"
예제 #8
0
def test_get_runs_limit(db_gateway: PyMongoDataAccess):
    runs = list(db_gateway.get_run_dao().get_runs(limit=1))
    assert len(runs) == 1
    assert runs[0]["host"]["hostname"] == "ntbacer"
예제 #9
0
def test_get_runs(db_gateway: PyMongoDataAccess):
    runs = list(db_gateway.get_run_dao().get_runs())
    assert len(runs) == 2
    assert runs[0]["host"]["hostname"] == "ntbacer"
    assert runs[1]["host"]["hostname"] == "martin-virtual-machine"
예제 #10
0
def db_gateway() -> PyMongoDataAccess:
    db_gw = PyMongoDataAccess.build_data_access("n/a", 0, "testdb", "runs")
    # Use MongoMockClient with MongoMock
    db_gw._create_client = create_mongomock_client
    db_gw.connect()
    return db_gw