Пример #1
0
  def testFunctionCallsFromFunction(self):
    x = constant_op.constant(5.0)
    y = constant_op.constant(10.0)

    @function.defun
    def fn():

      @function.defun
      def inner_fn():
        return x + y

      return inner_fn()

    def fn2():
      return 2 * fn()

    fn2_defun = function.make_defun_op(fn2)

    # Call `fn2` to make sure `fn` is correctly instantiated so
    # `function_def_to_graph` can find it.
    fn2_defun()

    fdef = fn2_defun._inference_function.definition
    func_graph = function_def_to_graph.function_def_to_graph(fdef)
    with func_graph.as_default():
      x_ph, y_ph = func_graph.inputs
      with self.session(graph=func_graph) as sess:
        self.assertEqual(
            sess.run(func_graph.outputs[0], feed_dict={
                x_ph: 5.0,
                y_ph: 10.0
            }), 30.0)
Пример #2
0
  def testDefunOpGraphModeNoneOutput(self):
    def fn(unused_a, unused_b):
      return None

    x = constant_op.constant(1)
    fn_op = function.make_defun_op(fn, x, x)

    self.assertEqual(fn_op.output_dtypes, None)
    self.assertEqual(fn_op.output_shapes, None)
    self.assertAllEqual(fn_op(x, x), None)
Пример #3
0
    def testDefunOpGraphModeNoneOutput(self):
        def fn(unused_a, unused_b):
            return None

        x = constant_op.constant(1)
        fn_op = function.make_defun_op(fn, x, x)

        self.assertEqual(fn_op.output_dtypes, None)
        self.assertEqual(fn_op.output_shapes, None)
        self.assertAllEqual(fn_op(x, x), None)
Пример #4
0
    def validate(indexed_slice):
      def f():
        return indexed_slice

      output = function.defun(f)()
      self.assertTrue(isinstance(output, ops.IndexedSlices))
      self.assertAllEqual(indexed_slice.values, output.values)
      self.assertAllEqual(indexed_slice.indices, output.indices)
      self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)

      self.assertEqual(
          function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
    def validate(indexed_slice):
      def f():
        return indexed_slice

      output = function.defun(f)()
      self.assertTrue(isinstance(output, ops.IndexedSlices))
      self.assertAllEqual(indexed_slice.values, output.values)
      self.assertAllEqual(indexed_slice.indices, output.indices)
      self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)

      self.assertEqual(
          function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
Пример #6
0
  def testBasicDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return matmul(a, a)

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(t)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
Пример #7
0
    def testBasicDefunOpGraphMode(self):
        matmul = function.defun(math_ops.matmul)

        def sq(a):
            return matmul(a, a)

        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

        sq_op = function.make_defun_op(sq, t)

        self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
        out = sq_op(t)
        self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
Пример #8
0
  def testDefunOpGraphModeWithGradients(self):
    v = resource_variable_ops.ResourceVariable(1.0, name='v')

    def step():
      def inner():
        return v * v

      return backprop.implicit_grad(inner)()[0][0]

    step_op = function.make_defun_op(step)

    self.assertEqual(step_op.output_dtypes, dtypes.float32)
    self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
    self.assertAllEqual(step_op(), 2.0)
Пример #9
0
    def testDefunOpGraphModeWithGradients(self):
        v = resource_variable_ops.ResourceVariable(1.0, name='v')

        def step():
            def inner():
                return v * v

            return backprop.implicit_grad(inner)()[0][0]

        step_op = function.make_defun_op(step)

        self.assertEqual(step_op.output_dtypes, dtypes.float32)
        self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
        self.assertAllEqual(step_op(), 2.0)
  def testNestedInputsDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    pair = collections.namedtuple('pair', ['a', 'b'])
    def a_times_b(inputs):
      return matmul(inputs.a['a'], inputs.b['b'])

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    inputs = pair({'a': t}, {'b': t})
    sq_op = function.make_defun_op(a_times_b, inputs)

    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
    out = sq_op(inputs)
    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
Пример #11
0
  def testNestedOutputDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return (matmul(a, a), {'b': constant_op.constant(1.0)})

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes,
                     (tensor_shape.TensorShape([2, 2]),
                      {'b': tensor_shape.TensorShape([])}))
    self.assertEqual(sq_op.output_dtypes,
                     (dtypes.float32, {'b': dtypes.float32}))
    (a, b) = sq_op(t)
    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
    self.assertAllEqual(b['b'].numpy(), 1.0)
  def testNestedOutputDefunOpGraphMode(self):
    matmul = function.defun(math_ops.matmul)

    def sq(a):
      return (matmul(a, a), {'b': constant_op.constant(1.0)})

    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])

    sq_op = function.make_defun_op(sq, t)

    self.assertEqual(sq_op.output_shapes,
                     (tensor_shape.TensorShape([2, 2]),
                      {'b': tensor_shape.TensorShape([])}))
    self.assertEqual(sq_op.output_dtypes,
                     (dtypes.float32, {'b': dtypes.float32}))
    (a, b) = sq_op(t)
    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
    self.assertAllEqual(b['b'].numpy(), 1.0)
