示例#1
0
    def testStartCMLETrainingWithUserContainer(self, mock_discovery):
        self._training_inputs['masterConfig'] = {'imageUri': 'my-custom-image'}
        mock_discovery.build.return_value = self._mock_api_client
        mock_create = mock.Mock()
        self._mock_api_client.projects().jobs().create = mock_create
        mock_get = mock.Mock()
        self._mock_api_client.projects().jobs().get.return_value = mock_get
        mock_get.execute.return_value = {
            'state': 'SUCCEEDED',
        }

        class_path = 'foo.bar.class'

        runner.start_cmle_training(self._inputs, self._outputs,
                                   self._exec_properties, class_path,
                                   self._training_inputs)

        mock_create.assert_called_with(body=mock.ANY,
                                       parent='projects/{}'.format(
                                           self._project_id))
        (_, kwargs) = mock_create.call_args
        body = kwargs['body']
        self.assertDictContainsSubset(
            {
                'masterConfig': {
                    'imageUri': 'my-custom-image',
                },
                'args': [
                    '--executor_class_path', class_path, '--inputs', '{}',
                    '--outputs', '{}', '--exec-properties',
                    '{"custom_config": {}}'
                ],
            }, body['trainingInput'])
        self.assertStartsWith(body['jobId'], 'tfx_')
        mock_get.execute.assert_called_with()
示例#2
0
    def testStartCMLETraining(self, mock_discovery, mock_dependency_utils):
        mock_discovery.build.return_value = self._mock_api_client
        mock_create = mock.Mock()
        self._mock_api_client.projects().jobs().create = mock_create
        mock_get = mock.Mock()
        self._mock_api_client.projects().jobs().get.return_value = mock_get
        mock_get.execute.return_value = {
            'state': 'SUCCEEDED',
        }
        mock_dependency_utils.build_ephemeral_package.return_value = self._fake_package

        class_path = 'foo.bar.class'

        runner.start_cmle_training(self._inputs, self._outputs,
                                   self._exec_properties, class_path,
                                   self._training_inputs)

        mock_dependency_utils.build_ephemeral_package.assert_called_with()

        mock_create.assert_called_with(body=mock.ANY,
                                       parent='projects/{}'.format(
                                           self._project_id))
        (_, kwargs) = mock_create.call_args
        body = kwargs['body']
        self.assertDictContainsSubset(
            {
                'pythonVersion':
                runner._get_caip_python_version(),
                'runtimeVersion':
                '.'.join(tf.__version__.split('.')[0:2]),
                'jobDir':
                self._job_dir,
                'args': [
                    '--executor_class_path', class_path, '--inputs', '{}',
                    '--outputs', '{}', '--exec-properties',
                    '{"custom_config": {}}'
                ],
                'pythonModule':
                'tfx.scripts.run_executor',
                'packageUris': [os.path.join(self._job_dir, 'fake_package')],
            }, body['trainingInput'])
        self.assertStartsWith(body['jobId'], 'tfx_')
        mock_get.execute.assert_called_with()
