예제 #1
0
    def testSimpleFunctionParse(self):
        def func_a(a: int, b: int, unused_c: Text, unused_d: bytes,
                   unused_e: Parameter[float]) -> OutputDict(c=float):
            return {'c': float(a + b)}

        inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
            parse_typehint_component_function(func_a))
        self.assertDictEqual(
            inputs, {
                'a': standard_artifacts.Integer,
                'b': standard_artifacts.Integer,
                'unused_c': standard_artifacts.String,
                'unused_d': standard_artifacts.Bytes,
            })
        self.assertDictEqual(outputs, {
            'c': standard_artifacts.Float,
        })
        self.assertDictEqual(parameters, {
            'unused_e': float,
        })
        self.assertDictEqual(
            arg_formats, {
                'a': ArgFormats.ARTIFACT_VALUE,
                'b': ArgFormats.ARTIFACT_VALUE,
                'unused_c': ArgFormats.ARTIFACT_VALUE,
                'unused_d': ArgFormats.ARTIFACT_VALUE,
                'unused_e': ArgFormats.PARAMETER,
            })
        self.assertDictEqual(arg_defaults, {})
        self.assertEqual(returned_values, set(['c']))
예제 #2
0
  def testOptionalArguments(self):
    # Various optional argument schemes.
    def func_a(a: float,
               b: int,
               c: Parameter[Text],
               d: int = 123,
               e: Optional[int] = 345,
               f: Text = 'abc',
               g: bytes = b'xyz',
               h: Parameter[Text] = 'default',
               i: Parameter[int] = 999,
               examples: InputArtifact[standard_artifacts.Examples] = None):
      del a, b, c, d, e, f, g, h, i, examples

    inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
        parse_typehint_component_function(func_a))
    self.assertDictEqual(
        inputs,
        {
            'a': standard_artifacts.Float,
            'b': standard_artifacts.Integer,
            # 'c' is missing here as it is a parameter.
            'd': standard_artifacts.Integer,
            'e': standard_artifacts.Integer,
            'f': standard_artifacts.String,
            'g': standard_artifacts.Bytes,
            # 'h' is missing here as it is a parameter.
            # 'i' is missing here as it is a parameter.
            'examples': standard_artifacts.Examples,
        })
    self.assertDictEqual(outputs, {})
    self.assertDictEqual(parameters, {
        'c': Text,
        'h': Text,
        'i': int,
    })
    self.assertDictEqual(
        arg_formats, {
            'a': ArgFormats.ARTIFACT_VALUE,
            'b': ArgFormats.ARTIFACT_VALUE,
            'c': ArgFormats.PARAMETER,
            'd': ArgFormats.ARTIFACT_VALUE,
            'e': ArgFormats.ARTIFACT_VALUE,
            'f': ArgFormats.ARTIFACT_VALUE,
            'g': ArgFormats.ARTIFACT_VALUE,
            'h': ArgFormats.PARAMETER,
            'i': ArgFormats.PARAMETER,
            'examples': ArgFormats.INPUT_ARTIFACT,
        })
    self.assertDictEqual(
        arg_defaults, {
            'd': 123,
            'e': 345,
            'f': 'abc',
            'g': b'xyz',
            'h': 'default',
            'i': 999,
            'examples': None,
        })
    self.assertEqual(returned_values, set([]))
