Exemplo n.º 1
0
def test_get_description(experiment, some_interdeps):

    ds = DataSet()

    assert ds.run_id == 1

    desc = ds.description
    assert desc == RunDescriber(InterDependencies_())

    ds.set_interdependencies(some_interdeps[1])

    assert ds.description.interdeps == some_interdeps[1]

    # the run description gets written as the dataset is marked as started,
    # so now no description should be stored in the database
    prematurely_loaded_ds = DataSet(run_id=1)
    assert prematurely_loaded_ds.description == RunDescriber(
        InterDependencies_())

    ds.mark_started()

    loaded_ds = DataSet(run_id=1)

    expected_desc = RunDescriber(some_interdeps[1])

    assert loaded_ds.description == expected_desc
Exemplo n.º 2
0
def generate_DB_file_with_empty_runs():
    """
    Generate a DB file that holds empty runs and runs with no interdependencies
    """

    v2fixturepath = os.path.join(fixturepath, 'version2')
    os.makedirs(v2fixturepath, exist_ok=True)
    path = os.path.join(v2fixturepath, 'empty_runs.db')

    if os.path.exists(path):
        os.remove(path)

    from qcodes.dataset.sqlite_base import connect
    from qcodes.dataset.measurements import Measurement
    from qcodes.dataset.experiment_container import Experiment
    from qcodes import Parameter
    from qcodes.dataset.data_set import DataSet

    conn = connect(path)
    exp = Experiment(path)
    exp._new(name='experiment_1', sample_name='no_sample_1')

    # Now make some parameters to use in measurements
    params = []
    for n in range(5):
        params.append(Parameter(f'p{n}', label=f'Parameter {n}',
                                unit=f'unit {n}', set_cmd=None, get_cmd=None))

    # truly empty run, no layouts table, no nothing
    dataset = DataSet(path, conn=conn)
    dataset._new('empty_dataset', exp_id=exp.exp_id)

    # empty run
    meas = Measurement(exp)
    with meas.run() as datasaver:
        pass

    # run with no interdeps
    meas = Measurement(exp)
    for param in params:
        meas.register_parameter(param)

    with meas.run() as datasaver:
        pass

    with meas.run() as datasaver:
        for _ in range(10):
            res = tuple((p, 0.0) for p in params)
            datasaver.add_result(*res)
Exemplo n.º 3
0
def test_foreground_after_background_raises(empty_temp_db_connection):
    new_experiment("test", "test1", conn=empty_temp_db_connection)
    ds1 = DataSet(conn=empty_temp_db_connection)
    ds1.mark_started(start_bg_writer=True)

    ds2 = DataSet(conn=empty_temp_db_connection)
    with pytest.raises(RuntimeError, match="All datasets written"):
        ds2.mark_started(start_bg_writer=False)
Exemplo n.º 4
0
def test_old_versions_not_touched(two_empty_temp_db_connections,
                                  some_paramspecs):

    source_conn, target_conn = two_empty_temp_db_connections

    target_path = path_to_dbfile(target_conn)
    source_path = path_to_dbfile(source_conn)

    fixturepath = os.sep.join(qcodes.tests.dataset.__file__.split(os.sep)[:-1])
    fixturepath = os.path.join(fixturepath, 'fixtures', 'db_files', 'version2',
                               'some_runs.db')
    if not os.path.exists(fixturepath):
        pytest.skip("No db-file fixtures found. You can generate test db-files"
                    " using the scripts in the legacy_DB_generation folder")

    # First test that we can not use an old version as source

    with raise_if_file_changed(fixturepath):
        with pytest.warns(UserWarning) as warning:
            extract_runs_into_db(fixturepath, target_path, 1)
            expected_mssg = ('Source DB version is 2, but this '
                             'function needs it to be in version 4. '
                             'Run this function again with '
                             'upgrade_source_db=True to auto-upgrade '
                             'the source DB file.')
            assert warning[0].message.args[0] == expected_mssg

    # Then test that we can not use an old version as target

    # first create a run in the new version source
    source_exp = Experiment(conn=source_conn)
    source_ds = DataSet(conn=source_conn, exp_id=source_exp.exp_id)

    for ps in some_paramspecs[2].values():
        source_ds.add_parameter(ps)
    source_ds.mark_started()
    source_ds.add_result({ps.name: 0.0 for ps in some_paramspecs[2].values()})
    source_ds.mark_completed()

    with raise_if_file_changed(fixturepath):
        with pytest.warns(UserWarning) as warning:
            extract_runs_into_db(source_path, fixturepath, 1)
            expected_mssg = ('Target DB version is 2, but this '
                             'function needs it to be in version 4. '
                             'Run this function again with '
                             'upgrade_target_db=True to auto-upgrade '
                             'the target DB file.')
            assert warning[0].message.args[0] == expected_mssg
