Esempio n. 1
0
  def _init_from_proto(self, hparam_def):
    """Creates a new HParams from `HParamDef` protocol buffer.

    Args:
      hparam_def: `HParamDef` protocol buffer.
    """
    assert isinstance(hparam_def, hparam_pb2.HParamDef)
    for name, value in hparam_def.hparam.items():
      kind = value.WhichOneof('kind')
      if kind.endswith('_value'):
        # Single value.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, int(getattr(value, kind)))
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(name, compat.as_str(getattr(value, kind)))
        else:
          self.add_hparam(name, getattr(value, kind))
      else:
        # List of values.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(
              name, [compat.as_str(v) for v in getattr(value, kind).value])
        else:
          self.add_hparam(name, [v for v in getattr(value, kind).value])
Esempio n. 2
0
def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements):
  """Populates the TF_ImportGraphDefOptions `options`."""
  c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
  c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
  c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)

  for input_src, input_dst in input_map.items():
    input_src = compat.as_str(input_src)
    if input_src.startswith('^'):
      src_name = compat.as_bytes(input_src[1:])
      dst_op = input_dst._as_tf_output().oper  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name,
                                                           dst_op)
    else:
      src_name, src_idx = _ParseTensorName(input_src)
      src_name = compat.as_str(src_name)
      dst_output = input_dst._as_tf_output()  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name,
                                                    src_idx, dst_output)
  for name in return_elements or []:
    if ':' in name:
      op_name, index = _ParseTensorName(name)
      op_name = compat.as_str(op_name)
      c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
    else:
      c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
                                                       compat.as_str(name))
Esempio n. 3
0
def meta_graph_transform(
    base_meta_graph_def, input_names, output_names, transforms, tags,
    checkpoint_path=None):
  """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' transform.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
  meta_graph_def = _meta_graph_pb2.MetaGraphDef()

  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

  transformed_graph_def = _do_transforms(
      base_meta_graph_def.graph_def,
      input_names,
      output_names,
      initializer_names,
      transforms,
      base_meta_graph_def.saver_def,
      checkpoint_path)

  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
  meta_graph_def.meta_info_def.ClearField('tags')
  for tag in tags:
    meta_graph_def.meta_info_def.tags.append(tag)

  base_op_names = [compat.as_str(node.name)
                   for node in base_meta_graph_def.graph_def.node]
  retained_op_names = [compat.as_str(node.name)
                       for node in meta_graph_def.graph_def.node]
  removed_op_names = set(base_op_names) - set(retained_op_names)

  # Copy saver, excluding any pruned nodes
  _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

  # Copy collections, excluding any pruned nodes
  for collection_name in base_meta_graph_def.collection_def:
    _add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name,
        removed_op_names)

  # Copy signature_defs, excluding any pruned nodes
  for signature_name in base_meta_graph_def.signature_def:
    _add_pruned_signature(
        base_meta_graph_def, meta_graph_def, signature_name,
        removed_op_names)

  return meta_graph_def
Esempio n. 4
0
def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
  """Asserts that two `GraphDef`s are (mostly) the same.

  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
  nodes, attrs, and control inputs.  Node names are used to match up nodes
  between the graphs, so the naming of nodes must be consistent.

  Args:
    actual: The `GraphDef` we have.
    expected: The `GraphDef` we expected.
    checkpoint_v2: boolean determining whether to ignore randomized attribute
        values that appear in V2 checkpoints.

  Raises:
    AssertionError: If the `GraphDef`s do not match.
    TypeError: If either argument is not a `GraphDef`.
  """
  if not isinstance(actual, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for actual, got %s" %
                    type(actual).__name__)
  if not isinstance(expected, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for expected, got %s" %
                    type(expected).__name__)

  if checkpoint_v2:
    _strip_checkpoint_v2_randomized(actual)
    _strip_checkpoint_v2_randomized(expected)

  diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
                                                expected.SerializeToString())
  if diff:
    raise AssertionError(compat.as_str(diff))
Esempio n. 5
0
def _create_new_tf_function(func_graph):
  """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
  c_func = c_api.TF_GraphToFunction_wrapper(
      func_graph._c_graph,
      compat.as_str(func_graph.name),
      False,  # append_hash_to_fn_name
      None,  # opers
      [t._as_tf_output() for t in func_graph.inputs],
      [t._as_tf_output() for t in func_graph.outputs],
      [],
      None,  # opts
      None)  # description
  _ = c_api_util.ScopedTFFunction(c_func)

  # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
  # deserializing it into a Python FunctionDef, then reserializing it to create
  # a new TF_Function that we add to the graph.
  fdef = _function.function_def_from_tf_function(c_func)
  defined_func = _function._from_definition(fdef)
  defined_func._sub_functions = func_graph._functions
  defined_func.add_to_graph(func_graph._outer_graph)

  return func_graph.name
Esempio n. 6
0
def _ProcessReturnElementsParam(return_elements):
  """Type-checks and possibly canonicalizes `return_elements`."""
  if return_elements is None: return None
  if not all(isinstance(x, compat.bytes_or_text_types)
             for x in return_elements):
    raise TypeError('return_elements must be a list of strings.')
  return tuple(compat.as_str(x) for x in return_elements)
Esempio n. 7
0
  def encode_arg(arg, path):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      user_specified_name = None
      try:
        user_specified_name = compat.as_str(
            arg.op.get_attr("_user_specified_name"))
      except ValueError:
        pass

      if path and user_specified_name and user_specified_name != path[0]:
        # The user has explicitly named the argument differently than the name
        # of the function argument.
        name = user_specified_name
      else:
        name = "/".join([str(p) for p in path])
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
    )):
      return arg
    return UnknownArgument()
def _clean_save_and_restore(graph_def, op, removed_op_names):
  """Clean the specified save and restore op.

  Updates the dtypes attribute of the save / restore op and the associated name
  and shape tensors to remove entries for variables that have been removed.

  Args:
    graph_def: A GraphDef proto to be transformed.
    op: The save or restore op to update.
    removed_op_names: List of op names that have been removed.
  """
  name = op.name + '/tensor_names'
  shape = op.name + '/shape_and_slices'
  name_op = _find_op(graph_def, name)
  shape_op = _find_op(graph_def, shape)
  name_op_value_tensor = name_op.attr['value'].tensor
  shape_op_value_tensor = shape_op.attr['value'].tensor
  names = []
  shapes = []
  dtypes = []
  for index, value in enumerate(name_op_value_tensor.string_val):
    if not _is_removed(compat.as_str(value), removed_op_names):
      names.append(value)
      shapes.append(shape_op_value_tensor.string_val[index])
      dtypes.append(op.attr['dtypes'].list.type[index])
  name_op_value_tensor.string_val[:] = names
  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
  shape_op_value_tensor.string_val[:] = shapes
  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
  op.attr['dtypes'].list.type[:] = dtypes

  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
Esempio n. 9
0
  def __init__(self, name, graph, operations, inputs, outputs, attrs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
    fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
        graph._c_graph,  # pylint: disable=protected-access
        compat.as_str(name),
        False,
        [o._c_op for o in operations],  # pylint: disable=protected-access
        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
        [],
        None,
        compat.as_str(""))

    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(iga): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use status.
      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
          fn, compat.as_str(name), serialized)

    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = c_api_util.ScopedTFFunction(fn)
    self._grad_func = None
Esempio n. 10
0
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
      if not export_scope or compat.as_str(v.s).startswith(export_scope):
        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def
