def test_channel_utils_as_channel_success(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') chnl_original = Channel('MyTypeName', artifacts=[instance_a, instance_b]) chnl_result = channel_utils.as_channel(chnl_original) self.assertEqual(chnl_original, chnl_result)
def testUnwrapChannelDict(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') channel_dict = { 'id': Channel('MyTypeName', artifacts=[instance_a, instance_b]) } result = channel_utils.unwrap_channel_dict(channel_dict) self.assertDictEqual(result, {'id': [instance_a, instance_b]})
def deserialize_artifact( artifact_type: metadata_store_pb2.ArtifactType, artifact: Optional[metadata_store_pb2.Artifact] = None) -> Artifact: """Reconstruct Artifact object from MLMD proto descriptors. Internal method, no backwards compatibility guarantees. Args: artifact_type: A metadata_store_pb2.ArtifactType proto object describing the type of the artifact. artifact: A metadata_store_pb2.Artifact proto object describing the contents of the artifact. If not provided, an Artifact of the desired type with empty contents is created. Returns: Artifact subclass object for the given MLMD proto descriptors. """ # Validate inputs. if not isinstance(artifact_type, metadata_store_pb2.ArtifactType): raise ValueError(( 'Expected metadata_store_pb2.ArtifactType for artifact_type, got %s ' 'instead') % (artifact_type, )) if artifact and not isinstance(artifact, metadata_store_pb2.Artifact): raise ValueError( ('Expected metadata_store_pb2.Artifact for artifact, got %s ' 'instead') % (artifact, )) # Make sure this module path containing the standard Artifact subclass # definitions is imported. Modules containing custom artifact subclasses that # need to be deserialized should be imported by the entrypoint of the # application or container. from tfx.types import standard_artifacts # pylint: disable=g-import-not-at-top,unused-variable # Attempt to find the appropriate Artifact subclass for reconstructing this # object. artifact_cls = None for cls in Artifact.__subclasses__(): if cls.TYPE_NAME == artifact_type.name: artifact_cls = cls # Construct the Artifact object, using a concrete Artifact subclass when # possible. if artifact_cls: result = artifact_cls() result.set_mlmd_artifact_type(artifact_type) else: absl.logging.warning(( 'Could not load artifact class for type %r; using fallback ' 'deserialization for the relevant artifact. If this is not intended, ' 'please make sure that the artifact class for this type can be ' 'imported within your container or environment where a component is ' 'executed to consume this type.') % (artifact_type.name)) result = Artifact(mlmd_artifact_type=artifact_type) if artifact: result.set_mlmd_artifact(artifact) return result
def _get_outputs_of_execution( self, desired_input_ids: Set[int], execution_id: int, events: List[metadata_store_pb2.Event] ) -> Optional[Dict[Text, List[Artifact]]]: """Fetches outputs produced by a historical execution with desired inputs. If the desired input ids are not exactly the same as the input artifacts of the given execution id, return nothing. Otherwise, return the output artifacts in the format of key -> List[Artifact]. Args: desired_input_ids: artifact ids of desired inputs. execution_id: the id of the execution that produced the outputs. events: events related to the execution id. Returns: A dict of key -> List[Artifact] as the result """ execution_input_ids = set(event.artifact_id for event in events if event.type == metadata_store_pb2.Event.INPUT) # Only needs to compare the length of the input ids set since we only need # to rule out the case that past execution uses more inputs than given # inputs. if len(desired_input_ids) != len(execution_input_ids): absl.logging.debug('Execution %s does not match all inputs' % execution_id) return None absl.logging.debug('Execution %s matches all inputs' % execution_id) result = collections.defaultdict(list) output_events = [ event for event in events if event.type in [metadata_store_pb2.Event.OUTPUT] ] output_events.sort(key=lambda e: e.path.steps[1].index) cached_output_artifacts = self.store.get_artifacts_by_id( [e.artifact_id for e in output_events]) artifact_types = self.store.get_artifact_types_by_id( [a.type_id for a in cached_output_artifacts]) for event, mlmd_artifact, artifact_type in zip(output_events, cached_output_artifacts, artifact_types): key = event.path.steps[0].key tfx_artifact = Artifact(mlmd_artifact_type=artifact_type) tfx_artifact.set_mlmd_artifact(mlmd_artifact) result[key].append(tfx_artifact) return result
def _get_data_view_info( examples: artifact.Artifact) -> Optional[Tuple[str, int]]: """Returns the payload format and data view URI and ID from examples.""" assert examples.type is standard_artifacts.Examples, ( 'examples must be of type standard_artifacts.Examples') payload_format = examples_utils.get_payload_format(examples) if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO: data_view_uri = examples.get_string_custom_property( constants.DATA_VIEW_URI_PROPERTY_KEY) if data_view_uri: data_view_create_time = examples.get_int_custom_property( constants.DATA_VIEW_CREATE_TIME_KEY) return data_view_uri, data_view_create_time return None
def search_artifacts(self, artifact_name: Text, pipeline_name: Text, run_id: Text, producer_component_id: Text) -> List[Artifact]: """Search artifacts that matches given info. Args: artifact_name: the name of the artifact that set by producer component. The name is logged both in artifacts and the events when the execution being published. pipeline_name: the name of the pipeline that produces the artifact run_id: the run id of the pipeline run that produces the artifact producer_component_id: the id of the component that produces the artifact Returns: A list of Artifacts that matches the given info Raises: RuntimeError: when no matching execution is found given producer info. """ producer_execution = None matching_artifact_ids = set() for execution in self._store.get_executions(): if (execution.properties['pipeline_name'].string_value == pipeline_name and execution.properties['run_id'].string_value == run_id and execution.properties['component_id'].string_value == producer_component_id): producer_execution = execution if not producer_execution: raise RuntimeError( 'Cannot find matching execution with pipeline name %s,' 'run id %s and component id %s' % (pipeline_name, run_id, producer_component_id)) for event in self._store.get_events_by_execution_ids( [producer_execution.id]): if (event.type == metadata_store_pb2.Event.OUTPUT and event.path.steps[0].key == artifact_name): matching_artifact_ids.add(event.artifact_id) result_artifacts = [] for a in self._store.get_artifacts_by_id(list(matching_artifact_ids)): tfx_artifact = Artifact(a.properties['type_name'].string_value) tfx_artifact.artifact = a result_artifacts.append(tfx_artifact) return result_artifacts
def is_artifact_version_older_than(artifact: Artifact, artifact_version: Text) -> bool: """Check if artifact belongs to old version.""" if artifact.mlmd_artifact.state == metadata_store_pb2.Artifact.UNKNOWN: # Newly generated artifact should use the latest artifact payload format. return False # For artifact that resolved from MLMD. if not artifact.has_custom_property(ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY): # Artifact without version. return True if (version.parse( artifact.get_string_custom_property( ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY)) < version.parse(artifact_version)): # Artifact with old version. return True else: return False
def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any: artifact_type = metadata_store_pb2.ArtifactType() json_format.Parse(json.dumps(dict_data['type']), artifact_type) type_cls = artifact_utils.get_artifact_type_class(artifact_type) artifacts = list( Artifact.from_json_dict(a) for a in dict_data['artifacts']) producer_component_id = dict_data.get('producer_component_id', None) output_key = dict_data.get('output_key', None) return Channel(type=type_cls, artifacts=artifacts, producer_component_id=producer_component_id, output_key=output_key)
def get_tfxio_factory_from_artifact( examples: artifact.Artifact, telemetry_descriptors: List[Text], schema: Optional[schema_pb2.Schema] = None, read_as_raw_records: bool = False, raw_record_column_name: Optional[Text] = None ) -> Callable[[Text], tfxio.TFXIO]: """Returns a factory function that creates a proper TFXIO. Args: examples: The Examples artifact that the TFXIO is intended to access. telemetry_descriptors: A set of descriptors that identify the component that is instantiating the TFXIO. These will be used to construct the namespace to contain metrics for profiling and are therefore expected to be identifiers of the component itself and not individual instances of source use. schema: TFMD schema. Note that without a schema, some TFXIO interfaces in certain TFXIO implementations might not be available. read_as_raw_records: If True, ignore the payload type of `examples`. Always use RawTfRecord TFXIO. raw_record_column_name: If provided, the arrow RecordBatch produced by the TFXIO will contain a string column of the given name, and the contents of that column will be the raw records. Note that not all TFXIO supports this option, and an error will be raised in that case. Required if read_as_raw_records == True. Returns: A function that takes a file pattern as input and returns a TFXIO instance. Raises: NotImplementedError: when given an unsupported example payload type. """ assert examples.type is standard_artifacts.Examples, ( 'examples must be of type standard_artifacts.Examples') # In case that the payload format custom property is not set. # Assume tf.Example. payload_format = examples_utils.get_payload_format(examples) data_view_uri = None if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO: data_view_uri = examples.get_string_custom_property( constants.DATA_VIEW_URI_PROPERTY_KEY) if not data_view_uri: data_view_uri = None return lambda file_pattern: make_tfxio( # pylint:disable=g-long-lambda file_pattern=file_pattern, telemetry_descriptors=telemetry_descriptors, payload_format=payload_format, data_view_uri=data_view_uri, schema=schema, read_as_raw_records=read_as_raw_records, raw_record_column_name=raw_record_column_name)
def search_artifacts(self, artifact_name: Text, pipeline_info: data_types.PipelineInfo, producer_component_id: Text) -> List[Artifact]: """Search artifacts that matches given info. Args: artifact_name: the name of the artifact that set by producer component. The name is logged both in artifacts and the events when the execution being published. pipeline_info: the information of the current pipeline producer_component_id: the id of the component that produces the artifact Returns: A list of Artifacts that matches the given info Raises: RuntimeError: when no matching execution is found given producer info. """ producer_execution = None matching_artifact_ids = set() # TODO(ruoyu): We need to revisit this when adding support for async # execution. context = self.get_pipeline_run_context(pipeline_info) if context is None: raise RuntimeError('Pipeline run context for %s does not exist' % pipeline_info) for execution in self.store.get_executions_by_context(context.id): if execution.properties[ 'component_id'].string_value == producer_component_id: producer_execution = execution break if not producer_execution: raise RuntimeError('Cannot find matching execution with pipeline name %s,' 'run id %s and component id %s' % (pipeline_info.pipeline_name, pipeline_info.run_id, producer_component_id)) for event in self.store.get_events_by_execution_ids([producer_execution.id ]): if (event.type == metadata_store_pb2.Event.OUTPUT and event.path.steps[0].key == artifact_name): matching_artifact_ids.add(event.artifact_id) # Get relevant artifacts along with their types. artifacts_by_id = self.store.get_artifacts_by_id( list(matching_artifact_ids)) matching_artifact_type_ids = list(set(a.type_id for a in artifacts_by_id)) matching_artifact_types = self.store.get_artifact_types_by_id( matching_artifact_type_ids) artifact_types = dict( zip(matching_artifact_type_ids, matching_artifact_types)) result_artifacts = [] for a in artifacts_by_id: tfx_artifact = Artifact(mlmd_artifact_type=artifact_types[a.type_id]) tfx_artifact.set_mlmd_artifact(a) tfx_artifact.set_mlmd_artifact_type(artifact_types[a.type_id]) result_artifacts.append(tfx_artifact) return result_artifacts
def refactor_model_blessing(model_blessing: artifact.Artifact, name_from_id: Mapping[int, str]) -> None: """Changes id-typed custom properties to string-typed runtime artifact name.""" if model_blessing.has_custom_property( constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY): model_blessing.set_string_custom_property( constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY, _get_full_name(artifact_id=model_blessing.get_int_custom_property( constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY), name_from_id=name_from_id)) if model_blessing.has_custom_property( constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY): model_blessing.set_string_custom_property( constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY, _get_full_name(artifact_id=model_blessing.get_int_custom_property( constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY), name_from_id=name_from_id))
def parse_artifact_dict(json_str: Text) -> Dict[Text, List[Artifact]]: """Parse a dict from key to list of Artifact from its json format.""" tfx_artifacts = {} for k, l in json.loads(json_str).items(): tfx_artifacts[k] = [Artifact.from_json_dict(v) for v in l] return tfx_artifacts
def testArtifactCollectionAsChannel(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') chnl = channel_utils.as_channel([instance_a, instance_b]) self.assertEqual(chnl.type_name, 'MyTypeName') self.assertItemsEqual(chnl.get(), [instance_a, instance_b])
def test_invalid_channel_type(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') with self.assertRaises(ValueError): Channel('AnotherTypeName', artifacts=[instance_a, instance_b])
def test_valid_channel(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') chnl = Channel('MyTypeName', artifacts=[instance_a, instance_b]) self.assertEqual(chnl.type_name, 'MyTypeName') self.assertItemsEqual(chnl.get(), [instance_a, instance_b])