def make_custom_export_strategy(name,
                                convert_fn,
                                feature_columns,
                                export_input_fn,
                                use_core_columns=False):
    """Makes custom exporter of GTFlow tree format.

  Args:
    name: A string, for the name of the export strategy.
    convert_fn: A function that converts the tree proto to desired format and
      saves it to the desired location. Can be None to skip conversion.
    feature_columns: A list of feature columns.
    export_input_fn: A function that takes no arguments and returns an
      `InputFnOps`.

  Returns:
    An `ExportStrategy`.
  """
    base_strategy = saved_model_export_utils.make_export_strategy(
        serving_input_fn=export_input_fn, strip_default_attrs=True)
    input_fn = export_input_fn()
    (sorted_feature_names, dense_floats, sparse_float_indices, _, _,
     sparse_int_indices, _,
     _) = gbdt_batch.extract_features(input_fn.features, feature_columns,
                                      use_core_columns)

    def export_fn(estimator,
                  export_dir,
                  checkpoint_path=None,
                  eval_result=None):
        """A wrapper to export to SavedModel, and convert it to other formats."""
        result_dir = base_strategy.export(estimator, export_dir,
                                          checkpoint_path, eval_result)
        with ops.Graph().as_default() as graph:
            with tf_session.Session(graph=graph) as sess:
                saved_model_loader.load(sess, [tag_constants.SERVING],
                                        result_dir)
                # Note: This is GTFlow internal API and might change.
                ensemble_model = graph.get_operation_by_name(
                    "ensemble_model/TreeEnsembleSerialize")
                _, dfec_str = sess.run(ensemble_model.outputs)
                dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
                dtec.ParseFromString(dfec_str)
                # Export the result in the same folder as the saved model.
                if convert_fn:
                    convert_fn(dtec, sorted_feature_names, len(dense_floats),
                               len(sparse_float_indices),
                               len(sparse_int_indices), result_dir,
                               eval_result)
                feature_importances = _get_feature_importances(
                    dtec, sorted_feature_names, len(dense_floats),
                    len(sparse_float_indices), len(sparse_int_indices))
                sorted_by_importance = sorted(feature_importances.items(),
                                              key=lambda x: -x[1])
                assets_dir = os.path.join(compat.as_bytes(result_dir),
                                          compat.as_bytes("assets.extra"))
                gfile.MakeDirs(assets_dir)
                with gfile.GFile(
                        os.path.join(compat.as_bytes(assets_dir),
                                     compat.as_bytes("feature_importances")),
                        "w") as f:
                    f.write("\n".join("%s, %f" % (k, v)
                                      for k, v in sorted_by_importance))
        return result_dir

    return export_strategy.ExportStrategy(name,
                                          export_fn,
                                          strip_default_attrs=True)
def make_export_strategy(serving_input_fn,
                         default_output_alternative_key=None,
                         assets_extra=None,
                         as_text=False,
                         exports_to_keep=5,
                         strip_default_attrs=None):
    """Create an ExportStrategy for use with Experiment.

  Args:
    serving_input_fn: A function that takes no arguments and returns an
      `InputFnOps`.
    default_output_alternative_key: the name of the head to serve when an
      incoming serving request does not explicitly request a specific head.
      Must be `None` if the estimator inherits from @{tf.estimator.Estimator}
      or for single-headed models.
    assets_extra: A dict specifying how to populate the assets.extra directory
      within the exported SavedModel.  Each key should give the destination
      path (including the filename) relative to the assets.extra directory.
      The corresponding value gives the full path of the source file to be
      copied.  For example, the simple case of copying a single file without
      renaming it is specified as
      `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
    as_text: whether to write the SavedModel proto in text format.
    exports_to_keep: Number of exports to keep.  Older exports will be
      garbage-collected.  Defaults to 5.  Set to None to disable garbage
      collection.
    strip_default_attrs: Boolean. If True, default attrs in the
      `GraphDef` will be stripped on write. This is recommended for better
      forward compatibility of the resulting `SavedModel`.

  Returns:
    An ExportStrategy that can be passed to the Experiment constructor.
  """
    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  strip_default_attrs=False):
        """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      strip_default_attrs: Boolean. If `True`, default-valued attributes will
        be removed from the NodeDefs.

    Returns:
      The string path to the exported directory.

    Raises:
      ValueError: If `estimator` is a @{tf.estimator.Estimator} instance
        and `default_output_alternative_key` was specified.
    """
        if isinstance(estimator, core_estimator.Estimator):
            if default_output_alternative_key is not None:
                raise ValueError(
                    'default_output_alternative_key is not supported in core '
                    'Estimator. Given: {}'.format(
                        default_output_alternative_key))
            export_result = estimator.export_savedmodel(
                export_dir_base,
                serving_input_fn,
                assets_extra=assets_extra,
                as_text=as_text,
                checkpoint_path=checkpoint_path,
                strip_default_attrs=strip_default_attrs)
        else:
            export_result = estimator.export_savedmodel(
                export_dir_base,
                serving_input_fn,
                default_output_alternative_key=default_output_alternative_key,
                assets_extra=assets_extra,
                as_text=as_text,
                checkpoint_path=checkpoint_path,
                strip_default_attrs=strip_default_attrs)

        garbage_collect_exports(export_dir_base, exports_to_keep)
        return export_result

    return export_strategy.ExportStrategy('Servo', export_fn,
                                          strip_default_attrs)
