Example #1
0
    def export(self, estimator, export_path, checkpoint_path, eval_result,
               is_the_final_export):
        """Exports the given `Estimator` to a specific format.

    Performs the export as defined by the base_exporter and invokes all of the
    specified rewriters.

    Args:
      estimator: the `Estimator` to export.
      export_path: A string containing a directory where to write the export.
      checkpoint_path: The checkpoint path to export.
      eval_result: The output of `Estimator.evaluate` on this checkpoint.
      is_the_final_export: This boolean is True when this is an export in the
        end of training.  It is False for the intermediate exports during the
        training. When passing `Exporter` to `tf.estimator.train_and_evaluate`
        `is_the_final_export` is always False if `TrainSpec.max_steps` is
        `None`.

    Returns:
      The string path to the base exported directory or `None` if export is
        skipped.

    Raises:
      RuntimeError: Unable to create a temporary rewrite directory.
    """
        base_path = self._base_exporter.export(estimator, export_path,
                                               checkpoint_path, eval_result,
                                               is_the_final_export)
        if not base_path:
            return None

        tmp_rewrite_folder = 'tmp-rewrite-' + str(int(time.time()))
        tmp_rewrite_path = os.path.join(export_path, tmp_rewrite_folder)
        if fileio.exists(tmp_rewrite_path):
            raise RuntimeError(
                'Unable to create a unique temporary rewrite path.')
        fileio.makedirs(tmp_rewrite_path)

        _invoke_rewriter(base_path, tmp_rewrite_path, self._rewriter_inst,
                         rewriter.ModelType.SAVED_MODEL,
                         rewriter.ModelType.ANY_MODEL)

        fileio.rmtree(base_path)
        fileio.rename(tmp_rewrite_path, base_path)
        return base_path
Example #2
0
    def _check_pipeline_existence(self,
                                  pipeline_name: Text,
                                  required: bool = True) -> None:
        """Check if pipeline folder exists and if not, exit system.

    Args:
      pipeline_name: Name of the pipeline.
      required: Set it as True if pipeline needs to exist else set it to False.
    """
        handler_pipeline_path = os.path.join(self._handler_home_dir,
                                             pipeline_name)
        # Check if pipeline folder exists.
        exists = fileio.exists(handler_pipeline_path)
        if required and not exists:
            # Check pipeline directory prior 0.25 and move files to the new location
            # automatically.
            old_handler_pipeline_path = os.path.join(
                self._get_deprecated_handler_home(), pipeline_name)
            if fileio.exists(old_handler_pipeline_path):
                fileio.makedirs(os.path.dirname(handler_pipeline_path))
                fileio.rename(old_handler_pipeline_path, handler_pipeline_path)
                engine_flag = self.flags_dict[labels.ENGINE_FLAG]
                handler_home_variable = engine_flag.upper() + '_HOME'
                click.echo((
                    '[WARNING] Pipeline "{pipeline_name}" was found in "{old_path}", '
                    'but the location that TFX stores pipeline information was moved '
                    'since TFX 0.25.0.\n'
                    '[WARNING] Your files in "{old_path}" was automatically moved to '
                    'the new location, "{new_path}".\n'
                    '[WARNING] If you want to keep the files at the old location, set '
                    '`{handler_home}` environment variable to "{old_handler_home}".'
                ).format(pipeline_name=pipeline_name,
                         old_path=old_handler_pipeline_path,
                         new_path=handler_pipeline_path,
                         handler_home=handler_home_variable,
                         old_handler_home=self._get_deprecated_handler_home()),
                           err=True)
            else:
                sys.exit('Pipeline "{}" does not exist.'.format(pipeline_name))
        elif not required and exists:
            sys.exit('Pipeline "{}" already exists.'.format(pipeline_name))