예제 #3
0
    def testArtifactFunctionParse(self):
        def func_a(
            examples: InputArtifact[standard_artifacts.Examples],
            model: OutputArtifact[standard_artifacts.Model],
            schema: InputArtifact[standard_artifacts.Schema],
            statistics: OutputArtifact[standard_artifacts.ExampleStatistics],
            num_steps: Parameter[int]
        ) -> OutputDict(precision=float,
                        recall=float,
                        message=str,
                        serialized_value=bytes,
                        is_blessed=bool):
            del examples, model, schema, statistics, num_steps
            return {
                'precision': 0.9,
                'recall': 0.8,
                'message': 'foo',
                'serialized_value': b'bar',
                'is_blessed': False,
            }

        inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
            parse_typehint_component_function(func_a))
        self.assertDictEqual(
            inputs, {
                'examples': standard_artifacts.Examples,
                'schema': standard_artifacts.Schema,
            })
        self.assertDictEqual(
            outputs, {
                'model': standard_artifacts.Model,
                'statistics': standard_artifacts.ExampleStatistics,
                'precision': standard_artifacts.Float,
                'recall': standard_artifacts.Float,
                'message': standard_artifacts.String,
                'serialized_value': standard_artifacts.Bytes,
                'is_blessed': standard_artifacts.Boolean,
            })
        self.assertDictEqual(parameters, {
            'num_steps': int,
        })
        self.assertDictEqual(
            arg_formats, {
                'examples': ArgFormats.INPUT_ARTIFACT,
                'model': ArgFormats.OUTPUT_ARTIFACT,
                'schema': ArgFormats.INPUT_ARTIFACT,
                'statistics': ArgFormats.OUTPUT_ARTIFACT,
                'num_steps': ArgFormats.PARAMETER,
            })
        self.assertDictEqual(arg_defaults, {})
        self.assertEqual(
            returned_values, {
                'precision': False,
                'recall': False,
                'message': False,
                'serialized_value': False,
                'is_blessed': False,
            })
예제 #4
0
    def testEmptyReturnValue(self):
        # No output typehint.
        def func_a(examples: InputArtifact[standard_artifacts.Examples],
                   model: OutputArtifact[standard_artifacts.Model], a: int,
                   b: float, c: Parameter[int], d: Parameter[Text],
                   e: Parameter[bytes]):
            del examples, model, a, b, c, d, e

        # `None` output typehint.
        def func_b(examples: InputArtifact[standard_artifacts.Examples],
                   model: OutputArtifact[standard_artifacts.Model], a: int,
                   b: float, c: Parameter[int], d: Parameter[Text],
                   e: Parameter[bytes]) -> None:
            del examples, model, a, b, c, d, e

        # Both functions should be parsed in the same way.
        for func in [func_a, func_b]:
            (inputs, outputs, parameters, arg_formats, arg_defaults,
             returned_values) = parse_typehint_component_function(func)
            self.assertDictEqual(
                inputs, {
                    'examples': standard_artifacts.Examples,
                    'a': standard_artifacts.Integer,
                    'b': standard_artifacts.Float,
                })
            self.assertDictEqual(outputs, {
                'model': standard_artifacts.Model,
            })
            self.assertDictEqual(parameters, {
                'c': int,
                'd': Text,
                'e': bytes,
            })
            self.assertDictEqual(
                arg_formats, {
                    'examples': ArgFormats.INPUT_ARTIFACT,
                    'model': ArgFormats.OUTPUT_ARTIFACT,
                    'a': ArgFormats.ARTIFACT_VALUE,
                    'b': ArgFormats.ARTIFACT_VALUE,
                    'c': ArgFormats.PARAMETER,
                    'd': ArgFormats.PARAMETER,
                    'e': ArgFormats.PARAMETER,
                })
            self.assertDictEqual(arg_defaults, {})
            self.assertEqual(returned_values, set([]))
예제 #5
0
    def testOptionalReturnValues(self):
        def func_a() -> OutputDict(precision=float,
                                   recall=float,
                                   message=str,
                                   serialized_value=bytes,
                                   optional_label=Optional[str],
                                   optional_metric=Optional[float]):
            return {
                'precision': 0.9,
                'recall': 0.8,
                'message': 'foo',
                'serialized_value': b'bar',
                'optional_label': None,
                'optional_metric': 1.0,
            }

        inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
            parse_typehint_component_function(func_a))
        self.assertDictEqual(inputs, {})
        self.assertDictEqual(
            outputs, {
                'precision': standard_artifacts.Float,
                'recall': standard_artifacts.Float,
                'message': standard_artifacts.String,
                'serialized_value': standard_artifacts.Bytes,
                'optional_label': standard_artifacts.String,
                'optional_metric': standard_artifacts.Float,
            })
        self.assertDictEqual(parameters, {})
        self.assertDictEqual(arg_formats, {})
        self.assertDictEqual(arg_defaults, {})
        self.assertEqual(
            returned_values, {
                'precision': False,
                'recall': False,
                'message': False,
                'serialized_value': False,
                'optional_label': True,
                'optional_metric': True,
            })