Exemplo n.º 3
0
def extend_export_strategy(base_export_strategy,
                           post_export_fn,
                           post_export_name=None):
    """Extend ExportStrategy, calling post_export_fn after export.

  Args:
    base_export_strategy: An ExportStrategy that can be passed to the Experiment
      constructor.
    post_export_fn: A user-specified function to call after exporting the
      SavedModel. Takes two arguments - the path to the SavedModel exported by
      base_export_strategy and the directory where to export the SavedModel
      modified by the post_export_fn. Returns the path to the exported
      SavedModel.
    post_export_name: The directory name under the export base directory where
      SavedModels generated by the post_export_fn will be written. If None, the
      directory name of base_export_strategy is used.

  Returns:
    An ExportStrategy that can be passed to the Experiment constructor.
  """
    def export_fn(estimator, export_dir_base, checkpoint_path=None):
        """Exports the given Estimator as a SavedModel and invokes post_export_fn.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graphs and checkpoint.
      checkpoint_path: The checkpoint path to export. If None (the default),
        the most recent checkpoint found within the model directory is chosen.

    Returns:
      The string path to the SavedModel indicated by post_export_fn.

    Raises:
      ValueError: If `estimator` is a `tf.estimator.Estimator` instance
        and `default_output_alternative_key` was specified or if post_export_fn
        does not return a valid directory.
      RuntimeError: If unable to create temporary or final export directory.
    """
        tmp_base_export_folder = 'temp-base-export-' + str(int(time.time()))
        tmp_base_export_dir = os.path.join(export_dir_base,
                                           tmp_base_export_folder)
        if gfile.Exists(tmp_base_export_dir):
            raise RuntimeError('Failed to obtain base export directory')
        gfile.MakeDirs(tmp_base_export_dir)
        tmp_base_export = base_export_strategy.export(estimator,
                                                      tmp_base_export_dir,
                                                      checkpoint_path)

        tmp_post_export_folder = 'temp-post-export-' + str(int(time.time()))
        tmp_post_export_dir = os.path.join(export_dir_base,
                                           tmp_post_export_folder)
        if gfile.Exists(tmp_post_export_dir):
            raise RuntimeError('Failed to obtain temp export directory')

        gfile.MakeDirs(tmp_post_export_dir)
        tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir)

        if not tmp_post_export.startswith(tmp_post_export_dir):
            raise ValueError(
                'post_export_fn must return a sub-directory of {}'.format(
                    tmp_post_export_dir))
        post_export_relpath = os.path.relpath(tmp_post_export,
                                              tmp_post_export_dir)
        post_export = os.path.join(export_dir_base, post_export_relpath)
        if gfile.Exists(post_export):
            raise RuntimeError('Failed to obtain final export directory')
        gfile.Rename(tmp_post_export, post_export)

        gfile.DeleteRecursively(tmp_base_export_dir)
        gfile.DeleteRecursively(tmp_post_export_dir)
        return post_export

    name = post_export_name if post_export_name else base_export_strategy.name
    return export_strategy.ExportStrategy(name, export_fn)