Exemplo n.º 5
0
def test_parent_dataset_links_invalid_input():
    """
    Test that invalid input is rejected
    """
    links = generate_some_links(3)

    ds = DataSet()

    match = re.escape('Invalid input. Did not receive a list of Links')
    with pytest.raises(ValueError, match=match):
        ds.parent_dataset_links = [ds.guid]

    match = re.escape('Invalid input. All links must point to this dataset. '
                      'Got link(s) with head(s) pointing to another dataset.')
    with pytest.raises(ValueError, match=match):
        ds.parent_dataset_links = links
Exemplo n.º 6
0
def test_create_dataset_pass_both_connection_and_path_to_db(experiment):
    with pytest.raises(ValueError,
                       match="Received BOTH conn and path_to_db. "
                       "Please provide only one or "
                       "the other."):
        some_valid_connection = experiment.conn
        _ = DataSet(path_to_db="some valid path", conn=some_valid_connection)
Exemplo n.º 7
0
def test_integration_station_and_measurement(two_empty_temp_db_connections,
                                             inst):
    """
    An integration test where the runs in the source DB file are produced
    with the Measurement object and there is a Station as well
    """
    source_conn, target_conn = two_empty_temp_db_connections
    source_path = path_to_dbfile(source_conn)
    target_path = path_to_dbfile(target_conn)

    source_exp = Experiment(conn=source_conn)

    # Set up measurement scenario
    station = Station(inst)

    meas = Measurement(exp=source_exp, station=station)
    meas.register_parameter(inst.back)
    meas.register_parameter(inst.plunger)
    meas.register_parameter(inst.cutter, setpoints=(inst.back, inst.plunger))

    with meas.run() as datasaver:
        for back_v in [1, 2, 3]:
            for plung_v in [-3, -2.5, 0]:
                datasaver.add_result((inst.back, back_v),
                                     (inst.plunger, plung_v),
                                     (inst.cutter, back_v+plung_v))

    extract_runs_into_db(source_path, target_path, 1)

    target_ds = DataSet(conn=target_conn, run_id=1)

    assert datasaver.dataset.the_same_dataset_as(target_ds)
Exemplo n.º 8
0
def _get_timestamp_button(ds: DataSet) -> Box:
    try:
        total_time = str(
            datetime.fromtimestamp(ds.run_timestamp_raw)  # type: ignore
            -
            datetime.fromtimestamp(ds.completed_timestamp_raw)  # type: ignore
        )
    except TypeError:
        total_time = "?"
    start = ds.run_timestamp()
    body = _yaml_dump({
        ".run_timestamp": start,
        ".completed_timestamp": ds.completed_timestamp(),
        "total_time": total_time,
    })
    return button_to_text(start or "", body)
Exemplo n.º 9
0
def test_create_dataset_pass_both_connection_and_path_to_db(experiment):
    with pytest.raises(ValueError, match="Both `path_to_db` and `conn` "
                                         "arguments have been passed together "
                                         "with non-None values. This is not "
                                         "allowed."):
        some_valid_connection = experiment.conn
        _ = DataSet(path_to_db="some valid path", conn=some_valid_connection)