예제 #6
0
def component(func: types.FunctionType) -> Callable[..., Any]:
    """Decorator: creates a component from a typehint-annotated Python function.

  This decorator creates a component based on typehint annotations specified for
  the arguments and return value for a Python function. Specifically, function
  arguments can be annotated with the following types and associated semantics:

  * `Parameter[T]` where `T` is `int`, `float`, `str`, or `bytes`: indicates
    that a primitive type execution parameter, whose value is known at pipeline
    construction time, will be passed for this argument. These parameters will
    be recorded in ML Metadata as part of the component's execution record. Can
    be an optional argument.
  * `int`, `float`, `str`, `bytes`: indicates that a primitive type value will
    be passed for this argument. This value is tracked as an `Integer`, `Float`
    `String` or `Bytes` artifact (see `tfx.types.standard_artifacts`) whose
    value is read and passed into the given Python component function. Can be
    an optional argument.
  * `InputArtifact[ArtifactType]`: indicates that an input artifact object of
    type `ArtifactType` (deriving from `tfx.types.Artifact`) will be passed for
    this argument. This artifact is intended to be consumed as an input by this
    component (possibly reading from the path specified by its `.uri`). Can be
    an optional argument by specifying a default value of `None`.
  * `OutputArtifact[ArtifactType]`: indicates that an output artifact object of
    type `ArtifactType` (deriving from `tfx.types.Artifact`) will be passed for
    this argument. This artifact is intended to be emitted as an output by this
    component (and written to the path specified by its `.uri`). Cannot be an
    optional argument.

  The return value typehint should be either empty or `None`, in the case of a
  component function that has no return values, or an instance of
  `OutputDict(key_1=type_1, ...)`, where each key maps to a given type (each
  type is a primitive value type, i.e. `int`, `float`, `str` or `bytes`), to
  indicate that the return value is a dictionary with specified keys and value
  types.

  Note that output artifacts should not be included in the return value
  typehint; they should be included as `OutputArtifact` annotations in the
  function inputs, as described above.

  The function to which this decorator is applied must be at the top level of
  its Python module (it may not be defined within nested classes or function
  closures).

  This is example usage of component definition using this decorator:

      from tfx.components.base.annotations import OutputDict
      from tfx.components.base.annotations import
      InputArtifact
      from tfx.components.base.annotations import
      OutputArtifact
      from tfx.components.base.annotations import
      Parameter
      from tfx.components.base.decorators import component
      from tfx.types.standard_artifacts import Examples
      from tfx.types.standard_artifacts import Model

      @component
      def MyTrainerComponent(
          training_data: InputArtifact[Examples],
          model: OutputArtifact[Model],
          dropout_hyperparameter: float,
          num_iterations: Parameter[int] = 10
          ) -> OutputDict(loss=float, accuracy=float):
        '''My simple trainer component.'''

        records = read_examples(training_data.uri)
        model_obj = train_model(records, num_iterations, dropout_hyperparameter)
        model_obj.write_to(model.uri)

        return {
          'loss': model_obj.loss,
          'accuracy': model_obj.accuracy
        }

      # Example usage in a pipeline graph definition:
      # ...
      trainer = MyTrainerComponent(
          examples=example_gen.outputs['examples'],
          dropout_hyperparameter=other_component.outputs['dropout'],
          num_iterations=1000)
      pusher = Pusher(model=trainer.outputs['model'])
      # ...

  Experimental: no backwards compatibility guarantees.

  Args:
    func: Typehint-annotated component executor function.

  Returns:
    `base_component.BaseComponent` subclass for the given component executor
    function.

  Raises:
    EnvironmentError: if the current Python interpreter is not Python 3.
  """
    if six.PY2:
        raise EnvironmentError('`@component` is only supported in Python 3.')

    # Defining a component within a nested class or function closure causes
    # problems because in this case, the generated component classes can't be
    # referenced via their qualified module path.
    #
    # See https://www.python.org/dev/peps/pep-3155/ for details about the special
    # '<locals>' namespace marker.
    if '<locals>' in func.__qualname__.split('.'):
        raise ValueError(
            'The @component decorator can only be applied to a function defined '
            'at the module level. It cannot be used to construct a component for a '
            'function defined in a nested class or function closure.')

    inputs, outputs, parameters, arg_formats, arg_defaults, returned_values = (
        function_parser.parse_typehint_component_function(func))

    spec_inputs = {}
    spec_outputs = {}
    spec_parameters = {}
    for key, artifact_type in inputs.items():
        spec_inputs[key] = component_spec.ChannelParameter(
            type=artifact_type, optional=(key in arg_defaults))
    for key, artifact_type in outputs.items():
        assert key not in arg_defaults, 'Optional outputs are not supported.'
        spec_outputs[key] = component_spec.ChannelParameter(type=artifact_type)
    for key, primitive_type in parameters.items():
        spec_parameters[key] = component_spec.ExecutionParameter(
            type=primitive_type, optional=(key in arg_defaults))
    component_spec_class = type(
        '%s_Spec' % func.__name__, (tfx_types.ComponentSpec, ), {
            'INPUTS': spec_inputs,
            'OUTPUTS': spec_outputs,
            'PARAMETERS': spec_parameters,
        })

    executor_class = type(
        '%s_Executor' % func.__name__,
        (_FunctionExecutor, ),
        {
            '_ARG_FORMATS': arg_formats,
            '_ARG_DEFAULTS': arg_defaults,
            # The function needs to be marked with `staticmethod` so that later
            # references of `self._FUNCTION` do not result in a bound method (i.e.
            # one with `self` as its first parameter).
            '_FUNCTION': staticmethod(func),
            '_RETURNED_VALUES': returned_values,
            '__module__': func.__module__,
        })

    # Expose the generated executor class in the same module as the decorated
    # function. This is needed so that the executor class can be accessed at the
    # proper module path. One place this is needed is in the Dill pickler used by
    # Apache Beam serialization.
    module = sys.modules[func.__module__]
    setattr(module, '%s_Executor' % func.__name__, executor_class)

    executor_spec_instance = executor_spec.ExecutorClassSpec(
        executor_class=executor_class)

    return type(
        func.__name__, (_SimpleComponent, ), {
            'SPEC_CLASS': component_spec_class,
            'EXECUTOR_SPEC': executor_spec_instance,
            '__module__': func.__module__,
        })