Exemplo n.º 4
0
def make_export_strategy(train_config, args, keep_target, assets_extra=None):
    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  eval_result=None):
        with ops.Graph().as_default() as g:
            contrib_variables.create_global_step(g)

            input_ops = serving_from_csv_input(train_config, args, keep_target)
            model_fn_ops = estimator._call_model_fn(
                input_ops.features, None, model_fn_lib.ModeKeys.INFER)
            output_fetch_tensors = make_output_tensors(
                train_config=train_config,
                args=args,
                input_ops=input_ops,
                model_fn_ops=model_fn_ops,
                keep_target=keep_target)

            signature_def_map = {
                'serving_default':
                signature_def_utils.predict_signature_def(
                    input_ops.default_inputs, output_fetch_tensors)
            }

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
            if not checkpoint_path:
                raise NotFittedError("Couldn't find trained model at %s." %
                                     estimator._model_dir)

            export_dir = saved_model_export_utils.get_timestamped_export_dir(
                export_dir_base)

            with tf_session.Session('') as session:
                #variables.initialize_local_variables()
                variables.local_variables_initializer()
                data_flow_ops.tables_initializer()
                saver_for_restore = saver.Saver(variables.global_variables(),
                                                sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer())

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=init_op)
                builder.save(False)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    gfile.MakeDirs(dest_path)
                    gfile.Copy(source, dest_absolute)

        # only keep the last 3 models
        saved_model_export_utils.garbage_collect_exports(export_dir_base,
                                                         exports_to_keep=3)

        # save the last model to the model folder.
        # export_dir_base = A/B/intermediate_models/
        if keep_target:
            final_dir = os.path.join(args.job_dir, 'evaluation_model')
        else:
            final_dir = os.path.join(args.job_dir, 'model')
        if file_io.is_directory(final_dir):
            file_io.delete_recursively(final_dir)
        file_io.recursive_create_dir(final_dir)
        _recursive_copy(export_dir, final_dir)

        return export_dir

    if keep_target:
        intermediate_dir = 'intermediate_evaluation_models'
    else:
        intermediate_dir = 'intermediate_prediction_models'

    return export_strategy.ExportStrategy(intermediate_dir, export_fn)