Exemplo n.º 10
0
def test_mark_complete_is_deprecated_and_marks_as_completed(experiment):
    """Test that the deprecated `mark_complete` calls `mark_completed`"""
    ds = DataSet()

    with patch.object(ds, 'mark_completed', autospec=True) as mark_completed:
        pytest.deprecated_call(ds.mark_complete)
        mark_completed.assert_called_once()
Exemplo n.º 11
0
    def __init__(self, dataset: DataSet, write_period: numeric_types,
                 parameters: Dict[str, ParamSpec]) -> None:
        self._dataset = dataset
        if DataSaver.default_callback is not None and 'run_tables_subscription_callback' in DataSaver.default_callback:
            callback = DataSaver.default_callback[
                'run_tables_subscription_callback']
            min_wait = DataSaver.default_callback[
                'run_tables_subscription_min_wait']
            min_count = DataSaver.default_callback[
                'run_tables_subscription_min_count']
            snapshot = dataset.get_metadata('snapshot')
            self._dataset.subscribe(callback, min_wait=min_wait,
                                    min_count=min_count,
                                    state={},
                                    callback_kwargs={'run_id':
                                                         self._dataset.run_id,
                                                     'snapshot': snapshot})

        self.write_period = float(write_period)
        self.parameters = parameters
        self._known_parameters = list(parameters.keys())
        self._results: List[dict] = []  # will be filled by addResult
        self._last_save_time = monotonic()
        self._known_dependencies: Dict[str, List[str]] = {}
        for param, parspec in parameters.items():
            if parspec.depends_on != '':
                self._known_dependencies.update(
                    {str(param): parspec.depends_on.split(', ')})
Exemplo n.º 12
0
    def __init__(self, dataset: DataSet, write_period: numeric_types,
                 interdeps: InterDependencies_) -> None:
        self._dataset = dataset
        if DataSaver.default_callback is not None \
                and 'run_tables_subscription_callback' \
                    in DataSaver.default_callback:
            callback = DataSaver.default_callback[
                'run_tables_subscription_callback']
            min_wait = DataSaver.default_callback[
                'run_tables_subscription_min_wait']
            min_count = DataSaver.default_callback[
                'run_tables_subscription_min_count']
            snapshot = dataset.get_metadata('snapshot')
            self._dataset.subscribe(callback,
                                    min_wait=min_wait,
                                    min_count=min_count,
                                    state={},
                                    callback_kwargs={
                                        'run_id': self._dataset.run_id,
                                        'snapshot': snapshot
                                    })
        default_subscribers = qcodes.config.subscription.default_subscribers
        for subscriber in default_subscribers:
            self._dataset.subscribe_from_config(subscriber)

        self._interdeps = interdeps
        self.write_period = float(write_period)
        # self._results will be filled by add_result
        self._results: List[Dict[str, VALUE]] = []
        self._last_save_time = perf_counter()
        self._known_dependencies: Dict[str, List[str]] = {}
        self.parent_datasets: List[DataSet] = []

        for link in self._dataset.parent_dataset_links:
            self.parent_datasets.append(load_by_guid(link.tail))
Exemplo n.º 13
0
def test_load_by_guid(some_interdeps):
    ds = DataSet()
    ds.set_interdependencies(some_interdeps[1])
    ds.mark_started()
    ds.add_results([{'ps1': 1, 'ps2': 2}])

    loaded_ds = load_by_guid(ds.guid)

    assert loaded_ds.the_same_dataset_as(ds)
