Ejemplo n.º 1
0
def get_fn(exec_properties: Dict[Text, Any],
           fn_name: Text) -> Callable[..., Any]:
    """Loads and returns user-defined function."""
    logging.error('udf_utils.get_fn %r %r', exec_properties, fn_name)

    has_module_file = bool(exec_properties.get(_MODULE_FILE_KEY))
    has_module_path = bool(exec_properties.get(_MODULE_PATH_KEY))
    has_fn = bool(exec_properties.get(fn_name))

    if has_module_path:
        module_path = exec_properties[_MODULE_PATH_KEY]
        return import_utils.import_func_from_module(module_path, fn_name)
    elif has_module_file:
        if has_fn:
            return import_utils.import_func_from_source(
                exec_properties[_MODULE_FILE_KEY], exec_properties[fn_name])
        else:
            return import_utils.import_func_from_source(
                exec_properties[_MODULE_FILE_KEY], fn_name)
    elif has_fn:
        fn_path_split = exec_properties[fn_name].split('.')
        return import_utils.import_func_from_module(
            '.'.join(fn_path_split[0:-1]), fn_path_split[-1])
    else:
        raise ValueError(
            'Neither module file or user function have been supplied in `exec_properties`.'
        )
Ejemplo n.º 2
0
 def testImportFuncFromSource(self):
     source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
     test_fn_file = os.path.join(source_data_dir, 'test_fn.ext')
     fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
     fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
     self.assertIs(fn_1, fn_2)
     self.assertEqual(10, fn_1([1, 2, 3, 4]))
Ejemplo n.º 3
0
  def _GetPreprocessingFn(self, inputs: Mapping[Text, Any],
                          unused_outputs: Mapping[Text, Any]) -> Any:
    """Returns a user defined preprocessing_fn.

    Args:
      inputs: A dictionary of labelled input values.
      unused_outputs: A dictionary of labelled output values.

    Returns:
      User defined function.

    Raises:
      ValueError: When neither or both of MODULE_FILE and PREPROCESSING_FN
        are present in inputs.
    """
    has_module_file = bool(
        common.GetSoleValue(inputs, labels.MODULE_FILE, strict=False))
    has_preprocessing_fn = bool(
        common.GetSoleValue(inputs, labels.PREPROCESSING_FN, strict=False))

    if has_module_file == has_preprocessing_fn:
      raise ValueError(
          'Neither or both of MODULE_FILE and PREPROCESSING_FN have been '
          'supplied in inputs.')

    if has_module_file:
      return import_utils.import_func_from_source(
          common.GetSoleValue(inputs, labels.MODULE_FILE), 'preprocessing_fn')

    preprocessing_fn_path_split = common.GetSoleValue(
        inputs, labels.PREPROCESSING_FN).split('.')
    return import_utils.import_func_from_module(
        '.'.join(preprocessing_fn_path_split[0:-1]),
        preprocessing_fn_path_split[-1])
Ejemplo n.º 4
0
 def testtestImportFuncFromModuleReload(self):
     temp_dir = self.create_tempdir().full_path
     test_fn_file = os.path.join(temp_dir, 'fn.py')
     with tf.io.gfile.GFile(test_fn_file, mode='w') as f:
         f.write("""def test_fn(inputs):
         return sum(inputs)
       """)
     i = import_utils._tfx_module_finder.count_registered
     fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
     self.assertEqual(10, fn_1([1, 2, 3, 4]))
     with tf.io.gfile.GFile(test_fn_file, mode='w') as f:
         f.write("""def test_fn(inputs):
         return 1+sum(inputs)
       """)
     fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
     self.assertEqual(11, fn_2([1, 2, 3, 4]))
     fn_3 = getattr(importlib.reload(sys.modules['user_module_%d' % i]),
                    'test_fn')
     self.assertEqual(11, fn_3([1, 2, 3, 4]))
Ejemplo n.º 5
0
  def _GetPreprocessingFn(self, inputs: Mapping[Text, Any],
                          unused_outputs: Mapping[Text, Any]) -> Any:
    """Returns a user defined preprocessing_fn.

    Args:
      inputs: A dictionary of labelled input values.
      unused_outputs: A dictionary of labelled output values.

    Returns:
      User defined function.
    """
    return import_utils.import_func_from_source(
        common.GetSoleValue(inputs, labels.PREPROCESSING_FN),
        'preprocessing_fn')