Exemplo n.º 5
0
def make_best_model_export_strategy(serving_input_fn,
                                    exports_to_keep=1,
                                    model_dir=None,
                                    event_file_pattern=None,
                                    compare_fn=None,
                                    default_output_alternative_key=None,
                                    strip_default_attrs=None):
    """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.

  Args:
    serving_input_fn: a function that takes no arguments and returns an
      `InputFnOps`.
    exports_to_keep: an integer indicating how many historical best models need
      to be preserved.
    model_dir: Directory where model parameters, graph etc. are saved. This will
        be used to load eval metrics from the directory when the export strategy
        is created. So the best metrics would not be lost even if the export
        strategy got preempted, which guarantees that only the best model would
        be exported regardless of preemption. If None, however, the export
        strategy would not be preemption-safe. To be preemption-safe, both
        model_dir and event_file_pattern would be needed.
    event_file_pattern: event file name pattern relative to model_dir, e.g.
        "eval_continuous/*.tfevents.*". If None, however, the export strategy
        would not be preemption-safe. To be preemption-safe, both
        model_dir and event_file_pattern would be needed.
    compare_fn: a function that select the 'best' candidate from a dictionary
        of evaluation result keyed by corresponding checkpoint path.
    default_output_alternative_key: the key for default serving signature for
        multi-headed inference graphs.
    strip_default_attrs: Boolean. If True, default attrs in the
      `GraphDef` will be stripped on write. This is recommended for better
      forward compatibility of the resulting `SavedModel`.

  Returns:
    An ExportStrategy that can be passed to the Experiment constructor.
  """
    best_model_export_strategy = make_export_strategy(
        serving_input_fn,
        exports_to_keep=exports_to_keep,
        default_output_alternative_key=default_output_alternative_key,
        strip_default_attrs=strip_default_attrs)

    full_event_file_pattern = os.path.join(
        model_dir,
        event_file_pattern) if model_dir and event_file_pattern else None
    best_model_selector = BestModelSelector(full_event_file_pattern,
                                            compare_fn)

    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path,
                  eval_result=None):
        """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
        if not checkpoint_path:
            # TODO(b/67425018): switch to
            #    checkpoint_path = estimator.latest_checkpoint()
            #  as soon as contrib is cleaned up and we can thus be sure that
            #  estimator is a tf.estimator.Estimator and not a
            #  tf.contrib.learn.Estimator
            checkpoint_path = checkpoint_management.latest_checkpoint(
                estimator.model_dir)
        export_checkpoint_path, export_eval_result = best_model_selector.update(
            checkpoint_path, eval_result)

        if export_checkpoint_path and export_eval_result is not None:
            checkpoint_base = os.path.basename(export_checkpoint_path)
            export_dir = os.path.join(export_dir_base, checkpoint_base)
            return best_model_export_strategy.export(estimator, export_dir,
                                                     export_checkpoint_path,
                                                     export_eval_result)
        else:
            return ''

    return export_strategy.ExportStrategy('best_model', export_fn)
def make_best_model_export_strategy(serving_input_fn,
                                    exports_to_keep=1,
                                    compare_fn=None,
                                    default_output_alternative_key=None):
    """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.

  Args:
    serving_input_fn: a function that takes no arguments and returns an
      `InputFnOps`.
    exports_to_keep: an integer indicating how many historical best models need
      to be preserved.
    compare_fn: a function that select the 'best' candidate from a dictionary
        of evaluation result keyed by corresponding checkpoint path.
    default_output_alternative_key: the key for default serving signature for
        multi-headed inference graphs.

  Returns:
    An ExportStrategy that can be passed to the Experiment constructor.
  """
    best_model_export_strategy = make_export_strategy(
        serving_input_fn,
        exports_to_keep=exports_to_keep,
        default_output_alternative_key=default_output_alternative_key)

    best_model_selector = BestModelSelector(compare_fn)

    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path,
                  eval_result=None):
        """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
        if not checkpoint_path:
            # TODO(b/67425018): switch to
            #    checkpoint_path = estimator.latest_checkpoint()
            #  as soon as contrib is cleaned up and we can thus be sure that
            #  estimator is a tf.estimator.Estimator and not a
            #  tf.contrib.learn.Estimator
            checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
        export_checkpoint_path, export_eval_result = best_model_selector.update(
            checkpoint_path, eval_result)

        if export_checkpoint_path and export_eval_result is not None:
            checkpoint_base = os.path.basename(export_checkpoint_path)
            export_dir = os.path.join(export_dir_base, checkpoint_base)
            return best_model_export_strategy.export(estimator, export_dir,
                                                     export_checkpoint_path,
                                                     export_eval_result)
        else:
            return ''

    return export_strategy.ExportStrategy('best_model', export_fn)