Exemplo n.º 14
0
def do_experiment(experiment_name,
                  sweep_object,
                  setup=None,
                  cleanup=None,
                  station=None,
                  live_plot=False):

    if "/" in experiment_name:
        experiment_name, sample_name = experiment_name.split("/")
    else:
        sample_name = None

    try:
        experiment = load_experiment_by_name(experiment_name, sample_name)
    except ValueError:  # experiment does not exist yet
        db_location = qcodes.config["core"]["db_location"]
        DataSet(db_location)
        experiment = new_experiment(experiment_name, sample_name)

    def add_actions(action, callables):
        if callables is None:
            return

        for cabble in np.atleast_1d(callables):
            if not isinstance(cabble, tuple):
                cabble = (cabble, ())

            action(*cabble)

    if live_plot:
        try:
            from plottr.qcodes_dataset import QcodesDatasetSubscriber
            from plottr.tools import start_listener

            start_listener()

        except ImportError:
            warn("Cannot perform live plots, plottr not installed")
            live_plot = False

    meas = SweepMeasurement(exp=experiment, station=station)
    meas.register_sweep(sweep_object)

    add_actions(meas.add_before_run, setup)
    add_actions(meas.add_after_run, cleanup)

    with meas.run() as datasaver:

        if live_plot:
            datasaver.dataset.subscribe(QcodesDatasetSubscriber(
                datasaver.dataset),
                                        state=[],
                                        min_wait=0,
                                        min_count=1)

        for data in sweep_object:
            datasaver.add_result(*data.items())

    return _DataExtractor(datasaver)
Exemplo n.º 15
0
def get_ds_info_from_path(path: str, run_id: int, get_structure: bool = True):
    """
    Convenience function that determines the dataset from `path` and
    `run_id`, then calls `get_ds_info`.
    """

    ds = DataSet(path_to_db=path, run_id=run_id)
    return get_ds_info(ds.conn, run_id, get_structure=get_structure)
Exemplo n.º 16
0
def test_dataset_length():

    path_to_db = get_DB_location()
    ds = DataSet(path_to_db, run_id=None)

    assert len(ds) == 0

    parameter = ParamSpecBase(name='single', paramtype='numeric',
                              label='', unit='N/A')
    idps = InterDependencies_(standalones=(parameter,))
    ds.set_interdependencies(idps)

    ds.mark_started()
    ds.add_results([{parameter.name: 1}])
    ds.mark_completed()

    assert len(ds) == 1
Exemplo n.º 17
0
 def process(self, **kw):
     if not None in self._pathAndId:
         path, runId = self._pathAndId
         ds = DataSet(path_to_db=path, run_id=runId)
         if ds.number_of_results > self.nLoadedRecords:
             data = ds_to_datadict(ds)
             self.nLoadedRecords = ds.number_of_results
             return dict(dataOut=data)
Exemplo n.º 18
0
def test_metadata(experiment, request):

    metadata1 = {'number': 1, "string": "Once upon a time..."}
    metadata2 = {'more': 'meta'}

    ds1 = DataSet(metadata=metadata1)
    request.addfinalizer(ds1.conn.close)
    ds2 = DataSet(metadata=metadata2)
    request.addfinalizer(ds2.conn.close)

    assert ds1.run_id == 1
    assert ds1.metadata == metadata1
    assert ds2.run_id == 2
    assert ds2.metadata == metadata2

    loaded_ds1 = DataSet(run_id=1)
    request.addfinalizer(loaded_ds1.conn.close)
    assert loaded_ds1.metadata == metadata1
    loaded_ds2 = DataSet(run_id=2)
    request.addfinalizer(loaded_ds2.conn.close)
    assert loaded_ds2.metadata == metadata2

    badtag = 'lex luthor'
    sorry_metadata = {'superman': 1, badtag: None, 'spiderman': 'two'}

    bad_tag_msg = (f'Tag {badtag} has value None. '
                    ' That is not a valid metadata value!')

    with pytest.raises(RuntimeError,
                       match='Rolling back due to unhandled exception') as e:
        for tag, value in sorry_metadata.items():
            ds1.add_metadata(tag, value)

    assert error_caused_by(e, bad_tag_msg)
