def test_two_nested_atomics(): sqlite_conn = sqlite3.connect(':memory:') conn_plus = ConnectionPlus(sqlite_conn) atomic_in_progress = conn_plus.atomic_in_progress isolation_level = conn_plus.isolation_level assert False is conn_plus.in_transaction with atomic(conn_plus) as atomic_conn_1: assert conn_plus_in_transaction(conn_plus) assert conn_plus_in_transaction(atomic_conn_1) with atomic(atomic_conn_1) as atomic_conn_2: assert conn_plus_in_transaction(conn_plus) assert conn_plus_in_transaction(atomic_conn_1) assert conn_plus_in_transaction(atomic_conn_2) assert conn_plus_in_transaction(conn_plus) assert conn_plus_in_transaction(atomic_conn_1) assert conn_plus_in_transaction(atomic_conn_2) assert conn_plus_is_idle(conn_plus, isolation_level) assert conn_plus_is_idle(atomic_conn_1, isolation_level) assert conn_plus_is_idle(atomic_conn_2, isolation_level) assert atomic_in_progress == conn_plus.atomic_in_progress assert atomic_in_progress == atomic_conn_1.atomic_in_progress assert atomic_in_progress == atomic_conn_2.atomic_in_progress
def test_atomic(): sqlite_conn = sqlite3.connect(':memory:') match_str = re.escape('atomic context manager only accepts ConnectionPlus ' 'database connection objects.') with pytest.raises(ValueError, match=match_str): with atomic(sqlite_conn): pass conn_plus = ConnectionPlus(sqlite_conn) assert False is conn_plus.atomic_in_progress atomic_in_progress = conn_plus.atomic_in_progress isolation_level = conn_plus.isolation_level assert False is conn_plus.in_transaction with atomic(conn_plus) as atomic_conn: assert conn_plus_in_transaction(atomic_conn) assert conn_plus_in_transaction(conn_plus) assert isolation_level == conn_plus.isolation_level assert False is conn_plus.in_transaction assert atomic_in_progress is conn_plus.atomic_in_progress assert isolation_level == conn_plus.isolation_level assert False is atomic_conn.in_transaction assert atomic_in_progress is atomic_conn.atomic_in_progress
def add_parameter_values(self, spec: ParamSpec, values: VALUES): """ Add a parameter to the DataSet and associates result values with the new parameter. Adds a parameter to the DataSet and associates result values with the new parameter. If the DataSet is not empty, then the count of provided values must equal the current count of results in the DataSet, or an error will result. It is an error to add parameters to a completed DataSet. # TODO: fix type checking """ # first check that the len of values (if dataset is not empty) # is the right size i.e. the same as the dataset if len(self) > 0: if len(values) != len(self): raise ValueError("Need to have {} values but got {}.".format( len(self), len(values) )) with atomic(self.conn) as self.conn: add_parameter(self.conn, self.table_name, spec) # now add values! results = [{spec.name: value} for value in values] self.add_results(results)
def modify_result(self, index: int, results: Dict[str, VALUES]) -> None: """ Modify a logically single result of existing parameters Args: - index: zero-based index of the result to be modified. - results: dictionary of updates with name of a parameter as the key and the value to associate as the value. It is an error to modify a result at an index less than zero or beyond the end of the DataSet. It is an error to provide a value for a key or keyword that is not the name of a parameter in this DataSet. It is an error to modify a result in a completed DataSet. """ if self.completed: raise CompletedError for param in results.keys(): if param not in self.paramspecs.keys(): raise ValueError(f'No such parameter: {param}.') with atomic(self.conn) as self.conn: modify_values(self.conn, self.table_name, index, list(results.keys()), list(results.values()) )
def test_atomic_on_connection_plus_that_is_in_progress(in_transaction): sqlite_conn = sqlite3.connect(':memory:') conn_plus = ConnectionPlus(sqlite_conn) # explicitly set to True for testing purposes conn_plus.atomic_in_progress = True # implement parametrizing over connection's `in_transaction` attribute if in_transaction: conn_plus.cursor().execute('BEGIN') assert in_transaction is conn_plus.in_transaction isolation_level = conn_plus.isolation_level in_transaction = conn_plus.in_transaction with atomic(conn_plus) as atomic_conn: assert True is conn_plus.atomic_in_progress assert isolation_level == conn_plus.isolation_level assert in_transaction is conn_plus.in_transaction assert True is atomic_conn.atomic_in_progress assert isolation_level == atomic_conn.isolation_level assert in_transaction is atomic_conn.in_transaction assert True is conn_plus.atomic_in_progress assert isolation_level == conn_plus.isolation_level assert in_transaction is conn_plus.in_transaction assert True is atomic_conn.atomic_in_progress assert isolation_level == atomic_conn.isolation_level assert in_transaction is atomic_conn.in_transaction
def modify_results(self, start_index: int, updates: List[Dict[str, VALUES]]): """ Modify a sequence of results in the DataSet. Args: - index: zero-based index of the result to be modified. - results: sequence of dictionares of updates with name of a parameter as the key and the value to associate as the value. It is an error to modify a result at an index less than zero or beyond the end of the DataSet. It is an error to provide a value for a key or keyword that is not the name of a parameter in this DataSet. It is an error to modify a result in a completed DataSet. """ if self.completed: raise CompletedError keys = [list(val.keys()) for val in updates] flattened_keys = [item for sublist in keys for item in sublist] mod_params = set(flattened_keys) old_params = set(self.paramspecs.keys()) if not mod_params.issubset(old_params): raise ValueError('Can not modify values for parameter(s) ' f'{mod_params.difference(old_params)}, ' 'no such parameter(s) in the dataset.') values = [list(val.values()) for val in updates] flattened_values = [item for sublist in values for item in sublist] with atomic(self.conn): modify_many_values(self.conn, self.table_name, start_index, flattened_keys, flattened_values)
def unsubscribe(self, uuid: str) -> None: """ Remove subscriber with the provided uuid """ with atomic(self.conn) as self.conn: self._remove_trigger(uuid) sub = self.subscribers[uuid] sub.schedule_stop() sub.join() del self.subscribers[uuid]
def test_atomic_raises(experiment): conn = experiment.conn bad_sql = '""' # it seems that the type of error raised differs between python versions # 3.6.0 (OperationalError) and 3.6.3 (RuntimeError) # -strange, huh? with pytest.raises((OperationalError, RuntimeError)): with mut.atomic(conn): mut.transaction(conn, bad_sql)
def test_atomic_on_outmost_connection_that_is_in_transaction(): conn = ConnectionPlus(sqlite3.connect(':memory:')) conn.execute('BEGIN') assert True is conn.in_transaction match_str = re.escape('SQLite connection has uncommitted transactions. ' 'Please commit those before starting an atomic ' 'transaction.') with pytest.raises(RuntimeError, match=match_str): with atomic(conn): pass
def unsubscribe_all(self): """ Remove all subscribers """ sql = "select * from sqlite_master where type = 'trigger';" triggers = atomic_transaction(self.conn, sql).fetchall() with atomic(self.conn) as self.conn: for trigger in triggers: self._remove_trigger(trigger['name']) for sub in self.subscribers.values(): sub.schedule_stop() sub.join() self.subscribers.clear()
def fix_version_4a_run_description_bug(conn: ConnectionPlus) -> Dict[str, int]: """ Fix function to fix a bug where the RunDescriber accidentally wrote itself to string using the (new) InterDependencies_ object instead of the (old) InterDependencies object. After the first run, this function should be idempotent. Args: conn: the connection to the database Returns: A dict with the fix results ('runs_inspected', 'runs_fixed') """ user_version = get_user_version(conn) if not user_version == 4: raise RuntimeError('Database of wrong version. Will not apply fix. ' 'Expected version 4, found version {user_version}') no_of_runs_query = "SELECT max(run_id) FROM runs" no_of_runs = one(atomic_transaction(conn, no_of_runs_query), 'max(run_id)') no_of_runs = no_of_runs or 0 with atomic(conn) as conn: pbar = tqdm(range(1, no_of_runs + 1)) pbar.set_description("Fixing database") # collect some metrics runs_inspected = 0 runs_fixed = 0 for run_id in pbar: desc_str = get_run_description(conn, run_id) desc_ser = json.loads(desc_str) idps_ser = desc_ser['interdependencies'] if RunDescriber._is_description_old_style(idps_ser): pass else: new_desc = RunDescriber.from_json(desc_str) update_run_description(conn, run_id, new_desc.to_json()) runs_fixed += 1 runs_inspected += 1 return {'runs_inspected': runs_inspected, 'runs_fixed': runs_fixed}
def add_metadata(self, tag: str, metadata: Any): """ Adds metadata to the DataSet. The metadata is stored under the provided tag. Note that None is not allowed as a metadata value. Args: tag: represents the key in the metadata dictionary metadata: actual metadata """ self._metadata[tag] = metadata # `add_meta_data` is not atomic by itself, hence using `atomic` with atomic(self.conn) as conn: add_meta_data(conn, self.run_id, {tag: metadata})
def test_atomic_with_exception(): sqlite_conn = sqlite3.connect(':memory:') conn_plus = ConnectionPlus(sqlite_conn) sqlite_conn.execute('PRAGMA user_version(25)') sqlite_conn.commit() assert 25 == sqlite_conn.execute('PRAGMA user_version').fetchall()[0][0] with pytest.raises(RuntimeError, match="Rolling back due to unhandled exception") as e: with atomic(conn_plus) as atomic_conn: atomic_conn.execute('PRAGMA user_version(42)') raise Exception('intended exception') assert error_caused_by(e, 'intended exception') assert 25 == sqlite_conn.execute('PRAGMA user_version').fetchall()[0][0]
def extract_runs_into_db(source_db_path: str, target_db_path: str, *run_ids: int, upgrade_source_db: bool=False, upgrade_target_db: bool=False) -> None: """ Extract a selection of runs into another DB file. All runs must come from the same experiment. They will be added to an experiment with the same name and sample_name in the target db. If such an experiment does not exist, it will be created. Args: source_db_path: Path to the source DB file target_db_path: Path to the target DB file. The target DB file will be created if it does not exist. run_ids: The run_ids of the runs to copy into the target DB file upgrade_source_db: If the source DB is found to be in a version that is not the newest, should it be upgraded? """ # Check for versions (s_v, new_v) = get_db_version_and_newest_available_version(source_db_path) if s_v < new_v and not upgrade_source_db: warn(f'Source DB version is {s_v}, but this function needs it to be' f' in version {new_v}. Run this function again with ' 'upgrade_source_db=True to auto-upgrade the source DB file.') return if os.path.exists(target_db_path): (t_v, new_v) = get_db_version_and_newest_available_version(target_db_path) if t_v < new_v and not upgrade_target_db: warn(f'Target DB version is {t_v}, but this function needs it to ' f'be in version {new_v}. Run this function again with ' 'upgrade_target_db=True to auto-upgrade the target DB file.') return source_conn = connect(source_db_path) # Validate that all runs are in the source database do_runs_exist = is_run_id_in_database(source_conn, run_ids) if False in do_runs_exist.values(): source_conn.close() non_existing_ids = [rid for rid in run_ids if not do_runs_exist[rid]] err_mssg = ("Error: not all run_ids exist in the source database. " "The following run(s) is/are not present: " f"{non_existing_ids}") raise ValueError(err_mssg) # Validate that all runs are from the same experiment source_exp_ids = np.unique(get_exp_ids_from_run_ids(source_conn, run_ids)) if len(source_exp_ids) != 1: source_conn.close() raise ValueError('Did not receive runs from a single experiment. ' f'Got runs from experiments {source_exp_ids}') # Fetch the attributes of the runs' experiment # hopefully, this is enough to uniquely identify the experiment exp_attr_names = ['name', 'sample_name', 'start_time', 'end_time', 'format_string'] exp_attr_vals = select_many_where(source_conn, 'experiments', *exp_attr_names, where_column='exp_id', where_value=source_exp_ids[0]) exp_attrs = dict(zip(exp_attr_names, exp_attr_vals)) # Massage the target DB file to accomodate the runs # (create new experiment if needed) target_conn = connect(target_db_path) # this function raises if the target DB file has several experiments # matching both the name and sample_name try: with atomic(target_conn) as target_conn: target_exp_id = _create_exp_if_needed(target_conn, exp_attrs['name'], exp_attrs['sample_name'], exp_attrs['format_string'], exp_attrs['start_time'], exp_attrs['end_time']) # Finally insert the runs for run_id in run_ids: _extract_single_dataset_into_db(DataSet(run_id=run_id, conn=source_conn), target_conn, target_exp_id) finally: source_conn.close() target_conn.close()
def test_that_use_of_atomic_commits_only_at_outermost_context( tmp_path, create_conn_plus): """ This test tests the behavior of `ConnectionPlus` that is created from `sqlite3.Connection` with respect to `atomic` context manager and commits. """ dbfile = str(tmp_path / 'temp.db') # just initialize the database file, connection objects needed for # testing in this test function are created separately, see below connect(dbfile) sqlite_conn = sqlite3.connect(dbfile) conn_plus = create_conn_plus(sqlite_conn) # this connection is going to be used to test whether changes have been # committed to the database file control_conn = connect(dbfile) get_all_runs = 'SELECT * FROM runs' insert_run_with_name = 'INSERT INTO runs (name) VALUES (?)' # assert that at the beginning of the test there are no runs in the # table; we'll be adding new rows to the runs table below assert 0 == len(conn_plus.execute(get_all_runs).fetchall()) assert 0 == len(control_conn.execute(get_all_runs).fetchall()) # add 1 new row, and assert the state of the runs table at every step # note that control_conn will only detect the change after the `atomic` # context manager is exited with atomic(conn_plus) as atomic_conn: assert 0 == len(conn_plus.execute(get_all_runs).fetchall()) assert 0 == len(atomic_conn.execute(get_all_runs).fetchall()) assert 0 == len(control_conn.execute(get_all_runs).fetchall()) atomic_conn.cursor().execute(insert_run_with_name, ['aaa']) assert 1 == len(conn_plus.execute(get_all_runs).fetchall()) assert 1 == len(atomic_conn.execute(get_all_runs).fetchall()) assert 0 == len(control_conn.execute(get_all_runs).fetchall()) assert 1 == len(conn_plus.execute(get_all_runs).fetchall()) assert 1 == len(atomic_conn.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) # let's add two new rows but each inside its own `atomic` context manager # we expect to see the actual change in the database only after we exit # the outermost context. with atomic(conn_plus) as atomic_conn_1: assert 1 == len(conn_plus.execute(get_all_runs).fetchall()) assert 1 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) atomic_conn_1.cursor().execute(insert_run_with_name, ['bbb']) assert 2 == len(conn_plus.execute(get_all_runs).fetchall()) assert 2 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) with atomic(atomic_conn_1) as atomic_conn_2: assert 2 == len(conn_plus.execute(get_all_runs).fetchall()) assert 2 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 2 == len(atomic_conn_2.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) atomic_conn_2.cursor().execute(insert_run_with_name, ['ccc']) assert 3 == len(conn_plus.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) assert 3 == len(conn_plus.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) assert 1 == len(control_conn.execute(get_all_runs).fetchall()) assert 3 == len(conn_plus.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_1.execute(get_all_runs).fetchall()) assert 3 == len(atomic_conn_2.execute(get_all_runs).fetchall()) assert 3 == len(control_conn.execute(get_all_runs).fetchall())