Exemplo n.º 7
0
def make_export_strategy(args, keep_target, assets_extra, features, schema,
                         stats):
    """Makes prediction graph that takes json input.

  Args:
    args: command line args
    keep_target: If ture, target column is returned in prediction graph. Target
        column must also exist in input data
    assets_extra: other fiels to copy to the output folder
    job_dir: root job folder
    features: features dict
    schema: schema list
    stats: stats dict
  """
    target_name = feature_transforms.get_target_name(features)
    csv_header = [col['name'] for col in schema]
    if not keep_target:
        csv_header.remove(target_name)

    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  eval_result=None):
        with ops.Graph().as_default() as g:
            contrib_variables.create_global_step(g)

            input_ops = feature_transforms.build_csv_serving_tensors_for_training_step(
                args.analysis, features, schema, stats, keep_target)
            model_fn_ops = estimator._call_model_fn(
                input_ops.features, None, model_fn_lib.ModeKeys.INFER)
            output_fetch_tensors = make_prediction_output_tensors(
                args=args,
                features=features,
                input_ops=input_ops,
                model_fn_ops=model_fn_ops,
                keep_target=keep_target)

            # Don't use signature_def_utils.predict_signature_def as that renames
            # tensor names if there is only 1 input/output tensor!
            signature_inputs = {
                key: tf.saved_model.utils.build_tensor_info(tensor)
                for key, tensor in six.iteritems(input_ops.default_inputs)
            }
            signature_outputs = {
                key: tf.saved_model.utils.build_tensor_info(tensor)
                for key, tensor in six.iteritems(output_fetch_tensors)
            }
            signature_def_map = {
                'serving_default':
                signature_def_utils.build_signature_def(
                    signature_inputs, signature_outputs,
                    tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
            }

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 estimator._model_dir)

            export_dir = saved_model_export_utils.get_timestamped_export_dir(
                export_dir_base)

            if (model_fn_ops.scaffold is not None
                    and model_fn_ops.scaffold.saver is not None):
                saver_for_restore = model_fn_ops.scaffold.saver
            else:
                saver_for_restore = saver.Saver(sharded=True)

            with tf_session.Session('') as session:
                saver_for_restore.restore(session, checkpoint_path)
                init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    resources.initialize_resources(
                        resources.shared_resources()), tf.tables_initializer())

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=init_op)
                builder.save(False)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    file_io.recursive_create_dir(dest_path)
                    file_io.copy(source, dest_absolute)

        # only keep the last 3 models
        saved_model_export_utils.garbage_collect_exports(export_dir_base,
                                                         exports_to_keep=3)

        # save the last model to the model folder.
        # export_dir_base = A/B/intermediate_models/
        if keep_target:
            final_dir = os.path.join(args.job_dir, 'evaluation_model')
        else:
            final_dir = os.path.join(args.job_dir, 'model')
        if file_io.is_directory(final_dir):
            file_io.delete_recursively(final_dir)
        file_io.recursive_create_dir(final_dir)
        recursive_copy(export_dir, final_dir)

        return export_dir

    if keep_target:
        intermediate_dir = 'intermediate_evaluation_models'
    else:
        intermediate_dir = 'intermediate_prediction_models'

    return export_strategy.ExportStrategy(intermediate_dir, export_fn)