Exemplo n.º 19
0
def test_load_by_X_functions(two_empty_temp_db_connections, some_interdeps):
    """
    Test some different loading functions
    """
    source_conn, target_conn = two_empty_temp_db_connections

    source_path = path_to_dbfile(source_conn)
    target_path = path_to_dbfile(target_conn)

    source_exp1 = Experiment(conn=source_conn)
    source_ds_1_1 = DataSet(conn=source_conn, exp_id=source_exp1.exp_id)

    source_exp2 = Experiment(conn=source_conn)
    source_ds_2_1 = DataSet(conn=source_conn, exp_id=source_exp2.exp_id)

    source_ds_2_2 = DataSet(conn=source_conn,
                            exp_id=source_exp2.exp_id,
                            name="customname")

    for ds in (source_ds_1_1, source_ds_2_1, source_ds_2_2):
        ds.set_interdependencies(some_interdeps[1])
        ds.mark_started()
        ds.add_result({name: 0.0 for name in some_interdeps[1].names})
        ds.mark_completed()

    extract_runs_into_db(source_path, target_path, source_ds_2_2.run_id)

    test_ds = load_by_guid(source_ds_2_2.guid, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)

    test_ds = load_by_id(1, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)

    test_ds = load_by_counter(1, 1, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)
Exemplo n.º 20
0
def test_load_by_guid(some_paramspecs):
    paramspecs = some_paramspecs[2]
    ds = DataSet()
    ds.add_parameter(paramspecs['ps1'])
    ds.add_parameter(paramspecs['ps2'])
    ds.mark_started()
    ds.add_result({'ps1': 1, 'ps2': 2})

    loaded_ds = load_by_guid(ds.guid)

    assert loaded_ds.the_same_dataset_as(ds)
Exemplo n.º 21
0
def datadict_from_path_and_run_id(path: str, run_id: int) -> DataDictBase:
    """
    Load a qcodes dataset as a DataDict.

    :param path: file path of the qcodes .db file.
    :param run_id: run_id of the dataset.
    :return: DataDict containing the data.
    """
    ds = DataSet(path_to_db=path, run_id=run_id)
    return ds_to_datadict(ds)
Exemplo n.º 22
0
def test_dataset_location(empty_temp_db_connection):
    """
    Test that an dataset and experiment points to the correct db file when
    a connection is supplied.
    """
    exp = new_experiment("test", "test1", conn=empty_temp_db_connection)
    ds = DataSet(conn=empty_temp_db_connection)
    assert path_to_dbfile(empty_temp_db_connection) == \
           empty_temp_db_connection.path_to_dbfile
    assert exp.path_to_db == empty_temp_db_connection.path_to_dbfile
    assert ds.path_to_db == empty_temp_db_connection.path_to_dbfile
Exemplo n.º 23
0
def test_integer_timestamps_in_database_are_supported():
    ds = DataSet()

    ds.mark_started()
    ds.mark_completed()

    with atomic(ds.conn) as conn:
        _rewrite_timestamps(conn, ds.run_id, 42, 69)

    assert isinstance(ds.run_timestamp_raw, float)
    assert isinstance(ds.completed_timestamp_raw, float)
    assert isinstance(ds.run_timestamp(), str)
    assert isinstance(ds.completed_timestamp(), str)
Exemplo n.º 24
0
def make_shadow_dataset(dataset: DataSet):
    """
    Creates a new DataSet object that points to the same run_id in the same
    database file as the given dataset object.

    Note that in order to achieve it `path_to_db` because this will create a
    new sqlite3 connection object behind the scenes. This is very useful for
    situations where one needs to assert the underlying modifications to the
    database file.
    """
    return DataSet(path_to_db=dataset.path_to_db, run_id=dataset.run_id)
