Example #1
0
 def testUnwrapChannelDictDeprecated(self):
   with mock.patch.object(tf_logging, 'warning'):
     warn_mock = mock.MagicMock()
     tf_logging.warning = warn_mock
     channel.unwrap_channel_dict({})
     warn_mock.assert_called_once()
     self.assertIn('tfx.utils.channel.unwrap_channel_dict has been renamed to',
                   warn_mock.call_args[0][5])
Example #2
0
 def _fetch_cached_artifacts(
         self, output_dict: Dict[Text, channel.Channel],
         cached_execution_id: int) -> Dict[Text, List[types.TfxArtifact]]:
     """Fetch cached output artifacts."""
     output_artifacts_dict = channel.unwrap_channel_dict(output_dict)
     return self._metadata_handler.fetch_previous_result_artifacts(
         output_artifacts_dict, cached_execution_id)
Example #3
0
 def pre_execution(
     self,
     input_dict: Dict[Text, channel.Channel],
     output_dict: Dict[Text, channel.Channel],
     exec_properties: Dict[Text, Any],
     driver_args: data_types.DriverArgs,
     pipeline_info: data_types.PipelineInfo,
     component_info: data_types.ComponentInfo,
 ) -> data_types.ExecutionDecision:
   input_artifacts = channel.unwrap_channel_dict(input_dict)
   output_artifacts = channel.unwrap_channel_dict(output_dict)
   tf.gfile.MakeDirs(pipeline_info.pipeline_root)
   types.get_single_instance(output_artifacts['output']).uri = os.path.join(
       pipeline_info.pipeline_root, 'output')
   return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                       exec_properties, 123, False)
Example #4
0
 def setUp(self):
     self._mock_metadata = tf.test.mock.Mock()
     self._input_dict = {
         'input_data':
         channel.Channel(
             type_name='input_data',
             artifacts=[types.TfxArtifact(type_name='input_data')])
     }
     input_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'input_dir')
     # valid input artifacts must have a uri pointing to an existing directory.
     for key, input_channel in self._input_dict.items():
         for index, artifact in enumerate(input_channel.get()):
             artifact.id = index + 1
             uri = os.path.join(input_dir, key, str(artifact.id), '')
             artifact.uri = uri
             tf.gfile.MakeDirs(uri)
     self._output_dict = {
         'output_data':
         channel.Channel(type_name='output_data',
                         artifacts=[
                             types.TfxArtifact(type_name='output_data',
                                               split='split')
                         ])
     }
     self._input_artifacts = channel.unwrap_channel_dict(self._input_dict)
     self._output_artifacts = {
         'output_data': [types.TfxArtifact(type_name='OutputType')],
     }
     self._exec_properties = {
         'key': 'value',
     }
     self._execution_id = 100
Example #5
0
 def resolve_input_artifacts(
     self,
     input_dict: Dict[Text, channel.Channel],
     exec_properties: Dict[Text, Any],
     pipeline_info: data_types.PipelineInfo,
 ) -> Dict[Text, List[types.TfxArtifact]]:
   """Overrides BaseDriver.resolve_input_artifacts()."""
   return self._prepare_input_for_processing(
       channel.unwrap_channel_dict(input_dict), exec_properties)
Example #6
0
 def test_unwrap_channel_dict(self):
     instance_a = types.TfxArtifact('MyTypeName')
     instance_b = types.TfxArtifact('MyTypeName')
     channel_dict = {
         'id': channel.Channel('MyTypeName',
                               artifacts=[instance_a, instance_b])
     }
     result = channel.unwrap_channel_dict(channel_dict)
     self.assertDictEqual(result, {'id': [instance_a, instance_b]})
Example #7
0
 def _prepare_output_artifacts(
     self,
     output_dict: Dict[Text, channel.Channel],
     execution_id: int,
     pipeline_info: data_types.PipelineInfo,
     component_info: data_types.ComponentInfo,
 ) -> Dict[Text, List[types.TfxArtifact]]:
     """Prepare output artifacts by assigning uris to each artifact."""
     result = channel.unwrap_channel_dict(output_dict)
     base_output_dir = os.path.join(pipeline_info.pipeline_root,
                                    component_info.component_id)
     for name, output_list in result.items():
         for artifact in output_list:
             artifact.uri = _generate_output_uri(artifact, base_output_dir,
                                                 name, execution_id)
     return result
Example #8
0
 def testUnwrapChannelDictDeprecated(self):
     channel.unwrap_channel_dict({})
     self._assertDeprecatedWarningRegex(
         'tfx.utils.channel.unwrap_channel_dict has been renamed to '
         'tfx.types.channel_utils.unwrap_channel_dict')