Exemplo n.º 1
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: types.Channel = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
            # TODO(b/147702778): deprecate trainer_fn.
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            output: Optional[types.Channel] = None,
            model_run: Optional[types.Channel] = None,
            transform_output: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  A Channel of type `standard_artifacts.Schema`, serving as the
        schema of training and eval data.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.

        For default executor, The module_file must implement a function named
        `trainer_fn` at its top level. The function must have the following
        signature.

        def trainer_fn(trainer.executor.TrainerFnArgs,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of
            tfma.export.EvalInputReceiver. Exactly one of 'module_file' or
            'trainer_fn' must be supplied.

        For generic executor, The module_file must implement a function named
        `run_fn` at its top level with function signature:
        `def run_fn(trainer.executor.TrainerFnArgs)`, and the trained model must
        be saved to TrainerFnArgs.serving_model_dir when execute this function.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor.
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance or a dict, containing args
        used for training. Current only num_steps is available. If it's provided
        as a dict and any field is a RuntimeParameter, it should have the same
        field names as a TrainArgs proto message.
      eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args
        used for evaluation. Current only num_steps is available. If it's
        provided as a dict and any field is a RuntimeParameter, it should have
        the same field names as a EvalArgs proto message.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec.
      output: Optional `Model` channel for result of exported models.
      model_run: Optional `ModelRun` channel, as the working dir of models,
        can be used to output non-model related output (e.g., TensorBoard logs).
      transform_output: Backwards compatibility alias for the 'transform_graph'
        argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        Trainer components are declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        if [bool(module_file),
                bool(run_fn),
                bool(trainer_fn)].count(True) != 1:
            raise ValueError(
                "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be supplied."
            )

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transform_output:
            absl.logging.warning(
                'The "transform_output" argument to the Trainer component has '
                'been renamed to "transform_graph" and is deprecated. Please update '
                "your usage as support for this argument will be removed soon."
            )
            transform_graph = transform_output
        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")
        examples = examples or transformed_examples
        output = output or types.Channel(
            type=standard_artifacts.Model,
            artifacts=[standard_artifacts.Model()])
        model_run = model_run or types.Channel(
            type=standard_artifacts.ModelRun,
            artifacts=[standard_artifacts.ModelRun()])
        spec = TrainerSpec(
            examples=examples,
            transform_graph=transform_graph,
            schema=schema,
            base_model=base_model,
            hyperparameters=hyperparameters,
            train_args=train_args,
            eval_args=eval_args,
            module_file=module_file,
            run_fn=run_fn,
            trainer_fn=trainer_fn,
            custom_config=json_utils.dumps(custom_config),
            model=output,
            # TODO(b/158106209): change the model_run as optional output artifact
            model_run=model_run)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)