示例#3
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]):
        """Starts a trainer job on Google Cloud AI Platform.

    Args:
      input_dict: Passthrough input dict for tfx.components.Trainer.executor.
      output_dict: Passthrough input dict for tfx.components.Trainer.executor.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Trainer.executor.
        custom_config.ai_platform_training_args is consumed by this class.  For
        the full set of parameters supported by Google Cloud AI Platform, refer
        to
        https://cloud.google.com/ml-engine/docs/tensorflow/training-jobs#configuring_the_job

    Returns:
      None
    Raises:
      ValueError: if ai_platform_training_args is not in
      exec_properties.custom_config.
      RuntimeError: if the Google Cloud AI Platform training job failed.
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        if not exec_properties.get('custom_config',
                                   {}).get('ai_platform_training_args'):
            err_msg = '\'ai_platform_training_args\' not found in custom_config.'
            tf.logging.error(err_msg)
            raise ValueError(err_msg)

        training_inputs = exec_properties.get(
            'custom_config', {}).pop('ai_platform_training_args')
        executor_class_path = '%s.%s' % (
            tfx_trainer_executor.Executor.__module__,
            tfx_trainer_executor.Executor.__name__)
        return runner.start_cmle_training(input_dict, output_dict,
                                          exec_properties, executor_class_path,
                                          training_inputs)
示例#4
0
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    """Uses a user-supplied tf.estimator to train a TensorFlow model locally.

    The Trainer Executor invokes a training_fn callback function provided by
    the user via the module_file parameter.  With the tf.estimator returned by
    this function, the Trainer Executor then builds a TensorFlow model using the
    user-provided tf.estimator.

    Args:
      input_dict: Input dict from input key to a list of ML-Metadata Artifacts.
        - examples: Examples used for training, must include 'train' and 'eval'
          splits.
        - transform_output: Optional input transform graph.
        - schema: Schema of the data.
      output_dict: Output dict from output key to a list of Artifacts.
        - output: Exported model.
      exec_properties: A dict of execution properties.
        - train_args: JSON string of trainer_pb2.TrainArgs instance, providing
          args for training.
        - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing
          args for eval.
        - module_file: Python module file containing UDF model definition.
        - warm_starting: Whether or not we need to do warm starting.
        - warm_start_from: Optional. If warm_starting is True, this is the
          directory to find previous model to warm start on.

    Returns:
      None

    Raises:
      ValueError: When neither or both of 'module_file' and 'trainer_fn'
        are present in 'exec_properties'.
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    # TODO(zhitaoli): Deprecate this in a future version.
    if exec_properties.get('custom_config', None):
      cmle_args = exec_properties.get('custom_config',
                                      {}).get('cmle_training_args')
      if cmle_args:
        executor_class_path = '.'.join([Executor.__module__, Executor.__name__])
        absl.logging.warn(
            'Passing \'cmle_training_args\' to trainer directly is deprecated, '
            'please use extension executor at '
            'tfx.extensions.google_cloud_ai_platform.trainer.executor instead')

        return runner.start_cmle_training(input_dict, output_dict,
                                          exec_properties, executor_class_path,
                                          cmle_args)

    trainer_fn = self._GetTrainerFn(exec_properties)

    # Set up training parameters
    train_files = [
        _all_files_pattern(
            artifact_utils.get_split_uri(input_dict['examples'], 'train'))
    ]
    transform_output = artifact_utils.get_single_uri(
        input_dict['transform_output']) if input_dict.get(
            'transform_output', None) else None
    eval_files = [
        _all_files_pattern(
            artifact_utils.get_split_uri(input_dict['examples'], 'eval'))
    ]
    schema_file = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(input_dict['schema']))

    train_args = trainer_pb2.TrainArgs()
    eval_args = trainer_pb2.EvalArgs()
    json_format.Parse(exec_properties['train_args'], train_args)
    json_format.Parse(exec_properties['eval_args'], eval_args)

    # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
    # num_steps=None.  Conversion of the proto to python will set the default
    # value of an int as 0 so modify the value here.  Tensorflow will raise an
    # error if num_steps <= 0.
    train_steps = train_args.num_steps or None
    eval_steps = eval_args.num_steps or None

    output_path = artifact_utils.get_single_uri(output_dict['output'])
    serving_model_dir = path_utils.serving_model_dir(output_path)
    eval_model_dir = path_utils.eval_model_dir(output_path)

    # Assemble warm start path if needed.
    warm_start_from = None
    if exec_properties.get('warm_starting') and exec_properties.get(
        'warm_start_from'):
      previous_model_dir = os.path.join(exec_properties['warm_start_from'],
                                        path_utils.SERVING_MODEL_DIR)
      if previous_model_dir and tf.io.gfile.exists(
          os.path.join(previous_model_dir, self._CHECKPOINT_FILE_NAME)):
        warm_start_from = previous_model_dir

    # TODO(b/126242806) Use PipelineInputs when it is available in third_party.
    hparams = _HParamWrapper(
        # A list of uris for train files.
        train_files=train_files,
        # An optional single uri for transform graph produced by TFT. Will be
        # None if not specified.
        transform_output=transform_output,
        # A single uri for the output directory of the serving model.
        serving_model_dir=serving_model_dir,
        # A list of uris for eval files.
        eval_files=eval_files,
        # A single uri for schema file.
        schema_file=schema_file,
        # Number of train steps.
        train_steps=train_steps,
        # Number of eval steps.
        eval_steps=eval_steps,
        # A single uri for the model directory to warm start from.
        warm_start_from=warm_start_from)

    schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())

    training_spec = trainer_fn(hparams, schema)

    # Train the model
    absl.logging.info('Training model.')
    tf.estimator.train_and_evaluate(training_spec['estimator'],
                                    training_spec['train_spec'],
                                    training_spec['eval_spec'])
    absl.logging.info('Training complete.  Model written to %s',
                      serving_model_dir)

    # Export an eval savedmodel for TFMA
    absl.logging.info('Exporting eval_savedmodel for TFMA.')
    tfma.export.export_eval_savedmodel(
        estimator=training_spec['estimator'],
        export_dir_base=eval_model_dir,
        eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])

    absl.logging.info('Exported eval_savedmodel to %s.', eval_model_dir)