Exemple #1
0
    def pre_execution(
        self,
        input_dict: Dict[Text, types.Channel],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        output_artifacts = {
            IMPORT_RESULT_KEY:
            self._import_artifacts(
                source_uri=exec_properties[SOURCE_URI_KEY],
                destination_channel=output_dict[IMPORT_RESULT_KEY],
                reimport=exec_properties[REIMPORT_OPTION_KEY],
                split_names=exec_properties[SPLIT_KEY])
        }

        output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel(
            output_artifacts[IMPORT_RESULT_KEY])

        return data_types.ExecutionDecision(
            input_dict={},
            output_dict=output_artifacts,
            exec_properties={},
            execution_id=self._register_execution(
                exec_properties={},
                pipeline_info=pipeline_info,
                component_info=component_info),
            use_cached_results=False)
Exemple #2
0
  def pre_execution(
      self,
      input_dict: Dict[Text, types.Channel],
      output_dict: Dict[Text, types.Channel],
      exec_properties: Dict[Text, Any],
      driver_args: data_types.DriverArgs,
      pipeline_info: data_types.PipelineInfo,
      component_info: data_types.ComponentInfo,
  ) -> data_types.ExecutionDecision:
    resolver_class = exec_properties[RESOLVER_CLASS]
    if exec_properties[RESOLVER_CONFIGS]:
      resolver = resolver_class(**exec_properties[RESOLVER_CONFIGS])
    else:
      resolver = resolver_class()
    resolve_result = resolver.resolve(
        metadata_handler=self._metadata_handler,
        source_channels=input_dict.copy())
    if not resolve_result.has_complete_result:
      raise RuntimeError('Cannot resolve all artifacts as needed.')

    return data_types.ExecutionDecision(
        input_dict={},
        output_dict=resolve_result.per_key_resolve_result,
        exec_properties=exec_properties,
        execution_id=self._register_execution(
            exec_properties={},
            pipeline_info=pipeline_info,
            component_info=component_info),
        use_cached_results=True)
Exemple #3
0
    def test_new_execution(self, mock_metadata_class, mock_driver_class,
                           mock_executor_class, mock_get_logger):
        self._setup_mocks(mock_metadata_class, mock_driver_class,
                          mock_executor_class, mock_get_logger)
        adapter, input_dict, output_dict, exec_properties, driver_args = self._setup_adapter_and_args(
        )

        self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json]

        self.mock_driver.prepare_execution.return_value = data_types.ExecutionDecision(
            input_dict, output_dict, exec_properties, execution_id=12345)

        check_result = adapter.check_cache_and_maybe_prepare_execution(
            'cached_branch', 'uncached_branch', ti=self.mock_task_instance)

        mock_driver_class.assert_called_with(
            metadata_handler=self.mock_metadata)
        self.mock_driver.prepare_execution.called_with(input_dict, output_dict,
                                                       exec_properties,
                                                       driver_args)
        self.mock_task_instance.xcom_pull.assert_called_with(
            dag_id='input_one_component_id', key='input_one_key')

        calls = [
            mock.call(key='_exec_inputs',
                      value=types.jsonify_tfx_type_dict(input_dict)),
            mock.call(key='_exec_outputs',
                      value=types.jsonify_tfx_type_dict(output_dict)),
            mock.call(key='_exec_properties',
                      value=json.dumps(exec_properties)),
            mock.call(key='_execution_id', value=12345)
        ]
        self.mock_task_instance.xcom_push.assert_has_calls(calls)

        self.assertEqual(check_result, 'uncached_branch')
