Esempio n. 1
0
    def testFromLibraryCyclicGradFuncs(self):
        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        @function.Defun(dtypes.float32)
        def F2(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        # Create invalid function def library where F1 has gradient function F2 and
        # F2 has gradient function F1
        library = function_pb2.FunctionDefLibrary()
        library.function.extend([F1.definition, F2.definition])

        gradient1 = function_pb2.GradientDef()
        gradient1.function_name = F1.name
        gradient1.gradient_func = F2.name

        gradient2 = function_pb2.GradientDef()
        gradient2.function_name = F2.name
        gradient2.gradient_func = F1.name

        library.gradient.extend([gradient1, gradient2])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary contains cyclic gradient functions!"):
            function._from_library(library)
Esempio n. 2
0
    def testFromLibraryMissingFuncDef(self):
        @function.Defun(dtypes.float32, dtypes.float32)
        def G1(x, dy):
            return x * dy

        @function.Defun(dtypes.float32)
        def F1(x):
            return math_ops.exp(x) - math_ops.exp(-x)

        gradient = function_pb2.GradientDef()
        gradient.function_name = F1.name
        gradient.gradient_func = G1.name

        # Create invalid function def that is missing G1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([F1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'G1_........' FunctionDef"):
            function._from_library(library)

        # Create invalid function def that is missing F1 function def
        library = function_pb2.FunctionDefLibrary()
        library.gradient.extend([gradient])
        library.function.extend([G1.definition])

        with self.assertRaisesRegexp(
                ValueError,
                "FunctionDefLibrary missing 'F1_........' FunctionDef"):
            function._from_library(library)
def function_def_to_graph_def(fdef, input_shapes=None):
  """Convert a FunctionDef to a GraphDef.

  Steps:
  1. Creates placeholder nodes corresponding to inputs in
     `FunctionDef.signature.input_arg`.
  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
  3. Renames inputs of all nodes to use the convention of GraphDef instead of
     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
     in FunctionDefs is different from GraphDefs.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. If specified, its length must match length of
      `fdef.signature.input_arg`. If a shape is None, the corresponding input
      placeholder will have unknown shape.

  Returns:
    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).

  Raises:
    ValueError: If the length of input_shapes does not match the number of
      input_args or if the FunctionDef is invalid.
  """
  graph_def = graph_pb2.GraphDef()
  graph_def.versions.CopyFrom(
      versions_pb2.VersionDef(
          producer=versions.GRAPH_DEF_VERSION,
          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))

  default_graph = ops.get_default_graph()

  copied_functions = set()

  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
    raise ValueError("Length of `input_shapes` must match the number "
                     f"of `input_arg`s in `fdef`. Got "
                     f"{len(input_shapes)} `input_shapes` and "
                     f"{len(fdef.signature.input_arg)} `input_arg`s.")

  # 1. Create placeholders for input nodes.
  for i, arg_def in enumerate(fdef.signature.input_arg):
    node_def = graph_def.node.add()
    node_def.name = arg_def.name
    node_def.op = "Placeholder"
    node_def.attr["dtype"].type = arg_def.type
    if input_shapes and input_shapes[i] is not None:
      input_shape = input_shapes[i]
      if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
        input_shape = input_shape.as_proto()
      node_def.attr["shape"].shape.CopyFrom(input_shape)
    arg_attrs = fdef.arg_attr[i].attr
    for k in arg_attrs:
      # Only copy internal attributes. Normal attributes for nodes cannot be
      # applied to these Placeholder nodes.
      if k == "_output_shapes":
        if arg_attrs[k].WhichOneof("value") == "list":
          node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0])
        elif arg_attrs[k].WhichOneof("value") == "shape":
          node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape)
      elif k.startswith("_"):
        node_def.attr[k].CopyFrom(arg_attrs[k])

  # 2. Copy all body NodeDefs to the GraphDef.
  graph_def.node.extend(fdef.node_def)

  # 3. Perform the renaming.

  # Build the tensor name mapping then flatten the tensor names.
  # See comment on `FunctionDef.node_def` on how the tensor naming in
  # FunctionDefs is different from GraphDefs.
  nested_to_flat_tensor_name = {}

  for arg_def in fdef.signature.input_arg:
    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
    control_name = "^" + arg_def.name
    nested_to_flat_tensor_name[control_name] = control_name

  for node_def in fdef.node_def:
    graph = default_graph
    while True:
      f = graph._functions.get(node_def.op, None)  # pylint: disable=protected-access
      if f is not None or not hasattr(graph, "outer_graph"):
        break
      graph = graph.outer_graph

    if f is not None:
      fdef = f.definition
      op_def = fdef.signature
      if node_def.op not in copied_functions:
        # Since this function is referenced as an op type, we have no choice but
        # to copy it into the GraphDef if we want downstream tools to process
        # it.
        graph_def.library.function.add().CopyFrom(fdef)
        copied_functions.add(node_def.op)
        if f.grad_func_name:
          grad_def = function_pb2.GradientDef()
          grad_def.function_name = f.name
          grad_def.gradient_func = f.grad_func_name
          graph_def.library.gradient.extend([grad_def])
    else:
      op_def = default_graph._get_op_def(node_def.op)  # pylint: disable=protected-access

    for attr in op_def.attr:
      if attr.type == "func":
        fname = node_def.attr[attr.name].func.name
        if not is_function(fname):
          raise ValueError(f"Function {fname} was not found. Please make sure "
                           "the FunctionDef `fdef` is correct.")
      elif attr.type == "list(func)":
        for fn in node_def.attr[attr.name].list.func:
          fname = fn.name
          if not is_function(fname):
            raise ValueError(f"Function {fname} was not found. Please make "
                             "sure the FunctionDef `fdef` is correct.")

    # Iterate over output_args in op_def to build the map.
    # Index of the output tensor in the flattened list of *all* output
    # tensors of the op.
    flattened_index = 0
    for arg_def in op_def.output_arg:
      num_args = _get_num_args(arg_def, node_def)
      for i in range(num_args):
        # Map tensor names from "node_name:output_arg_name:index" to
        # "node_name:flattened_index".
        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
        flat_name = "{}:{}".format(node_def.name, flattened_index)
        nested_to_flat_tensor_name[nested_name] = flat_name
        flattened_index += 1
    control_name = "^" + node_def.name
    nested_to_flat_tensor_name[control_name] = control_name

  # Update inputs of all nodes in graph.
  for node_def in graph_def.node:
    for i in range(len(node_def.input)):
      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]

  return graph_def, nested_to_flat_tensor_name