コード例 #1
0
ファイル: executor.py プロジェクト: zorrock/tfx
    def Do(self, input_dict: Dict[Text, List[types.TfxArtifact]],
           output_dict: Dict[Text, List[types.TfxArtifact]],
           exec_properties: Dict[Text, Any]):
        """Overrides the tfx_pusher_executor.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: Mostly a passthrough input dict for
        tfx.components.Pusher.executor.  custom_config.ai_platform_serving_args
        is consumed by this class.  For the full set of parameters supported by
        Google Cloud AI Platform, refer to
        https://cloud.google.com/ml-engine/docs/tensorflow/deploying-models#creating_a_model_version.

    Returns:
      None
    Raises:
      ValueError: if ai_platform_serving_args is not in
      exec_properties.custom_config.
      RuntimeError: if the Google Cloud AI Platform training job failed.
    """
        self._log_startup(input_dict, output_dict, exec_properties)
        if not self.CheckBlessing(input_dict, output_dict):
            return

        model_export = types.get_single_instance(input_dict['model_export'])
        model_export_uri = model_export.uri
        model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])
        model_push = types.get_single_instance(output_dict['model_push'])
        # TODO(jyzhao): should this be in driver or executor.
        if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
            model_push.set_int_custom_property('pushed', 0)
            tf.logging.info('Model on %s was not blessed', )
            return

        exec_properties_copy = exec_properties.copy()
        custom_config = exec_properties_copy.pop('custom_config', {})
        ai_platform_serving_args = custom_config['ai_platform_serving_args']

        # Deploy the model.
        model_path = path_utils.serving_model_path(model_export_uri)
        # Note: we do not have a logical model version right now. This
        # model_version is a timestamp mapped to trainer's exporter.
        model_version = os.path.basename(model_path)
        if ai_platform_serving_args is not None:
            cmle_runner.deploy_model_for_cmle_serving(
                model_path, model_version, ai_platform_serving_args)

        # Make sure artifacts are populated in a standard way by calling
        # tfx.pusher.executor.Executor.Do().
        exec_properties_copy['push_destination'] = exec_properties.get(
            'push_destination', self._make_local_temp_destination())
        super(Executor, self).Do(input_dict, output_dict, exec_properties_copy)