예제 #7
0
    def testFunctionParseErrors(self):
        # Non-function arguments.
        with self.assertRaisesRegexp(
                ValueError, 'Expected a typehint-annotated Python function'):
            parse_typehint_component_function(object())
        with self.assertRaisesRegexp(
                ValueError, 'Expected a typehint-annotated Python function'):
            parse_typehint_component_function('foo')

        # Unannotated lambda.
        with self.assertRaisesRegexp(
                ValueError,
                'must have all arguments annotated with typehints'):
            parse_typehint_component_function(lambda x: True)

        # Function with *args and **kwargs.
        with self.assertRaisesRegexp(
                ValueError,
                'must have either an OutputDict instance or `None` as its return'
        ):

            def func_a(a: int, b: int) -> object:
                del a, b
                return object()

            parse_typehint_component_function(func_a)

        # Function with *args and **kwargs.
        with self.assertRaisesRegexp(
                ValueError,
                r'does not support \*args or \*\*kwargs arguments'):

            def func_b(a: int, b: int, *unused_args) -> OutputDict(c=float):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_b)
        with self.assertRaisesRegexp(
                ValueError,
                r'does not support \*args or \*\*kwargs arguments'):

            def func_c(a: int, b: int, **unused_kwargs) -> OutputDict(c=float):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_c)

        # Not all arguments annotated with typehints.
        with self.assertRaisesRegexp(
                ValueError, 'must have all arguments annotated with typehint'):

            def func_d(a: int, b) -> OutputDict(c=float):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_d)

        # Artifact type used in annotation without `InputArtifact[ArtifactType]` or
        # `OutputArtifact[ArtifactType]` wrapper.
        with self.assertRaisesRegexp(
                ValueError, 'Invalid type hint annotation.*'
                'should indicate whether it is used as an input or output artifact'
        ):

            def func_e(
                    a: int, unused_b: standard_artifacts.Examples
            ) -> OutputDict(c=float):
                return {'c': float(a)}

            parse_typehint_component_function(func_e)

        # Invalid input typehint.
        with self.assertRaisesRegexp(ValueError,
                                     'Unknown type hint annotation'):

            def func_f(a: int, b: Dict[int, int]) -> OutputDict(c=float):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_f)

        # Invalid output typehint.
        with self.assertRaisesRegexp(ValueError,
                                     'Unknown type hint annotation'):

            def func_g(a: int, b: int) -> OutputDict(c='whatever'):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_g)

        # Output artifact in the wrong place.
        with self.assertRaisesRegexp(
                ValueError,
                'Output artifacts .* should be declared as function parameters'
        ):

            def func_h(a: int,
                       b: int) -> OutputDict(c=standard_artifacts.Examples):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_h)
        with self.assertRaisesRegexp(
                ValueError,
                'Output artifacts .* should be declared as function parameters'
        ):

            def func_i(
                a: int, b: int
            ) -> OutputDict(c=OutputArtifact[standard_artifacts.Examples]):
                return {'c': float(a + b)}

            parse_typehint_component_function(func_i)

        # Input artifact declared optional with non-`None` default value.
        with self.assertRaisesRegexp(
                ValueError,
                'If an input artifact is declared as an optional argument, its default '
                'value must be `None`'):

            def func_j(
                a: int,
                b: int,
                examples: InputArtifact[standard_artifacts.Examples] = 123
            ) -> OutputDict(c=float):
                del examples
                return {'c': float(a + b)}

            parse_typehint_component_function(func_j)

        # Output artifact declared optional.
        with self.assertRaisesRegexp(
                ValueError,
                'Output artifact of component function cannot be declared as optional'
        ):

            def func_k(
                a: int,
                b: int,
                model: OutputArtifact[standard_artifacts.Model] = None
            ) -> OutputDict(c=float):
                del model
                return {'c': float(a + b)}

            parse_typehint_component_function(func_k)

        # Optional parameter's default value does not match declared type.
        with self.assertRaisesRegexp(
                ValueError,
                'The default value for optional input value .* on function .* must be '
                'an instance of its declared type .* or `None`'):

            def func_l(a: int,
                       b: int,
                       num_iterations: int = 'abc') -> OutputDict(c=float):
                del num_iterations
                return {'c': float(a + b)}

            parse_typehint_component_function(func_l)

        # Optional parameter's default value does not match declared type.
        with self.assertRaisesRegexp(
                ValueError,
                'The default value for optional parameter .* on function .* must be an '
                'instance of its declared type .* or `None`'):

            def func_m(
                    a: int,
                    b: int,
                    num_iterations: Parameter[int] = 'abc'
            ) -> OutputDict(c=float):
                del num_iterations
                return {'c': float(a + b)}

            parse_typehint_component_function(func_m)