def testFunctionCallsFromFunction(self): x = constant_op.constant(5.0) y = constant_op.constant(10.0) @function.defun def fn(): @function.defun def inner_fn(): return x + y return inner_fn() def fn2(): return 2 * fn() fn2_defun = function.make_defun_op(fn2) # Call `fn2` to make sure `fn` is correctly instantiated so # `function_def_to_graph` can find it. fn2_defun() fdef = fn2_defun._inference_function.definition func_graph = function_def_to_graph.function_def_to_graph(fdef) with func_graph.as_default(): x_ph, y_ph = func_graph.inputs with self.session(graph=func_graph) as sess: self.assertEqual( sess.run(func_graph.outputs[0], feed_dict={ x_ph: 5.0, y_ph: 10.0 }), 30.0)
def testDefunOpGraphModeNoneOutput(self): def fn(unused_a, unused_b): return None x = constant_op.constant(1) fn_op = function.make_defun_op(fn, x, x) self.assertEqual(fn_op.output_dtypes, None) self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None)
def testDefunOpGraphModeNoneOutput(self): def fn(unused_a, unused_b): return None x = constant_op.constant(1) fn_op = function.make_defun_op(fn, x, x) self.assertEqual(fn_op.output_dtypes, None) self.assertEqual(fn_op.output_shapes, None) self.assertAllEqual(fn_op(x, x), None)
def validate(indexed_slice): def f(): return indexed_slice output = function.defun(f)() self.assertTrue(isinstance(output, ops.IndexedSlices)) self.assertAllEqual(indexed_slice.values, output.values) self.assertAllEqual(indexed_slice.indices, output.indices) self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) self.assertEqual( function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
def validate(indexed_slice): def f(): return indexed_slice output = function.defun(f)() self.assertTrue(isinstance(output, ops.IndexedSlices)) self.assertAllEqual(indexed_slice.values, output.values) self.assertAllEqual(indexed_slice.indices, output.indices) self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) self.assertEqual( function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) def sq(a): return matmul(a, a) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = function.make_defun_op(sq, t) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testBasicDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) def sq(a): return matmul(a, a) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = function.make_defun_op(sq, t) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testDefunOpGraphModeWithGradients(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') def step(): def inner(): return v * v return backprop.implicit_grad(inner)()[0][0] step_op = function.make_defun_op(step) self.assertEqual(step_op.output_dtypes, dtypes.float32) self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([])) self.assertAllEqual(step_op(), 2.0)
def testDefunOpGraphModeWithGradients(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') def step(): def inner(): return v * v return backprop.implicit_grad(inner)()[0][0] step_op = function.make_defun_op(step) self.assertEqual(step_op.output_dtypes, dtypes.float32) self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([])) self.assertAllEqual(step_op(), 2.0)
def testNestedInputsDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) pair = collections.namedtuple('pair', ['a', 'b']) def a_times_b(inputs): return matmul(inputs.a['a'], inputs.b['b']) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) inputs = pair({'a': t}, {'b': t}) sq_op = function.make_defun_op(a_times_b, inputs) self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) out = sq_op(inputs) self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) def sq(a): return (matmul(a, a), {'b': constant_op.constant(1.0)}) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = function.make_defun_op(sq, t) self.assertEqual(sq_op.output_shapes, (tensor_shape.TensorShape([2, 2]), {'b': tensor_shape.TensorShape([])})) self.assertEqual(sq_op.output_dtypes, (dtypes.float32, {'b': dtypes.float32})) (a, b) = sq_op(t) self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) self.assertAllEqual(b['b'].numpy(), 1.0)
def testNestedOutputDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) def sq(a): return (matmul(a, a), {'b': constant_op.constant(1.0)}) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq_op = function.make_defun_op(sq, t) self.assertEqual(sq_op.output_shapes, (tensor_shape.TensorShape([2, 2]), {'b': tensor_shape.TensorShape([])})) self.assertEqual(sq_op.output_dtypes, (dtypes.float32, {'b': dtypes.float32})) (a, b) = sq_op(t) self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) self.assertAllEqual(b['b'].numpy(), 1.0)
def testControlDependencies(self): def fn(inp): x = constant_op.constant(2.0, name="x") # TODO(b/79881896): Test external control dependency once that's # supported. with ops.control_dependencies([x, inp]): constant_op.constant(3.0, name="y") return 4.0 inp = constant_op.constant(1.0) fdef = function.make_defun_op(fn, inp)._inference_function.definition func_graph = function_def_to_graph.function_def_to_graph(fdef) op = func_graph.get_operation_by_name("y") self.assertEqual(len(op.control_inputs), 2) self.assertEqual(op.control_inputs[0].name, "x") self.assertEqual(op.control_inputs[1].name, "placeholder")
def execute(self, fn, *args, **kwargs): """Execute function `fn(*args, **kwargs)` inside the CriticalSection. Args: fn: The function to execute. Must return at least one tensor. *args: Additional positional arguments to `fn`. **kwargs: Additional keyword arguments to `fn`. Several keywords are reserved for `execute`. These are: - name; The name to use when creating the execute operation. - exclusive_resource_access; Whether the resources required by `fn` should be exclusive to this `CriticalSection`. Default: `True`. You may want to set this to `False` if you will be accessing a resource in read-only mode in two different CriticalSections. Returns: The tensors returned from `fn(*args, **kwargs)`. Raises: ValueError: If `fn` attempts to use this `CriticalSection` in any nested way. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same resources as in `*args`, `**kwargs`, and any additionaly captured inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, if another execution in another `CriticalSection` was created without `exclusive_resource_access=True`, a `ValueError` will be raised. """ name = kwargs.pop("name", None) exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) args = nest.map_structure(ops.convert_to_tensor, args) with ops.name_scope(name, "critical_section_execute", []): fn_op = function.make_defun_op(fn, *args, **kwargs) flat_dtypes = nest.flatten(fn_op.output_dtypes) flat_shapes = nest.flatten(fn_op.output_shapes) all_inputs = nest.flatten(args) + fn_op.captured_inputs if self._handle in all_inputs: raise ValueError("The function fn attempts to access the " "CriticalSection in which it would be running. This " "is illegal and would cause deadlocks. " "CriticalSection: %s." % self._handle) if context.in_graph_mode(): # Collections and op introspection does not work in eager # mode. This is generally ok; since eager mode (as of # writing) executes sequentially anyway. all_input_resources = [ x for x in all_inputs if x.dtype == dtypes.resource] for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): if sg.op.inputs[0].name == self._handle.name: # Other executions in the same critical section are allowed. continue if not (exclusive_resource_access or sg.exclusive_resource_access): # Neither execution requested exclusive access. continue sg_input_names = [y.name for y in sg.op.inputs[1:]] for res in all_input_resources: if res.name in sg_input_names: raise ValueError( "This execution would access resource %s; but either this " "execution (CriticalSection: %s) or Execution '%s' " "(CriticalSection: %s) requested exclusive resource access " "of this resource for their critical section. Did you mean " "to call execute with keyword argument " "exclusive_resource_access=False?" % (res.name, self.name, sg.op.name, sg.op.inputs[0].op.name)) flat_outputs = gen_resource_variable_ops.execute_in_critical_section( critical_section=self._handle, arguments=all_inputs, f=fn_op, output_types=flat_dtypes, output_shapes=flat_shapes) if context.in_graph_mode(): if isinstance(flat_outputs, ops.Operation): flat_outputs = [flat_outputs] op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) else flat_outputs[0]) signature = _ExecutionSignature( op=op, exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) return (flat_outputs[0] if (len(flat_outputs) == 1 and isinstance(flat_outputs[0], ops.Operation)) else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
def execute(self, fn, *args, **kwargs): """Execute function `fn(*args, **kwargs)` inside the CriticalSection. Args: fn: The function to execute. Must return at least one tensor. *args: Additional positional arguments to `fn`. **kwargs: Additional keyword arguments to `fn`. Several keywords are reserved for `execute`. These are: - name; The name to use when creating the execute operation. - exclusive_resource_access; Whether the resources required by `fn` should be exclusive to this `CriticalSection`. Default: `True`. You may want to set this to `False` if you will be accessing a resource in read-only mode in two different CriticalSections. Returns: The tensors returned from `fn(*args, **kwargs)`. Raises: ValueError: If `fn` attempts to use this `CriticalSection` in any nested way. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same resources as in `*args`, `**kwargs`, and any additionaly captured inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, if another execution in another `CriticalSection` was created without `exclusive_resource_access=True`, a `ValueError` will be raised. """ name = kwargs.pop("name", None) exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) args = nest.map_structure(ops.convert_to_tensor, args) with ops.name_scope(name, "critical_section_execute", []): fn_op = function.make_defun_op(fn, *args, **kwargs) flat_dtypes = nest.flatten(fn_op.output_dtypes) flat_shapes = nest.flatten(fn_op.output_shapes) all_inputs = nest.flatten(args) + fn_op.captured_inputs if self._handle in all_inputs: raise ValueError( "The function fn attempts to access the " "CriticalSection in which it would be running. This " "is illegal and would cause deadlocks. " "CriticalSection: %s." % self._handle) if context.in_graph_mode(): # Collections and op introspection does not work in eager # mode. This is generally ok; since eager mode (as of # writing) executes sequentially anyway. all_input_resources = [ x for x in all_inputs if x.dtype == dtypes.resource ] for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): if sg.op.inputs[0].name == self._handle.name: # Other executions in the same critical section are allowed. continue if not (exclusive_resource_access or sg.exclusive_resource_access): # Neither execution requested exclusive access. continue sg_input_names = [y.name for y in sg.op.inputs[1:]] for res in all_input_resources: if res.name in sg_input_names: raise ValueError( "This execution would access resource %s; but either this " "execution (CriticalSection: %s) or Execution '%s' " "(CriticalSection: %s) requested exclusive resource access " "of this resource for their critical section. Did you mean " "to call execute with keyword argument " "exclusive_resource_access=False?" % (res.name, self.name, sg.op.name, sg.op.inputs[0].op.name)) flat_outputs = gen_resource_variable_ops.execute_in_critical_section( critical_section=self._handle, arguments=all_inputs, f=fn_op, output_types=flat_dtypes, output_shapes=flat_shapes) if context.in_graph_mode(): if isinstance(flat_outputs, ops.Operation): flat_outputs = [flat_outputs] op = (flat_outputs[0].op if isinstance( flat_outputs[0], ops.Tensor) else flat_outputs[0]) signature = _ExecutionSignature( op=op, exclusive_resource_access=exclusive_resource_access) ops.add_to_collections(CRITICAL_SECTION_EXECUTIONS, signature) return (flat_outputs[0] if (len(flat_outputs) == 1 and isinstance(flat_outputs[0], ops.Operation)) else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))