Exemplo n.º 1
0
    def _create_definition_if_needed(self):
        """Creates the function definition if it's not created yet."""

        if self._definition is not None:
            return

        # Create the func_def object.
        temp_graph = _ExperimentalFuncGraph(
            capture_by_value=self._capture_by_value)
        with temp_graph.as_default():
            # List of placeholders for the function_def.
            inputs = []
            for (argname, argtype) in self._args:
                argholder = array_ops.placeholder(argtype, name=argname)
                inputs.append(argholder)
            # Call func and gather the output tensors.
            with vs.variable_scope("", custom_getter=temp_graph.getvar):
                outputs = self._func(*inputs)
            # If func only returned one value, make it a tuple.
            if not isinstance(outputs, (list, tuple)):
                outputs = (outputs, )
            if any([_ is None for _ in outputs]):
                raise ValueError("Function can not return None.")
            # Ensures each output is a Tensor.
            outputs = [ops.convert_to_tensor(_) for _ in outputs]
        self._extra_inputs = temp_graph.extra_inputs
        inputs.extend(temp_graph.extra_args)
        self._sub_functions = temp_graph._functions

        # Build the FunctionDef
        self._definition = function._graph_to_function_def(
            temp_graph,
            temp_graph.get_operations(),
            inputs,
            outputs,
            out_names=self._out_names)

        # Extra kwargs are treated as attrs on the function def.
        sig_pre_func_name = self._func_name or function._get_func_name(
            self._func)
        kwargs_attr = function._parse_kwargs_as_attrs(sig_pre_func_name,
                                                      **self._extra_kwargs)
        for k in kwargs_attr:
            self._definition.attr[k].CopyFrom(kwargs_attr[k])

        # Hash the definition and its dependencies.
        self._hash_str = self._create_hash_str(
            self._definition.signature.input_arg,
            self._definition.signature.output_arg, self._definition.node_def)

        # Finally, we decide the function name to use.  If not specified,
        # make up something which is almost certainly unique (but deterministic).
        if not self._func_name:
            self._func_name = "_".join(
                [function._get_func_name(self._func), self._hash_str])
        self._definition.signature.name = self._func_name
        if self._func.__doc__:
            self._definition.signature.description = self._func.__doc__
Exemplo n.º 2
0
  def _create_definition_if_needed(self):
    """Creates the function definition if it's not created yet."""

    if self._definition is not None:
      return

    # Create the func_def object.
    temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    self._sub_functions = temp_graph._functions

    # Build the FunctionDef
    self._definition = function._graph_to_function_def(
        temp_graph, temp_graph.get_operations(), inputs, outputs,
        out_names=self._out_names)

    # Extra kwargs are treated as attrs on the function def.
    sig_pre_func_name = self._func_name or function._get_func_name(self._func)
    kwargs_attr = function._parse_kwargs_as_attrs(
        sig_pre_func_name, **self._extra_kwargs)
    for k in kwargs_attr:
      self._definition.attr[k].CopyFrom(kwargs_attr[k])

    # Hash the definition and its dependencies.
    self._hash_str = self._create_hash_str(
        self._definition.signature.input_arg,
        self._definition.signature.output_arg,
        self._definition.node_def)

    # Finally, we decide the function name to use.  If not specified,
    # make up something which is almost certainly unique (but deterministic).
    if not self._func_name:
      self._func_name = "_".join([function._get_func_name(self._func),
                                  self._hash_str])
    self._definition.signature.name = self._func_name
    if self._func.__doc__:
      self._definition.signature.description = self._func.__doc__
Exemplo n.º 3
0
 def testTwoInputsSameOp(self):
   g = ops.Graph()
   with g.as_default():
     m = array_ops.placeholder(dtypes.float32)
     s, u, v = linalg_ops.svd(m)
     ss = math_ops.reduce_sum(s)
     uu = math_ops.reduce_sum(u)
     vv = math_ops.reduce_sum(v)
     result = ss + uu + vv
   f = function._graph_to_function_def(
       g,
       g.get_operations()[1:],  # skip the placeholder
       [s, u, v],
       [result])
   self.assertEqual(len(f.signature.input_arg), 3)
