Exemplo n.º 1
0
  def __init__(
      self,
      examples: types.Channel = None,
      transformed_examples: Optional[types.Channel] = None,
      transform_graph: Optional[types.Channel] = None,
      schema: Optional[types.Channel] = None,
      base_model: Optional[types.Channel] = None,
      hyperparameters: Optional[types.Channel] = None,
      module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      # TODO(b/147702778): deprecate trainer_fn.
      trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
      train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
      eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
      custom_config: Optional[Dict[Text, Any]] = None,
      custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
      model: Optional[types.Channel] = None,
      model_run: Optional[types.Channel] = None,
      instance_name: Optional[Text] = None):
    """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  An optional Channel of type `standard_artifacts.Schema`, serving
        as the schema of training and eval data. Schema is optional when
        1) transform_graph is provided which contains schema.
        2) user module bypasses the usage of schema, e.g., hardcoded.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `run_fn` at its top
        level with function signature:
          `def run_fn(trainer.fn_args_utils.FnArgs)`,
        and the trained model must be saved to FnArgs.serving_model_dir when
        this function is executed.

        For Estimator based Executor, The module_file must implement a function
        named `trainer_fn` at its top level. The function must have the
        following signature.
          def trainer_fn(trainer.fn_args_utils.FnArgs,
                         tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
            ...
          where the returned Dict has the following key-values.
            'estimator': an instance of tf.estimator.Estimator
            'train_spec': an instance of tf.estimator.TrainSpec
            'eval_spec': an instance of tf.estimator.EvalSpec
            'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.

      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default).
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
        uses Estimator based Executor
      train_args: A trainer_pb2.TrainArgs instance or a dict, containing args
        used for training. Currently only splits and num_steps are available. If
        it's provided as a dict and any field is a RuntimeParameter, it should
        have the same field names as a TrainArgs proto message. Default
        behavior (when splits is empty) is train on `train` split.
      eval_args: A trainer_pb2.EvalArgs instance or a dict, containing args
        used for evaluation. Currently only splits and num_steps are available.
        If it's provided as a dict and any field is a RuntimeParameter, it
        should have the same field names as a EvalArgs proto message. Default
        behavior (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec.
      model: Optional `Model` channel for result of exported models.
      model_run: Optional `ModelRun` channel, as the working dir of models,
        can be used to output non-model related output (e.g., TensorBoard logs).
      instance_name: Optional unique instance name. Necessary iff multiple
        Trainer components are declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
    if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1:
      raise ValueError(
          "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
          "supplied.")

    if bool(examples) == bool(transformed_examples):
      raise ValueError(
          "Exactly one of 'example' or 'transformed_example' must be supplied.")

    if transformed_examples and not transform_graph:
      raise ValueError("If 'transformed_examples' is supplied, "
                       "'transform_graph' must be supplied too.")
    examples = examples or transformed_examples
    model = model or types.Channel(type=standard_artifacts.Model)
    model_run = model_run or types.Channel(type=standard_artifacts.ModelRun)
    spec = TrainerSpec(
        examples=examples,
        transform_graph=transform_graph,
        schema=schema,
        base_model=base_model,
        hyperparameters=hyperparameters,
        train_args=train_args,
        eval_args=eval_args,
        module_file=module_file,
        run_fn=run_fn,
        trainer_fn=trainer_fn,
        custom_config=json_utils.dumps(custom_config),
        model=model,
        model_run=model_run)
    super(Trainer, self).__init__(
        spec=spec,
        custom_executor_spec=custom_executor_spec,
        instance_name=instance_name)
Exemplo n.º 2
0
    def __init__(
            self,
            examples: types.Channel = None,
            transformed_examples: Optional[types.Channel] = None,
            transform_graph: Optional[types.Channel] = None,
            schema: types.Channel = None,
            base_model: Optional[types.Channel] = None,
            hyperparameters: Optional[types.Channel] = None,
            module_file: Optional[Union[Text,
                                        data_types.RuntimeParameter]] = None,
            trainer_fn: Optional[Union[Text,
                                       data_types.RuntimeParameter]] = None,
            train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
            eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None,
            output: Optional[types.Channel] = None,
            transform_output: Optional[types.Channel] = None,
            instance_name: Optional[Text] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_graph: An optional Channel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  A Channel of type `standard_artifacts.Schema`, serving as the
        schema of training and eval data.
      base_model: A Channel of type `Model`, containing model that will be used
        for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A Channel of type `standard_artifacts.HyperParameters`,
        serving as the hyperparameters for training module. Tuner's output best
        hyperparameters can be feed into this.
      module_file: A path to python module file containing UDF model
        definition. The module_file must implement a function named
        `trainer_fn` at its top level. The function must have the following
        signature.

        def trainer_fn(trainer.executor._TrainerFnArgs,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of
            tfma.export.EvalInputReceiver. Exactly one of 'module_file' or
            'trainer_fn' must be supplied.
      trainer_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF. Exactly one of
        'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Current only num_steps is available.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Current only num_steps is available.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec.
      output: Optional `Model` channel for result of exported models.
      transform_output: Backwards compatibility alias for the 'transform_graph'
        argument.
      instance_name: Optional unique instance name. Necessary iff multiple
        Trainer components are declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and 'trainer_fn' is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
        transform_graph = transform_graph or transform_output
        if bool(module_file) == bool(trainer_fn):
            raise ValueError(
                "Exactly one of 'module_file' or 'trainer_fn' must be supplied"
            )

        if bool(examples) == bool(transformed_examples):
            raise ValueError(
                "Exactly one of 'example' or 'transformed_example' must be supplied."
            )

        if transformed_examples and not transform_graph:
            raise ValueError("If 'transformed_examples' is supplied, "
                             "'transform_graph' must be supplied too.")
        examples = examples or transformed_examples
        output = output or types.Channel(
            type=standard_artifacts.Model,
            artifacts=[standard_artifacts.Model()])
        spec = TrainerSpec(examples=examples,
                           transform_output=transform_graph,
                           schema=schema,
                           base_model=base_model,
                           hyperparameters=hyperparameters,
                           train_args=train_args,
                           eval_args=eval_args,
                           module_file=module_file,
                           trainer_fn=trainer_fn,
                           custom_config=custom_config,
                           output=output)
        super(Trainer,
              self).__init__(spec=spec,
                             custom_executor_spec=custom_executor_spec,
                             instance_name=instance_name)
Exemplo n.º 3
0
  def __init__(
      self,
      examples: types.Channel = None,
      transformed_examples: Optional[types.Channel] = None,
      transform_output: Optional[types.Channel] = None,
      schema: types.Channel = None,
      module_file: Optional[Text] = None,
      trainer_fn: Optional[Text] = None,
      train_args: trainer_pb2.TrainArgs = None,
      eval_args: trainer_pb2.EvalArgs = None,
      custom_config: Optional[Dict[Text, Any]] = None,
      executor_class: Optional[Type[base_executor.BaseExecutor]] = None,
      output: Optional[types.Channel] = None,
      transform_graph: Optional[types.Channel] = None,
      name: Optional[Text] = None):
    """Construct a Trainer component.

    Args:
      examples: A Channel of 'ExamplesPath' type, serving as the source of
        examples that are used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_output: An optional Channel of 'TransformPath' type, serving as
        the input transform graph if present.
      schema:  A Channel of 'SchemaPath' type, serving as the schema of training
        and eval data.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `trainer_fn` at its
        top level. The function must have the following signature.

        def trainer_fn(tf.contrib.training.HParams,
                       tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
          ...

        where the returned Dict has the following key-values.
          'estimator': an instance of tf.estimator.Estimator
          'train_spec': an instance of tf.estimator.TrainSpec
          'eval_spec': an instance of tf.estimator.EvalSpec
          'eval_input_receiver_fn': an instance of tfma.export.EvalInputReceiver

        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      trainer_fn:  A python path to UDF model definition function. See
        'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Current only num_steps is available.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Current only num_steps is available.
      custom_config: A dict which contains the training job parameters to be
        passed to Google Cloud ML Engine.  For the full set of parameters
        supported by Google Cloud ML Engine, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      executor_class: Optional custom executor class.
      output: Optional 'ModelExportPath' channel for result of exported models.
      transform_graph: Forwards compatibility alias for the 'transform_output'
        argument.
      name: Optional unique name. Necessary iff multiple Trainer components are
        declared in the same pipeline.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and 'trainer_fn' is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_output'
            is not supplied.
    """
    transform_output = transform_output or transform_graph
    if bool(module_file) == bool(trainer_fn):
      raise ValueError(
          "Exactly one of 'module_file' or 'trainer_fn' must be supplied")

    if bool(examples) == bool(transformed_examples):
      raise ValueError(
          "Exactly one of 'example' or 'transformed_example' must be supplied.")

    if transformed_examples and not transform_output:
      raise ValueError("If 'transformed_examples' is supplied, "
                       "'transform_output' must be supplied too.")
    examples = examples or transformed_examples
    output = output or types.Channel(
        type=standard_artifacts.Model, artifacts=[standard_artifacts.Model()])
    spec = TrainerSpec(
        examples=examples,
        transform_output=transform_output,
        schema=schema,
        train_args=train_args,
        eval_args=eval_args,
        module_file=module_file,
        trainer_fn=trainer_fn,
        custom_config=custom_config,
        output=output)
    super(Trainer, self).__init__(
        spec=spec, custom_executor_class=executor_class, name=name)