Esempio n. 11
0
    def save(self, sess, save_path, global_step=None, latest_filename=None):
        """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: string.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
        if latest_filename is None:
            latest_filename = "checkpoint"

        if os.path.split(latest_filename)[0]:
            raise ValueError(
                "'latest_filename' must not contain path components")

        if global_step is not None:
            if not isinstance(global_step, compat.integral_types):
                global_step = training_util.global_step(sess, global_step)
            checkpoint_file = "%s-%d" % (save_path, global_step)
        else:
            checkpoint_file = save_path
        save_path = os.path.dirname(save_path)
        if not isinstance(sess, session.SessionInterface):
            raise TypeError("'sess' must be a Session; %s" % sess)

        model_checkpoint_path = sess.run(
            self._save_tensor_name,
            {self._filename_tensor_name: checkpoint_file})
        model_checkpoint_path = compat.as_str(model_checkpoint_path)
        self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
        update_checkpoint_state(save_path, model_checkpoint_path,
                                self.last_checkpoints, latest_filename)
        return model_checkpoint_path
Esempio n. 12
0
def _ProcessReturnElementsParam(return_elements):
    """Type-checks and possibly canonicalizes `return_elements`."""
    if return_elements is None:
        return None
    if not all(
            isinstance(x, compat.bytes_or_text_types)
            for x in return_elements):
        raise TypeError('return_elements must be a list of strings.')
    return tuple(compat.as_str(x) for x in return_elements)
Esempio n. 13
0
  def save(self, sess, save_path, global_step=None, latest_filename=None):
    """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: String.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
    if latest_filename is None:
      latest_filename = "checkpoint"

    if os.path.split(latest_filename)[0]:
      raise ValueError("'latest_filename' must not contain path components")

    if global_step is not None:
      if not isinstance(global_step, compat.integral_types):
        global_step = training_util.global_step(sess, global_step)
      checkpoint_file = "%s-%d" % (save_path, global_step)
    else:
      checkpoint_file = save_path
    save_path = os.path.dirname(save_path)
    if not isinstance(sess, session.SessionInterface):
      raise TypeError("'sess' must be a Session; %s" % sess)

    model_checkpoint_path = sess.run(
        self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
    model_checkpoint_path = compat.as_str(model_checkpoint_path)
    self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
    update_checkpoint_state(save_path, model_checkpoint_path,
                            self.last_checkpoints, latest_filename)
    return model_checkpoint_path
Esempio n. 14
0
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def
Esempio n. 15
0
  def __init__(self, name, graph, operations, inputs, outputs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
    with errors.raise_exception_on_not_ok_status() as status:
      fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
          graph._c_graph,  # pylint: disable=protected-access
          compat.as_str(name),
          False,
          [o._c_op for o in operations],  # pylint: disable=protected-access
          [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
          [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
          [],
          None,
          compat.as_str(""),
          status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      with errors.raise_exception_on_not_ok_status() as status:
        pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = fn
    self._grad_func = None
Esempio n. 16
0
    def __init__(self, name, graph, operations, inputs, outputs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
        with errors.raise_exception_on_not_ok_status() as status:
            fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
                graph._c_graph,  # pylint: disable=protected-access
                compat.as_str(name),
                False,
                [o._c_op for o in operations],  # pylint: disable=protected-access
                [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
                [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
                [],
                None,
                compat.as_str(""),
                status)
        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            with errors.raise_exception_on_not_ok_status() as status:
                pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.in_eager_mode():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = fn
        self._grad_func = None
Esempio n. 17
0
def canonicalize_signatures(signatures):
    """Converts `signatures` into a dictionary of concrete functions."""
    if signatures is None:
        return {}
    if not isinstance(signatures, collections.Mapping):
        signatures = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures
        }
    concrete_signatures = {}
    for signature_key, function in signatures.items():
        signature_function = _get_signature(function)
        if signature_function is None:
            raise ValueError((
                "Expected a TensorFlow function to generate a signature for, but "
                "got {}. Only `tf.functions` with an input signature or "
                "concrete functions can be used as a signature."
            ).format(function))

        # Re-wrap the function so that it returns a dictionary of Tensors. This
        # matches the format of 1.x-style signatures.
        # pylint: disable=cell-var-from-loop
        @def_function.function
        def signature_wrapper(**kwargs):
            structured_outputs = signature_function(**kwargs)
            return _normalize_outputs(structured_outputs,
                                      signature_function.name, signature_key)

        # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
        # always match keyword arguments.
        tensor_spec_signature = {}
        for keyword, tensor in zip(
                signature_function._arg_keywords,  # pylint: disable=protected-access
                signature_function.inputs):
            keyword = compat.as_str(keyword)
            tensor_spec_signature[
                keyword] = tensor_spec.TensorSpec.from_tensor(tensor,
                                                              name=keyword)
        final_concrete = signature_wrapper.get_concrete_function(
            **tensor_spec_signature)
        # pylint: disable=protected-access
        if len(final_concrete._arg_keywords) == 1:
            # If there is only one input to the signature, a very common case, then
            # ordering is unambiguous and we can let people pass a positional
            # argument. Since SignatureDefs are unordered (protobuf "map") multiple
            # arguments means we need to be keyword-only.
            final_concrete._num_positional_args = 1
        else:
            final_concrete._num_positional_args = 0
        # pylint: enable=protected-access
        concrete_signatures[signature_key] = final_concrete
        # pylint: enable=cell-var-from-loop
    return concrete_signatures
    def _ReadAndCheckRowsUsingFeatures(self, num_rows):
        self.server.handler.num_rows = num_rows

        with self.test_session() as sess:
            feature_configs = {
                "int64_col":
                parsing_ops.FixedLenFeature([1], dtype=dtypes.int64),
                "string_col":
                parsing_ops.FixedLenFeature([1],
                                            dtype=dtypes.string,
                                            default_value="s_default"),
            }
            reader = cloud.BigQueryReader(
                project_id=_PROJECT,
                dataset_id=_DATASET,
                table_id=_TABLE,
                num_partitions=4,
                features=feature_configs,
                timestamp_millis=1,
                test_end_point=("%s:%s" %
                                (self.server.httpd.server_address[0],
                                 self.server.httpd.server_address[1])))

            key, value = _SetUpQueue(reader)

            seen_rows = []
            features = parsing_ops.parse_example(array_ops.reshape(value, [1]),
                                                 feature_configs)
            for _ in range(num_rows):
                int_value, str_value = sess.run(
                    [features["int64_col"], features["string_col"]])

                # Parse values returned from the session.
                self.assertEqual(int_value.shape, (1, 1))
                self.assertEqual(str_value.shape, (1, 1))
                int64_col = int_value[0][0]
                string_col = str_value[0][0]
                seen_rows.append(int64_col)

                # Compare.
                expected_row = _ROWS[int64_col]
                self.assertEqual(int64_col, expected_row[0])
                self.assertEqual(
                    compat.as_str(string_col),
                    ("s_%d" % int64_col) if expected_row[1] else "s_default")

            self.assertItemsEqual(seen_rows, range(num_rows))

            with self.assertRaisesOpError(
                    "is closed and has insufficient elements "
                    "\\(requested 1, current size 0\\)"):
                sess.run([key, value])
Esempio n. 19
0
  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    with self._lock:
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex[1])))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex)))
            self._exc_info_to_raise = sys.exc_info()
        self._stop_event.set()
Esempio n. 20
0
    def request_stop(self, ex=None):
        """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
        with self._lock:
            if not self._stop_event.is_set():
                if ex and self._exc_info_to_raise is None:
                    if isinstance(ex, tuple):
                        logging.info("Error reported to Coordinator: %s",
                                     compat.as_str(unicode(ex[1])))
                        self._exc_info_to_raise = ex
                    else:
                        logging.info("Error reported to Coordinator: %s",
                                     compat.as_str(unicode(ex)))
                        self._exc_info_to_raise = sys.exc_info()
                self._stop_event.set()
Esempio n. 21
0
def _node_def_unbound(from_node_def,
                      export_scope,
                      unbound_inputs,
                      as_unbound_inputs,
                      clear_devices=False):
    """Create a `NodeDef` proto with export_scope stripped given input names
  that are treated as unbound.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    as_unbound_inputs: A list of `String`s. Input names that are treated as
      unbound when exporting Operations.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
    node_def = copy.deepcopy(from_node_def)
    as_unbound_inputs = set(as_unbound_inputs)
    for i, v in enumerate(node_def.input):
        if node_def.input[i] in as_unbound_inputs:
            # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
            # identifiable.
            node_def.input[i] = _unbound_name(v)
            unbound_inputs.append(node_def.input[i])
        else:
            node_def.input[i] = ops.strip_name_scope(v, export_scope)
    node_def.name = compat.as_bytes(
        ops.strip_name_scope(from_node_def.name, export_scope))
    for k, v in six.iteritems(from_node_def.attr):
        if k == "_class":
            new_s = []
            for s in v.list.s:
                if compat.as_str(s) in as_unbound_inputs:
                    new_s.append(compat.as_bytes(_unbound_name(s)))
                else:
                    new_s.append(
                        compat.as_bytes(ops.strip_name_scope(s, export_scope)))
            node_def.attr[k].CopyFrom(
                attr_value_pb2.AttrValue(
                    list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
        else:
            node_def.attr[k].CopyFrom(v)

    if clear_devices:
        node_def.device = ""

    return node_def
Esempio n. 22
0
    def _init_from_proto(self, hparam_def):
        """Creates a new HParams from `HParamDef` protocol buffer.

        Args:
          hparam_def: `HParamDef` protocol buffer.
        """
        if not isinstance(hparam_def, hparam_pb2.HParamDef):
            raise AssertionError('Wrong "hparam_def" type')
        for name, value in hparam_def.hparam.items():
            kind = value.WhichOneof('kind')
            if kind.endswith('_value'):
                # Single value.
                if kind.startswith('int64'):
                    # Setting attribute value to be 'int' to ensure the type is compatible
                    # with both Python2 and Python3.
                    self.add_hparam(name, int(getattr(value, kind)))
                elif kind.startswith('bytes'):
                    # Setting attribute value to be 'str' to ensure the type is compatible
                    # with both Python2 and Python3. UTF-8 encoding is assumed.
                    self.add_hparam(name, compat.as_str(getattr(value, kind)))
                else:
                    self.add_hparam(name, getattr(value, kind))
            else:
                # List of values.
                if kind.startswith('int64'):
                    # Setting attribute value to be 'int' to ensure the type is compatible
                    # with both Python2 and Python3.
                    self.add_hparam(
                        name, [int(v) for v in getattr(value, kind).value])
                elif kind.startswith('bytes'):
                    # Setting attribute value to be 'str' to ensure the type is compatible
                    # with both Python2 and Python3. UTF-8 encoding is assumed.
                    self.add_hparam(
                        name,
                        [compat.as_str(v) for v in getattr(value, kind).value])
                else:
                    self.add_hparam(name,
                                    [v for v in getattr(value, kind).value])
Esempio n. 23
0
  def _set_c_attrs(self, attrs):
    """Sets `attrs` as attributes of self._c_func.

    Requires that self._c_func is not None.

    Args:
      attrs: a dictionary from attribute name to attribute proto value
    """
    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use the same status.
      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
                                         serialized)
Esempio n. 24
0
  def _set_c_attrs(self, attrs):
    """Sets `attrs` as attributes of self._c_func.

    Requires that self._c_func is not None.

    Args:
      attrs: a dictionary from attribute name to attribute proto value
    """
    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use the same status.
      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
                                         serialized)