Exemplo n.º 8
0
def make_best_model_export_strategy(serving_input_fn,
                                    exports_to_keep=1,
                                    compare_fn=None,
                                    default_output_alternative_key=None,
                                    strip_default_attrs=False):
    # pylint: disable=line-too-long
    """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.

  Args:
    serving_input_fn: a function that takes no arguments and returns an
      `InputFnOps`.
    exports_to_keep: an integer indicating how many historical best models need
      to be preserved.
    compare_fn: a function that select the 'best' candidate from a dictionary
        of evaluation result keyed by corresponding checkpoint path.
    default_output_alternative_key: the key for default serving signature for
        multi-headed inference graphs.
    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
      removed from the NodeDefs. For a detailed guide, see
      [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).

  Returns:
    An ExportStrategy that can be passed to the Experiment constructor.
  """
    # pylint: enable=line-too-long
    best_model_export_strategy = make_export_strategy(
        serving_input_fn,
        exports_to_keep=exports_to_keep,
        default_output_alternative_key=default_output_alternative_key,
        strip_default_attrs=strip_default_attrs)

    best_model_selector = BestModelSelector(compare_fn)

    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path,
                  eval_result=None):
        """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
        if not checkpoint_path:
            # TODO (b/67425018): switch to id:804 gh:805
            #    checkpoint_path = estimator.latest_checkpoint()
            #  as soon as contrib is cleaned up and we can thus be sure that
            #  estimator is a tf.estimator.Estimator and not a
            #  tf.contrib.learn.Estimator
            checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
        export_checkpoint_path, export_eval_result = best_model_selector.update(
            checkpoint_path, eval_result)

        if export_checkpoint_path and export_eval_result is not None:
            checkpoint_base = os.path.basename(export_checkpoint_path)
            export_dir = os.path.join(export_dir_base, checkpoint_base)
            return best_model_export_strategy.export(estimator, export_dir,
                                                     export_checkpoint_path,
                                                     export_eval_result)
        else:
            return ''

    return export_strategy.ExportStrategy('best_model', export_fn)
Exemplo n.º 9
0
def make_export_strategy(args, keep_target, assets_extra, features, schema):
    """Makes prediction graph that takes json input.

  Args:
    args: command line args
    keep_target: If ture, target column is returned in prediction graph. Target
        column must also exist in input data
    assets_extra: other fiels to copy to the output folder
    job_dir: root job folder
    features: features dict
    schema: schema list
  """
    target_name = get_target_name(features)
    raw_metadata = metadata_io.read_metadata(
        os.path.join(args.analysis_output_dir, RAW_METADATA_DIR))

    csv_header = [col['name'] for col in schema]
    if not keep_target:
        csv_header.remove(target_name)

    def export_fn(estimator,
                  export_dir_base,
                  checkpoint_path=None,
                  eval_result=None):
        with ops.Graph().as_default() as g:
            contrib_variables.create_global_step(g)

            input_ops = input_fn_maker.build_default_transforming_serving_input_fn(
                raw_metadata=raw_metadata,
                transform_savedmodel_dir=os.path.join(args.analysis_output_dir,
                                                      TRANSFORM_FN_DIR),
                raw_label_keys=[target_name],
                raw_feature_keys=csv_header,
                convert_scalars_to_vectors=True)()

            model_fn_ops = estimator._call_model_fn(
                input_ops.features, None, model_fn_lib.ModeKeys.INFER)
            output_fetch_tensors = make_prediction_output_tensors(
                args=args,
                features=features,
                input_ops=input_ops,
                model_fn_ops=model_fn_ops,
                keep_target=keep_target)

            signature_def_map = {
                'serving_default':
                signature_def_utils.predict_signature_def(
                    input_ops.default_inputs, output_fetch_tensors)
            }

            if not checkpoint_path:
                # Locate the latest checkpoint
                checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
            if not checkpoint_path:
                raise ValueError("Couldn't find trained model at %s." %
                                 estimator._model_dir)

            export_dir = saved_model_export_utils.get_timestamped_export_dir(
                export_dir_base)

            with tf_session.Session('') as session:
                variables.local_variables_initializer()
                data_flow_ops.tables_initializer()
                saver_for_restore = saver.Saver(variables.global_variables(),
                                                sharded=True)
                saver_for_restore.restore(session, checkpoint_path)

                init_op = control_flow_ops.group(
                    variables.local_variables_initializer(),
                    data_flow_ops.tables_initializer())

                # Perform the export
                builder = saved_model_builder.SavedModelBuilder(export_dir)
                builder.add_meta_graph_and_variables(
                    session, [tag_constants.SERVING],
                    signature_def_map=signature_def_map,
                    assets_collection=ops.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS),
                    legacy_init_op=init_op)
                builder.save(False)

            # Add the extra assets
            if assets_extra:
                assets_extra_path = os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets.extra'))
                for dest_relative, source in assets_extra.items():
                    dest_absolute = os.path.join(
                        compat.as_bytes(assets_extra_path),
                        compat.as_bytes(dest_relative))
                    dest_path = os.path.dirname(dest_absolute)
                    file_io.recursive_create_dir(dest_path)
                    file_io.copy(source, dest_absolute)

        # only keep the last 3 models
        saved_model_export_utils.garbage_collect_exports(export_dir_base,
                                                         exports_to_keep=3)

        # save the last model to the model folder.
        # export_dir_base = A/B/intermediate_models/
        if keep_target:
            final_dir = os.path.join(args.job_dir, 'evaluation_model')
        else:
            final_dir = os.path.join(args.job_dir, 'model')
        if file_io.is_directory(final_dir):
            file_io.delete_recursively(final_dir)
        file_io.recursive_create_dir(final_dir)
        recursive_copy(export_dir, final_dir)

        return export_dir

    if keep_target:
        intermediate_dir = 'intermediate_evaluation_models'
    else:
        intermediate_dir = 'intermediate_prediction_models'

    return export_strategy.ExportStrategy(intermediate_dir, export_fn)