コード例 #2
0
  def Do(self, input_dict: Dict[Text, List[types.TfxType]],
         output_dict: Dict[Text, List[types.TfxType]],
         exec_properties: Dict[Text, Any]) -> None:
    """Get human review result on a model through Slack channel.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - slack_blessing: model blessing result.
      exec_properties: A dict of execution properties, including:
        - slack_token: Token used to setup connection with slack server.
        - channel_id: The id of the Slack channel to send and receive messages.
        - timeout_sec: How long do we wait for response, in seconds.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    # Fetch execution properties from exec_properties dict.
    slack_token = exec_properties['slack_token']
    channel_id = exec_properties['channel_id']
    timeout_sec = exec_properties['timeout_sec']

    # Fetch input URIs from input_dict.
    model_export_uri = types.get_single_uri(input_dict['model_export'])
    model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])

    # Fetch output artifact from output_dict.
    slack_blessing = types.get_single_instance(output_dict['slack_blessing'])

    # We only consider a model as blessed if both of the following conditions
    # are met:
    # - The model is blessed by model validator. This is determined by looking
    #   for file named 'BLESSED' from the output from Model Validator.
    # - The model is blessed by a human reviewer. This logic is in
    #   _fetch_slack_blessing().
    try:
      with Timeout(timeout_sec):
        blessed = tf.gfile.Exists(os.path.join(
            model_blessing_uri, 'BLESSED')) and self._fetch_slack_blessing(
                slack_token, channel_id, model_export_uri)
    except TimeoutError:  # pylint: disable=undefined-variable
      tf.logging.info('Timeout fetching manual model evaluation result.')
      blessed = False

    # If model is blessed, write an empty file named 'BLESSED' in the assigned
    # output path. Otherwise, write an empty file named 'NOT_BLESSED' instead.
    if blessed:
      io_utils.write_string_file(
          os.path.join(slack_blessing.uri, 'BLESSED'), '')
      slack_blessing.set_int_custom_property('blessed', 1)
    else:
      io_utils.write_string_file(
          os.path.join(slack_blessing.uri, 'NOT_BLESSED'), '')
      slack_blessing.set_int_custom_property('blessed', 0)
    tf.logging.info('Blessing result %s written to %s.', blessed,
                    slack_blessing.uri)
コード例 #3
0
    def Do(self, input_dict: Dict[Text, List[types.TfxArtifact]],
           output_dict: Dict[Text, List[types.TfxArtifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Get human review result on a model through Slack channel.

        Args:
          input_dict: Input dict from input key to a list of artifacts, including:
            - input_example: an example for an input

          output_dict: Output dict from key to a list of artifacts, including:
            - output_example: an example for an output
          exec_properties: A dict of execution properties, including:
            - string_parameter: An string execution parameter (only used in here, not persistent or shared up stream)
            - integer_parameter: An integer execution parameter (only used in here, not persistent or shared up stream)
            - input_config: not of concern here, only relevant for Driver
            - output_config: not of concern here, only relevant for Driver

        Returns:
          None
        """
        self._log_startup(input_dict, output_dict, exec_properties)

        # Fetch execution properties from exec_properties dict.
        string_parameter = exec_properties['string_execution_parameter']
        integer_parameter = exec_properties['integer_execution_parameter']

        # Fetch input URIs from input_dict.
        input_example_uri = types.get_single_uri(input_dict['input_example'])

        # Fetch output artifact from output_dict.
        output_example = types.get_single_instance(
            output_dict['output_example'])

        print("I AM RUNNING!")
        print(string_parameter)
        print(integer_parameter)
        print(input_example_uri)
        print(output_example)

        input_data = ""

        # load your input
        if tf.gfile.Exists(input_example_uri):
            with open(input_example_uri, "r") as file:
                input_data = file.read()

        # make some changes
        output_data = input_data + " changed by an awesome custom executor!"

        # update output uri for up stream components to know the filename
        output_example.uri = os.path.join(output_example.uri,
                                          _DEFAULT_FILE_NAME)

        # write the changes back to your output
        io_utils.write_string_file(output_example.uri, output_data)

        # you can also set custom properties to make checks in up stream components more quickly.
        # this is optional.
        output_example.set_string_custom_property('stringProperty', "Awesome")
        output_example.set_int_custom_property('intProperty', 42)
コード例 #4
0
ファイル: executor.py プロジェクト: zwcdp/tfx
def _CsvToExample(  # pylint: disable=invalid-name
        pipeline, input_dict, exec_properties):  # pylint: disable=unused-argument
    """Read CSV file and transform to TF examples.

  Args:
    pipeline: beam pipeline.
    input_dict: Input dict from input key to a list of Artifacts.
      - input-base: input dir that contains csv data. csv files must have header
        line.
    exec_properties: A dict of execution properties.

  Returns:
    PCollection of TF examples.
  """
    input_base = types.get_single_instance(input_dict['input-base'])
    input_base_uri = input_base.uri
    csv_uri = io_utils.get_only_uri_in_dir(input_base_uri)
    tf.logging.info(
        'Processing input csv data {} to TFExample.'.format(csv_uri))

    return (
        pipeline
        | 'ReadFromText' >> beam.io.ReadFromText(csv_uri, skip_header_lines=1)
        | 'ParseCSV' >> csv_decoder.DecodeCSV(
            io_utils.load_csv_column_names(csv_uri))
        | 'ToTFExample' >> beam.Map(_dict_to_example))
コード例 #5
0
 def testGetSingleInstanceDeprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         my_artifact = artifact.Artifact('TestType')
         self.assertIs(my_artifact,
                       types.get_single_instance([my_artifact]))
         warn_mock.assert_called_once()
         self.assertIn(
             'tfx.utils.types.get_single_instance has been renamed to',
             warn_mock.call_args[0][5])
コード例 #6
0
  def test_get_from_split_list(self):
    """Test various retrieval utilities on a list of split TfxTypes."""
    split_list = []
    for split in ['train', 'eval']:
      instance = types.TfxType('MyTypeName', split=split)
      instance.uri = '/tmp/' + split
      split_list.append(instance)

    with self.assertRaises(ValueError):
      types.get_single_instance(split_list)

    with self.assertRaises(ValueError):
      types.get_single_uri(split_list)

    self.assertEqual(split_list[0],
                     types._get_split_instance(split_list, 'train'))
    self.assertEqual('/tmp/train', types.get_split_uri(split_list, 'train'))
    self.assertEqual(split_list[1],
                     types._get_split_instance(split_list, 'eval'))
    self.assertEqual('/tmp/eval', types.get_split_uri(split_list, 'eval'))