Esempio n. 25
0
def _validate_namespace_whitelist(namespace_whitelist):
  """Validates namespace whitelist argument."""
  if namespace_whitelist is None:
    return []
  if not isinstance(namespace_whitelist, list):
    raise TypeError("Namespace whitelist must be a list of strings.")

  processed = []
  for namespace in namespace_whitelist:
    if not isinstance(namespace, six.string_types):
      raise ValueError("Whitelisted namespace must be a string. Got: {} of type"
                       " {}.".format(namespace, type(namespace)))
    processed.append(compat.as_str(namespace))
  return processed
  def _ReadAndCheckRowsUsingFeatures(self, num_rows):
    self.server.handler.num_rows = num_rows

    with self.test_session() as sess:
      feature_configs = {
          "int64_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.int64),
          "string_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.string, default_value="s_default"),
      }
      reader = cloud.BigQueryReader(
          project_id=_PROJECT,
          dataset_id=_DATASET,
          table_id=_TABLE,
          num_partitions=4,
          features=feature_configs,
          timestamp_millis=1,
          test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
                                     self.server.httpd.server_address[1])))

      key, value = _SetUpQueue(reader)

      seen_rows = []
      features = parsing_ops.parse_example(
          array_ops.reshape(value, [1]), feature_configs)
      for _ in range(num_rows):
        int_value, str_value = sess.run(
            [features["int64_col"], features["string_col"]])

        # Parse values returned from the session.
        self.assertEqual(int_value.shape, (1, 1))
        self.assertEqual(str_value.shape, (1, 1))
        int64_col = int_value[0][0]
        string_col = str_value[0][0]
        seen_rows.append(int64_col)

        # Compare.
        expected_row = _ROWS[int64_col]
        self.assertEqual(int64_col, expected_row[0])
        self.assertEqual(
            compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
            else "s_default")

      self.assertItemsEqual(seen_rows, range(num_rows))

      with self.assertRaisesOpError("is closed and has insufficient elements "
                                    "\\(requested 1, current size 0\\)"):
        sess.run([key, value])
Esempio n. 27
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.
  """
    if cluster_resolver is None:
        cluster_resolver = TPUClusterResolver("")
    master = cluster_resolver.master()

    logging.info("Initializing the TPU system.")

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        # The easiest way to trigger a rewrite is to run the function with
        # TPUPartitionedCallOp.
        @function.defun
        def _tpu_init_fn():
            return tpu.initialize_system()

        # We can't call _tpu_init_fn normally (because it contains just a dummy op,
        # see above) but need to define it to get it added to eager context
        # and get its assigned name.
        # pylint: disable=protected-access
        graph_func = _tpu_init_fn._get_concrete_function_internal()
        func_name = compat.as_str(graph_func._inference_function.name)
        # pylint: enable=protected-access

        output = tpu_functional_ops.TPUPartitionedCall(args=[],
                                                       device_ordinal=0,
                                                       Tout=[dtypes.string],
                                                       f=func_name)
        serialized_topology = output[0].numpy()
    else:
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())

    logging.info("Finished initializing TPU system.")
    return topology.Topology(serialized=serialized_topology)
Esempio n. 28
0
 def lookup(self, name):
     """Looks up "name".
     Args:
       name: a string specifying the registry key for the candidate.
     Returns:
       Registered object if found
     Raises:
       LookupError: if "name" has not been registered.
     """
     name = compat.as_str(name)
     if name in self._registry:
         return self._registry[name][_TYPE_TAG]
     else:
         raise LookupError("%s registry has no entry for: %s" %
                           (self._name, name))
Esempio n. 29
0
 def lookup(self, name):
     """Looks up "name".
     Args:
       name: a string specifying the registry key for the candidate.
     Returns:
       Registered object if found
     Raises:
       LookupError: if "name" has not been registered.
     """
     name = compat.as_str(name)
     if name in self._registry:
         return self._registry[name][_TYPE_TAG]
     else:
         raise LookupError(
             "%s registry has no entry for: %s" % (self._name, name))
Esempio n. 30
0
def _node_def(from_node_def, export_scope, unbound_inputs):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    else:
      node_def.attr[k].CopyFrom(v)

  return node_def
Esempio n. 31
0
def _GetColocationNames(op):
  """Returns names of the ops that `op` should be colocated with."""
  colocation_names = []
  try:
    class_values = op.get_attr('_class')
  except ValueError:
    # No _class attr
    return
  for val in class_values:
    val = compat.as_str(val)
    if val.startswith('loc:@'):
      colocation_node_name = val[len('loc:@'):]
      if colocation_node_name != op.name:
        colocation_names.append(colocation_node_name)
  return colocation_names
Esempio n. 32
0
def _GetColocationNames(op):
    """Returns names of the ops that `op` should be colocated with."""
    colocation_names = []
    try:
        class_values = op.get_attr('_class')
    except ValueError:
        # No _class attr
        return
    for val in class_values:
        val = compat.as_str(val)
        if val.startswith('loc:@'):
            colocation_node_name = val[len('loc:@'):]
            if colocation_node_name != op.name:
                colocation_names.append(colocation_node_name)
    return colocation_names
def canonicalize_signatures(signatures):
  """Converts `signatures` into a dictionary of concrete functions."""
  if signatures is None:
    return {}
  if not isinstance(signatures, collections.Mapping):
    signatures = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
  concrete_signatures = {}
  for signature_key, function in signatures.items():
    signature_function = _get_signature(function)
    if signature_function is None:
      raise ValueError(
          ("Expected a TensorFlow function to generate a signature for, but "
           "got {}. Only `tf.functions` with an input signature or "
           "concrete functions can be used as a signature.").format(function))

    # Re-wrap the function so that it returns a dictionary of Tensors. This
    # matches the format of 1.x-style signatures.
    # pylint: disable=cell-var-from-loop
    @def_function.function
    def signature_wrapper(**kwargs):
      structured_outputs = signature_function(**kwargs)
      return _normalize_outputs(
          structured_outputs, signature_function.name, signature_key)
    # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
    # always match keyword arguments.
    tensor_spec_signature = {}
    for keyword, tensor in zip(
        signature_function._arg_keywords,  # pylint: disable=protected-access
        signature_function.inputs):
      keyword = compat.as_str(keyword)
      tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
          tensor, name=keyword)
    final_concrete = signature_wrapper.get_concrete_function(
        **tensor_spec_signature)
    # pylint: disable=protected-access
    if len(final_concrete._arg_keywords) == 1:
      # If there is only one input to the signature, a very common case, then
      # ordering is unambiguous and we can let people pass a positional
      # argument. Since SignatureDefs are unordered (protobuf "map") multiple
      # arguments means we need to be keyword-only.
      final_concrete._num_positional_args = 1
    else:
      final_concrete._num_positional_args = 0
    # pylint: enable=protected-access
    concrete_signatures[signature_key] = final_concrete
    # pylint: enable=cell-var-from-loop
  return concrete_signatures
Esempio n. 34
0
def _validate_namespace_whitelist(namespace_whitelist):
  """Validates namespace whitelist argument."""
  if namespace_whitelist is None:
    return None
  if not isinstance(namespace_whitelist, list):
    raise TypeError("`namespace_whitelist` must be a list of strings. Got: "
                    f"{namespace_whitelist} with type "
                    f"{type(namespace_whitelist)}.")

  processed = []
  for namespace in namespace_whitelist:
    if not isinstance(namespace, six.string_types):
      raise ValueError("Whitelisted namespace must be a string. Got: "
                       f"{namespace} of type {type(namespace)}.")
    processed.append(compat.as_str(namespace))
  return processed
def _get_signature_name_changes(concrete_function):
    """Checks for user-specified signature input names that are normalized."""
    # Map of {user-given name: normalized name} if the names are un-identical.
    name_changes = {}
    for signature_input_name, graph_input in zip(
            concrete_function.function_def.signature.input_arg,
            concrete_function.graph.inputs):
        try:
            user_specified_name = compat.as_str(
                graph_input.op.get_attr("_user_specified_name"))
            if signature_input_name.name != user_specified_name:
                name_changes[user_specified_name] = signature_input_name.name
        except ValueError:
            # Signature input does not have a user-specified name.
            pass
    return name_changes
Esempio n. 36
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices in a separate session and graph.

  Args:
    cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.contrib.tpu.Topology object for the topology of the TPU cluster.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  master = cluster_resolver.master()

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    # The easiest way to trigger a rewrite is to run the function with
    # TPUPartitionedCallOp.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    # We can't call _tpu_init_fn normally (because it contains just a dummy op,
    # see above) but need to define it to get it added to eager context
    # and get its assigned name.
    # pylint: disable=protected-access
    graph_func = _tpu_init_fn._get_concrete_function_internal()
    func_name = compat.as_str(graph_func._inference_function.name)
    # pylint: enable=protected-access

    output = tpu_functional_ops.TPUPartitionedCall(
        args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
    serialized_topology = output[0].numpy()
  else:
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  return topology.Topology(serialized=serialized_topology)
Esempio n. 37
0
def imperative_grad(tape,
                    target,
                    sources,
                    output_gradients=None,
                    sources_raw=None,
                    unconnected_gradients=UnconnectedGradients.NONE):
  """Computes gradients from the imperatively defined tape on top of the stack.

  Works by filtering the tape, computing how many downstream usages are of each
  tensor and entry, and repeatedly applying backward functions until we have
  gradients for all sources.

  Args:
   tape: the gradient tape which stores the trace.
   target: either a Tensor or list of Tensors to be differentiated.
   sources: list of Tensors for which we want gradients
   output_gradients: if not None, a list of gradient provided for each Target,
    or None if we are to use the target's computed downstream gradient.
   sources_raw: if not None, a list of the source python objects from which the
    sources were generated. Should have the same length as sources. Only needs
    to be populated if unconnected_gradients is 'zero'.
   unconnected_gradients: determines the value returned if the target and
    sources are unconnected. When 'none' the value returned is None wheras when
    'zero' a zero tensor in the same shape as the sources is returned.

  Returns:
   the gradient wrt each of the sources.

  Raises:
    ValueError: if the arguments are invalid.
    RuntimeError: if something goes wrong.
  """
  try:
    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
  except ValueError:
    raise ValueError(
        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)

  return pywrap_tfe.TFE_Py_TapeGradient(
      tape._tape,  # pylint: disable=protected-access
      target,
      sources,
      output_gradients,
      sources_raw,
      compat.as_str(unconnected_gradients.value))
Esempio n. 38
0
  def _revive_metric_from_config(self, metadata, node_id):
    class_name = compat.as_str(metadata['class_name'])
    config = metadata.get('config')

    if not generic_utils.validate_config(config):
      return None

    try:
      obj = metrics.deserialize(
          generic_utils.serialize_keras_class_and_config(class_name, config))
    except ValueError:
      return None

    build_input_shape = metadata.get('build_input_shape')
    if build_input_shape is not None and hasattr(obj, '_build'):
      obj._build(build_input_shape)  # pylint: disable=protected-access

    return obj
Esempio n. 39
0
def get_temp_export_dir(timestamped_export_dir):
    """Builds a directory name based on the argument but starting with 'temp-'.

  This relies on the fact that TensorFlow Serving ignores subdirectories of
  the base directory that can't be parsed as integers.

  Args:
    timestamped_export_dir: the name of the eventual export directory, e.g.
      /foo/bar/<timestamp>

  Returns:
    A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>.
  """
    (dirname, basename) = os.path.split(timestamped_export_dir)
    temp_export_dir = os.path.join(
        compat.as_bytes(dirname),
        compat.as_bytes('temp-{}'.format(compat.as_str(basename))))
    return temp_export_dir
Esempio n. 40
0
def get_matching_files(filename):
  """Returns a list of files that match the given pattern.

  Args:
    filename: string, the pattern

  Returns:
    Returns a list of strings containing filenames that match the given pattern.

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
  with errors.raise_exception_on_not_ok_status() as status:
    # Convert each element to string, since the return values of the
    # vector of string should be interpreted as strings, not bytes.
    return [compat.as_str(matching_filename)
            for matching_filename in pywrap_tensorflow.GetMatchingFiles(
                compat.as_bytes(filename), status)]