Exemplo n.º 4
0
  def testBasic(self):
    g = tf.Graph()

    # Define a function
    #   foo(a:float, b:float, c:float)->u:float,v:float,w:float
    #     u = matmul(a, b) + c
    #     v = u^2
    #     w = u + v
    foo = tf.Graph()
    with foo.as_default():
      a = tf.placeholder(tf.float32, name="a")
      b = tf.placeholder(tf.float32, name="b")
      c = tf.placeholder(tf.float32, name="c")
      u = tf.add(tf.matmul(a, b), c, name="u")
      v = tf.square(u, name="v")
      w = tf.add_n([u, v], name="w")
    fdef = function._graph_to_function_def(foo, "foo", [a, b, c], [u, v, w])

    class Mock(function._DefinedFunction):

      def __init__(self, fdef):
        self._func_name = "foo"
        self._definition = fdef
        self._sub_functions = collections.OrderedDict()
        self._grad_func = None
        self._python_grad_func = None
        self._hash = hash(fdef.SerializeToString())

    g._add_function(Mock(fdef))

    # Compute 2 * 3 + 4 and its square.
    with g.as_default(), tf.Session() as sess:
      two = tf.constant(self._mat(2.0), name="two")
      three = tf.constant(self._mat(3.0), name="three")
      four = tf.constant(self._mat(4.0), name="four")
      # TODO(zhifengc): w/ @decorator sugar, we will just do:
      #   y, s, t = foo_func(two, three, four)

      # The graph contains two ops each of which calls foo.
      u0, v0, w0 = g.create_op(
          "foo", [two, three, four], [tf.float32, tf.float32, tf.float32],
          compute_shapes=False).outputs
      u1, v1, w1 = g.create_op(
          "foo", [four, two, three], [tf.float32, tf.float32, tf.float32],
          compute_shapes=False).outputs

      # Checks some property of the graph def.
      gdef = g.as_graph_def()
      self.assertEqual(len(gdef.node), 5)  # 5 nodes added.
      self.assertEqual(len(gdef.library.function), 1)  # 1 function is defined.

      for _ in xrange(10):
        # Run the graph, which is basically two function calls.
        ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run([u0, v0, w0,
                                                                   u1, v1, w1])
        self.assertAllEqual(ans_u0, self._mat(10.0))  # 2 * 3 + 4 = 10
        self.assertAllEqual(ans_v0, self._mat(100.0))  # 10^2 = 100
        self.assertAllEqual(ans_w0, self._mat(110.0))  # 100 + 10 = 110
        self.assertAllEqual(ans_u1, self._mat(11.0))  # 4 * 2 + 3 = 11
        self.assertAllEqual(ans_v1, self._mat(121.0))  # 11^2 = 121
        self.assertAllEqual(ans_w1, self._mat(132.0))  # 11 + 121 = 132
Exemplo n.º 5
0
    def testBasic(self):
        g = tf.Graph()

        # Define a function
        #   foo(a:float, b:float, c:float)->u:float,v:float,w:float
        #     u = matmul(a, b) + c
        #     v = u^2
        #     w = u + v
        foo = tf.Graph()
        with foo.as_default():
            a = tf.placeholder(tf.float32, name="a")
            b = tf.placeholder(tf.float32, name="b")
            c = tf.placeholder(tf.float32, name="c")
            u = tf.add(tf.matmul(a, b), c, name="u")
            v = tf.square(u, name="v")
            w = tf.add_n([u, v], name="w")
        fdef = function._graph_to_function_def(foo, "foo", [a, b, c],
                                               [u, v, w])

        class Mock(function._DefinedFunction):
            def __init__(self, fdef):
                self._func_name = "foo"
                self._definition = fdef
                self._sub_functions = collections.OrderedDict()
                self._grad_func = None
                self._python_grad_func = None
                self._hash = hash(fdef.SerializeToString())

        g._add_function(Mock(fdef))

        # Compute 2 * 3 + 4 and its square.
        with g.as_default(), tf.Session() as sess:
            two = tf.constant(self._mat(2.0), name="two")
            three = tf.constant(self._mat(3.0), name="three")
            four = tf.constant(self._mat(4.0), name="four")
            # TODO(zhifengc): w/ @decorator sugar, we will just do:
            #   y, s, t = foo_func(two, three, four)

            # The graph contains two ops each of which calls foo.
            u0, v0, w0 = g.create_op("foo", [two, three, four],
                                     [tf.float32, tf.float32, tf.float32],
                                     compute_shapes=False).outputs
            u1, v1, w1 = g.create_op("foo", [four, two, three],
                                     [tf.float32, tf.float32, tf.float32],
                                     compute_shapes=False).outputs

            # Checks some property of the graph def.
            gdef = g.as_graph_def()
            self.assertEqual(len(gdef.node), 5)  # 5 nodes added.
            self.assertEqual(len(gdef.library.function),
                             1)  # 1 function is defined.

            for _ in xrange(10):
                # Run the graph, which is basically two function calls.
                ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run(
                    [u0, v0, w0, u1, v1, w1])
                self.assertAllEqual(ans_u0, self._mat(10.0))  # 2 * 3 + 4 = 10
                self.assertAllEqual(ans_v0, self._mat(100.0))  # 10^2 = 100
                self.assertAllEqual(ans_w0, self._mat(110.0))  # 100 + 10 = 110
                self.assertAllEqual(ans_u1, self._mat(11.0))  # 4 * 2 + 3 = 11
                self.assertAllEqual(ans_v1, self._mat(121.0))  # 11^2 = 121
                self.assertAllEqual(ans_w1, self._mat(132.0))  # 11 + 121 = 132