コード例 #7
0
ファイル: types_test.py プロジェクト: luvneries/tfx
  def test_get_from_split_list(self):
    """Test various retrieval utilities on a list of split TfxTypes."""
    split_list = []
    for split in ['train', 'eval']:
      instance = types.TfxType('MyTypeName', split=split)
      instance.uri = '/tmp/' + split
      split_list.append(instance)

    with self.assertRaises(ValueError):
      types.get_single_instance(split_list)

    with self.assertRaises(ValueError):
      types.get_single_uri(split_list)

    self.assertEqual(split_list[0],
                     types._get_split_instance(split_list, 'train'))
    self.assertEqual('/tmp/train', types.get_split_uri(split_list, 'train'))
    self.assertEqual(split_list[1], types._get_split_instance(
        split_list, 'eval'))
    self.assertEqual('/tmp/eval', types.get_split_uri(split_list, 'eval'))
コード例 #8
0
 def test_get_from_single_list(self):
   """Test various retrieval utilities on a single list of TfxType."""
   single_list = [types.TfxType('MyTypeName', split='eval')]
   single_list[0].uri = '/tmp/evaluri'
   self.assertEqual(single_list[0], types.get_single_instance(single_list))
   self.assertEqual('/tmp/evaluri', types.get_single_uri(single_list))
   self.assertEqual(single_list[0],
                    types._get_split_instance(single_list, 'eval'))
   self.assertEqual('/tmp/evaluri', types.get_split_uri(single_list, 'eval'))
   with self.assertRaises(ValueError):
     types._get_split_instance(single_list, 'train')
   with self.assertRaises(ValueError):
     types.get_split_uri(single_list, 'train')
コード例 #9
0
ファイル: types_test.py プロジェクト: luvneries/tfx
 def test_get_from_single_list(self):
   """Test various retrieval utilities on a single list of TfxType."""
   single_list = [types.TfxType('MyTypeName', split='eval')]
   single_list[0].uri = '/tmp/evaluri'
   self.assertEqual(single_list[0], types.get_single_instance(single_list))
   self.assertEqual('/tmp/evaluri', types.get_single_uri(single_list))
   self.assertEqual(single_list[0],
                    types._get_split_instance(single_list, 'eval'))
   self.assertEqual('/tmp/evaluri', types.get_split_uri(single_list, 'eval'))
   with self.assertRaises(ValueError):
     types._get_split_instance(single_list, 'train')
   with self.assertRaises(ValueError):
     types.get_split_uri(single_list, 'train')