Esempio n. 41
0
    def testReadWrite(self):
        with self.test_session() as sess:
            contents = "ASDASDASDASDASDAS"
            filename = "iptf://repo/root/foo"
            meta_filename = "iptf://meta/repo/root/foo"

            wf = io_ops.write_file(filename=constant_op.constant(filename),
                                   contents=constant_op.constant(contents))
            reader = io_ops.WholeFileReader("test_reader")
            queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            queue.enqueue_many([[filename]]).run()
            queue.close().run()
            with sess.graph.control_dependencies([wf]):
                key, value = sess.run(reader.read(queue))
            self.assertEqual(key, compat.as_bytes(filename))
            self.assertEqual(value, compat.as_bytes(contents))

            queue2 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            queue2.enqueue_many([[meta_filename]]).run()
            queue2.close().run()
            key, value = sess.run(reader.read(queue2))

            d = json.loads(compat.as_str(value))
            ipfs_path = d["IpfsPath"]
            queue3 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            queue3.enqueue_many([[ipfs_path]]).run()
            queue3.close().run()
            with sess.graph.control_dependencies([wf]):
                key, value = sess.run(reader.read(queue3))
            self.assertEqual(key, compat.as_bytes(ipfs_path))
            self.assertEqual(value, compat.as_bytes(contents))

            with gfile.Open(meta_filename, "wb") as f:
                f.write(compat.as_bytes('{"command": "publish"}'))

            ipns_path = d["IpnsPath"]
            queue4 = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
            queue4.enqueue_many([[ipns_path]]).run()
            queue4.close().run()
            with sess.graph.control_dependencies([wf]):
                key, value = sess.run(reader.read(queue4))
            self.assertEqual(key, compat.as_bytes(ipns_path))
            self.assertEqual(value, compat.as_bytes(contents))
def canonicalize_signatures(signatures):
    """Converts `signatures` into a dictionary of concrete functions."""
    if signatures is None:
        return {}
    if not isinstance(signatures, collections.Mapping):
        signatures = {
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures
        }
    concrete_signatures = {}
    for signature_key, function in signatures.items():
        signature_function = _get_signature(function)
        if signature_function is None:
            raise ValueError((
                "Expected a TensorFlow function to generate a signature for, but "
                "got {}. Only `tf.functions` with an input signature or "
                "concrete functions can be used as a signature."
            ).format(function))

        # Re-wrap the function so that it only takes keyword arguments and it
        # returns a dictionary of Tensors. This matches the format of 1.x-style
        # signatures.
        # pylint: disable=cell-var-from-loop
        @def_function.function
        def signature_wrapper(**kwargs):
            structured_outputs = signature_function(**kwargs)
            return _normalize_outputs(structured_outputs,
                                      signature_function.name, signature_key)

        # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
        # always match keyword arguments.
        tensor_spec_signature = {}
        for keyword, tensor in zip(
                signature_function._arg_keywords,  # pylint: disable=protected-access
                signature_function.inputs):
            keyword = compat.as_str(keyword)
            tensor_spec_signature[
                keyword] = tensor_spec.TensorSpec.from_tensor(tensor,
                                                              name=keyword)
        concrete_signatures[signature_key] = (
            signature_wrapper.get_concrete_function(**tensor_spec_signature))
        # pylint: enable=cell-var-from-loop
    return concrete_signatures
Esempio n. 43
0
def imperative_grad(
    tape,
    target,
    sources,
    output_gradients=None,
    unconnected_gradients=gradients_impl.UnconnectedGradients.NONE):
  """Computes gradients from the imperatively defined tape on top of the stack.

  Works by filtering the tape, computing how many downstream usages are of each
  tensor and entry, and repeatedly applying backward functions until we have
  gradients for all sources.

  Args:
   tape: the gradient tape which stores the trace.
   target: either a Tensor or list of Tensors to be differentiated.
   sources: list of Tensors for which we want gradients
   output_gradients: if not None, a list of gradient provided for each Target,
    or None if we are to use the target's computed downstream gradient.
   unconnected_gradients: determines the value returned if the target and
    sources are unconnected. When 'none' the value returned is None wheras when
    'zero' a zero tensor in the same shape as the sources is returned.

  Returns:
   the gradient wrt each of the sources.

  Raises:
    ValueError: if the arguments are invalid.
    RuntimeError: if something goes wrong.
  """
  try:
    unconnected_gradients = gradients_impl.UnconnectedGradients(
        unconnected_gradients)
  except ValueError:
    raise ValueError(
        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)

  return pywrap_tensorflow.TFE_Py_TapeGradient(
      tape._tape,  # pylint: disable=protected-access
      target,
      sources,
      output_gradients,
      compat.as_str(unconnected_gradients.value))
Esempio n. 44
0
def get_matching_files(filename):
    """Returns a list of files that match the given pattern.

  Args:
    filename: string, the pattern

  Returns:
    Returns a list of strings containing filenames that match the given pattern.

  Raises:
    errors.OpError: If there are filesystem / directory listing errors.
  """
    with errors.raise_exception_on_not_ok_status() as status:
        # Convert each element to string, since the return values of the
        # vector of string should be interpreted as strings, not bytes.
        return [
            compat.as_str(matching_filename)
            for matching_filename in pywrap_tensorflow.GetMatchingFiles(
                compat.as_bytes(filename), status)
        ]
Esempio n. 45
0
    def _recreate_base_user_object(self, proto):
        revived_classes = {
            '_tf_keras_layer': (RevivedLayer, base_layer.Layer),
            '_tf_keras_network': (RevivedNetwork, network_lib.Network),
            '_tf_keras_model': (RevivedModel, training_lib.Model),
            '_tf_keras_sequential': (RevivedSequential, models_lib.Sequential)
        }

        parent_classes = revived_classes.get(proto.identifier, None)

        if parent_classes is not None:
            parent_classes = revived_classes[proto.identifier]
            metadata = json.loads(proto.metadata)
            revived_cls = type(compat.as_str(metadata['class_name']),
                               parent_classes,
                               {'__setattr__': parent_classes[1].__setattr__})
            obj = revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
            return obj, revived_cls._revive_setter  # pylint: disable=protected-access

        return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