Example #3
0
def package_user_module_file(instance_name: Text, module_path: Text,
                             pipeline_root: Text) -> Tuple[Text, Text]:
    """Package the given user module file into a Python Wheel package.

  Args:
      instance_name: Name of the component instance, for creating a unique wheel
        package name.
      module_path: Path to the module file to be packaged.
      pipeline_root: Text

  Returns:
      dist_file_path: Path to the generated wheel file.
      user_module_path: Path for referencing the user module when stored
        as the _MODULE_PATH_KEY execution property. Format should be treated
        as opaque by the user.

  Raises:
      RuntimeError: When wheel building fails.
  """
    module_path = os.path.abspath(io_utils.ensure_local(module_path))
    if not module_path.endswith('.py'):
        raise ValueError('Module path %r is not a ".py" file.' % module_path)
    if not os.path.exists(module_path):
        raise ValueError('Module path %r does not exist.' % module_path)

    user_module_dir, module_file_name = os.path.split(module_path)
    user_module_name = re.sub(r'\.py$', '', module_file_name)
    source_files = []

    # Discover all Python source files in this directory for inclusion.
    for file_name in os.listdir(user_module_dir):
        if file_name.endswith('.py'):
            source_files.append(file_name)
    module_names = []
    for file_name in source_files:
        if file_name in (_EPHEMERAL_SETUP_PY_FILE_NAME, '__init__.py'):
            continue
        module_name = re.sub(r'\.py$', '', file_name)
        module_names.append(module_name)

    # Set up build directory.
    build_dir = tempfile.mkdtemp()
    for source_file in source_files:
        shutil.copyfile(os.path.join(user_module_dir, source_file),
                        os.path.join(build_dir, source_file))

    # Generate an ephemeral wheel for this module.
    logging.info(
        'Generating ephemeral wheel package for %r (including modules: %s).',
        module_path, module_names)

    version_hash = _get_version_hash(user_module_dir, source_files)
    logging.info('User module package has hash fingerprint version %s.',
                 version_hash)

    setup_py_path = os.path.join(build_dir, _EPHEMERAL_SETUP_PY_FILE_NAME)
    with open(setup_py_path, 'w') as f:
        f.write(
            _get_ephemeral_setup_py_contents(
                'tfx-user-code-%s' % instance_name, '0.0+%s' % version_hash,
                module_names))

    temp_dir = tempfile.mkdtemp()
    dist_dir = tempfile.mkdtemp()
    bdist_command = [
        sys.executable, setup_py_path, 'bdist_wheel', '--bdist-dir', temp_dir,
        '--dist-dir', dist_dir
    ]
    logging.info('Executing: %s', bdist_command)
    try:
        subprocess.check_call(bdist_command, cwd=build_dir)
    except subprocess.CalledProcessError as e:
        raise RuntimeError('Failed to build wheel.') from e

    dist_files = os.listdir(dist_dir)
    if len(dist_files) != 1:
        raise RuntimeError(
            'Unexpectedly found %d output files in wheel output directory %s.'
            % (len(dist_files), dist_dir))
    build_dist_file_path = os.path.join(dist_dir, dist_files[0])
    # Copy wheel file atomically to wheel staging directory.
    dist_wheel_directory = os.path.join(pipeline_root, '_wheels')
    dist_file_path = os.path.join(dist_wheel_directory, dist_files[0])
    temp_dist_file_path = dist_file_path + '.tmp'
    fileio.makedirs(dist_wheel_directory)
    fileio.copy(build_dist_file_path, temp_dist_file_path, overwrite=True)
    fileio.rename(temp_dist_file_path, dist_file_path, overwrite=True)
    logging.info(
        ('Successfully built user code wheel distribution at %r; target user '
         'module is %r.'), dist_file_path, user_module_name)

    # Encode the user module key as a specification of a user module name within
    # a packaged wheel path.
    assert '@' not in user_module_name, ('Unexpected invalid module name: %s' %
                                         user_module_name)
    user_module_path = '%s@%s' % (user_module_name, dist_file_path)
    logging.info('Full user module path is %r', user_module_path)

    return dist_file_path, user_module_path