示例#1
0
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(meta_graph_def)
    root = tracking.AutoTrackable()
    if init_op is not None:
      asset_feed_tensors = []
      asset_paths = []
      for tensor_name, value in loader_impl.get_asset_tensors(
          self._export_dir, meta_graph_def).items():
        asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
        asset_paths.append(tracking.TrackableAsset(value))
      init_fn = wrapped.prune(
          feeds=asset_feed_tensors,
          fetches=[wrapped.graph.as_graph_element(init_op)])
      initializer = _Initializer(init_fn, asset_paths)
      initializer._initialize()  # pylint: disable=protected-access
      root.initializer = initializer
      root.asset_paths = asset_paths
    else:
      root.asset_paths = []
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    return root
示例#2
0
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(meta_graph_def)
    root = tracking.AutoTrackable()
    if init_op is not None:
      asset_feed_tensors = []
      asset_paths = []
      for tensor_name, value in loader_impl.get_asset_tensors(
          self._export_dir, meta_graph_def).items():
        asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
        asset_paths.append(tracking.TrackableAsset(value))
      init_fn = wrapped.prune(
          feeds=asset_feed_tensors,
          fetches=[wrapped.graph.as_graph_element(init_op)])
      initializer = _Initializer(init_fn, asset_paths)
      initializer.initialize()
      root.initializer = initializer
      root.asset_paths = asset_paths
    else:
      root.asset_paths = []
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    return root
示例#3
0
def get_asset_tensors(saved_model_dir, meta_graph_def_to_load):
  try:
    return loader_impl.get_asset_tensors(saved_model_dir,
                                         meta_graph_def_to_load)
  # TODO(b/124491249): Remove this backwards compatibility once TFT 0.14 is
  # released.
  except AttributeError:
    return loader_impl._get_asset_tensors(  # pylint: disable=protected-access
        saved_model_dir, meta_graph_def_to_load)
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.TrackableAsset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root
示例#5
0
  def load(self, tags):
    """Creates an object from the MetaGraph identified by `tags`."""
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    load_graph_returns = [None]
    wrapped = wrap_function.wrap_function(
        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
        signature=[])
    saver, = load_graph_returns
    self.restore_variables(wrapped, saver)
    with wrapped.graph.as_default():
      init_op = loader_impl.get_init_op(
          meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
      # Add a dummy Tensor we know we can fetch to add control dependencies to.
      init_anchor = constant_op.constant(0., name="dummy_fetch")

    root = tracking.AutoTrackable()
    asset_feed_tensors = []
    asset_paths = []
    for tensor_name, value in loader_impl.get_asset_tensors(
        self._export_dir, meta_graph_def).items():
      asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
      asset_paths.append(tracking.TrackableAsset(value))
    init_fn = wrapped.prune(
        feeds=asset_feed_tensors,
        fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
    initializer = _Initializer(init_fn, asset_paths)
    # pylint: disable=protected-access
    local_init_op, _ = initializer._initialize()
    # pylint: enable=protected-access
    with ops.init_scope():
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
        for variable in wrapped.graph.get_collection_ref(
            ops.GraphKeys.LOCAL_VARIABLES):
          # pylint: disable=protected-access
          variable._initializer_op = local_init_op
          # pylint: enable=protected-access
    root.initializer = initializer
    root.asset_paths = asset_paths
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)

    root.signatures = signature_serialization.create_signature_map(
        signature_functions)
    root.variables = list(wrapped.graph.variables)
    root.tensorflow_version = (
        meta_graph_def.meta_info_def.tensorflow_version)
    root.tensorflow_git_version = (
        meta_graph_def.meta_info_def.tensorflow_git_version)
    root.graph = wrapped.graph
    root.prune = wrapped.prune
    return root
示例#6
0
def get_asset_tensors(saved_model_dir, meta_graph_def_to_load):
  return loader_impl.get_asset_tensors(saved_model_dir, meta_graph_def_to_load)
示例#7
0
    def load(self, tags):
        """Creates an object from the MetaGraph identified by `tags`."""
        meta_graph_def = self.get_meta_graph_def_from_tags(tags)
        load_shared_name_suffix = "_load_{}".format(ops.uid())
        functions = function_deserialization.load_function_def_library(
            meta_graph_def.graph_def.library,
            load_shared_name_suffix=load_shared_name_suffix)
        # Replace existing functions in the MetaGraphDef with renamed functions so
        # we don't have duplicates or name collisions.
        meta_graph_def.graph_def.library.Clear()
        for function in functions.values():
            meta_graph_def.graph_def.library.function.add().CopyFrom(
                function.function_def)
        # We've renamed functions and shared names. We need the same operation on
        # the GraphDef itself for consistency.
        for node_def in meta_graph_def.graph_def.node:
            function_deserialization.fix_node_def(
                node_def,
                functions,
                load_shared_name_suffix,
                debug_name="MetaGraph import")

        load_graph_returns = [None]
        wrapped = wrap_function.wrap_function(functools.partial(
            self.load_graph, load_graph_returns, meta_graph_def),
                                              signature=[])
        saver, = load_graph_returns
        self.restore_variables(wrapped, saver)
        with wrapped.graph.as_default():
            init_op = loader_impl.get_init_op(
                meta_graph_def
            ) or monitored_session.Scaffold.default_local_init_op()
            # Add a dummy Tensor we know we can fetch to add control dependencies to.
            init_anchor = constant_op.constant(0., name="dummy_fetch")

        root = tracking.AutoTrackable()
        asset_feed_tensors = []
        asset_paths = []
        for tensor_name, value in loader_impl.get_asset_tensors(
                self._export_dir, meta_graph_def).items():
            asset_feed_tensors.append(
                wrapped.graph.as_graph_element(tensor_name))
            asset_paths.append(tracking.Asset(value))
        init_fn = wrapped.prune(
            feeds=asset_feed_tensors,
            fetches=[init_anchor,
                     wrapped.graph.as_graph_element(init_op)])
        initializer = _Initializer(init_fn, asset_paths)
        # pylint: disable=protected-access
        local_init_op, _ = initializer._initialize()
        # pylint: enable=protected-access
        with ops.init_scope():
            if not context.executing_eagerly():
                ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS,
                                      local_init_op)
                for variable in wrapped.graph.get_collection_ref(
                        ops.GraphKeys.LOCAL_VARIABLES):
                    # pylint: disable=protected-access
                    variable._initializer_op = local_init_op
                    # pylint: enable=protected-access
        root.initializer = initializer
        root.asset_paths = asset_paths
        signature_functions = self._extract_signatures(wrapped, meta_graph_def)

        root.signatures = signature_serialization.create_signature_map(
            signature_functions)
        root.variables = list(wrapped.graph.variables)
        root.tensorflow_version = (
            meta_graph_def.meta_info_def.tensorflow_version)
        root.tensorflow_git_version = (
            meta_graph_def.meta_info_def.tensorflow_git_version)
        root.graph = wrapped.graph
        root.prune = wrapped.prune
        return root