def _clean_save_and_restore(graph_def, op, removed_op_names):
  """Clean the specified save and restore op.

  Updates the dtypes attribute of the save / restore op and the associated name
  and shape tensors to remove entries for variables that have been removed.

  Args:
    graph_def: A GraphDef proto to be transformed.
    op: The save or restore op to update.
    removed_op_names: List of op names that have been removed.
  """
  name = op.name + '/tensor_names'
  shape = op.name + '/shape_and_slices'
  name_op = _find_op(graph_def, name)
  shape_op = _find_op(graph_def, shape)
  name_op_value_tensor = name_op.attr['value'].tensor
  shape_op_value_tensor = shape_op.attr['value'].tensor
  names = []
  shapes = []
  dtypes = []
  for index, value in enumerate(name_op_value_tensor.string_val):
    if not _is_removed(compat.as_str(value), removed_op_names):
      names.append(value)
      shapes.append(shape_op_value_tensor.string_val[index])
      dtypes.append(op.attr['dtypes'].list.type[index])
  name_op_value_tensor.string_val[:] = names
  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
  shape_op_value_tensor.string_val[:] = shapes
  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
  op.attr['dtypes'].list.type[:] = dtypes

  if not name_op.attr['_output_shapes'].list.shape:
    name_op.attr['_output_shapes'].list.shape.add()
    name_op.attr['_output_shapes'].list.shape[0].dim.add()
  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)

  if not shape_op.attr['_output_shapes'].list.shape:
    shape_op.attr['_output_shapes'].list.shape.add()
    shape_op.attr['_output_shapes'].list.shape[0].dim.add()
  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
Esempio n. 47
0
def revive_custom_object(identifier, metadata):
    """Revives object from SavedModel."""
    if ops.executing_eagerly_outside_functions():
        model_class = training_lib.Model
    else:
        model_class = training_lib_v1.Model

    revived_classes = {
        '_tf_keras_layer': (RevivedLayer, base_layer.Layer),
        '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
        '_tf_keras_network': (RevivedNetwork, network_lib.Network),
        '_tf_keras_model': (RevivedNetwork, model_class),
        '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential)
    }

    parent_classes = revived_classes.get(identifier, None)

    if parent_classes is not None:
        parent_classes = revived_classes[identifier]
        revived_cls = type(compat.as_str(metadata['class_name']),
                           parent_classes, {})
        return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
Esempio n. 48
0
  def _revive_graph_network(self, identifier, metadata, node_id):
    """Revives a graph network from config."""
    # Determine whether the metadata contains information for reviving a
    # functional or Sequential model.
    config = metadata.get('config')
    if not generic_utils.validate_config(config):
      return None

    class_name = compat.as_str(metadata['class_name'])
    if generic_utils.get_registered_object(class_name) is not None:
      return None
    model_is_functional_or_sequential = (
        metadata.get('is_graph_network', False) or
        class_name == 'Sequential' or
        class_name == 'Functional')
    if not model_is_functional_or_sequential:
      return None

    # Revive functional and sequential models as blank model objects for now (
    # must be initialized to enable setattr tracking and attribute caching).
    # Reconstruction of the network is deferred until all of the model's layers
    # have been revived.
    if class_name == 'Sequential':
      model = models_lib.Sequential(name=config['name'])
    # The model is a custom Sequential model.
    elif identifier == constants.SEQUENTIAL_IDENTIFIER:
      # Uses the custom class name, since the config does not have one.
      model = models_lib.Sequential(name=class_name)
    else:
      model = models_lib.Functional(
          inputs=[], outputs=[], name=config['name'])

    # Record this model and its layers. This will later be used to reconstruct
    # the model.
    layers = self._get_child_layer_node_ids(node_id)
    self.model_layer_dependencies[node_id] = (model, layers)
    if not layers:
      self._models_to_reconstruct.append(node_id)
    return model
Esempio n. 49
0
    def _revive_graph_network(self, metadata, node_id):
        """Revives a graph network from config."""
        class_name = compat.as_str(metadata['class_name'])
        config = metadata.get('config')

        # Determine whether the metadata contains information for reviving a
        # functional or Sequential model.
        model_is_functional_or_sequential = (
            metadata.get('is_graph_network', False)
            or metadata['class_name'] == 'Sequential'
            or metadata['class_name'] == 'Functional')
        if not (
                generic_utils.validate_config(config)
                and model_is_functional_or_sequential
        ) or generic_utils.get_registered_object(class_name) is not None:
            # Model should not be revived as a graph network. Try reviving directly
            # from config or as a custom model.
            return None

        # Revive functional and sequential models as blank model objects for now (
        # must be initialized to enable setattr tracking and attribute caching).
        # Reconstruction of the network is deferred until all of the model's layers
        # have been revived.
        if class_name == 'Sequential':
            model = models_lib.Sequential(name=config['name'])
        else:
            model = models_lib.Functional(inputs=[],
                                          outputs=[],
                                          name=config['name'])

        # Record this model and its layers. This will later be used to reconstruct
        # the model.
        layers = self._get_child_layer_node_ids(node_id, model.name)
        self.model_layer_dependencies[node_id] = (model, layers)
        if not layers:
            self._models_to_reconstruct.append(node_id)
        return model