def test_experiments_with_NULL_sample_name(two_empty_temp_db_connections,
                                           some_paramspecs):
    """
    In older API versions (corresponding to DB version 3),
    users could get away with setting the sample name to None

    This test checks that such an experiment gets correctly recognised and
    is thus not ever re-inserted into the target DB
    """
    source_conn, target_conn = two_empty_temp_db_connections
    source_exp_1 = Experiment(conn=source_conn, name='null_sample_name')

    source_path = path_to_dbfile(source_conn)
    target_path = path_to_dbfile(target_conn)

    # make 5 runs in experiment

    exp_1_run_ids = []
    for _ in range(5):

        source_dataset = DataSet(conn=source_conn, exp_id=source_exp_1.exp_id)
        exp_1_run_ids.append(source_dataset.run_id)

        for ps in some_paramspecs[2].values():
            source_dataset.add_parameter(ps)

        for val in range(10):
            source_dataset.add_result(
                {ps.name: val
                 for ps in some_paramspecs[2].values()})
        source_dataset.mark_complete()

    sql = """
          UPDATE experiments
          SET sample_name = NULL
          WHERE exp_id = 1
          """
    source_conn.execute(sql)
    source_conn.commit()

    assert source_exp_1.sample_name is None

    extract_runs_into_db(source_path, target_path, 1, 2, 3, 4, 5)

    assert len(get_experiments(target_conn)) == 1

    extract_runs_into_db(source_path, target_path, 1, 2, 3, 4, 5)

    assert len(get_experiments(target_conn)) == 1

    assert len(Experiment(exp_id=1, conn=target_conn)) == 5
Exemplo n.º 26
0
def get_data_from_ds(ds: DataSet,
                     start: Optional[int] = None,
                     end: Optional[int] = None) -> Dict[str, List[List]]:
    """
    Returns a dictionary in the format {'name' : data}, where data
    is what dataset.get_data('name') returns, i.e., a list of lists, where
    the inner list is the row as inserted into the DB.

    with `start` and `end` only a subset of rows in the DB can be specified.
    """
    names = [n for n, v in ds.paramspecs.items()]
    return {n: np.squeeze(ds.get_data(n, start=start, end=end)) for n in names}
Exemplo n.º 27
0
def parameter_test_helper(ds: DataSet,
                          toplevel_names: Sequence[str],
                          expected_names: Dict[str, Sequence[str]],
                          expected_shapes: Dict[str, Sequence[Tuple[int, ...]]],
                          expected_values: Dict[str, Sequence[np.ndarray]],
                          start: Optional[int] = None,
                          end: Optional[int] = None):
    """
    A helper function to compare the data we actually read out of a given
    dataset with the expected data.

    Args:
        ds: the dataset in question
        toplevel_names: names of the toplevel parameters of the dataset
        expected_names: names of the parameters expected to be loaded for a
            given parameter as a sequence indexed by the parameter.
        expected_shapes: expected shapes of the parameters loaded. The shapes
            should be stored as a tuple per parameter in a sequence containing
            all the loaded parameters for a given requested parameter.
        expected_values: expected content of the data arrays stored in a
            sequenceexpected_names:

    """

    data = ds.get_parameter_data(*toplevel_names, start=start, end=end)
    dataframe = ds.get_data_as_pandas_dataframe(*toplevel_names,
                                                start=start,
                                                end=end)

    all_data = ds.get_parameter_data(start=start, end=end)
    all_dataframe = ds.get_data_as_pandas_dataframe(start=start, end=end)

    all_parameters = list(all_data.keys())
    assert set(data.keys()).issubset(set(all_parameters))
    assert list(data.keys()) == list(dataframe.keys())
    assert len(data.keys()) == len(toplevel_names)
    assert len(dataframe.keys()) == len(toplevel_names)

    verify_data_dict(data, dataframe, toplevel_names, expected_names,
                     expected_shapes, expected_values)
    verify_data_dict(all_data, all_dataframe, toplevel_names, expected_names,
                     expected_shapes, expected_values)

    # Now lets remove a random element from the list
    # We do this one by one until there is only one element in the list
    subset_names = copy(all_parameters)
    while len(subset_names) > 1:
        elem_to_remove = random.randint(0, len(subset_names) - 1)
        name_removed = subset_names.pop(elem_to_remove)
        expected_names.pop(name_removed)
        expected_shapes.pop(name_removed)
        expected_values.pop(name_removed)

        subset_data = ds.get_parameter_data(*subset_names,
                                            start=start, end=end)
        subset_dataframe = ds.get_data_as_pandas_dataframe(*subset_names,
                                                           start=start,
                                                           end=end)
        verify_data_dict(subset_data, subset_dataframe, subset_names,
                         expected_names, expected_shapes, expected_values)
