Beispiel #1
0
def import_func_from_module(module_path: str, fn_name: str) -> Callable:  # pylint: disable=g-bare-generic
    """Imports a function from a module provided as source file or module path."""
    original_module_path = module_path
    wheel_context_manager = None
    if '@' in module_path:
        # The module path is a combined specification of a module path in a wheel
        # file path.
        module_path, wheel_path = module_path.split('@', maxsplit=1)
        wheel_path = io_utils.ensure_local(wheel_path)
        # Install pip dependencies and add it to the import resolution path.
        # TODO(b/187122070): Move `udf_utils.py` to `tfx/utils` and fix circular
        # dependency.
        from tfx.components.util import udf_utils  # pylint: disable=g-import-not-at-top # pytype: disable=import-error
        wheel_context_manager = udf_utils.TempPipInstallContext([wheel_path])
        with _imported_modules_from_source_lock:
            wheel_context_manager.__enter__()
    if module_path in sys.modules:
        importlib.reload(sys.modules[module_path])
    try:
        user_module = importlib.import_module(module_path)
    except ImportError as e:
        raise ImportError('Could not import requested module path %r.' %
                          original_module_path) from e
    # Restore original sys.path.
    if wheel_context_manager:
        with _imported_modules_from_source_lock:
            wheel_context_manager.__exit__(None, None, None)
    return getattr(user_module, fn_name)