Esempio n. 50
0
def get_timestamped_export_dir(export_dir_base):
    """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).

  Raises:
    RuntimeError: if repeated attempts fail to obtain a unique timestamped
      directory name.
  """
    attempts = 0
    while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
        timestamp = int(time.time())

        result_dir = os.path.join(compat.as_bytes(export_dir_base),
                                  compat.as_bytes(str(timestamp)))
        if not gfile.Exists(result_dir):
            # Collisions are still possible (though extremely unlikely): this
            # directory is not actually created yet, but it will be almost
            # instantly on return from this function.
            return result_dir
        time.sleep(1)
        attempts += 1
        logging.warn(
            'Directory {} already exists; retrying (attempt {}/{})'.format(
                compat.as_str(result_dir), attempts,
                MAX_DIRECTORY_CREATION_ATTEMPTS))
    raise RuntimeError('Failed to obtain a unique export directory name after '
                       '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
Esempio n. 51
0
  def _recreate_base_user_object(self, proto):
    if ops.executing_eagerly_outside_functions():
      model_class = training_lib.Model
    else:
      model_class = training_lib_v1.Model

    revived_classes = {
        '_tf_keras_layer': (RevivedLayer, base_layer.Layer),
        '_tf_keras_input_layer': (RevivedInputLayer, input_layer.InputLayer),
        '_tf_keras_network': (RevivedNetwork, network_lib.Network),
        '_tf_keras_model': (RevivedNetwork, model_class),
        '_tf_keras_sequential': (RevivedNetwork, models_lib.Sequential)
    }

    parent_classes = revived_classes.get(proto.identifier, None)

    if parent_classes is not None:
      parent_classes = revived_classes[proto.identifier]
      metadata = json.loads(proto.metadata)
      revived_cls = type(
          compat.as_str(metadata['class_name']), parent_classes, {})
      return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access

    return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
Esempio n. 52
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Esempio n. 53
0
def meta_graph_transform(base_meta_graph_def,
                         input_names,
                         output_names,
                         transforms,
                         tags,
                         checkpoint_path=None):
    """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' transform.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
    meta_graph_def = _meta_graph_pb2.MetaGraphDef()

    initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

    transformed_graph_def, updated_initializer_names = _do_transforms(
        base_meta_graph_def.graph_def, input_names, output_names,
        initializer_names, transforms, base_meta_graph_def.saver_def,
        checkpoint_path)

    meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
    meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
    meta_graph_def.meta_info_def.ClearField('tags')
    for tag in tags:
        meta_graph_def.meta_info_def.tags.append(tag)

    base_op_names = [
        compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node
    ]
    retained_op_names = [
        compat.as_str(node.name) for node in meta_graph_def.graph_def.node
    ]
    removed_op_names = set(base_op_names) - set(retained_op_names)

    # Copy saver, excluding any pruned nodes if graph was not frozen.
    # TODO(b/63447631): Revisit this once the problem is addressed. Currently
    # _add_pruned_saver assumes that the save and restore nodes have not been
    # removed but freeze_graph (correctly) removes them.
    if _FREEZE_GRAPH_TRANSFORM not in transforms:
        _add_pruned_saver(base_meta_graph_def, meta_graph_def,
                          removed_op_names)

    # Copy collections, excluding any pruned nodes
    for collection_name in base_meta_graph_def.collection_def:
        _add_pruned_collection(base_meta_graph_def, meta_graph_def,
                               collection_name, removed_op_names)

    # Append newly added initalizers to collection.
    _add_new_inits_to_collection(meta_graph_def, updated_initializer_names)

    # Copy signature_defs, excluding any pruned nodes
    for signature_name in base_meta_graph_def.signature_def:
        _add_pruned_signature(base_meta_graph_def, meta_graph_def,
                              signature_name, removed_op_names)

    return meta_graph_def
Esempio n. 54
0
  def __init__(self,
               tpu=None,
               zone=None,
               project=None,
               job_name='worker',
               coordinator_name=None,
               coordinator_address=None,
               credentials='default',
               service=None,
               discovery_url=None):
    """Creates a new TPUClusterResolver object.

    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
    for the IP addresses and ports of each Cloud TPU listed.

    Args:
      tpu: Either a string, or a list of strings corresponding to the TPUs to
        use. If the single string is the empty string, the string 'local', or a
        string that begins with 'grpc://' or '/bns', then it is assumed to not
        correspond with a Cloud TPU and will instead be passed as the session
        master and no ClusterSpec propagation will be done.
      zone: Zone where the TPUs are located. If omitted or empty, we will assume
        that the zone of the TPU is the same as the zone of the GCE VM, which we
        will try to discover from the GCE metadata service.
      project: Name of the GCP project containing Cloud TPUs. If omitted or
        empty, we will try to discover the project name of the GCE VM from the
        GCE metadata service.
      job_name: Name of the TensorFlow job the TPUs belong to.
      coordinator_name: The name to use for the coordinator. Set to None if the
        coordinator should not be included in the computed ClusterSpec.
      coordinator_address: The address of the coordinator (typically an ip:port
        pair). If set to None, a TF server will be started. If coordinator_name
        is None, a TF server will not be started even if coordinator_address is
        None.
      credentials: GCE Credentials. If None, then we use default credentials
        from the oauth2client
      service: The GCE API object returned by the googleapiclient.discovery
        function. If you specify a custom service object, then the credentials
        parameter will be ignored.
      discovery_url: A URL template that points to the location of
        the discovery service. It should have two parameters {api} and
        {apiVersion} that when filled in produce an absolute URL to the
        discovery document for that service. The environment variable
        'TPU_API_DISCOVERY_URL' will override this.

    Raises:
      ImportError: If the googleapiclient is not installed.
      ValueError: If no TPUs are specified.
    """
    if isinstance(tpu, list):
      if not tpu:
        raise ValueError('At least one TPU must be specified.')
      if len(tpu) != 1:
        raise NotImplementedError(
            'Using multiple TPUs in a single session is not yet implemented')
      tpu = tpu[0]

    in_gke = self._inGke()
    # When using GKE with Cloud TPUs, the env variable will be set.
    if tpu is None:
      if in_gke:
        tpu = self._gkeEndpoints()
      else:
        tpu = self._envVarFallback()

    if tpu is None:
      raise ValueError('Please provide a TPU Name to connect to.')

    self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
    self._job_name = job_name
    self._credentials = credentials

    should_resolve = self._shouldResolve()

    if not project and should_resolve:
      project = compat.as_str(
          self._requestComputeMetadata('project/project-id'))

    if not zone and should_resolve:
      zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
      zone = zone_path.split('/')[-1]

    self._project = project
    self._zone = zone

    if credentials == 'default' and should_resolve:
      if _GOOGLE_API_CLIENT_INSTALLED:
        self._credentials = GoogleCredentials.get_application_default()

    if service is None and should_resolve:
      if not _GOOGLE_API_CLIENT_INSTALLED:
        raise ImportError('googleapiclient and oauth2client must be installed '
                          'before using the TPU cluster resolver. Execute: '
                          '`pip install --upgrade google-api-python-client` '
                          'and `pip install --upgrade oauth2client` to '
                          'install with pip.')

      final_discovery_url = self._discoveryUrl() or discovery_url
      if final_discovery_url:
        self._service = discovery.build(
            'tpu', 'v1alpha1',
            credentials=self._credentials,
            discoveryServiceUrl=final_discovery_url)
      else:
        self._service = discovery.build(
            'tpu', 'v1alpha1',
            credentials=self._credentials)
    else:
      self._service = service

    self._coordinator_name = coordinator_name
    if coordinator_name and not coordinator_address and (should_resolve or
                                                         in_gke):
      self._start_local_server()
    else:
      self._coordinator_address = coordinator_address
Esempio n. 55
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

  # LINT.IfChange
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # TODO(ashankar): Should this just copy over or should it do some
    # more nuanced merging? For example, the graph may already have some
    # marked "bad versions" and we don't want to lose those because of
    # what's in graph_def.versions? The C++ ImporGraphDef does something
    # more nuanced.
    g.graph_def_versions.CopyFrom(graph_def.versions)

    if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
      if not scope:
        # The caller must have passed `name=''`.
        raise ValueError(
            'tf.import_graph_def() requires a non-empty `name` if `input_map` '
            'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
            '`input_map` values before calling tf.import_graph_def().')
      with ops.name_scope('_inputs'):
        input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      if node.op not in op_dict:
        raise ValueError('No op named %s in defined operations.' % node.op)
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)
      if producer_op_dict:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
          producer_op_def = producer_op_dict[node.op]
          # We make a copy of node.attr to iterate through since we
          # may modify node.attr inside the loop.
          for key in list(node.attr):
            if _FindAttrInOpDef(key, op_def) is None:
              # No attr_def in consumer, look in producer.
              attr_def = _FindAttrInOpDef(key, producer_op_def)
              if (attr_def and attr_def.HasField('default_value') and
                  node.attr[key] == attr_def.default_value):
                # Unknown attr had default value in producer, delete it
                # so it can be understood by consumer.
                del node.attr[key]

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
      op = name_to_op[node.name]
      input_types = _InputTypes(node, op_dict)

      # Rewrite the colocation attributes in the graph, since the
      # names of new ops may have changed.
      for key, value in op.node_def.attr.items():
        if key == '_class':
          class_values = value.list
          new_class_values = []
          for class_value in class_values.s:
            if class_value.startswith(b'loc:@'):
              op_to_bind_to = class_value[5:].decode()
              # Find the op by its original name.
              if op_to_bind_to not in name_to_op:
                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' % (
                                     op_to_bind_to, node.name))
              original_op = name_to_op[op_to_bind_to]
              new_class_values.append(compat.as_bytes(
                  'loc:@' + original_op.name))
            else:
              new_class_values.append(class_value)
          value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
              s=new_class_values))

      # NOTE(mrry): We cannot use zip here because control inputs do not appear
      # in the list of input_types.
      for i, input_name in enumerate(
          [_CanonicalInputName(x) for x in node.input]):

        if _IsControlInput(input_name):
          # (a) Input is a control input that should be taken from an op
          #     in "graph_def".
          try:
            source_op = name_to_op[input_name[1:]]
          except KeyError:
            raise ValueError(
                _InvalidNodeMessage(
                    node,
                    'Control input %r not found in graph_def.' % (input_name,)))
          # pylint: disable=protected-access
          op._add_control_input(source_op)
          # pylint: enable=protected-access

        else:
          try:
            input_type = input_types[i]
          except IndexError:
            raise ValueError(_InvalidNodeMessage(
                node, 'More inputs specified (%r) than the op expects.'
                % (input_name,)))

          if input_name in input_map:
            # (b) Input should be replaced by a tensor from the caller.
            source_tensor = input_map[input_name]
            used_input_keys.add(input_name)

          else:
            # (c) Input should be taken from an op in `graph_def`.
            operation_name, output_index = _ParseTensorName(input_name)
            try:
              source_op = name_to_op[operation_name]
              source_tensor = list(source_op.values())[output_index]
            except (KeyError, IndexError):
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Input tensor %r not found in graph_def.'
                      % (input_name,)))

          try:
            # pylint: disable=protected-access
            op._add_input(source_tensor, dtype=input_type)
            # pylint: enable=protected-access
          except TypeError as te:
            raise ValueError(_InvalidNodeMessage(
                node, 'Input tensor %r %s' % (input_name, te)))

      # pylint: disable=protected-access
      if op._input_dtypes != input_types:
        raise ValueError(
            _InvalidNodeMessage(
                node,
                'Input types mismatch (expected %r but got %r)'
                % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                   ', '.join(x.name for x in op._input_dtypes))))
      # pylint: enable=protected-access

      if not g._is_function(op.type):  # pylint: disable=protected-access
        # Execute shape inference for this op.
        # NOTE(mrry): If the graph contains a cycle, the full shape information
        # may not be available for this op's inputs.
        ops.set_shapes_for_outputs(op)
      # For nodes with _output_shapes set, set the output shapes.
      if '_output_shapes' in op.node_def.attr:
        for i, output in enumerate(op.outputs):
          dims = op.node_def.attr['_output_shapes'].list.shape[i]
          output_shape = tensor_shape.TensorShape(
              None if dims.unknown_rank else
              [dim.size if dim.size >= 0 else None for dim in dims.dim])

          try:
            output.set_shape(output_shape)
          except ValueError as e:
            # If the output shape is incompatible with what is inferred
            # by the graph for a very specific whitelist of ops, then we
            # ignore this output shape.  This can happen if there is a
            # bug in the shape function for some operation, and the
            # serialized graph def has the incorrect shape set when
            # running on a newer binary with the fixed shape function.
            # This is an escape hatch that allows us to correct shape
            # functions that are not critical to correct execution but
            # would cause graphs to fail if imported after correcting.
            #
            # This can be removed after 2017/03/08.
            if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                           'FIFOQueue', 'PriorityQueue', 'QueueSize',
                           'Stack', 'Barrier', 'BarrierReadySize',
                           'BarrierIncompleteSize', 'HashTable',
                           'MutableHashTable',
                           'MutableHashTableOfTensors', 'Mutex',
                           'CuckooTable', 'IndexTable',
                           'WholeFileReader', 'TextLineReader',
                           'FixedLengthRecordReader',
                           'TFRecordReader', 'IdentityReader',
                           'RefSwitch', 'RefEnter', 'RefNextIteration',
                           'RefMerge', 'RefIdentity']:
              pass
            elif op.type in [
                'ConditionalAccumulator', 'SparseConditionalAccumulator',
                'Table'
            ]:
              # This can be removed after 2017/04/24.
              pass
            else:
              raise e

        del op.node_def.attr['_output_shapes']

      # Apply device functions for this op.
      # NOTE(mrry): We do this after configuring the inputs, because
      # the result of the device functions may depend on the inputs.
      with _MaybeDevice(node.device):
        g._apply_device_functions(op)  # pylint: disable=protected-access

    # Treat unused input mappings as an error, because they are likely to be
    # due to a typo.
    unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys)
    if unused_input_keys:
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(unused_input_keys))

    if return_elements is None:
      return None
    else:
      ret = []
      for name in return_elements:
        name = compat.as_str(name)
        if ':' in name:
          try:
            operation_name, output_index = _ParseTensorName(name)
            ret.append(name_to_op[operation_name].outputs[output_index])
          except (ValueError, KeyError, IndexError):
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
        else:
          try:
            ret.append(name_to_op[name])
          except KeyError:
            raise ValueError(
                'Requested return_element %r not found in graph_def.' % name)
      return ret
Esempio n. 56
0
def _CanonicalInputName(input_name):
  input_name = compat.as_str(input_name)
  if _IsControlInput(input_name):
    return input_name
  input_op_name, output_index = _ParseTensorName(input_name)
  return '%s:%d' % (input_op_name, output_index)
Esempio n. 57
0
def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
    # they are likely to be due to a typo.
    missing_unused_input_keys = (
        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
            results))
    if missing_unused_input_keys:
      missing_unused_input_keys = [compat.as_str(s)
                                   for s in missing_unused_input_keys]
      raise ValueError(
          'Attempted to map inputs that were not found in graph_def: [%s]'
          % ', '.join(missing_unused_input_keys))

    if return_elements is None:
      return None
    else:
      return _GatherReturnElements(return_elements, graph, results)

  else:
    g = graph

    # Use a canonical representation for all tensor names.
    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
    used_input_keys = set()
    name_to_op = {}

    # Add any functions defined in `graph_def` to `g`
    if graph_def.library and graph_def.library.function:
      # Copy op_dict so we don't clobber the original
      op_dict = copy.copy(op_dict)
      # pylint: disable=protected-access
      # Note that we do not prepend `name` to the function name. The reasoning
      # is that function names are similar to op definition names, which
      # currently do not have a scoped name or namespace scheme.
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(g)
        op_dict[f.name] = f.definition.signature
      # pylint: enable=protected-access

    # LINT.IfChange
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # TODO(ashankar): Should this just copy over or should it do some
      # more nuanced merging? For example, the graph may already have some
      # marked "bad versions" and we don't want to lose those because of
      # what's in graph_def.versions? The C++ ImporGraphDef does something
      # more nuanced.
      g.graph_def_versions.CopyFrom(graph_def.versions)

      input_map = _ConvertInputMapValues(name, input_map)

      # NOTE(mrry): We do this in two passes, because there may be a cycle in
      # `graph_def`.

      # 1. Add operations without their inputs.
      for node in graph_def.node:
        # Check to see if this op's name matches a previously seen op
        if node.name in name_to_op:
          raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
        # Set any default attr values that aren't present.
        if node.op not in op_dict:
          raise ValueError('No op named %s in defined operations.' % node.op)
        op_def = op_dict[node.op]
        for attr_def in op_def.attr:
          key = attr_def.name
          if attr_def.HasField('default_value'):
            value = node.attr[key]
            if value is None or value.WhichOneof('value') is None:
              node.attr[key].CopyFrom(attr_def.default_value)

        output_types = _OutputTypes(node, op_dict)
        name_to_op[node.name] = g.create_op(
            node.op, [], output_types, name=node.name, attrs=node.attr,
            compute_shapes=False, compute_device=False,
            op_def=op_def)

      # Maps from a node to the ops it is colocated with, if colocation
      # is specified in the attributes.
      colocation_pairs = collections.defaultdict(list)

      # 2. Add inputs to the operations.
      for node in graph_def.node:
        op = name_to_op[node.name]
        input_types = _InputTypes(node, op_dict)
        apply_device_function = True

        # Rewrite the colocation attributes in the graph, since the
        # names of new ops may have changed.
        for key, value in op.node_def.attr.items():
          if key == '_class':
            class_values = value.list
            new_class_values = []
            for class_value in class_values.s:
              if class_value.startswith(b'loc:@'):
                op_to_bind_to = class_value[5:].decode()
                # Find the op by its original name.
                if op_to_bind_to not in name_to_op:
                  raise ValueError('Specified colocation to an op that '
                                   'does not exist during import: %s in %s' % (
                                       op_to_bind_to, node.name))
                original_op = name_to_op[op_to_bind_to]
                new_class_values.append(compat.as_bytes(
                    'loc:@' + original_op.name))
                if op_to_bind_to != node.name:
                  # Keep track of this mapping for a later phase.
                  colocation_pairs[op].append(original_op)
                  # Don't apply this op's device function,
                  # the colocation constraint will ensure
                  # the proper device gets assigned at runtime.
                  apply_device_function = False

              else:
                new_class_values.append(class_value)
            value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
                s=new_class_values))

        # NOTE(mrry): We cannot use zip here because control inputs do not
        # appear in the list of input_types.
        for i, input_name in enumerate(
            [_CanonicalInputName(x) for x in node.input]):

          if _IsControlInput(input_name):
            # (a) Input is a control input that should be taken from an op
            #     in "graph_def".
            try:
              source_op = name_to_op[input_name[1:]]
            except KeyError:
              raise ValueError(
                  _InvalidNodeMessage(
                      node,
                      'Control input %r not found in graph_def.'
                      % (input_name,)))
            # pylint: disable=protected-access
            op._add_control_input(source_op)
            # pylint: enable=protected-access

          else:
            try:
              input_type = input_types[i]
            except IndexError:
              raise ValueError(_InvalidNodeMessage(
                  node, 'More inputs specified (%r) than the op expects.'
                  % (input_name,)))

            if input_name in input_map:
              # (b) Input should be replaced by a tensor from the caller.
              source_tensor = input_map[input_name]
              used_input_keys.add(input_name)

            else:
              # (c) Input should be taken from an op in `graph_def`.
              operation_name, output_index = _ParseTensorName(input_name)
              try:
                source_op = name_to_op[operation_name]
                source_tensor = list(source_op.values())[output_index]
              except (KeyError, IndexError):
                raise ValueError(
                    _InvalidNodeMessage(
                        node,
                        'Input tensor %r not found in graph_def.'
                        % (input_name,)))

            try:
              # pylint: disable=protected-access
              op._add_input(source_tensor, dtype=input_type)
              # pylint: enable=protected-access
            except TypeError as te:
              raise ValueError(_InvalidNodeMessage(
                  node, 'Input tensor %r %s' % (input_name, te)))

        # pylint: disable=protected-access
        if op._input_types != input_types:
          raise ValueError(
              _InvalidNodeMessage(
                  node,
                  'Input types mismatch (expected %r but got %r)'
                  % (', '.join(dtypes.as_dtype(x).name for x in input_types),
                     ', '.join(x.name for x in op._input_types))))
        # pylint: enable=protected-access

        if not g._is_function(op.type):  # pylint: disable=protected-access
          # Execute shape inference for this op.
          # NOTE(mrry): If the graph contains a cycle, the full shape
          # information may not be available for this op's inputs.
          ops.set_shapes_for_outputs(op)
        # For nodes with _output_shapes set, set the output shapes.
        if '_output_shapes' in op.node_def.attr:
          for i, output in enumerate(op.outputs):
            dims = op.node_def.attr['_output_shapes'].list.shape[i]
            output_shape = tensor_shape.TensorShape(
                None if dims.unknown_rank else
                [dim.size if dim.size >= 0 else None for dim in dims.dim])

            try:
              output.set_shape(output_shape)
            except ValueError as e:
              # If the output shape is incompatible with what is inferred
              # by the graph for a very specific whitelist of ops, then we
              # ignore this output shape.  This can happen if there is a
              # bug in the shape function for some operation, and the
              # serialized graph def has the incorrect shape set when
              # running on a newer binary with the fixed shape function.
              # This is an escape hatch that allows us to correct shape
              # functions that are not critical to correct execution but
              # would cause graphs to fail if imported after correcting.
              #
              # This can be removed after 2017/03/08.
              if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
                             'FIFOQueue', 'PriorityQueue', 'QueueSize',
                             'Stack', 'Barrier', 'BarrierReadySize',
                             'BarrierIncompleteSize', 'HashTable',
                             'MutableHashTable',
                             'MutableHashTableOfTensors', 'Mutex',
                             'CuckooTable', 'IndexTable',
                             'WholeFileReader', 'TextLineReader',
                             'FixedLengthRecordReader',
                             'TFRecordReader', 'IdentityReader',
                             'LMDBReader',
                             'RefSwitch', 'RefEnter', 'RefNextIteration',
                             'RefMerge', 'RefIdentity']:
                pass
              elif op.type in [
                  'ConditionalAccumulator', 'SparseConditionalAccumulator',
                  'Table'
              ]:
                # This can be removed after 2017/04/24.
                pass
              else:
                raise e

          del op.node_def.attr['_output_shapes']

        # NOTE(mrry): We do this after configuring the inputs, because
        # the result of the device functions may depend on the inputs.
        if apply_device_function:
          with _MaybeDevice(node.device):
            g._apply_device_functions(op)  # pylint: disable=protected-access

      # The following loop populates the device field of ops that are
      # colocated with another op.  This is implied by the colocation
      # attribute, but we propagate the device field for completeness.
      for op, coloc_op_list in colocation_pairs.items():
        coloc_device = None
        # Find any device in the list of colocated ops that have a
        # device, if it exists.  We assume that if multiple ops
        # have devices, they refer to the same device.  Otherwise, a
        # runtime error will occur since the colocation property
        # cannot be guaranteed.
        #
        # One possible improvement is to try to check for compatibility
        # of all devices in this list at import time here, which would
        # require implementing a compatibility function for device specs
        # in python.
        for coloc_op in coloc_op_list:
          if coloc_op.device:
            coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
            break
        if coloc_device:
          op._set_device(coloc_device)  # pylint: disable=protected-access

      # Treat input mappings that don't appear in the graph as an error,
      # because they are likely to be due to a typo.
      def _IsImportedNodeOutput(tensor_name):
        operation_name, output_index = _ParseTensorName(tensor_name)
        try:
          return output_index < len(name_to_op[operation_name].outputs)
        except KeyError:
          return False
      absent_input_keys = [
          k for k in frozenset(input_map.keys()).difference(used_input_keys)
          if not _IsImportedNodeOutput(k)]
      if absent_input_keys:
        raise ValueError(
            'Attempted to map inputs that were not found in graph_def: [%s]'
            % ', '.join(absent_input_keys))

      if return_elements is None:
        return None
      else:
        ret = []
        for name in return_elements:
          name = compat.as_str(name)
          if ':' in name:
            try:
              operation_name, output_index = _ParseTensorName(name)
              ret.append(name_to_op[operation_name].outputs[output_index])
            except (ValueError, KeyError, IndexError):
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
          else:
            try:
              ret.append(name_to_op[name])
            except KeyError:
              raise ValueError(
                  'Requested return_element %r not found in graph_def.' % name)
        return ret
Esempio n. 58
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    temp_graph = func_graph_from_py_func(
        self._func, self._arg_names, self._arg_types, self._func_name,
        self._capture_by_value, self._caller_device)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = _get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]
Esempio n. 59
0
def meta_graph_transform(
    base_meta_graph_def, input_names, output_names, transforms, tags,
    checkpoint_path=None):
  """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' and
      'sparsify_gather' transforms.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
  meta_graph_def = _meta_graph_pb2.MetaGraphDef()

  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

  transformed_graph_def, updated_initializer_names = _do_transforms(
      base_meta_graph_def.graph_def, input_names, output_names,
      initializer_names, transforms, base_meta_graph_def.saver_def,
      checkpoint_path)

  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
  meta_graph_def.meta_info_def.ClearField('tags')
  for tag in tags:
    meta_graph_def.meta_info_def.tags.append(tag)

  base_op_names = [compat.as_str(node.name)
                   for node in base_meta_graph_def.graph_def.node]
  retained_op_names = [compat.as_str(node.name)
                       for node in meta_graph_def.graph_def.node]
  removed_op_names = set(base_op_names) - set(retained_op_names)

  # Copy saver, excluding any pruned nodes if graph was not frozen.
  # TODO(b/63447631): Revisit this once the problem is addressed. Currently
  # _add_pruned_saver assumes that the save and restore nodes have not been
  # removed but freeze_graph (correctly) removes them.
  if _FREEZE_GRAPH_TRANSFORM not in transforms:
    _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

  # Copy collections, excluding any pruned nodes
  for collection_name in base_meta_graph_def.collection_def:
    _add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name,
        removed_op_names)

  # Append newly added initializers to collection.
  _add_new_inits_to_collection(meta_graph_def, updated_initializer_names)

  # Copy signature_defs, excluding any pruned nodes
  for signature_name in base_meta_graph_def.signature_def:
    _add_pruned_signature(
        base_meta_graph_def, meta_graph_def, signature_name,
        removed_op_names)

  return meta_graph_def