Exemple #4
0
    def test_cached_execution(self, mock_metadata_class, mock_driver_class,
                              mock_executor_class, mock_get_logger):
        self._setup_mocks(mock_metadata_class, mock_driver_class,
                          mock_executor_class, mock_get_logger)
        adapter, input_dict, output_dict, exec_properties, driver_args = self._setup_adapter_and_args(
        )

        self.mock_task_instance.xcom_pull.side_effect = [self.input_one_json]

        self.mock_driver.prepare_execution.return_value = data_types.ExecutionDecision(
            input_dict, output_dict, exec_properties)

        check_result = adapter.check_cache_and_maybe_prepare_execution(
            'cached_branch', 'uncached_branch', ti=self.mock_task_instance)

        mock_get_logger.assert_called_with(self._logger_config)
        mock_driver_class.assert_called_with(
            metadata_handler=self.mock_metadata)
        self.mock_driver.prepare_execution.called_with(input_dict, output_dict,
                                                       exec_properties,
                                                       driver_args)
        self.mock_task_instance.xcom_pull.assert_called_with(
            dag_id='input_one_component_id', key='input_one_key')
        self.mock_task_instance.xcom_push.assert_called_with(
            key='output_one_key', value=self.output_one_json)

        self.assertEqual(check_result, 'cached_branch')
Exemple #5
0
    def pre_execution(
        self,
        input_dict: Dict[Text, types.Channel],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        input_artifacts = channel_utils.unwrap_channel_dict(input_dict)
        output_artifacts = channel_utils.unwrap_channel_dict(output_dict)

        # Generating missing output artifact URIs
        for name, artifacts in output_artifacts.items():
            for idx, artifact in enumerate(artifacts):
                if not artifact.uri:
                    suffix = str(idx + 1) if idx > 0 else ''
                    artifact.uri = os.path.join(
                        pipeline_info.pipeline_root,
                        'artifacts',
                        name + suffix,
                        'data',
                    )
                    fileio.makedirs(os.path.dirname(artifact.uri))

        return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                            exec_properties, 123, False)
Exemple #6
0
  def _default_caching_handling(
      self,
      input_dict: Dict[Text, List[types.TfxArtifact]],
      output_dict: Dict[Text, List[types.TfxArtifact]],
      exec_properties: Dict[Text, Any],
      driver_args: data_types.DriverArgs,
  ) -> data_types.ExecutionDecision:
    """Check cache for desired and applicable identical execution."""
    enable_cache = driver_args.enable_cache
    base_output_dir = driver_args.base_output_dir
    worker_name = driver_args.worker_name

    # If caching is enabled, try to get previous execution results and directly
    # use as output.
    if enable_cache:
      output_result = self._get_output_from_previous_run(
          input_dict, output_dict, exec_properties, driver_args)
      if output_result:
        tf.logging.info('Found cache from previous run.')
        return data_types.ExecutionDecision(
            input_dict=input_dict,
            output_dict=output_result,
            exec_properties=exec_properties)

    # Previous run is not available, prepare execution.
    # Registers execution in metadata.
    execution_id = self._metadata_handler.prepare_execution(
        worker_name, exec_properties)
    tf.logging.info('Preparing new execution.')

    # Checks inputs exist and have valid states and locks them to avoid GC half
    # way
    self._verify_artifacts(input_dict)

    # Updates output.
    for name, output_list in output_dict.items():
      for artifact in output_list:
        artifact.uri = self._generate_output_uri(artifact, base_output_dir,
                                                 name, execution_id)

    return data_types.ExecutionDecision(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
        execution_id=execution_id)
Exemple #7
0
 def _run_driver(
         self, input_dict: Dict[Text, List[types.TfxArtifact]],
         output_dict: Dict[Text, List[types.TfxArtifact]],
         exec_properties: Dict[Text, Any]) -> data_types.ExecutionDecision:
     """Prepare inputs, outputs and execution properties for actual execution."""
     # TODO(jyzhao): support driver after go/tfx-oss-artifact-passing.
     tf.logging.info('Run driver for %s', self._name)
     # Return a fake result that makes sure execution_decision.execution_needed
     # is true to always trigger the executor.
     return data_types.ExecutionDecision(input_dict, output_dict,
                                         exec_properties, 1)