Пример #13
0
  def testControlDependencies(self):

    def fn(inp):
      x = constant_op.constant(2.0, name="x")
      # TODO(b/79881896): Test external control dependency once that's
      # supported.
      with ops.control_dependencies([x, inp]):
        constant_op.constant(3.0, name="y")
      return 4.0

    inp = constant_op.constant(1.0)
    fdef = function.make_defun_op(fn, inp)._inference_function.definition
    func_graph = function_def_to_graph.function_def_to_graph(fdef)

    op = func_graph.get_operation_by_name("y")
    self.assertEqual(len(op.control_inputs), 2)
    self.assertEqual(op.control_inputs[0].name, "x")
    self.assertEqual(op.control_inputs[1].name, "placeholder")
  def execute(self, fn, *args, **kwargs):
    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    name = kwargs.pop("name", None)
    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)

    args = nest.map_structure(ops.convert_to_tensor, args)
    with ops.name_scope(name, "critical_section_execute", []):
      fn_op = function.make_defun_op(fn, *args, **kwargs)
      flat_dtypes = nest.flatten(fn_op.output_dtypes)
      flat_shapes = nest.flatten(fn_op.output_shapes)
      all_inputs = nest.flatten(args) + fn_op.captured_inputs
      if self._handle in all_inputs:
        raise ValueError("The function fn attempts to access the "
                         "CriticalSection in which it would be running.  This "
                         "is illegal and would cause deadlocks.  "
                         "CriticalSection: %s." % self._handle)

      if context.in_graph_mode():
        # Collections and op introspection does not work in eager
        # mode.  This is generally ok; since eager mode (as of
        # writing) executes sequentially anyway.
        all_input_resources = [
            x for x in all_inputs if x.dtype == dtypes.resource]
        for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
          if sg.op.inputs[0].name == self._handle.name:
            # Other executions in the same critical section are allowed.
            continue
          if not (exclusive_resource_access or sg.exclusive_resource_access):
            # Neither execution requested exclusive access.
            continue
          sg_input_names = [y.name for y in sg.op.inputs[1:]]
          for res in all_input_resources:
            if res.name in sg_input_names:
              raise ValueError(
                  "This execution would access resource %s; but either this "
                  "execution (CriticalSection: %s) or Execution '%s' "
                  "(CriticalSection: %s) requested exclusive resource access "
                  "of this resource for their critical section.  Did you mean "
                  "to call execute with keyword argument "
                  "exclusive_resource_access=False?"
                  % (res.name,
                     self.name,
                     sg.op.name,
                     sg.op.inputs[0].op.name))

      flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
          critical_section=self._handle,
          arguments=all_inputs,
          f=fn_op,
          output_types=flat_dtypes,
          output_shapes=flat_shapes)

      if context.in_graph_mode():
        if isinstance(flat_outputs, ops.Operation):
          flat_outputs = [flat_outputs]
        op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
              else flat_outputs[0])
        signature = _ExecutionSignature(
            op=op,
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return (flat_outputs[0]
              if (len(flat_outputs) == 1
                  and isinstance(flat_outputs[0], ops.Operation))
              else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
    def execute(self, fn, *args, **kwargs):
        """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
        name = kwargs.pop("name", None)
        exclusive_resource_access = kwargs.pop("exclusive_resource_access",
                                               True)

        args = nest.map_structure(ops.convert_to_tensor, args)
        with ops.name_scope(name, "critical_section_execute", []):
            fn_op = function.make_defun_op(fn, *args, **kwargs)
            flat_dtypes = nest.flatten(fn_op.output_dtypes)
            flat_shapes = nest.flatten(fn_op.output_shapes)
            all_inputs = nest.flatten(args) + fn_op.captured_inputs
            if self._handle in all_inputs:
                raise ValueError(
                    "The function fn attempts to access the "
                    "CriticalSection in which it would be running.  This "
                    "is illegal and would cause deadlocks.  "
                    "CriticalSection: %s." % self._handle)

            if context.in_graph_mode():
                # Collections and op introspection does not work in eager
                # mode.  This is generally ok; since eager mode (as of
                # writing) executes sequentially anyway.
                all_input_resources = [
                    x for x in all_inputs if x.dtype == dtypes.resource
                ]
                for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
                    if sg.op.inputs[0].name == self._handle.name:
                        # Other executions in the same critical section are allowed.
                        continue
                    if not (exclusive_resource_access
                            or sg.exclusive_resource_access):
                        # Neither execution requested exclusive access.
                        continue
                    sg_input_names = [y.name for y in sg.op.inputs[1:]]
                    for res in all_input_resources:
                        if res.name in sg_input_names:
                            raise ValueError(
                                "This execution would access resource %s; but either this "
                                "execution (CriticalSection: %s) or Execution '%s' "
                                "(CriticalSection: %s) requested exclusive resource access "
                                "of this resource for their critical section.  Did you mean "
                                "to call execute with keyword argument "
                                "exclusive_resource_access=False?" %
                                (res.name, self.name, sg.op.name,
                                 sg.op.inputs[0].op.name))

            flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
                critical_section=self._handle,
                arguments=all_inputs,
                f=fn_op,
                output_types=flat_dtypes,
                output_shapes=flat_shapes)

            if context.in_graph_mode():
                if isinstance(flat_outputs, ops.Operation):
                    flat_outputs = [flat_outputs]
                op = (flat_outputs[0].op if isinstance(
                    flat_outputs[0], ops.Tensor) else flat_outputs[0])
                signature = _ExecutionSignature(
                    op=op, exclusive_resource_access=exclusive_resource_access)
                ops.add_to_collections(CRITICAL_SECTION_EXECUTIONS, signature)

            return (flat_outputs[0] if
                    (len(flat_outputs) == 1
                     and isinstance(flat_outputs[0], ops.Operation)) else
                    nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))