Exemplo n.º 28
0
def test_missing_runs_raises(two_empty_temp_db_connections, some_paramspecs):
    """
    Test that an error is raised if we attempt to extract a run not present in
    the source DB
    """
    source_conn, target_conn = two_empty_temp_db_connections

    source_exp_1 = Experiment(conn=source_conn)

    # make 5 runs in first experiment

    exp_1_run_ids = []
    for _ in range(5):

        source_dataset = DataSet(conn=source_conn, exp_id=source_exp_1.exp_id)
        exp_1_run_ids.append(source_dataset.run_id)

        for ps in some_paramspecs[2].values():
            source_dataset.add_parameter(ps)

        source_dataset.mark_started()

        for val in range(10):
            source_dataset.add_result(
                {ps.name: val
                 for ps in some_paramspecs[2].values()})
        source_dataset.mark_completed()

    source_path = path_to_dbfile(source_conn)
    target_path = path_to_dbfile(target_conn)

    run_ids = [1, 8, 5, 3, 2, 4, 4, 4, 7, 8]
    wrong_ids = [8, 7, 8]

    expected_err = ("Error: not all run_ids exist in the source database. "
                    "The following run(s) is/are not present: "
                    f"{wrong_ids}")

    with pytest.raises(ValueError, match=re.escape(expected_err)):
        extract_runs_into_db(source_path, target_path, *run_ids)
Exemplo n.º 29
0
def test_load_by_X_functions(two_empty_temp_db_connections,
                             some_interdeps):
    """
    Test some different loading functions
    """
    source_conn, target_conn = two_empty_temp_db_connections

    source_path = path_to_dbfile(source_conn)
    target_path = path_to_dbfile(target_conn)

    source_exp1 = Experiment(conn=source_conn)
    source_ds_1_1 = DataSet(conn=source_conn, exp_id=source_exp1.exp_id)

    source_exp2 = Experiment(conn=source_conn)
    source_ds_2_1 = DataSet(conn=source_conn, exp_id=source_exp2.exp_id)

    source_ds_2_2 = DataSet(conn=source_conn,
                            exp_id=source_exp2.exp_id,
                            name="customname")

    for ds in (source_ds_1_1, source_ds_2_1, source_ds_2_2):
        ds.set_interdependencies(some_interdeps[1])
        ds.mark_started()
        ds.add_results([{name: 0.0 for name in some_interdeps[1].names}])
        ds.mark_completed()

    extract_runs_into_db(source_path, target_path, source_ds_2_2.run_id)
    extract_runs_into_db(source_path, target_path, source_ds_2_1.run_id)
    extract_runs_into_db(source_path, target_path, source_ds_1_1.run_id)

    test_ds = load_by_guid(source_ds_2_2.guid, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)

    test_ds = load_by_id(1, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)

    test_ds = load_by_run_spec(captured_run_id=source_ds_2_2.captured_run_id,
                               conn=target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)

    assert source_exp2.exp_id == 2

    # this is now the first run in the db so run_id is 1
    target_run_id = 1
    # and the experiment ids will be interchanged.
    target_exp_id = 1

    test_ds = load_by_counter(target_run_id, target_exp_id, target_conn)
    assert source_ds_2_2.the_same_dataset_as(test_ds)
Exemplo n.º 30
0
    def data_set(self, counter: int) -> DataSet:
        """
        Get dataset with the specified counter from this experiment

        Args:
            counter: the counter of the run we want to load

        Returns:
            the dataset
        """
        run_id = get_runid_from_expid_and_counter(self.conn, self.exp_id,
                                                  counter)
        return DataSet(run_id=run_id, conn=self.conn)