コード例 #10
0
  def Do(self, input_dict,
         output_dict,
         exec_properties):
    """Take input csv data and generates train and eval tf examples.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - input-base: input dir that contains csv data. csv files must have
          header line.
      output_dict: Output dict from output key to a list of Artifacts.
        - examples: train and eval split of tf examples.
      exec_properties: A dict of execution properties.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    training_tfrecord = types.get_split_uri(output_dict['examples'], 'train')
    eval_tfrecord = types.get_split_uri(output_dict['examples'], 'eval')

    input_base = types.get_single_instance(input_dict['input-base'])
    input_base_uri = input_base.uri

    tf.logging.info('Generating examples.')

    raw_data = io_utils.get_only_uri_in_dir(input_base_uri)
    tf.logging.info('No split {}.'.format(raw_data))

    with beam.Pipeline(argv=self._get_beam_pipeline_args()) as pipeline:
      example_splits = (
          pipeline
          # pylint: disable=no-value-for-parameter
          | 'CsvToSerializedExample' >> _CsvToSerializedExample(raw_data)
          | 'SplitData' >> beam.Partition(_partition_fn, 2))
      # TODO(jyzhao): make shuffle optional.
      # pylint: disable=expression-not-assigned
      (example_splits[0]
       | 'ShuffleTrainSplit' >> beam.transforms.Reshuffle()
       | 'OutputTrainSplit' >> beam.io.WriteToTFRecord(
           os.path.join(training_tfrecord, DEFAULT_FILE_NAME),
           file_name_suffix='.gz'))
      (example_splits[1]
       | 'ShuffleEvalSplit' >> beam.transforms.Reshuffle()
       | 'OutputEvalSplit' >> beam.io.WriteToTFRecord(
           os.path.join(eval_tfrecord, DEFAULT_FILE_NAME),
           file_name_suffix='.gz'))
      # pylint: enable=expression-not-assigned

    tf.logging.info('Examples generated.')
コード例 #11
0
 def pre_execution(
     self,
     input_dict: Dict[Text, channel.Channel],
     output_dict: Dict[Text, channel.Channel],
     exec_properties: Dict[Text, Any],
     driver_args: data_types.DriverArgs,
     pipeline_info: data_types.PipelineInfo,
     component_info: data_types.ComponentInfo,
 ) -> data_types.ExecutionDecision:
   input_artifacts = channel.unwrap_channel_dict(input_dict)
   output_artifacts = channel.unwrap_channel_dict(output_dict)
   tf.gfile.MakeDirs(pipeline_info.pipeline_root)
   types.get_single_instance(output_artifacts['output']).uri = os.path.join(
       pipeline_info.pipeline_root, 'output')
   return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                       exec_properties, 123, False)
コード例 #12
0
ファイル: executor.py プロジェクト: dizcology/tfx
    def CheckBlessing(self, input_dict: Dict[Text, List[types.TfxType]],
                      output_dict: Dict[Text, List[types.TfxType]]) -> bool:
        """Check that model is blessed by upstream ModelValidator, or update output.

    Args:
      input_dict: Input dict from input key to a list of artifacts:
        - model_blessing: model blessing path from model_validator. Pusher looks
          for a file named 'BLESSED' to consider the model blessed and safe to
          push.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one.

    Returns:
      True if the model is blessed by validator.
    """
        model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])
        model_push = types.get_single_instance(output_dict['model_push'])
        # TODO(jyzhao): should this be in driver or executor.
        if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
            model_push.set_int_custom_property('pushed', 0)
            tf.logging.info('Model on %s was not blessed', model_blessing_uri)
            return False
        return True
コード例 #13
0
ファイル: executor.py プロジェクト: ashishML/tfx
    def Do(self, input_dict, output_dict, exec_properties):
        """Push model to target if blessed.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: A dict of execution properties, including:
        - push_destination: JSON string of pusher_pb2.PushDestination instance,
          providing instruction of destination to push model.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)
        model_export = types.get_single_instance(input_dict['model_export'])
        model_export_uri = model_export.uri
        model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])
        model_push = types.get_single_instance(output_dict['model_push'])
        model_push_uri = model_push.uri
        # TODO(jyzhao): should this be in driver or executor.
        if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
            model_push.set_int_custom_property('pushed', 0)
            tf.logging.info('Model on %s was not blessed', )
            return
        tf.logging.info('Model pushing.')
        # Copy the model we are pushing into
        model_path = path_utils.serving_model_path(model_export_uri)
        # Note: we do not have a logical model version right now. This
        # model_version is a timestamp mapped to trainer's exporter.
        model_version = os.path.basename(model_path)
        tf.logging.info('Model version is %s', model_version)
        io_utils.copy_dir(model_path,
                          os.path.join(model_push_uri, model_version))
        tf.logging.info('Model written to %s.', model_push_uri)

        # Copied to a fixed outside path, which can be listened by model server.
        #
        # If model is already successfully copied to outside before, stop copying.
        # This is because model validator might blessed same model twice (check
        # mv driver) with different blessing output, we still want Pusher to
        # handle the mv output again to keep metadata tracking, but no need to
        # copy to outside path again..
        # TODO(jyzhao): support rpc push and verification.
        push_destination = pusher_pb2.PushDestination()
        json_format.Parse(exec_properties['push_destination'],
                          push_destination)
        serving_path = os.path.join(push_destination.filesystem.base_directory,
                                    model_version)
        if tf.gfile.Exists(serving_path):
            tf.logging.info(
                'Destination directory %s already exists, skipping current push.',
                serving_path)
        else:
            # tf.serving won't load partial model, it will retry until fully copied.
            io_utils.copy_dir(model_path, serving_path)
            tf.logging.info('Model written to serving path %s.', serving_path)

        model_push.set_int_custom_property('pushed', 1)
        model_push.set_string_custom_property('pushed_model', model_export_uri)
        model_push.set_int_custom_property('pushed_model_id', model_export.id)
        tf.logging.info('Model pushed to %s.', serving_path)

        if exec_properties.get('custom_config'):
            cmle_serving_args = exec_properties.get(
                'custom_config', {}).get('cmle_serving_args')
            if cmle_serving_args is not None:
                return cmle_runner.deploy_model_for_serving(
                    serving_path, model_version, cmle_serving_args,
                    exec_properties['log_root'])