Exemplo n.º 2
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: types.Channel = None,
            transform_output: Optional[types.Channel] = None,
            schema: types.Channel = None,
            module_file: Optional[Text] = None,
            trainer_fn: Optional[Text] = None,
            train_args: trainer_pb2.TrainArgs = None,
            eval_args: trainer_pb2.EvalArgs = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            executor_class: Optional[Type[base_executor.BaseExecutor]] = None,
            output: Optional[types.Channel] = None,
            name: Optional[Text] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of 'ExamplesPath' type, serving as the source of
        examples that are used in training. May be raw or transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_output: An optional Channel of 'TransformPath' type, serving as
        the input transform graph if present.
      schema:  A Channel of 'SchemaPath' type, serving as the schema of training
        and eval data.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `trainer_fn` at its
        top level. The function must have the following signature.

        def trainer_fn(tf.contrib.training.HParams,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver

        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      trainer_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Current only num_steps is available.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Current only num_steps is available.
      custom_config: A dict which contains the training job parameters to be
        passed to Google Cloud ML Engine.  For the full set of parameters
        supported by Google Cloud ML Engine, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      executor_class: Optional custom executor class.
      output: Optional 'ModelExportPath' channel for result of exported models.
      name: Optional unique name. Necessary iff multiple Trainer components are
        declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and 'trainer_fn' is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_output'
            is not supplied.
    """
        if bool(module_file) == bool(trainer_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'trainer_fn' must be supplied"
            )

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_output:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_output' must be supplied too.")
        examples = examples or transformed_examples
        output = output or types.Channel(
            type=standard_artifacts.Model,
            artifacts=[standard_artifacts.Model()])
        spec = TrainerSpec(examples=examples,
                           transform_output=transform_output,
                           schema=schema,
                           train_args=train_args,
                           eval_args=eval_args,
                           module_file=module_file,
                           trainer_fn=trainer_fn,
                           custom_config=custom_config,
                           output=output)
        super(Trainer, self).__init__(spec=spec,
                                      custom_executor_class=executor_class,
                                      name=name)
Exemplo n.º 3
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: Optional[types.Channel] = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
            # TODO(b/147702778): deprecate trainer_fn.
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            model: Optional[types.Channel] = None,
            model_run: Optional[types.Channel] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. Schema is optional when
        1) transform_graph is provided which contains schema.
        2) user module bypasses the usage of schema, e.g., hardcoded.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `run_fn` at its top
        level with function signature:
          `def run_fn(trainer.fn_args_utils.FnArgs)`,
        and the trained model must be saved to FnArgs.serving_model_dir when
        this function is executed.

        For Estimator based Executor, The module_file must implement a function
        named `trainer_fn` at its top level. The function must have the
        following signature.
          def trainer_fn(trainer.fn_args_utils.FnArgs,
                         tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
            ...
          where the returned Dict has the following key-values.
            'estimator': an instance of tf.estimator.Estimator
            'train_spec': an instance of tf.estimator.TrainSpec
            'eval_spec': an instance of tf.estimator.EvalSpec
            'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.

      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default).
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
        uses Estimator based Executor
      train_args: A trainer_pb2.TrainArgs instance or a dict, containing args
        used for training. Currently only splits and num_steps are available. If
        it's provided as a dict and any field is a RuntimeParameter, it should
        have the same field names as a TrainArgs proto message. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args
        used for evaluation. Currently only splits and num_steps are available.
        If it's provided as a dict and any field is a RuntimeParameter, it
        should have the same field names as a EvalArgs proto message. Default
        behavior (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec. This is experimental
        and is subject to change in the future.
      model: Optional `Model` channel for result of exported models.
      model_run: Optional `ModelRun` channel, as the working dir of models,
        can be used to output non-model related output (e.g., TensorBoard logs).

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        if [bool(module_file),
                bool(run_fn),
                bool(trainer_fn)].count(True) != 1:
            raise ValueError(
                "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
                "supplied.")

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")
        examples = examples or transformed_examples
        model = model or types.Channel(type=standard_artifacts.Model)
        model_run = model_run or types.Channel(
            type=standard_artifacts.ModelRun)
        spec = TrainerSpec(examples=examples,
                           transform_graph=transform_graph,
                           schema=schema,
                           base_model=base_model,
                           hyperparameters=hyperparameters,
                           train_args=train_args,
                           eval_args=eval_args,
                           module_file=module_file,
                           run_fn=run_fn,
                           trainer_fn=trainer_fn,
                           custom_config=json_utils.dumps(custom_config),
                           model=model,
                           model_run=model_run)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec)
Exemplo n.º 4
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: types.Channel = None,
            base_model: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            output: Optional[types.Channel] = None,
            transform_output: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  A Channel of type `standard_artifacts.Schema`, serving as the
        schema of training and eval data.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      module_file: A path to python module file containing UDF model
        definition. The module_file must implement a function named
        `trainer_fn` at its top level. The function must have the following
        signature.

        def trainer_fn(trainer.executor._TrainerFnArgs,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of
            tfma.export.EvalInputReceiver. Exactly one of 'module_file' or
            'trainer_fn' must be supplied.
      trainer_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF. Exactly one of
        'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Current only num_steps is available.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Current only num_steps is available.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec.
      output: Optional `Model` channel for result of exported models.
      transform_output: Backwards compatibility alias for the 'transform_graph'
        argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        Trainer components are declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and 'trainer_fn' is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        transform_graph = transform_graph or transform_output
        if bool(module_file) == bool(trainer_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'trainer_fn' must be supplied"
            )

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")
        examples = examples or transformed_examples
        output = output or types.Channel(
            type=standard_artifacts.Model,
            artifacts=[standard_artifacts.Model()])
        spec = TrainerSpec(examples=examples,
                           transform_output=transform_graph,
                           schema=schema,
                           base_model=base_model,
                           train_args=train_args,
                           eval_args=eval_args,
                           module_file=module_file,
                           trainer_fn=trainer_fn,
                           custom_config=custom_config,
                           output=output)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)