def testUnwrapChannelDictDeprecated(self): with mock.patch.object(tf_logging, 'warning'): warn_mock = mock.MagicMock() tf_logging.warning = warn_mock channel.unwrap_channel_dict({}) warn_mock.assert_called_once() self.assertIn('tfx.utils.channel.unwrap_channel_dict has been renamed to', warn_mock.call_args[0][5])
def _fetch_cached_artifacts( self, output_dict: Dict[Text, channel.Channel], cached_execution_id: int) -> Dict[Text, List[types.TfxArtifact]]: """Fetch cached output artifacts.""" output_artifacts_dict = channel.unwrap_channel_dict(output_dict) return self._metadata_handler.fetch_previous_result_artifacts( output_artifacts_dict, cached_execution_id)
def pre_execution( self, input_dict: Dict[Text, channel.Channel], output_dict: Dict[Text, channel.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: input_artifacts = channel.unwrap_channel_dict(input_dict) output_artifacts = channel.unwrap_channel_dict(output_dict) tf.gfile.MakeDirs(pipeline_info.pipeline_root) types.get_single_instance(output_artifacts['output']).uri = os.path.join( pipeline_info.pipeline_root, 'output') return data_types.ExecutionDecision(input_artifacts, output_artifacts, exec_properties, 123, False)
def setUp(self): self._mock_metadata = tf.test.mock.Mock() self._input_dict = { 'input_data': channel.Channel( type_name='input_data', artifacts=[types.TfxArtifact(type_name='input_data')]) } input_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'input_dir') # valid input artifacts must have a uri pointing to an existing directory. for key, input_channel in self._input_dict.items(): for index, artifact in enumerate(input_channel.get()): artifact.id = index + 1 uri = os.path.join(input_dir, key, str(artifact.id), '') artifact.uri = uri tf.gfile.MakeDirs(uri) self._output_dict = { 'output_data': channel.Channel(type_name='output_data', artifacts=[ types.TfxArtifact(type_name='output_data', split='split') ]) } self._input_artifacts = channel.unwrap_channel_dict(self._input_dict) self._output_artifacts = { 'output_data': [types.TfxArtifact(type_name='OutputType')], } self._exec_properties = { 'key': 'value', } self._execution_id = 100
def resolve_input_artifacts( self, input_dict: Dict[Text, channel.Channel], exec_properties: Dict[Text, Any], pipeline_info: data_types.PipelineInfo, ) -> Dict[Text, List[types.TfxArtifact]]: """Overrides BaseDriver.resolve_input_artifacts().""" return self._prepare_input_for_processing( channel.unwrap_channel_dict(input_dict), exec_properties)
def test_unwrap_channel_dict(self): instance_a = types.TfxArtifact('MyTypeName') instance_b = types.TfxArtifact('MyTypeName') channel_dict = { 'id': channel.Channel('MyTypeName', artifacts=[instance_a, instance_b]) } result = channel.unwrap_channel_dict(channel_dict) self.assertDictEqual(result, {'id': [instance_a, instance_b]})
def _prepare_output_artifacts( self, output_dict: Dict[Text, channel.Channel], execution_id: int, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, List[types.TfxArtifact]]: """Prepare output artifacts by assigning uris to each artifact.""" result = channel.unwrap_channel_dict(output_dict) base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) for name, output_list in result.items(): for artifact in output_list: artifact.uri = _generate_output_uri(artifact, base_output_dir, name, execution_id) return result
def testUnwrapChannelDictDeprecated(self): channel.unwrap_channel_dict({}) self._assertDeprecatedWarningRegex( 'tfx.utils.channel.unwrap_channel_dict has been renamed to ' 'tfx.types.channel_utils.unwrap_channel_dict')