Esempio n. 60
0
  def __init__(self,
               tpu=None,
               zone=None,
               project=None,
               job_name='worker',
               coordinator_name=None,
               coordinator_address=None,
               credentials='default',
               service=None,
               discovery_url=None):
    """Creates a new TPUClusterResolver object.

    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
    for the IP addresses and ports of each Cloud TPU listed.

    Args:
      tpu: A string corresponding to the TPU to use. If the string is the empty
        string, the string 'local', or a string that begins with 'grpc://' or
        '/bns', then it is assumed to not correspond with a Cloud TPU and will
        instead be passed as the session master and no ClusterSpec propagation
        will be done. In the future, this may also support a list of strings
        when multiple Cloud TPUs are used.
      zone: Zone where the TPUs are located. If omitted or empty, we will assume
        that the zone of the TPU is the same as the zone of the GCE VM, which we
        will try to discover from the GCE metadata service.
      project: Name of the GCP project containing Cloud TPUs. If omitted or
        empty, we will try to discover the project name of the GCE VM from the
        GCE metadata service.
      job_name: Name of the TensorFlow job the TPUs belong to.
      coordinator_name: The name to use for the coordinator. Set to None if the
        coordinator should not be included in the computed ClusterSpec.
      coordinator_address: The address of the coordinator (typically an ip:port
        pair). If set to None, a TF server will be started. If coordinator_name
        is None, a TF server will not be started even if coordinator_address is
        None.
      credentials: GCE Credentials. If None, then we use default credentials
        from the oauth2client
      service: The GCE API object returned by the googleapiclient.discovery
        function. If you specify a custom service object, then the credentials
        parameter will be ignored.
      discovery_url: A URL template that points to the location of
        the discovery service. It should have two parameters {api} and
        {apiVersion} that when filled in produce an absolute URL to the
        discovery document for that service. The environment variable
        'TPU_API_DISCOVERY_URL' will override this.

    Raises:
      ImportError: If the googleapiclient is not installed.
      ValueError: If no TPUs are specified.
      RuntimeError: If an empty TPU name is specified and this is running in a
        Google Cloud environment.
    """
    if isinstance(tpu, list):
      if not tpu:
        raise ValueError('At least one TPU must be specified.')
      if len(tpu) != 1:
        raise NotImplementedError(
            'Using multiple TPUs in a single session is not yet implemented')
      tpu = tpu[0]

    in_gke = self._inGke()
    # When using GKE with Cloud TPUs, the env variable will be set.
    if tpu is None:
      if in_gke:
        tpu = self._gkeEndpoints()
      else:
        tpu = self._envVarFallback()

    if tpu is None:
      raise ValueError('Please provide a TPU Name to connect to.')

    self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes

    # If we are running in Cloud and don't specify a TPU name
    if self._isRunningInGCE() and not self._tpu:
      raise RuntimeError('You need to specify a TPU Name if you are running in '
                         'the Google Cloud environment.')

    # By default the task_type is 'worker` and the task_id is 0 (which is the
    # first worker in the task).
    self.task_type = job_name
    self.task_id = 0

    if tpu.startswith('grpc://'):
      # Cloud environment, where we are using GRPC to communicate to TPUs.
      self._environment = ''
    elif tpu == 'local' or not tpu:
      # Google environment, where the TPU is attached to the host.
      self._environment = 'google'
    elif tpu.startswith('/bns') or tpu.startswith('uptc://'):
      # Google environment, where we reach the TPU through BNS.
      self._environment = 'google'

    # If TPU is in the Google environment or exists locally, we don't use any
    # RPC layer.
    if tpu.startswith('/bns') or tpu.startswith(
        'uptc://') or tpu == 'local' or not tpu:
      self.rpc_layer = None
    else:
      self.rpc_layer = 'grpc'

    # Setting this overrides the return value of self._shouldResolve()
    self._should_resolve_override = None

    # We strip out the protocol if it is included, and override the
    # shouldResolve function to never resolve. We are adding the protocol back
    # in later in self.master().
    if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
      tpu = tpu[len(self.rpc_layer + '://'):]
      self._tpu = tpu
      self._should_resolve_override = False

    # Whether we should actually attempt to contact Cloud APIs
    should_resolve = self._shouldResolve()

    # We error out if we are in a non-Cloud environment which cannot talk to the
    # Cloud APIs using the standard class and a special object is not passed in.
    self._service = service
    if (self._service is None and should_resolve and
        not _GOOGLE_API_CLIENT_INSTALLED):
      raise ImportError('googleapiclient and oauth2client must be installed '
                        'before using the TPU cluster resolver. Execute: '
                        '`pip install --upgrade google-api-python-client` '
                        'and `pip install --upgrade oauth2client` to '
                        'install with pip.')

    # We save user-passed credentials, unless the user didn't pass in anything.
    self._credentials = credentials
    if (credentials == 'default' and should_resolve and
        _GOOGLE_API_CLIENT_INSTALLED):
      self._credentials = None

    # Automatically detect project and zone if unspecified.
    if not project and should_resolve:
      project = compat.as_str(
          self._requestComputeMetadata('project/project-id'))
    if not zone and should_resolve:
      zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
      zone = zone_path.split('/')[-1]
    self._project = project
    self._zone = zone

    self._discovery_url = self._environmentDiscoveryUrl() or discovery_url

    self._coordinator_name = coordinator_name
    if (coordinator_name and not coordinator_address and
        (should_resolve or in_gke)):
      self._start_local_server()
    else:
      self._coordinator_address = coordinator_address