def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) candidate_dict = {} for k, c in source_channels.items(): cancidate_artifacts = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=c.type_name, producer_component_id=c.producer_component_id, output_key=c.output_key) candidate_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in cancidate_artifacts ] resolved_dict = self._resolve(candidate_dict) resolve_state_dict = { k: len(artifact_list) >= self._desired_num_of_artifact for k, artifact_list in resolved_dict.items() } return base_resolver.ResolveResult( per_key_resolve_result=resolved_dict, per_key_resolve_state=resolve_state_dict)
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> resolver.ResolveResult: # First, checks whether we have exactly Model and ModelBlessing Channels. model_channel_key = None model_blessing_channel_key = None assert len(source_channels) == 2, 'Expecting 2 input Channels' for k, c in source_channels.items(): if issubclass(c.type, standard_artifacts.Model): model_channel_key = k elif issubclass(c.type, standard_artifacts.ModelBlessing): model_blessing_channel_key = k else: raise RuntimeError('Only expecting Model or ModelBlessing, got %s' % c.type) assert model_channel_key is not None, 'Expecting Model as input' assert model_blessing_channel_key is not None, ('Expecting ModelBlessing as' ' input') model_channel = source_channels[model_channel_key] model_blessing_channel = source_channels[model_blessing_channel_key] # Gets the pipeline context as the artifact search space. pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) candidate_dict = {} # Gets all models in the search space and sort in reverse order by id. all_models = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_channel.type_name, producer_component_id=model_channel.producer_component_id, output_key=model_channel.output_key) candidate_dict[model_channel_key] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in all_models ] # Gets all ModelBlessing artifacts in the search space. all_model_blessings = metadata_handler.get_qualified_artifacts( contexts=[pipeline_context], type_name=model_blessing_channel.type_name, producer_component_id=model_blessing_channel.producer_component_id, output_key=model_blessing_channel.output_key) candidate_dict[model_blessing_channel_key] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in all_model_blessings ] resolved_dict = self._resolve(candidate_dict, model_channel_key, model_blessing_channel_key) resolve_state_dict = { k: bool(artifact_list) for k, artifact_list in resolved_dict.items() } return resolver.ResolveResult( per_key_resolve_result=resolved_dict, per_key_resolve_state=resolve_state_dict)
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: artifacts_dict = {} resolve_state_dict = {} pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context( [c.type_name for c in source_channels.values()], pipeline_context.id) for k, c in source_channels.items(): previous_artifacts = sorted(artifacts_in_context[c.type_name], key=lambda m: m.id, reverse=True) if len(previous_artifacts) >= self._desired_num_of_artifact: artifacts_dict[k] = [ _generate_tfx_artifact(a, c.type) for a in previous_artifacts[:self._desired_num_of_artifact] ] resolve_state_dict[k] = True else: artifacts_dict[k] = [ _generate_tfx_artifact(a, c.type) for a in previous_artifacts ] resolve_state_dict[k] = False return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: artifacts_dict = {} resolve_state_dict = {} pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) for k, c in source_channels.items(): candidate_artifacts = metadata_handler.get_qualified_artifacts( context=pipeline_context, type_name=c.type_name, producer_component_id=c.producer_component_id, output_key=c.output_key) previous_artifacts = sorted( candidate_artifacts, key=lambda a: a.artifact.id, reverse=True) if len(previous_artifacts) >= self._desired_num_of_artifact: artifacts_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in previous_artifacts[:self._desired_num_of_artifact] ] resolve_state_dict[k] = True else: artifacts_dict[k] = [ artifact_utils.deserialize_artifact(a.type, a.artifact) for a in previous_artifacts ] resolve_state_dict[k] = False return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)
def resolve( self, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: artifacts_dict = {} resolve_state_dict = {} for k, c in source_channels.items(): previous_artifacts = sorted(metadata_handler.get_artifacts_by_type( c.type_name), key=lambda m: m.id, reverse=True) if len(previous_artifacts) >= self._desired_num_of_artifact: artifacts_dict[k] = [ _generate_tfx_artifact(a, c.type) for a in previous_artifacts[:self._desired_num_of_artifact] ] resolve_state_dict[k] = True else: artifacts_dict[k] = [ _generate_tfx_artifact(a, c.type) for a in previous_artifacts ] resolve_state_dict[k] = False return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)
def test_fetch_previous_result(self): with Metadata( connection_config=self._connection_config, logger=self._logger) as m: # Create an 'previous' execution. exec_properties = {'log_root': 'path'} eid = m.prepare_execution('Test', exec_properties) input_artifact = types.TfxType(type_name='ExamplesPath') m.publish_artifacts([input_artifact]) output_artifact = types.TfxType(type_name='ExamplesPath') input_dict = {'input': [input_artifact]} output_dict = {'output': [output_artifact]} m.publish_execution(eid, input_dict, output_dict) # Test previous_run. self.assertEqual(None, m.previous_run('Test', input_dict, {})) self.assertEqual(None, m.previous_run('Test', {}, exec_properties)) self.assertEqual(None, m.previous_run('Test2', input_dict, exec_properties)) self.assertEqual(eid, m.previous_run('Test', input_dict, exec_properties)) # Test fetch_previous_result_artifacts. new_output_artifact = types.TfxType(type_name='ExamplesPath') self.assertNotEqual(types.ARTIFACT_STATE_PUBLISHED, new_output_artifact.state) new_output_dict = {'output': [new_output_artifact]} updated_output_dict = m.fetch_previous_result_artifacts( new_output_dict, eid) previous_artifact = output_dict['output'][-1].artifact current_artifact = updated_output_dict['output'][-1].artifact self.assertEqual(types.ARTIFACT_STATE_PUBLISHED, current_artifact.properties['state'].string_value) self.assertEqual(previous_artifact.id, current_artifact.id) self.assertEqual(previous_artifact.type_id, current_artifact.type_id)
def test_execution(self): with Metadata(connection_config=self._connection_config, logger=self._logger) as m: # Test prepare_execution. exec_properties = {} eid = m.prepare_execution('Test', exec_properties) [execution] = m.store.get_executions() self.assertProtoEquals( """ id: 1 type_id: 1 properties { key: "state" value { string_value: "new" } }""", execution) # Test publish_execution. input_artifact = types.TfxArtifact(type_name='ExamplesPath') m.publish_artifacts([input_artifact]) output_artifact = types.TfxArtifact(type_name='ExamplesPath') input_dict = {'input': [input_artifact]} output_dict = {'output': [output_artifact]} m.publish_execution(eid, input_dict, output_dict) # Make sure artifacts in output_dict are published. self.assertEqual(types.ARTIFACT_STATE_PUBLISHED, output_artifact.state) # Make sure execution state are changed. [execution] = m.store.get_executions_by_id([eid]) self.assertEqual('complete', execution.properties['state'].string_value) # Make sure events are published. events = m.store.get_events_by_execution_ids([eid]) self.assertEqual(2, len(events)) self.assertEqual(input_artifact.id, events[0].artifact_id) self.assertEqual(metadata_store_pb2.Event.DECLARED_INPUT, events[0].type) self.assertProtoEquals( """ steps { key: "input" } steps { index: 0 }""", events[0].path) self.assertEqual(output_artifact.id, events[1].artifact_id) self.assertEqual(metadata_store_pb2.Event.DECLARED_OUTPUT, events[1].type) self.assertProtoEquals( """ steps { key: "output" } steps { index: 0 }""", events[1].path)
def test_empty_artifact(self): with Metadata(self._connection_config) as m: m.publish_artifacts([]) eid = m.prepare_execution('Test', {}) m.publish_execution(eid, {}, {}) [execution] = m.store.get_executions_by_id([eid]) self.assertProtoEquals( """ id: 1 type_id: 1 properties { key: "state" value { string_value: "complete" } }""", execution)
def resolve( self, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: artifacts_dict = {} for k, c in source_channels.items(): previous_artifacts = metadata_handler.get_artifacts_by_type(c.type_name) if previous_artifacts: latest_mlmd_artifact = max(previous_artifacts, key=lambda m: m.id) result_artifact = types.Artifact(type_name=c.type_name) result_artifact.set_artifact(latest_mlmd_artifact) artifacts_dict[k] = ([result_artifact], True) else: artifacts_dict[k] = ([], False) return base_resolver.ResolveResult(per_key_resolve_result=artifacts_dict)
def test_artifact(self): with Metadata( connection_config=self._connection_config, logger=self._logger) as m: self.assertListEqual([], m.get_all_artifacts()) # Test publish artifact. artifact = types.TfxType(type_name='ExamplesPath') m.publish_artifacts([artifact]) [artifact] = m.store.get_artifacts() self.assertProtoEquals( """id: 1 type_id: 1 uri: "" properties { key: "split" value { string_value: "" } } properties { key: "state" value { string_value: "published" } } properties { key: "type_name" value { string_value: "ExamplesPath" } }""", artifact) # Test get artifact. self.assertListEqual([artifact], m.get_all_artifacts()) # Test artifact state. m.check_artifact_state(artifact, types.ARTIFACT_STATE_PUBLISHED) m.update_artifact_state(artifact, types.ARTIFACT_STATE_DELETED) m.check_artifact_state(artifact, types.ARTIFACT_STATE_DELETED) self.assertRaises(RuntimeError, m.check_artifact_state, artifact, types.ARTIFACT_STATE_PUBLISHED)
def test_get_cached_execution_ids(self): with Metadata(connection_config=self._connection_config, logger=self._logger) as m: mock_store = mock.Mock() mock_store.get_events_by_execution_ids.side_effect = [ [ metadata_store_pb2.Event( artifact_id=1, type=metadata_store_pb2.Event.INPUT) ], [ metadata_store_pb2.Event( artifact_id=1, type=metadata_store_pb2.Event.INPUT), metadata_store_pb2.Event( artifact_id=2, type=metadata_store_pb2.Event.INPUT), metadata_store_pb2.Event( artifact_id=3, type=metadata_store_pb2.Event.INPUT) ], [ metadata_store_pb2.Event( artifact_id=1, type=metadata_store_pb2.Event.INPUT), metadata_store_pb2.Event( artifact_id=2, type=metadata_store_pb2.Event.INPUT), ], ] m._store = mock_store input_one = types.TfxArtifact(type_name='ExamplesPath') input_one.id = 1 input_two = types.TfxArtifact(type_name='ExamplesPath') input_two.id = 2 input_dict = { 'input_one': [input_one], 'input_two': [input_two], } self.assertEqual(1, m._get_cached_execution_id(input_dict, [3, 2, 1]))
def _prepare_artifact( metadata_handler: metadata.Metadata, uri: Text, properties: Dict[Text, Any], custom_properties: Dict[Text, Any], reimport: bool, output_artifact_class: Type[types.Artifact], mlmd_artifact_type: Optional[metadata_store_pb2.ArtifactType] ) -> types.Artifact: """Prepares the Importer's output artifact. If there is already an artifact in MLMD with the same URI and properties / custom properties, that artifact will be reused unless the `reimport` argument is set to True. Args: metadata_handler: The handler of MLMD. uri: The uri of the artifact. properties: The properties of the artifact, given as a dictionary from string keys to integer / string values. Must conform to the declared properties of the destination channel's output type. custom_properties: The custom properties of the artifact, given as a dictionary from string keys to integer / string values. reimport: If set to True, will register a new artifact even if it already exists in the database. output_artifact_class: The class of the output artifact. mlmd_artifact_type: The MLMD artifact type of the Artifact to be created. Returns: An Artifact object representing the imported artifact. """ absl.logging.info( 'Processing source uri: %s, properties: %s, custom_properties: %s' % (uri, properties, custom_properties)) # Check types of custom properties. for key, value in custom_properties.items(): if not isinstance(value, (int, Text, bytes)): raise ValueError(( 'Custom property value for key %r must be a string or integer ' '(got %r instead)') % (key, value)) unfiltered_previous_artifacts = metadata_handler.get_artifacts_by_uri(uri) # Only consider previous artifacts as candidates to reuse, if the properties # of the imported artifact match those of the existing artifact. previous_artifacts = [] for candidate_mlmd_artifact in unfiltered_previous_artifacts: is_candidate = True candidate_artifact = output_artifact_class(mlmd_artifact_type) candidate_artifact.set_mlmd_artifact(candidate_mlmd_artifact) for key, value in properties.items(): if getattr(candidate_artifact, key) != value: is_candidate = False break for key, value in custom_properties.items(): if isinstance(value, int): if candidate_artifact.get_int_custom_property(key) != value: is_candidate = False break elif isinstance(value, (Text, bytes)): if candidate_artifact.get_string_custom_property(key) != value: is_candidate = False break if is_candidate: previous_artifacts.append(candidate_mlmd_artifact) result = output_artifact_class(mlmd_artifact_type) result.uri = uri for key, value in properties.items(): setattr(result, key, value) for key, value in custom_properties.items(): if isinstance(value, int): result.set_int_custom_property(key, value) elif isinstance(value, (Text, bytes)): result.set_string_custom_property(key, value) # If a registered artifact has the same uri and properties and the user does # not explicitly ask for reimport, reuse that artifact. if bool(previous_artifacts) and not reimport: absl.logging.info('Reusing existing artifact') result.set_mlmd_artifact(max(previous_artifacts, key=lambda m: m.id)) return result
def resolve( self, pipeline_info: data_types.PipelineInfo, metadata_handler: metadata.Metadata, source_channels: Dict[Text, types.Channel], ) -> base_resolver.ResolveResult: # First, checks whether we have exactly Model and ModelBlessing Channels. model_channel_key = None model_blessing_channel_key = None assert len(source_channels) == 2, 'Expecting 2 input Channels' for k, c in source_channels.items(): if issubclass(c.type, standard_artifacts.Model): model_channel_key = k elif issubclass(c.type, standard_artifacts.ModelBlessing): model_blessing_channel_key = k else: raise RuntimeError( 'Only expecting Model or ModelBlessing, got %s' % c.type) assert model_channel_key is not None, 'Expecting Model as input' assert model_blessing_channel_key is not None, ( 'Expecting ModelBlessing as' ' input') # Gets the pipeline context as the artifact search space. pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) if pipeline_context is None: raise RuntimeError('Pipeline context absent for %s' % pipeline_context) # Gets all artifacts of interests within context with one call. artifacts_in_context = metadata_handler.get_published_artifacts_by_type_within_context( [ source_channels[model_channel_key].type_name, source_channels[model_blessing_channel_key].type_name ], pipeline_context.id) # Gets all models in the search space and sort in reverse order by id. all_models = sorted( artifacts_in_context[source_channels[model_channel_key].type_name], key=lambda m: m.id, reverse=True) # Gets all ModelBlessing artifacts in the search space. all_model_blessings = artifacts_in_context[ source_channels[model_blessing_channel_key].type_name] # Makes a dict of {model_id : ModelBlessing artifact} for blessed models. all_blessed_model_ids = dict(( # pylint: disable=g-complex-comprehension a.custom_properties[ model_validator.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY]. int_value, a) for a in all_model_blessings if a.custom_properties[ model_validator.ARTIFACT_PROPERTY_BLESSED_KEY].int_value == 1) artifacts_dict = { model_channel_key: [], model_blessing_channel_key: [] } resolve_state_dict = { model_channel_key: False, model_blessing_channel_key: False } # Iterates all models, if blessed, set as result. As the model list was # sorted, it is guaranteed to get the latest blessed model. for model in all_models: if model.id in all_blessed_model_ids: artifacts_dict[model_channel_key] = [ _generate_tfx_artifact(model, standard_artifacts.Model) ] artifacts_dict[model_blessing_channel_key] = [ _generate_tfx_artifact(all_blessed_model_ids[model.id], standard_artifacts.ModelBlessing) ] resolve_state_dict[model_channel_key] = True resolve_state_dict[model_blessing_channel_key] = True break return base_resolver.ResolveResult( per_key_resolve_result=artifacts_dict, per_key_resolve_state=resolve_state_dict)