Exemplo n.º 1
0
 def testRuntimeInfoSimple(self):
     self._assert_placeholder_pb_equal_and_deepcopyable(
         ph.runtime_info('platform_config'), """
     placeholder {
       type: RUNTIME_INFO
       key: "platform_config"
     }
 """)
Exemplo n.º 2
0
 def testPlaceholdersInvolved(self):
     p = ('google/' + ph.runtime_info('platform_config').user + '/' +
          ph.output('model').uri + '/model/' + '0/' +
          ph.exec_property('version'))
     got = p.placeholders_involved()
     got_dict = {type(x): x for x in got}
     self.assertCountEqual(
         {
             ph.ArtifactPlaceholder, ph.ExecPropertyPlaceholder,
             ph.RuntimeInfoPlaceholder
         }, got_dict.keys())
Exemplo n.º 3
0
def create_test_pipeline():
    """Builds a pipeline with Placeholder in pipeline_root."""
    pipeline_name = "pipeline_root_placeholder"
    tfx_root = "tfx_root"
    data_path = os.path.join(tfx_root, "data_path")
    pipeline_root = ph.runtime_info("platform_config").base_dir

    example_gen = CsvExampleGen(input_base=data_path)
    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[example_gen],
                             enable_cache=True,
                             execution_mode=pipeline.ExecutionMode.SYNC)
Exemplo n.º 4
0
 def testRuntimeInfoInvalidKey(self):
     with self.assertRaises(ValueError):
         ph.runtime_info('invalid_key')
Exemplo n.º 5
0
    def testExecutionParameterTypeCheck(self):
        int_parameter = ExecutionParameter(type=int)
        int_parameter.type_check('int_parameter', 8)
        with self.assertRaisesRegex(
                TypeError, "Expected type <(class|type) 'int'>"
                " for parameter u?'int_parameter'"):
            int_parameter.type_check('int_parameter', 'string')

        list_parameter = ExecutionParameter(type=List[int])
        list_parameter.type_check('list_parameter', [])
        list_parameter.type_check('list_parameter', [42])
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a list for parameter'):
            list_parameter.type_check('list_parameter', 42)

        with self.assertRaisesRegex(
                TypeError, "Expecting item type <(class|type) "
                "'int'> for parameter u?'list_parameter'"):
            list_parameter.type_check('list_parameter', [42, 'wrong item'])

        dict_parameter = ExecutionParameter(type=Dict[str, int])
        dict_parameter.type_check('dict_parameter', {})
        dict_parameter.type_check('dict_parameter', {'key1': 1, 'key2': 2})
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a dict for parameter'):
            dict_parameter.type_check('dict_parameter', 'simple string')

        with self.assertRaisesRegex(
                TypeError, "Expecting value type "
                "<(class|type) 'int'>"):
            dict_parameter.type_check('dict_parameter', {'key1': '1'})

        proto_parameter = ExecutionParameter(type=example_gen_pb2.Input)
        proto_parameter.type_check('proto_parameter', example_gen_pb2.Input())
        proto_parameter.type_check(
            'proto_parameter',
            proto_utils.proto_to_json(example_gen_pb2.Input()))
        proto_parameter.type_check('proto_parameter',
                                   {'splits': [{
                                       'name': 'hello'
                                   }]})
        proto_parameter.type_check('proto_parameter', {'wrong_field': 42})
        with self.assertRaisesRegex(
                TypeError,
                "Expected type <class 'tfx.proto.example_gen_pb2.Input'>"):
            proto_parameter.type_check('proto_parameter', 42)
        with self.assertRaises(json_format.ParseError):
            proto_parameter.type_check('proto_parameter', {'splits': 42})

        output_channel = channel.Channel(type=_OutputArtifact)

        placeholder_parameter = ExecutionParameter(type=str)
        placeholder_parameter.type_check(
            'wrapped_channel_placeholder_parameter',
            output_channel.future()[0].value)
        placeholder_parameter.type_check(
            'placeholder_parameter',
            placeholder.runtime_info('platform_config').base_dir)
        with self.assertRaisesRegex(
                TypeError,
                'Only simple RuntimeInfoPlaceholders are supported'):
            placeholder_parameter.type_check(
                'placeholder_parameter',
                placeholder.runtime_info('platform_config').base_dir +
                placeholder.exec_property('version'))