Beispiel #2
0
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable:  # pylint: disable=g-bare-generic
  """Imports a function from a module provided as source file."""

  # If module path is not local, download to local file-system first,
  # because importlib can't import from GCS
  source_path = io_utils.ensure_local(source_path)

  module = None
  with _imported_modules_from_source_lock:
    if _tfx_module_finder.get_module_name_by_path(source_path) is None:
      logging.info('Loading %s because it has not been loaded before.',
                   source_path)
      # Create a unique module name.
      module_name = 'user_module_%d' % _tfx_module_finder.count_registered
      try:
        loader = importlib.machinery.SourceFileLoader(
            fullname=module_name,
            path=source_path,
        )
        spec = importlib.util.spec_from_loader(
            loader.name, loader, origin=source_path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[loader.name] = module
        loader.exec_module(module)
        _tfx_module_finder.register_module(module_name, source_path)
      except IOError:
        raise ImportError('{} in {} not found in '
                          'import_func_from_source()'.format(
                              fn_name, source_path))
    else:
      logging.info('%s is already loaded, reloading', source_path)
      module_name = _tfx_module_finder.get_module_name_by_path(source_path)
      module = sys.modules[module_name]
      importlib.reload(module)
  return getattr(module, fn_name)
Beispiel #3
0
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable:  # pylint: disable=g-bare-generic
    """Imports a function from a module provided as source file."""

    # If module path is not local, download to local file-system first,
    # because importlib can't import from GCS
    source_path = io_utils.ensure_local(source_path)

    try:
        if six.PY2:
            import imp  # pylint: disable=g-import-not-at-top
            try:
                user_module = imp.load_source('user_module', source_path)
                return getattr(user_module, fn_name)
            except IOError:
                raise

        else:
            spec = importlib.util.spec_from_file_location(
                'user_module', source_path)

            if not spec:
                raise ImportError()

            user_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(user_module)  # pytype: disable=attribute-error
            return getattr(user_module, fn_name)

    except IOError:
        raise ImportError(
            '{} in {} not found in import_func_from_source()'.format(
                fn_name, source_path))
Beispiel #4
0
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable:  # pylint: disable=g-bare-generic
    """Imports a function from a module provided as source file."""

    # If module path is not local, download to local file-system first,
    # because importlib can't import from GCS
    source_path = io_utils.ensure_local(source_path)

    try:
        if six.PY2:
            import imp  # pylint: disable=g-import-not-at-top
            try:
                user_module = imp.load_source('user_module', source_path)
                return getattr(user_module, fn_name)
            except IOError:
                raise

        else:
            loader = importlib.machinery.SourceFileLoader(
                fullname='user_module',
                path=source_path,
            )
            user_module = types.ModuleType(loader.name)
            loader.exec_module(user_module)
            return getattr(user_module, fn_name)

    except IOError:
        raise ImportError(
            '{} in {} not found in import_func_from_source()'.format(
                fn_name, source_path))
Beispiel #5
0
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable:  # pylint: disable=g-bare-generic
    """Imports a function from a module provided as source file."""

    # If module path is not local, download to local file-system first,
    # because importlib can't import from GCS
    source_path = io_utils.ensure_local(source_path)

    with _imported_modules_from_source_lock:
        if source_path not in _imported_modules_from_source:
            logging.info('Loading %s because it has not been loaded before.',
                         source_path)
            # Create a unique module name.
            module_name = 'user_module_%d' % len(_imported_modules_from_source)
            try:
                loader = importlib.machinery.SourceFileLoader(
                    fullname=module_name,
                    path=source_path,
                )
                spec = importlib.util.spec_from_loader(loader.name,
                                                       loader,
                                                       origin=source_path)
                module = importlib.util.module_from_spec(spec)
                sys.modules[loader.name] = module
                loader.exec_module(module)
                sys.meta_path.append(_ModuleFinder({module_name: source_path}))
                _imported_modules_from_source[source_path] = module
            except IOError:
                raise ImportError('{} in {} not found in '
                                  'import_func_from_source()'.format(
                                      fn_name, source_path))
        else:
            logging.info('%s is already loaded.', source_path)
    return getattr(_imported_modules_from_source[source_path], fn_name)
Beispiel #6
0
 def testEnsureLocalFromGCS(self, mock_copy_file):
   file_path = 'gs://path/to/testdata/test_fn.py'
   local_file_path = io_utils.ensure_local(file_path)
   self.assertEndsWith(local_file_path, '/test_fn.py')
   self.assertFalse(
       any([
           local_file_path.startswith(prefix)
           for prefix in io_utils._REMOTE_FS_PREFIX
       ]))
   mock_copy_file.assert_called_once_with(file_path, local_file_path, True)
Beispiel #7
0
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable:  # pylint: disable=g-bare-generic
    """Imports a function from a module provided as source file."""

    # If module path is not local, download to local file-system first,
    # because importlib can't import from GCS
    source_path = io_utils.ensure_local(source_path)

    try:
        loader = importlib.machinery.SourceFileLoader(
            fullname='user_module',
            path=source_path,
        )
        spec = importlib.util.spec_from_loader(loader.name,
                                               loader,
                                               origin=source_path)
        module = importlib.util.module_from_spec(spec)
        sys.modules[loader.name] = module
        loader.exec_module(module)
        return getattr(module, fn_name)

    except IOError:
        raise ImportError(
            '{} in {} not found in import_func_from_source()'.format(
                fn_name, source_path))
Beispiel #8
0
 def testEnsureLocalFromGCS(self, mock_copy_file):
     file_path = 'gs://path/to/testdata/test_fn.py'
     self.assertEqual('test_fn.py', io_utils.ensure_local(file_path))
     mock_copy_file.assert_called_once_with(file_path, 'test_fn.py', True)
Beispiel #9
0
 def testEnsureLocal(self):
     file_path = os.path.join(os.path.dirname(__file__), 'testdata',
                              'test_fn.py')
     self.assertEqual(file_path, io_utils.ensure_local(file_path))
Beispiel #10
0
def package_user_module_file(instance_name: Text, module_path: Text,
                             pipeline_root: Text) -> Tuple[Text, Text]:
    """Package the given user module file into a Python Wheel package.

  Args:
      instance_name: Name of the component instance, for creating a unique wheel
        package name.
      module_path: Path to the module file to be packaged.
      pipeline_root: Text

  Returns:
      dist_file_path: Path to the generated wheel file.
      user_module_path: Path for referencing the user module when stored
        as the _MODULE_PATH_KEY execution property. Format should be treated
        as opaque by the user.

  Raises:
      RuntimeError: When wheel building fails.
  """
    module_path = os.path.abspath(io_utils.ensure_local(module_path))
    if not module_path.endswith('.py'):
        raise ValueError('Module path %r is not a ".py" file.' % module_path)
    if not os.path.exists(module_path):
        raise ValueError('Module path %r does not exist.' % module_path)

    user_module_dir, module_file_name = os.path.split(module_path)
    user_module_name = re.sub(r'\.py$', '', module_file_name)
    source_files = []

    # Discover all Python source files in this directory for inclusion.
    for file_name in os.listdir(user_module_dir):
        if file_name.endswith('.py'):
            source_files.append(file_name)
    module_names = []
    for file_name in source_files:
        if file_name in (_EPHEMERAL_SETUP_PY_FILE_NAME, '__init__.py'):
            continue
        module_name = re.sub(r'\.py$', '', file_name)
        module_names.append(module_name)

    # Set up build directory.
    build_dir = tempfile.mkdtemp()
    for source_file in source_files:
        shutil.copyfile(os.path.join(user_module_dir, source_file),
                        os.path.join(build_dir, source_file))

    # Generate an ephemeral wheel for this module.
    logging.info(
        'Generating ephemeral wheel package for %r (including modules: %s).',
        module_path, module_names)

    version_hash = _get_version_hash(user_module_dir, source_files)
    logging.info('User module package has hash fingerprint version %s.',
                 version_hash)

    setup_py_path = os.path.join(build_dir, _EPHEMERAL_SETUP_PY_FILE_NAME)
    with open(setup_py_path, 'w') as f:
        f.write(
            _get_ephemeral_setup_py_contents(
                'tfx-user-code-%s' % instance_name, '0.0+%s' % version_hash,
                module_names))

    temp_dir = tempfile.mkdtemp()
    dist_dir = tempfile.mkdtemp()
    bdist_command = [
        sys.executable, setup_py_path, 'bdist_wheel', '--bdist-dir', temp_dir,
        '--dist-dir', dist_dir
    ]
    logging.info('Executing: %s', bdist_command)
    try:
        subprocess.check_call(bdist_command, cwd=build_dir)
    except subprocess.CalledProcessError as e:
        raise RuntimeError('Failed to build wheel.') from e

    dist_files = os.listdir(dist_dir)
    if len(dist_files) != 1:
        raise RuntimeError(
            'Unexpectedly found %d output files in wheel output directory %s.'
            % (len(dist_files), dist_dir))
    build_dist_file_path = os.path.join(dist_dir, dist_files[0])
    # Copy wheel file atomically to wheel staging directory.
    dist_wheel_directory = os.path.join(pipeline_root, '_wheels')
    dist_file_path = os.path.join(dist_wheel_directory, dist_files[0])
    temp_dist_file_path = dist_file_path + '.tmp'
    fileio.makedirs(dist_wheel_directory)
    fileio.copy(build_dist_file_path, temp_dist_file_path, overwrite=True)
    fileio.rename(temp_dist_file_path, dist_file_path, overwrite=True)
    logging.info(
        ('Successfully built user code wheel distribution at %r; target user '
         'module is %r.'), dist_file_path, user_module_name)

    # Encode the user module key as a specification of a user module name within
    # a packaged wheel path.
    assert '@' not in user_module_name, ('Unexpected invalid module name: %s' %
                                         user_module_name)
    user_module_path = '%s@%s' % (user_module_name, dist_file_path)
    logging.info('Full user module path is %r', user_module_path)

    return dist_file_path, user_module_path
Beispiel #11
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Runs a batch job to evaluate the eval_model against the given input.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - model: exported model.
        - examples: examples for eval the model.
      output_dict: Output dict from output key to a list of Artifacts.
        - evaluation: model evaluation results.
      exec_properties: A dict of execution properties.
        - eval_config: JSON string of tfma.EvalConfig.
        - feature_slicing_spec: JSON string of evaluator_pb2.FeatureSlicingSpec
          instance, providing the way to slice the data. Deprecated, use
          eval_config.slicing_specs instead.
        - example_splits: JSON-serialized list of names of splits on which the
          metrics are computed. Default behavior (when example_splits is set to
          None) is using the 'eval' split.

    Returns:
      None
    """
        if EXAMPLES_KEY not in input_dict:
            raise ValueError('EXAMPLES_KEY is missing from input dict.')
        if EVALUATION_KEY not in output_dict:
            raise ValueError('EVALUATION_KEY is missing from output dict.')
        if MODEL_KEY in input_dict and len(input_dict[MODEL_KEY]) > 1:
            raise ValueError(
                'There can be only one candidate model, there are %d.' %
                (len(input_dict[MODEL_KEY])))
        if BASELINE_MODEL_KEY in input_dict and len(
                input_dict[BASELINE_MODEL_KEY]) > 1:
            raise ValueError(
                'There can be only one baseline model, there are %d.' %
                (len(input_dict[BASELINE_MODEL_KEY])))

        self._log_startup(input_dict, output_dict, exec_properties)

        # Add fairness indicator metric callback if necessary.
        fairness_indicator_thresholds = exec_properties.get(
            'fairness_indicator_thresholds', None)
        add_metrics_callbacks = None
        if fairness_indicator_thresholds:
            add_metrics_callbacks = [
                tfma.post_export_metrics.fairness_indicators(  # pytype: disable=module-attr
                    thresholds=fairness_indicator_thresholds),
            ]

        output_uri = artifact_utils.get_single_uri(
            output_dict[constants.EVALUATION_KEY])

        # Make sure user packages get propagated to the remote Beam worker.
        unused_module_path, extra_pip_packages = udf_utils.decode_user_module_key(
            exec_properties.get(MODULE_PATH_KEY, None))
        for pip_package_path in extra_pip_packages:
            local_pip_package_path = io_utils.ensure_local(pip_package_path)
            self._beam_pipeline_args.append('--extra_package=%s' %
                                            local_pip_package_path)

        eval_shared_model_fn = udf_utils.try_get_fn(
            exec_properties=exec_properties,
            fn_name='custom_eval_shared_model'
        ) or tfma.default_eval_shared_model

        run_validation = False
        models = []
        if EVAL_CONFIG_KEY in exec_properties and exec_properties[
                EVAL_CONFIG_KEY]:
            slice_spec = None
            has_baseline = bool(input_dict.get(BASELINE_MODEL_KEY))
            eval_config = tfma.EvalConfig()
            proto_utils.json_to_proto(exec_properties[EVAL_CONFIG_KEY],
                                      eval_config)
            eval_config = tfma.update_eval_config_with_defaults(
                eval_config, has_baseline=has_baseline)
            tfma.verify_eval_config(eval_config)
            # Do not validate model when there is no thresholds configured. This is to
            # avoid accidentally blessing models when users forget to set thresholds.
            run_validation = bool(
                tfma.metrics.metric_thresholds_from_metrics_specs(
                    eval_config.metrics_specs))
            if len(eval_config.model_specs) > 2:
                raise ValueError(
                    """Cannot support more than two models. There are %d models in this
             eval_config.""" % (len(eval_config.model_specs)))
            # Extract model artifacts.
            for model_spec in eval_config.model_specs:
                if MODEL_KEY not in input_dict:
                    if not model_spec.prediction_key:
                        raise ValueError(
                            'model_spec.prediction_key required if model not provided'
                        )
                    continue
                if model_spec.is_baseline:
                    model_artifact = artifact_utils.get_single_instance(
                        input_dict[BASELINE_MODEL_KEY])
                else:
                    model_artifact = artifact_utils.get_single_instance(
                        input_dict[MODEL_KEY])
                if tfma.get_model_type(model_spec) == tfma.TF_ESTIMATOR:
                    model_path = path_utils.eval_model_path(
                        model_artifact.uri,
                        path_utils.is_old_model_artifact(model_artifact))
                else:
                    model_path = path_utils.serving_model_path(
                        model_artifact.uri,
                        path_utils.is_old_model_artifact(model_artifact))
                logging.info('Using %s as %s model.', model_path,
                             model_spec.name)
                models.append(
                    eval_shared_model_fn(
                        eval_saved_model_path=model_path,
                        model_name=model_spec.name,
                        eval_config=eval_config,
                        add_metrics_callbacks=add_metrics_callbacks))
        else:
            eval_config = None
            assert (FEATURE_SLICING_SPEC_KEY in exec_properties
                    and exec_properties[FEATURE_SLICING_SPEC_KEY]
                    ), 'both eval_config and feature_slicing_spec are unset.'
            feature_slicing_spec = evaluator_pb2.FeatureSlicingSpec()
            proto_utils.json_to_proto(
                exec_properties[FEATURE_SLICING_SPEC_KEY],
                feature_slicing_spec)
            slice_spec = self._get_slice_spec_from_feature_slicing_spec(
                feature_slicing_spec)
            model_artifact = artifact_utils.get_single_instance(
                input_dict[MODEL_KEY])
            model_path = path_utils.eval_model_path(
                model_artifact.uri,
                path_utils.is_old_model_artifact(model_artifact))
            logging.info('Using %s for model eval.', model_path)
            models.append(
                eval_shared_model_fn(
                    eval_saved_model_path=model_path,
                    model_name='',
                    eval_config=None,
                    add_metrics_callbacks=add_metrics_callbacks))

        eval_shared_model = models[0] if len(models) == 1 else models
        schema = None
        if SCHEMA_KEY in input_dict:
            schema = io_utils.SchemaReader().read(
                io_utils.get_only_uri_in_dir(
                    artifact_utils.get_single_uri(input_dict[SCHEMA_KEY])))

        # Load and deserialize example splits from execution properties.
        example_splits = json_utils.loads(
            exec_properties.get(EXAMPLE_SPLITS_KEY, 'null'))
        if not example_splits:
            example_splits = ['eval']
            logging.info(
                "The 'example_splits' parameter is not set, using 'eval' "
                'split.')

        logging.info('Evaluating model.')
        # TempPipInstallContext is needed here so that subprocesses (which
        # may be created by the Beam multi-process DirectRunner) can find the
        # needed dependencies.
        # TODO(b/187122662): Move this to the ExecutorOperator or Launcher.
        with udf_utils.TempPipInstallContext(extra_pip_packages):
            with self._make_beam_pipeline() as pipeline:
                examples_list = []
                tensor_adapter_config = None
                # pylint: disable=expression-not-assigned
                if tfma.is_batched_input(eval_shared_model, eval_config):
                    tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
                        examples=[
                            artifact_utils.get_single_instance(
                                input_dict[EXAMPLES_KEY])
                        ],
                        telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
                        schema=schema,
                        raw_record_column_name=tfma_constants.
                        ARROW_INPUT_COLUMN)
                    # TODO(b/161935932): refactor after TFXIO supports multiple patterns.
                    for split in example_splits:
                        file_pattern = io_utils.all_files_pattern(
                            artifact_utils.get_split_uri(
                                input_dict[EXAMPLES_KEY], split))
                        tfxio = tfxio_factory(file_pattern)
                        data = (pipeline
                                | 'ReadFromTFRecordToArrow[%s]' % split >>
                                tfxio.BeamSource())
                        examples_list.append(data)
                    if schema is not None:
                        # Use last tfxio as TensorRepresentations and ArrowSchema are fixed.
                        tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
                            arrow_schema=tfxio.ArrowSchema(),
                            tensor_representations=tfxio.TensorRepresentations(
                            ))
                else:
                    for split in example_splits:
                        file_pattern = io_utils.all_files_pattern(
                            artifact_utils.get_split_uri(
                                input_dict[EXAMPLES_KEY], split))
                        data = (pipeline
                                | 'ReadFromTFRecord[%s]' % split >> beam.io.
                                ReadFromTFRecord(file_pattern=file_pattern))
                        examples_list.append(data)

                custom_extractors = udf_utils.try_get_fn(
                    exec_properties=exec_properties,
                    fn_name='custom_extractors')
                extractors = None
                if custom_extractors:
                    extractors = custom_extractors(
                        eval_shared_model=eval_shared_model,
                        eval_config=eval_config,
                        tensor_adapter_config=tensor_adapter_config)

                (examples_list | 'FlattenExamples' >> beam.Flatten()
                 | 'ExtractEvaluateAndWriteResults' >>
                 (tfma.ExtractEvaluateAndWriteResults(
                     eval_shared_model=models[0]
                     if len(models) == 1 else models,
                     eval_config=eval_config,
                     extractors=extractors,
                     output_path=output_uri,
                     slice_spec=slice_spec,
                     tensor_adapter_config=tensor_adapter_config)))
        logging.info('Evaluation complete. Results written to %s.', output_uri)

        if not run_validation:
            # TODO(jinhuang): delete the BLESSING_KEY from output_dict when supported.
            logging.info('No threshold configured, will not validate model.')
            return
        # Set up blessing artifact
        blessing = artifact_utils.get_single_instance(
            output_dict[BLESSING_KEY])
        blessing.set_string_custom_property(
            constants.ARTIFACT_PROPERTY_CURRENT_MODEL_URI_KEY,
            artifact_utils.get_single_uri(input_dict[MODEL_KEY]))
        blessing.set_int_custom_property(
            constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY,
            input_dict[MODEL_KEY][0].id)
        if input_dict.get(BASELINE_MODEL_KEY):
            baseline_model = input_dict[BASELINE_MODEL_KEY][0]
            blessing.set_string_custom_property(
                constants.ARTIFACT_PROPERTY_BASELINE_MODEL_URI_KEY,
                baseline_model.uri)
            blessing.set_int_custom_property(
                constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY,
                baseline_model.id)
        if 'current_component_id' in exec_properties:
            blessing.set_string_custom_property(
                'component_id', exec_properties['current_component_id'])
        # Check validation result and write BLESSED file accordingly.
        logging.info('Checking validation results.')
        validation_result = tfma.load_validation_result(output_uri)
        if validation_result.validation_ok:
            io_utils.write_string_file(
                os.path.join(blessing.uri, constants.BLESSED_FILE_NAME), '')
            blessing.set_int_custom_property(
                constants.ARTIFACT_PROPERTY_BLESSED_KEY,
                constants.BLESSED_VALUE)
        else:
            io_utils.write_string_file(
                os.path.join(blessing.uri, constants.NOT_BLESSED_FILE_NAME),
                '')
            blessing.set_int_custom_property(
                constants.ARTIFACT_PROPERTY_BLESSED_KEY,
                constants.NOT_BLESSED_VALUE)
        logging.info('Blessing result %s written to %s.',
                     validation_result.validation_ok, blessing.uri)