示例#8
0
    def _model_fn_from_saved_model(self, features, labels, mode):
        """Load a SavedModel graph and return an EstimatorSpec."""
        # TODO(kathywu): Model function loads placeholders from the graph. Calling
        # export_all_saved_models creates another placeholder for the inputs, on top
        # of the original placeholders. There should be a way to avoid this.
        self._validate_mode(mode)

        g = tf.compat.v1.get_default_graph()
        if tf.compat.v1.train.get_global_step(g) is not None:
            raise RuntimeError(
                'Graph must not contain a global step tensor before the SavedModel is'
                ' loaded. Please make sure that the input function does not create a '
                'global step.')

        # Extract SignatureDef for information about the input and output tensors.
        signature_def = self._get_signature_def_for_mode(mode)

        # Generate input map for replacing the inputs in the SavedModel graph with
        # the provided features and labels.
        input_map = _generate_input_map(signature_def, features, labels)

        # Create a list of the names of output tensors. When the graph is loaded,
        # names of the output tensors may be remapped. This ensures that the correct
        # tensors are returned in the EstimatorSpec.
        output_tensor_names = [
            value.name for value in six.itervalues(signature_def.outputs)
        ]

        # Load the graph. `output_tensors` contains output `Tensors` in the same
        # same order as the `output_tensor_names` list.
        tags = export_lib.EXPORT_TAG_MAP[mode]
        _, output_tensors = self.saved_model_loader.load_graph(
            g, tags, input_map=input_map, return_elements=output_tensor_names)

        # Create saver object, and restore from the SavedModel `variables` directory
        # if no checkpoints have been saved in the `model_dir`.
        saver_obj = tf.compat.v1.train.Saver(
            saver_def=self._get_saver_def_from_mode(mode))
        init_fn = None
        if not super(SavedModelEstimator, self).latest_checkpoint():
            init_fn = self._restore_from_saver

        # Create a scaffold from the MetaGraphDef that contains ops to initialize
        # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
        # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
        # Get asset tensors, if any.
        meta_graph_def = self._get_meta_graph_def_for_mode(mode)
        asset_tensors_dictionary = loader_impl.get_asset_tensors(
            self.saved_model_loader.export_dir,
            meta_graph_def,
            import_scope=None)
        # TODO(kathywu): switch to loader_impl._get_main_op
        scaffold = tf.compat.v1.train.Scaffold(
            local_init_op=loader_impl._get_main_op_tensor(  # pylint: disable=protected-access
                meta_graph_def),
            local_init_feed_dict=asset_tensors_dictionary,
            saver=saver_obj,
            init_fn=init_fn)

        # Ensure that a global step tensor has been created.
        global_step_tensor = tf.compat.v1.train.get_global_step(g)
        tf.compat.v1.train.assert_global_step(global_step_tensor)

        # Extract values to return in the EstimatorSpec.
        output_map = dict(zip(output_tensor_names, output_tensors))
        outputs = {
            key: output_map[value.name]
            for key, value in six.iteritems(signature_def.outputs)
        }

        loss, predictions, metrics = _validate_and_extract_outputs(
            mode, outputs, signature_def.method_name)

        train_op = tf.compat.v1.get_collection(constants.TRAIN_OP_KEY)
        if len(train_op) > 1:
            raise RuntimeError(
                'Multiple ops found in the train_op collection.')
        train_op = None if not train_op else train_op[0]

        _clear_saved_model_collections()
        return model_fn_lib.EstimatorSpec(scaffold=scaffold,
                                          mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          predictions=predictions,
                                          eval_metric_ops=metrics)