Exemple #8
0
    def pre_execution(
        self,
        input_dict: Dict[str, types.BaseChannel],
        output_dict: Dict[str, types.Channel],
        exec_properties: Dict[str, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        # Registers contexts and execution
        contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists(
            pipeline_info)
        execution = self._metadata_handler.register_execution(
            exec_properties=exec_properties,
            pipeline_info=pipeline_info,
            component_info=component_info,
            contexts=contexts)
        # Gets resolved artifacts.
        resolver_class = exec_properties[RESOLVER_STRATEGY_CLASS]
        if exec_properties[RESOLVER_CONFIG]:
            resolver = resolver_class(**exec_properties[RESOLVER_CONFIG])
        else:
            resolver = resolver_class()
        input_artifacts = self._build_input_dict(pipeline_info, input_dict)
        output_artifacts = resolver.resolve_artifacts(
            store=self._metadata_handler.store,
            input_dict=input_artifacts,
        )
        if output_artifacts is None:
            # No inputs available. Still driver needs an ExecutionDecision, so use a
            # dummy dict with no artifacts.
            output_artifacts = {key: [] for key in input_artifacts}

        # TODO(b/148828122): This is a temporary workaround for interactive mode.
        for k, c in output_dict.items():
            output_dict[k] = types.Channel(type=c.type).set_artifacts(
                output_artifacts[k])
        # Updates execution to reflect artifact resolution results and mark
        # as cached.
        self._metadata_handler.update_execution(
            execution=execution,
            component_info=component_info,
            output_artifacts=output_artifacts,
            execution_state=metadata.EXECUTION_STATE_CACHED,
            contexts=contexts)

        return data_types.ExecutionDecision(input_dict={},
                                            output_dict=output_artifacts,
                                            exec_properties=exec_properties,
                                            execution_id=execution.id,
                                            use_cached_results=True)
 def pre_execution(
     self,
     input_dict: Dict[Text, channel.Channel],
     output_dict: Dict[Text, channel.Channel],
     exec_properties: Dict[Text, Any],
     driver_args: data_types.DriverArgs,
     pipeline_info: data_types.PipelineInfo,
     component_info: data_types.ComponentInfo,
 ) -> data_types.ExecutionDecision:
   input_artifacts = channel.unwrap_channel_dict(input_dict)
   output_artifacts = channel.unwrap_channel_dict(output_dict)
   tf.gfile.MakeDirs(pipeline_info.pipeline_root)
   types.get_single_instance(output_artifacts['output']).uri = os.path.join(
       pipeline_info.pipeline_root, 'output')
   return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                       exec_properties, 123, False)
Exemple #10
0
    def pre_execution(
        self,
        input_dict: Dict[Text, types.Channel],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        # Registers contexts and execution
        contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists(
            pipeline_info)
        execution = self._metadata_handler.register_execution(
            exec_properties=exec_properties,
            pipeline_info=pipeline_info,
            component_info=component_info,
            contexts=contexts)
        # Gets resolved artifacts.
        resolver_class = exec_properties[RESOLVER_CLASS]
        if exec_properties[RESOLVER_CONFIGS]:
            resolver = resolver_class(**exec_properties[RESOLVER_CONFIGS])
        else:
            resolver = resolver_class()
        resolve_result = resolver.resolve(
            pipeline_info=pipeline_info,
            metadata_handler=self._metadata_handler,
            source_channels=input_dict.copy())
        # TODO(b/148828122): This is a temporary walkaround for interactive mode.
        for k, c in output_dict.items():
            output_dict[k] = types.Channel(
                type=c.type,
                artifacts=resolve_result.per_key_resolve_result[k])
        # Updates execution to reflect artifact resolution results and mark
        # as cached.
        self._metadata_handler.update_execution(
            execution=execution,
            component_info=component_info,
            output_artifacts=resolve_result.per_key_resolve_result,
            execution_state=metadata.EXECUTION_STATE_CACHED,
            contexts=contexts)

        return data_types.ExecutionDecision(
            input_dict={},
            output_dict=resolve_result.per_key_resolve_result,
            exec_properties=exec_properties,
            execution_id=execution.id,
            use_cached_results=True)
Exemple #11
0
  def pre_execution(
      self,
      input_dict: Dict[Text, types.Channel],
      output_dict: Dict[Text, types.Channel],
      exec_properties: Dict[Text, Any],
      driver_args: data_types.DriverArgs,
      pipeline_info: data_types.PipelineInfo,
      component_info: data_types.ComponentInfo,
  ) -> data_types.ExecutionDecision:
    # Registers contexts and execution.
    contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists(
        pipeline_info)
    execution = self._metadata_handler.register_execution(
        exec_properties=exec_properties,
        pipeline_info=pipeline_info,
        component_info=component_info,
        contexts=contexts)
    # Create imported artifacts.
    output_artifacts = {
        IMPORT_RESULT_KEY: [
            self._prepare_artifact(
                uri=exec_properties[SOURCE_URI_KEY],
                properties=exec_properties[PROPERTIES_KEY],
                custom_properties=exec_properties[CUSTOM_PROPERTIES_KEY],
                destination_channel=output_dict[IMPORT_RESULT_KEY],
                reimport=exec_properties[REIMPORT_OPTION_KEY])
        ]
    }
    # Update execution with imported artifacts.
    self._metadata_handler.update_execution(
        execution=execution,
        component_info=component_info,
        output_artifacts=output_artifacts,
        execution_state=metadata.EXECUTION_STATE_CACHED,
        contexts=contexts)

    output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel(
        output_artifacts[IMPORT_RESULT_KEY])

    return data_types.ExecutionDecision(
        input_dict={},
        output_dict=output_artifacts,
        exec_properties=exec_properties,
        execution_id=execution.id,
        use_cached_results=False)
Exemple #12
0
  def pre_execution(
      self,
      input_dict: Dict[Text, types.Channel],
      output_dict: Dict[Text, types.Channel],
      exec_properties: Dict[Text, Any],
      driver_args: data_types.DriverArgs,
      pipeline_info: data_types.PipelineInfo,
      component_info: data_types.ComponentInfo,
  ) -> data_types.ExecutionDecision:
    """Handle pre-execution logic.

    There are four steps:
      1. Fetches input artifacts from metadata and checks whether uri exists.
      2. Registers execution.
      3. Decides whether a new execution is needed.
      4a. If (3), prepare output artifacts.
      4b. If not (3), fetch cached output artifacts.

    Args:
      input_dict: key -> Channel for inputs.
      output_dict: key -> Channel for outputs. Uris of the outputs are not
        assigned.
      exec_properties: Dict of other execution properties.
      driver_args: An instance of data_types.DriverArgs class.
      pipeline_info: An instance of data_types.PipelineInfo, holding pipeline
        related properties including pipeline_name, pipeline_root and run_id
      component_info: An instance of data_types.ComponentInfo, holding component
        related properties including component_type and component_id.

    Returns:
      data_types.ExecutionDecision object.

    Raises:
      RuntimeError: if any input as an empty uri.
    """
    # Step 1. Fetch inputs from metadata.
    input_artifacts = self.resolve_input_artifacts(input_dict, exec_properties,
                                                   driver_args, pipeline_info)
    self.verify_input_artifacts(artifacts_dict=input_artifacts)
    absl.logging.debug('Resolved input artifacts are: %s', input_artifacts)
    # Step 2. Register execution in metadata.
    execution_id = self._register_execution(
        exec_properties=exec_properties,
        pipeline_info=pipeline_info,
        component_info=component_info)
    output_artifacts = {}
    use_cached_results = False

    if driver_args.enable_cache:
      # TODO(b/136031301): Combine Step 3 and Step 4b after finish migration to
      # go/tfx-oss-artifact-passing.
      # Step 3. Decide whether a new execution is needed.
      cached_execution_id = self._metadata_handler.previous_execution(
          input_artifacts=input_artifacts,
          exec_properties=exec_properties,
          pipeline_info=pipeline_info,
          component_info=component_info)
      if cached_execution_id:
        absl.logging.debug('Found cached_execution: %s', cached_execution_id)
        # Step 4b. New execution not needed. Fetch cached output artifacts.
        try:
          output_artifacts = self._fetch_cached_artifacts(
              output_dict=output_dict, cached_execution_id=cached_execution_id)
          absl.logging.debug('Cached output artifacts are: %s',
                             output_artifacts)
          use_cached_results = True
        except RuntimeError:
          absl.logging.warning(
              'Error when trying to get cached output artifacts')
          use_cached_results = False
    if not use_cached_results:
      absl.logging.debug('Cached results not found, move on to new execution')
      # Step 4a. New execution is needed. Prepare output artifacts.
      output_artifacts = self._prepare_output_artifacts(
          output_dict=output_dict,
          execution_id=execution_id,
          pipeline_info=pipeline_info,
          component_info=component_info)
      absl.logging.debug(
          'Output artifacts skeleton for the upcoming execution are: %s',
          output_artifacts)
      exec_properties = self.resolve_exec_properties(exec_properties,
                                                     pipeline_info,
                                                     component_info)
      absl.logging.debug(
          'Execution properties for the upcoming execution are: %s',
          exec_properties)

    return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                        exec_properties, execution_id,
                                        use_cached_results)
Exemple #13
0
    def pre_execution(
        self,
        input_dict: Dict[Text, types.Channel],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        """Handle pre-execution logic.

    There are four steps:
      1. Fetches input artifacts from metadata and checks whether uri exists.
      2. Registers execution.
      3. Decides whether a new execution is needed.
      4a. If (3), prepare output artifacts.
      4b. If not (3), fetch cached output artifacts.

    Args:
      input_dict: key -> Channel for inputs.
      output_dict: key -> Channel for outputs. Uris of the outputs are not
        assigned.
      exec_properties: Dict of other execution properties.
      driver_args: An instance of data_types.DriverArgs class.
      pipeline_info: An instance of data_types.PipelineInfo, holding pipeline
        related properties including pipeline_name, pipeline_root and run_id
      component_info: An instance of data_types.ComponentInfo, holding component
        related properties including component_type and component_id.

    Returns:
      data_types.ExecutionDecision object.

    Raises:
      RuntimeError: if any input as an empty uri.
    """
        # Step 1. Fetch inputs from metadata.
        exec_properties = self.resolve_exec_properties(exec_properties,
                                                       pipeline_info,
                                                       component_info)
        input_artifacts = self.resolve_input_artifacts(input_dict,
                                                       exec_properties,
                                                       driver_args,
                                                       pipeline_info)
        self.verify_input_artifacts(artifacts_dict=input_artifacts)
        absl.logging.debug('Resolved input artifacts are: %s', input_artifacts)
        # Step 2. Register execution in metadata.
        contexts = self._metadata_handler.register_pipeline_contexts_if_not_exists(
            pipeline_info)
        execution = self._metadata_handler.register_execution(
            input_artifacts=input_artifacts,
            exec_properties=exec_properties,
            pipeline_info=pipeline_info,
            component_info=component_info,
            contexts=contexts)
        use_cached_results = False
        output_artifacts = None

        if driver_args.enable_cache:
            # Step 3. Decide whether a new execution is needed.
            output_artifacts = self._metadata_handler.get_cached_outputs(
                input_artifacts=input_artifacts,
                exec_properties=exec_properties,
                pipeline_info=pipeline_info,
                component_info=component_info)
        if output_artifacts is not None:
            # If cache should be used, updates execution to reflect that. Note that
            # with this update, publisher should / will be skipped.
            self._metadata_handler.update_execution(
                execution=execution,
                component_info=component_info,
                output_artifacts=output_artifacts,
                execution_state=metadata.EXECUTION_STATE_CACHED,
                contexts=contexts)
            use_cached_results = True
        else:
            absl.logging.debug(
                'Cached results not found, move on to new execution')
            # Step 4a. New execution is needed. Prepare output artifacts.
            output_artifacts = self._prepare_output_artifacts(
                input_artifacts=input_artifacts,
                output_dict=output_dict,
                exec_properties=exec_properties,
                execution_id=execution.id,
                pipeline_info=pipeline_info,
                component_info=component_info)
            absl.logging.debug(
                'Output artifacts skeleton for the upcoming execution are: %s',
                output_artifacts)
            # Updates the execution to reflect refreshed output artifacts and
            # execution properties.
            self._metadata_handler.update_execution(
                execution=execution,
                component_info=component_info,
                output_artifacts=output_artifacts,
                exec_properties=exec_properties,
                contexts=contexts)
            absl.logging.debug(
                'Execution properties for the upcoming execution are: %s',
                exec_properties)

        return data_types.ExecutionDecision(input_artifacts, output_artifacts,
                                            exec_properties, execution.id,
                                            use_cached_results)
Exemple #14
0
  def _default_caching_handling(
      self,
      input_dict: Dict[Text, List[types.TfxArtifact]],
      output_dict: Dict[Text, List[types.TfxArtifact]],
      exec_properties: Dict[Text, Any],
      driver_args: data_types.DriverArgs,
  ) -> data_types.ExecutionDecision:
    """Check cache for desired and applicable identical execution."""
    enable_cache = driver_args.enable_cache
    base_output_dir = driver_args.base_output_dir
    worker_name = driver_args.worker_name

    # If caching is enabled, try to get previous execution results and directly
    # use as output.
    if enable_cache:
      output_result = self._get_output_from_previous_run(
          input_dict, output_dict, exec_properties, driver_args)
      if output_result:
        tf.logging.info('Found cache from previous run.')
        return data_types.ExecutionDecision(
            input_dict=input_dict,
            output_dict=output_result,
            exec_properties=exec_properties)

    # Previous run is not available, prepare execution.
    # Registers execution in metadata.
    execution_id = self._metadata_handler.prepare_execution(
        worker_name, exec_properties)
    tf.logging.info('Preparing new execution.')

    # Checks inputs exist and have valid states and locks them to avoid GC half
    # way
    self._verify_inputs(input_dict)

    # Updates output.
    max_input_span = 0
    for input_list in input_dict.values():
      for single_input in input_list:
        max_input_span = max(max_input_span, single_input.span)
    # TODO(ruoyu): This location is dangerous because this function is not
    # guaranteed to be called on custom driver.
    for output_name, output_list in output_dict.items():
      for output_artifact in output_list:
        # Updates outputs uri based on execution id and optional split.
        # Last empty string forces this be to a directory.
        output_artifact.uri = os.path.join(base_output_dir, output_name,
                                           str(execution_id),
                                           output_artifact.split, '')
        if tf.gfile.Exists(output_artifact.uri):
          msg = 'Output artifact uri {} already exists'.format(
              output_artifact.uri)
          tf.logging.error(msg)
          raise RuntimeError(msg)
        else:
          # TODO(zhitaoli): Consider refactoring this out into something
          # which can handle permission bits.
          tf.logging.info('Creating output artifact uri %s as directory',
                          output_artifact.uri)
          tf.gfile.MakeDirs(output_artifact.uri)
        # Defaults to make the output span the max of input span.
        output_artifact.span = max_input_span

    return data_types.ExecutionDecision(
        input_dict=input_dict,
        output_dict=output_dict,
        exec_properties=exec_properties,
        execution_id=execution_id)