def test_create_snapshot_by_dictionary(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) assert snapshot.id assert snapshot.model_id == self.snapshot_input_dict['model_id'] assert snapshot.session_id == self.snapshot_input_dict['session_id'] assert snapshot.message == self.snapshot_input_dict['message'] assert snapshot.code_id == self.snapshot_input_dict['code_id'] assert snapshot.environment_id == self.snapshot_input_dict[ 'environment_id'] assert snapshot.file_collection_id == self.snapshot_input_dict[ 'file_collection_id'] assert snapshot.config == self.snapshot_input_dict['config'] assert snapshot.stats == self.snapshot_input_dict['stats'] assert snapshot.created_at assert snapshot.updated_at snapshot_2 = self.dal.snapshot.create( Snapshot(self.snapshot_input_dict)) assert snapshot_2.id != snapshot.id test_snapshot_input_dict = self.snapshot_input_dict.copy() test_snapshot_input_dict['id'] = "snapshot_id" snapshot_3 = self.dal.snapshot.create( Snapshot(test_snapshot_input_dict)) assert snapshot_3.id == test_snapshot_input_dict['id']
def test_query_snapshots(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) # All snapshots created are the same, 1 is deleted => 7 assert len(self.dal.snapshot.query({"id": snapshot.id})) == 1 _ = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) assert len( self.dal.snapshot.query( {"code_id": self.snapshot_input_dict['code_id']})) == 2 assert len(self.dal.snapshot.query({"visible": True})) == 2
def test_sort_snapshots(self): snapshot_1 = self.dal.snapshot.create( Snapshot(self.snapshot_input_dict)) snapshot_2 = self.dal.snapshot.create( Snapshot(self.snapshot_input_dict)) # Sorting of snapshot in descending items = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="created_at", sort_order="descending") assert items[0].created_at == snapshot_2.created_at # Sorting of snapshot in ascending items = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="created_at", sort_order="ascending") assert items[0].created_at == snapshot_1.created_at # Wrong order being passed in failed = False try: _ = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="created_at", sort_order="wrong_order") except InvalidArgumentType: failed = True assert failed # Wrong key and order being passed in failed = False try: _ = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="wrong_key", sort_order="wrong_order") except InvalidArgumentType: failed = True assert failed # wrong key and right order being passed in expected_items = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="created_at", sort_order="ascending") items = self.dal.snapshot.query( {"model_id": self.snapshot_input_dict["model_id"]}, sort_key="wrong_key", sort_order="ascending") expected_ids = [item.id for item in expected_items] ids = [item.id for item in items] assert set(expected_ids) == set(ids)
def test_query_snapshots_range_query(self): _ = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) _ = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) _ = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) snapshots = self.dal.snapshot.query({}, sort_key="created_at", sort_order="descending") result = self.dal.snapshot.query({ "created_at": { "$lt": snapshots[1].created_at.strftime('%Y-%m-%dT%H:%M:%S.%fZ') } }) assert len(snapshots) == 3 assert len(result) == 1
def test_to_dictionary(self): snapshot_entity = Snapshot(self.input_dict) output_dict = snapshot_entity.to_dictionary() for k, v in output_dict.items(): assert v == getattr(snapshot_entity, k) # Test stringify output_dict = snapshot_entity.to_dictionary(stringify=True) for k, v in output_dict.items(): if k in [ "config", "stats", "message", "label", "created_at", "updated_at" ]: assert isinstance(k, str) else: assert v == getattr(snapshot_entity, k)
def test_delete_snapshot(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) self.dal.snapshot.delete(snapshot.id) deleted = False try: self.dal.snapshot.get_by_id(snapshot.id) except EntityNotFound: deleted = True assert deleted
def test_init_with_id(self): self.input_dict['id'] = "test" snapshot_entity = Snapshot(self.input_dict) for k, v in self.input_dict.items(): assert getattr(snapshot_entity, k) == v assert snapshot_entity.task_id == None assert snapshot_entity.label == None assert snapshot_entity.visible == True assert snapshot_entity.created_at assert snapshot_entity.updated_at
def test_get_by_id_snapshot_new_driver_instance(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) # create new dal with new driver instance (fails) new_driver_instance = BlitzDBDALDriver("file", self.temp_dir) new_dal_instance = LocalDAL(new_driver_instance) new_snapshot_1 = new_dal_instance.snapshot.get_by_id(snapshot.id) assert new_snapshot_1.id == snapshot.id # create new dal instance with same driver (success) new_dal_instance = LocalDAL(self.datadriver) new_snapshot_2 = new_dal_instance.snapshot.get_by_id(snapshot.id) assert new_snapshot_2.id == snapshot.id
def test_update_snapshot(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) # Update required and optional parameters updated_snapshot_input_dict = self.snapshot_input_dict.copy() updated_snapshot_input_dict['id'] = snapshot.id updated_snapshot_input_dict['message'] = "this is really cool" updated_snapshot_input_dict['label'] = "new" updated_snapshot = self.dal.snapshot.update(updated_snapshot_input_dict) assert snapshot.id == updated_snapshot.id assert snapshot.updated_at < updated_snapshot.updated_at assert updated_snapshot.message == updated_snapshot_input_dict['message'] assert updated_snapshot.label == updated_snapshot_input_dict['label']
def test_to_dictionary(self): snapshot_entity = Snapshot(self.input_dict) output_dict = snapshot_entity.to_dictionary() for k, v in output_dict.items(): assert v == getattr(snapshot_entity, k)
def test_eq(self): snapshot_entity_1 = Snapshot(self.input_dict) snapshot_entity_2 = Snapshot(self.input_dict) assert snapshot_entity_1 == snapshot_entity_2
def create(self, dictionary): """Create snapshot object Parameters ---------- dictionary : dict for each of the 5 key components, this function will search for one of the variables below starting from the top. Default functionality is described below for each component as well for reference if none of the variables are given. code : code_id : str, optional code reference associated with the snapshot; if not provided will look to inputs below for code creation commit_id : str, optional commit id provided by the user if already available Default ------- commits will be taken and code created via the CodeController and are added to the snapshot at the time of snapshot creation environment : environment_id : str, optional id for environment used to create snapshot environment_paths : list, optional list of absolute or relative filepaths and/or dirpaths to collect with destination names (e.g. "/path/to/file>hello", "/path/to/file2", "/path/to/dir>newdir") Default ------- default environment files will be searched and environment will be created with the EnvironmentController and added to the snapshot at the time of snapshot creation file_collection : file_collection_id : str, optional file collection associated with the snapshot paths : list, optional list of absolute or relative filepaths and/or dirpaths to collect with destination names (e.g. "/path/to/file:hello", "/path/to/file2", "/path/to/dir:newdir") Default ------- paths will be considered empty ([]), and the FileCollectionController will create a blank FileCollection that is empty. config : config : dict, optional key, value pairs of configurations config_filepath : str, optional absolute filepath to configuration parameters file config_filename : str, optional name of file with configuration parameters Default ------- config will be considered empty ({}) and saved to the snapshot stats : stats : dict, optional key, value pairs of metrics and statistics stats_filepath : str, optional absolute filepath to stats parameters file stats_filename : str, optional name of file with metrics and statistics. Default ------- stats will be considered empty ({}) and saved to the snapshot for the remaining optional arguments it will search for them in the input dictionary message : str long description of snapshot session_id : str, optional session id within which snapshot is created, will overwrite default if given task_id : str, optional task id associated with snapshot label : str, optional short description of snapshot visible : bool, optional True if visible to user via list command else False Returns ------- datmo.core.entity.snapshot.Snapshot snapshot object with all relevant parameters Raises ------ RequiredArgumentMissing if required arguments are not given by the user FileIOError if files are not present or there is an error in File IO """ # Validate Inputs create_dict = { "model_id": self.model.id, "session_id": self.current_session.id, } validate("create_snapshot", dictionary) # Message must be present if "message" in dictionary: create_dict['message'] = dictionary['message'] else: raise RequiredArgumentMissing( __("error", "controller.snapshot.create.arg", "message")) # Code setup self._code_setup(dictionary, create_dict) # Environment setup self._env_setup(dictionary, create_dict) # File setup self._file_setup(dictionary, create_dict) # Config setup self._config_setup(dictionary, create_dict) # Stats setup self._stats_setup(dictionary, create_dict) # If snapshot object with required args already exists, return it # DO NOT create a new snapshot with the same required arguments results = self.dal.snapshot.query({ "model_id": create_dict["model_id"], "code_id": create_dict['code_id'], "environment_id": create_dict['environment_id'], "file_collection_id": create_dict['file_collection_id'], "config": create_dict['config'], "stats": create_dict['stats'] }) if results: return results[0] # Optional args for Snapshot entity optional_args = ["task_id", "label", "visible"] for optional_arg in optional_args: if optional_arg in dictionary: create_dict[optional_arg] = dictionary[optional_arg] # Create snapshot and return return self.dal.snapshot.create(Snapshot(create_dict))
def test_str(self): snapshot_entity = Snapshot(self.input_dict) for k, v in self.input_dict.items(): if k != "model_id": assert str(v) in str(snapshot_entity)
def test_get_by_shortened_id_snapshot(self): snapshot = self.dal.snapshot.create(Snapshot(self.snapshot_input_dict)) result = self.dal.snapshot.get_by_shortened_id(snapshot.id[:10]) assert snapshot.id == result.id