Ejemplo n.º 6
0
 def Do(self, input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
     del input_dict
     if _MODULE_FILE_KEY in exec_properties:
         create_decoder_func = import_utils.import_func_from_source(
             exec_properties.get(_MODULE_FILE_KEY),
             exec_properties.get(_CREATE_DECODER_FUNC_KEY))
     else:
         create_decoder_func = udf_utils.get_fn(exec_properties,
                                                _CREATE_DECODER_FUNC_KEY)
     tf_graph_record_decoder.save_decoder(
         create_decoder_func(),
         value_utils.GetSoleValue(output_dict, _DATA_VIEW_KEY).uri)
Ejemplo n.º 7
0
  def _GetTrainerFn(self, exec_properties: Dict[Text, Any]) -> Any:
    """Loads and returns user-defined trainer_fn."""

    has_module_file = bool(exec_properties.get('module_file'))
    has_trainer_fn = bool(exec_properties.get('trainer_fn'))

    if has_module_file == has_trainer_fn:
      raise ValueError(
          "Neither or both of 'module_file' 'trainer_fn' have been supplied in "
          "'exec_properties'.")

    if has_module_file:
      return import_utils.import_func_from_source(
          exec_properties['module_file'], 'trainer_fn')

    trainer_fn_path_split = exec_properties['trainer_fn'].split('.')
    return import_utils.import_func_from_module(
        '.'.join(trainer_fn_path_split[0:-1]), trainer_fn_path_split[-1])
Ejemplo n.º 8
0
  def _GetFn(self, exec_properties: Dict[Text, Any], fn_name: Text) -> Any:
    """Loads and returns user-defined function."""

    has_module_file = bool(exec_properties.get('module_file'))
    has_fn = bool(exec_properties.get(fn_name))

    if has_module_file == has_fn:
      raise ValueError(
          'Neither or both of module file and user function have been supplied in '
          "'exec_properties'.")

    if has_module_file:
      return import_utils.import_func_from_source(
          exec_properties['module_file'], fn_name)

    fn_path_split = exec_properties[fn_name].split('.')
    return import_utils.import_func_from_module('.'.join(fn_path_split[0:-1]),
                                                fn_path_split[-1])
Ejemplo n.º 9
0
def get_fn(exec_properties: Dict[Text, Any],
           fn_name: Text) -> Callable[..., Any]:
  """Loads and returns user-defined function."""

  has_module_file = bool(exec_properties.get(_MODULE_FILE_KEY))
  has_fn = bool(exec_properties.get(fn_name))

  if has_module_file == has_fn:
    raise ValueError(
        'Neither or both of module file and user function have been supplied '
        "in 'exec_properties'.")

  if has_module_file:
    return import_utils.import_func_from_source(
        exec_properties[_MODULE_FILE_KEY], fn_name)

  fn_path_split = exec_properties[fn_name].split('.')
  return import_utils.import_func_from_module('.'.join(fn_path_split[0:-1]),
                                              fn_path_split[-1])
Ejemplo n.º 10
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Uses a user-supplied tf.estimator to train a TensorFlow model locally.

    The Trainer Executor invokes a training_fn callback function provided by
    the user via the module_file parameter.  With the tf.estimator returned by
    this function, the Trainer Executor then builds a TensorFlow model using the
    user-provided tf.estimator.

    Args:
      input_dict: Input dict from input key to a list of ML-Metadata Artifacts.
        - examples: Examples used for training, must include 'train' and 'eval'
          splits.
        - transform_output: Optional input transform graph.
        - schema: Schema of the data.
      output_dict: Output dict from output key to a list of Artifacts.
        - output: Exported model.
      exec_properties: A dict of execution properties.
        - train_args: JSON string of trainer_pb2.TrainArgs instance, providing
          args for training.
        - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing
          args for eval.
        - module_file: Python module file containing UDF model definition.
        - warm_starting: Whether or not we need to do warm starting.
        - warm_start_from: Optional. If warm_starting is True, this is the
          directory to find previous model to warm start on.

    Returns:
      None

    Raises:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        # TODO(zhitaoli): Deprecate this in a future version.
        if exec_properties.get('custom_config', None):
            cmle_args = exec_properties.get('custom_config',
                                            {}).get('cmle_training_args')
            if cmle_args:
                executor_class_path = '.'.join(
                    [Executor.__module__, Executor.__name__])
                tf.logging.warn(
                    'Passing \'cmle_training_args\' to trainer directly is deprecated, '
                    'please use extension executor at '
                    'tfx.extensions.google_cloud_ai_platform.trainer.executor instead'
                )

                return runner.start_cmle_training(input_dict, output_dict,
                                                  exec_properties,
                                                  executor_class_path,
                                                  cmle_args)

        trainer_fn = import_utils.import_func_from_source(
            exec_properties['module_file'], 'trainer_fn')

        # Set up training parameters
        train_files = [
            _all_files_pattern(
                artifact_utils.get_split_uri(input_dict['examples'], 'train'))
        ]
        transform_output = artifact_utils.get_single_uri(
            input_dict['transform_output']
        ) if input_dict['transform_output'] else None
        eval_files = [
            _all_files_pattern(
                artifact_utils.get_split_uri(input_dict['examples'], 'eval'))
        ]
        schema_file = io_utils.get_only_uri_in_dir(
            artifact_utils.get_single_uri(input_dict['schema']))

        train_args = trainer_pb2.TrainArgs()
        eval_args = trainer_pb2.EvalArgs()
        json_format.Parse(exec_properties['train_args'], train_args)
        json_format.Parse(exec_properties['eval_args'], eval_args)

        # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
        # num_steps=None.  Conversion of the proto to python will set the default
        # value of an int as 0 so modify the value here.  Tensorflow will raise an
        # error if num_steps <= 0.
        train_steps = train_args.num_steps or None
        eval_steps = eval_args.num_steps or None

        output_path = artifact_utils.get_single_uri(output_dict['output'])
        serving_model_dir = path_utils.serving_model_dir(output_path)
        eval_model_dir = path_utils.eval_model_dir(output_path)

        # Assemble warm start path if needed.
        warm_start_from = None
        if exec_properties.get('warm_starting') and exec_properties.get(
                'warm_start_from'):
            previous_model_dir = os.path.join(
                exec_properties['warm_start_from'],
                path_utils.SERVING_MODEL_DIR)
            if previous_model_dir and tf.gfile.Exists(
                    os.path.join(previous_model_dir,
                                 self._CHECKPOINT_FILE_NAME)):
                warm_start_from = previous_model_dir

        # TODO(b/126242806) Use PipelineInputs when it is available in third_party.
        hparams = tf.contrib.training.HParams(
            # A list of uris for train files.
            train_files=train_files,
            # An optional single uri for transform graph produced by TFT. Will be
            # None if not specified.
            transform_output=transform_output,
            # A single uri for the output directory of the serving model.
            serving_model_dir=serving_model_dir,
            # A list of uris for eval files.
            eval_files=eval_files,
            # A single uri for schema file.
            schema_file=schema_file,
            # Number of train steps.
            train_steps=train_steps,
            # Number of eval steps.
            eval_steps=eval_steps,
            # A single uri for the model directory to warm start from.
            warm_start_from=warm_start_from)

        schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())

        training_spec = trainer_fn(hparams, schema)

        # Train the model
        tf.logging.info('Training model.')
        tf.estimator.train_and_evaluate(training_spec['estimator'],
                                        training_spec['train_spec'],
                                        training_spec['eval_spec'])
        tf.logging.info('Training complete.  Model written to %s',
                        serving_model_dir)

        # Export an eval savedmodel for TFMA
        tf.logging.info('Exporting eval_savedmodel for TFMA.')
        tfma.export.export_eval_savedmodel(
            estimator=training_spec['estimator'],
            export_dir_base=eval_model_dir,
            eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])

        tf.logging.info('Exported eval_savedmodel to %s.', eval_model_dir)
Ejemplo n.º 11
0
 def testImportFuncFromSourceMissingFunction(self):
     source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
     test_fn_file = os.path.join(source_data_dir, 'test_fn.ext')
     with self.assertRaises(AttributeError):
         import_utils.import_func_from_source(test_fn_file, 'non_existing')