示例#1
0
    def testConstructWrapperWithExistingEmptyDumpRoot(self):
        os.mkdir(self._tmp_dir)
        self.assertTrue(os.path.isdir(self._tmp_dir))

        local_cli_wrapper.LocalCLIDebugWrapperSession(session.Session(),
                                                      dump_root=self._tmp_dir,
                                                      log_usage=False)
 def testConstructWrapperWithExistingFileDumpRoot(self):
   file_path = os.path.join(self._tmp_dir, "foo")
   open(file_path, "a").close()  # Create the file
   self.assertTrue(os.path.isfile(file_path))
   with self.assertRaisesRegex(ValueError, "dump_root path points to a file"):
     local_cli_wrapper.LocalCLIDebugWrapperSession(
         session.Session(), dump_root=file_path, log_usage=False)
示例#3
0
    def before_run(self, run_context):
        if not self._session_wrapper:
            self._session_wrapper = local_cli_wrapper.LocalCLIDebugWrapperSession(
                run_context.session,
                ui_type=self._ui_type,
                dump_root=self._dump_root,
                thread_name_filter=self._thread_name_filter)

            # Actually register tensor filters registered prior to the construction
            # of the underlying LocalCLIDebugWrapperSession object.
            for filter_name in self._pending_tensor_filters:
                self._session_wrapper.add_tensor_filter(
                    filter_name, self._pending_tensor_filters[filter_name])

        # Increment run call counter.
        self._session_wrapper.increment_run_call_count()

        # Adapt run_context to an instance of OnRunStartRequest for invoking
        # superclass on_run_start().
        on_run_start_request = framework.OnRunStartRequest(
            run_context.original_args.fetches,
            run_context.original_args.feed_dict, None, None,
            self._session_wrapper.run_call_count)

        on_run_start_response = self._session_wrapper.on_run_start(
            on_run_start_request)
        self._performed_action = on_run_start_response.action

        run_args = session_run_hook.SessionRunArgs(
            None, feed_dict=None, options=config_pb2.RunOptions())
        if self._performed_action == framework.OnRunStartAction.DEBUG_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_debug(
                run_args.options,
                on_run_start_response.debug_urls,
                debug_ops=on_run_start_response.debug_ops,
                node_name_regex_whitelist=(
                    on_run_start_response.node_name_regex_whitelist),
                op_type_regex_whitelist=(
                    on_run_start_response.op_type_regex_whitelist),
                tensor_dtype_regex_whitelist=(
                    on_run_start_response.tensor_dtype_regex_whitelist),
                tolerate_debug_op_creation_failures=(
                    on_run_start_response.tolerate_debug_op_creation_failures))
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.PROFILE_RUN:
            # pylint: disable=protected-access
            self._session_wrapper._decorate_run_options_for_profile(
                run_args.options)
            # pylint: enable=protected-access
        elif self._performed_action == framework.OnRunStartAction.INVOKE_STEPPER:
            # The _finalized property must be set to False so that the NodeStepper
            # can insert ops for retrieving TensorHandles.
            # pylint: disable=protected-access
            run_context.session.graph._finalized = False
            # pylint: enable=protected-access

        return run_args
  def testConstructWrapperWithExistingNonEmptyDumpRoot(self):
    dir_path = os.path.join(self._tmp_dir, "foo")
    os.mkdir(dir_path)
    self.assertTrue(os.path.isdir(dir_path))

    with self.assertRaisesRegex(
        ValueError, "dump_root path points to a non-empty directory"):
      local_cli_wrapper.LocalCLIDebugWrapperSession(
          session.Session(), dump_root=self._tmp_dir, log_usage=False)
示例#5
0
def run_saved_model_with_feed_dict(saved_model_dir,
                                   tag_set,
                                   signature_def_key,
                                   input_tensor_key_feed_dict,
                                   outdir,
                                   overwrite_flag,
                                   tf_debug=False):
    """Runs SavedModel and fetch all outputs.

  Runs the input dictionary through the MetaGraphDef within a SavedModel
  specified by the given tag_set and SignatureDef. Also save the outputs to file
  if outdir is not None.

  Args:
    saved_model_dir: Directory containing the SavedModel to execute.
    tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
        string format, separated by ','. For tag-set contains multiple tags, all
        tags must be passed in.
    signature_def_key: A SignatureDef key string.
    input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    outdir: A directory to save the outputs to. If the directory doesn't exist,
        it will be created.
    overwrite_flag: A boolean flag to allow overwrite output file if file with
        the same name exists.
    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
        intermediate Tensor values and runtime GraphDefs while running the
        SavedModel.

  Raises:
    ValueError: When any of the input tensor keys is not valid.
    RuntimeError: An error when output file already exists and overwrite is not
    enabled.
  """
    # Get a list of output tensor names.
    meta_graph_def = saved_model_utils.get_meta_graph_def(
        saved_model_dir, tag_set)

    # Re-create feed_dict based on input tensor name instead of key as session.run
    # uses tensor name.
    inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)

    # Check if input tensor keys are valid.
    for input_key_name in input_tensor_key_feed_dict.keys():
        if input_key_name not in inputs_tensor_info.keys():
            raise ValueError(
                '"%s" is not a valid input key. Please choose from %s, or use '
                '--show option.' %
                (input_key_name,
                 '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))

    inputs_feed_dict = {
        inputs_tensor_info[key].name: tensor
        for key, tensor in input_tensor_key_feed_dict.items()
    }
    # Get outputs
    outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
        meta_graph_def, signature_def_key)
    # Sort to preserve order because we need to go from value to key later.
    output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
    output_tensor_names_sorted = [
        outputs_tensor_info[tensor_key].name
        for tensor_key in output_tensor_keys_sorted
    ]

    with session.Session(graph=ops_lib.Graph()) as sess:
        loader.load(sess, tag_set.split(','), saved_model_dir)

        if tf_debug:
            sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)

        outputs = sess.run(output_tensor_names_sorted,
                           feed_dict=inputs_feed_dict)

        for i, output in enumerate(outputs):
            output_tensor_key = output_tensor_keys_sorted[i]
            print('Result for output key %s:\n%s' %
                  (output_tensor_key, output))

            # Only save if outdir is specified.
            if outdir:
                # Create directory if outdir does not exist
                if not os.path.isdir(outdir):
                    os.makedirs(outdir)
                output_full_path = os.path.join(outdir,
                                                output_tensor_key + '.npy')

                # If overwrite not enabled and file already exist, error out
                if not overwrite_flag and os.path.exists(output_full_path):
                    raise RuntimeError(
                        'Output file %s already exists. Add \"--overwrite\" to overwrite'
                        ' the existing output files.' % output_full_path)

                np.save(output_full_path, output)
                print('Output %s is saved to %s' %
                      (output_tensor_key, output_full_path))
示例#6
0
 def testConstructWrapper(self):
     local_cli_wrapper.LocalCLIDebugWrapperSession(session.Session(),
                                                   log_usage=False)
示例#7
0
 def __init__(self):
     sess = session.Session(graph=ops_lib.Graph())
     loader.load(sess, tag_set.split(','), saved_model_dir)
     if tf_debug:
       sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)
     self._sess = sess