コード例 #14
0
ファイル: executor.py プロジェクト: luvneries/tfx
  def Do(self, input_dict,
         output_dict,
         exec_properties):
    """Push model to target if blessed.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - model_push: A list of 'ModelPushPath' artifact of size one. It will
          include the model in this push execution if the model was pushed.
      exec_properties: A dict of execution properties, including:
        - push_destination: JSON string of pusher_pb2.PushDestination instance,
          providing instruction of destination to push model.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)
    model_export = types.get_single_instance(input_dict['model_export'])
    model_export_uri = model_export.uri
    model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])
    model_push = types.get_single_instance(output_dict['model_push'])
    model_push_uri = model_push.uri
    # TODO(jyzhao): should this be in driver or executor.
    if not tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
      model_push.set_int_custom_property('pushed', 0)
      tf.logging.info('Model on %s was not blessed',)
      return
    tf.logging.info('Model pushing.')
    # Copy the model we are pushing into
    model_path = path_utils.serving_model_path(model_export_uri)
    # Note: we do not have a logical model version right now. This
    # model_version is a timestamp mapped to trainer's exporter.
    model_version = os.path.basename(model_path)
    tf.logging.info('Model version is %s', model_version)
    io_utils.copy_dir(model_path, os.path.join(model_push_uri, model_version))
    tf.logging.info('Model written to %s.', model_push_uri)

    # Copied to a fixed outside path, which can be listened by model server.
    #
    # If model is already successfully copied to outside before, stop copying.
    # This is because model validator might blessed same model twice (check
    # mv driver) with different blessing output, we still want Pusher to
    # handle the mv output again to keep metadata tracking, but no need to
    # copy to outside path again..
    # TODO(jyzhao): support rpc push and verification.
    push_destination = pusher_pb2.PushDestination()
    json_format.Parse(exec_properties['push_destination'], push_destination)
    serving_path = os.path.join(push_destination.filesystem.base_directory,
                                model_version)
    if tf.gfile.Exists(serving_path):
      tf.logging.info(
          'Destination directory %s already exists, skipping current push.',
          serving_path)
    else:
      # tf.serving won't load partial model, it will retry until fully copied.
      io_utils.copy_dir(model_path, serving_path)
      tf.logging.info('Model written to serving path %s.', serving_path)

    model_push.set_int_custom_property('pushed', 1)
    model_push.set_string_custom_property('pushed_model', model_export_uri)
    model_push.set_int_custom_property('pushed_model_id', model_export.id)
    tf.logging.info('Model pushed to %s.', serving_path)

    if exec_properties.get('custom_config'):
      cmle_serving_args = exec_properties.get('custom_config',
                                              {}).get('cmle_serving_args')
      if cmle_serving_args is not None:
        return cmle_runner.deploy_model_for_serving(serving_path, model_version,
                                                    cmle_serving_args,
                                                    exec_properties['log_root'])
コード例 #15
0
    def Do(self, input_dict, output_dict, exec_properties):
        """Validate current model against last blessed model.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - examples: examples for eval the model.
        - model: current model for validation.
      output_dict: Output dict from output key to a list of Artifacts.
        - blessing: model blessing result.
        - results: model validation results.
      exec_properties: A dict of execution properties.
        - blessed_model: last blessed model for validation.
        - blessed_model_id: last blessed model id.

    Returns:
      None
    """
        self._log_startup(input_dict, output_dict, exec_properties)

        # TODO(b/125451545): Provide a safe temp path from base executor instead.
        self._temp_path = os.path.join(
            types.get_single_uri(output_dict['results']), '.temp')
        tf.logging.info('Using temp path {} for tft.beam'.format(
            self._temp_path))

        eval_examples_uri = types.get_split_uri(input_dict['examples'], 'eval')
        blessing = types.get_single_instance(output_dict['blessing'])

        # Current model.
        current_model = types.get_single_instance(input_dict['model'])
        tf.logging.info('Using {} as current model.'.format(current_model.uri))
        blessing.set_string_custom_property('current_model', current_model.uri)
        blessing.set_int_custom_property('current_model_id', current_model.id)

        # Blessed model.
        blessed_model_dir = exec_properties['blessed_model']
        blessed_model_id = exec_properties['blessed_model_id']
        tf.logging.info('Using {} as blessed model.'.format(blessed_model_dir))
        if blessed_model_dir:
            blessing.set_string_custom_property('blessed_model',
                                                blessed_model_dir)
            blessing.set_int_custom_property('blessed_model_id',
                                             blessed_model_id)

        tf.logging.info('Validating model.')
        # TODO(b/125853306): support customized slice spec.
        blessed = self._generate_blessing_result(
            eval_examples_uri=eval_examples_uri,
            slice_spec=[tfma.slicer.slicer.SingleSliceSpec()],
            current_model_dir=current_model.uri,
            blessed_model_dir=blessed_model_dir)

        if blessed:
            io_utils.write_string_file(os.path.join(blessing.uri, 'BLESSED'),
                                       '')
            blessing.set_int_custom_property('blessed', 1)
        else:
            io_utils.write_string_file(
                os.path.join(blessing.uri, 'NOT_BLESSED'), '')
            blessing.set_int_custom_property('blessed', 0)
        tf.logging.info('Blessing result {} written to {}.'.format(
            blessed, blessing.uri))

        io_utils.delete_dir(self._temp_path)
        tf.logging.info('Cleaned up temp path {} on executor success.'.format(
            self._temp_path))
コード例 #16
0
    def Do(self, input_dict: Dict[Text, List[types.TfxArtifact]],
           output_dict: Dict[Text, List[types.TfxArtifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Get human review result on a model through Slack channel.

        Args:
          input_dict: Input dict from input key to a list of artifacts, including:
            - model_export: exported model from trainer.
            - model_blessing: model blessing path from model_validator.
          output_dict: Output dict from key to a list of artifacts, including:
            - slack_blessing: model blessing result.
          exec_properties: A dict of execution properties, including:
            - slack_token: Token used to setup connection with slack server.
            - channel_id: The id of the Slack channel to send and receive messages.
            - timeout_sec: How long do we wait for response, in seconds.

        Returns:
          None

        Raises:
          TimeoutError:
            When there is no decision made within timeout_sec.
          ConnectionError:
            When connection to slack server cannot be established.

        """
        self._log_startup(input_dict, output_dict, exec_properties)

        # Fetch execution properties from exec_properties dict.
        slack_token = exec_properties['slack_token']
        channel_id = exec_properties['channel_id']
        timeout_sec = exec_properties['timeout_sec']

        # Fetch input URIs from input_dict.
        model_export_uri = types.get_single_uri(input_dict['model_export'])
        model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])

        # Fetch output artifact from output_dict.
        slack_blessing = types.get_single_instance(output_dict['slack_blessing'])

        # We only consider a model as blessed if both of the following conditions
        # are met:
        # - The model is blessed by model validator. This is determined by looking
        #   for file named 'BLESSED' from the output from Model Validator.
        # - The model is blessed by a human reviewer. This logic is in
        #   _fetch_slack_blessing().
        slack_response = None
        with Timeout(timeout_sec):
            if tf.gfile.Exists(os.path.join(model_blessing_uri, 'BLESSED')):
                slack_response = self._fetch_slack_blessing(slack_token, channel_id,
                                                            model_export_uri)

        # If model is blessed, write an empty file named 'BLESSED' in the assigned
        # output path. Otherwise, write an empty file named 'NOT_BLESSED' instead.
        if slack_response and slack_response.approved:
            io_utils.write_string_file(
                os.path.join(slack_blessing.uri, 'BLESSED'), '')
            slack_blessing.set_int_custom_property('blessed', 1)
        else:
            io_utils.write_string_file(
                os.path.join(slack_blessing.uri, 'NOT_BLESSED'), '')
            slack_blessing.set_int_custom_property('blessed', 0)
        if slack_response:
            slack_blessing.set_string_custom_property('slack_decision_maker',
                                                      slack_response.user_id)
            slack_blessing.set_string_custom_property('slack_decision_message',
                                                      slack_response.message)
            slack_blessing.set_string_custom_property('slack_decision_channel',
                                                      slack_response.channel_id)
            slack_blessing.set_string_custom_property('slack_decision_thread',
                                                      slack_response.thread_ts)
        tf.logging.info('Blessing result written to %s.', slack_blessing.uri)