예제 #1
0
    def testCondNested(self):
        with context.graph_mode(), self.cached_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            variables.global_variables_initializer().run()
            p = array_ops.placeholder(dtype=dtypes.bool)
            q = array_ops.placeholder(dtype=dtypes.bool)
            with acd.AutomaticControlDependencies() as c:

                def true_fn():
                    v.assign(v + 1, name='true')
                    return 1.0

                def false_fn():
                    def inner_true_fn():
                        v.assign(v * 2, name='false_true')
                        return 2.0

                    def inner_false_fn():
                        v.assign(v * 3, name='false_false')
                        return 3.0

                    control_flow_ops.cond(q, inner_true_fn, inner_false_fn)
                    return 1.0

                control_flow_ops.cond(p, true_fn, false_fn)
                with ops.name_scope('final'):
                    val = v.read_value()
                val = c.mark_as_return(val)
            self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0)
            self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0)
            self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0)
            self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
예제 #2
0
 def testVariableReadsInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             read_op = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
         self.assertIn(read_op, c.ops_which_must_run)
예제 #3
0
    def testSendInOpsWithMustRun(self):
        with context.graph_mode(), self.cached_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            self.evaluate(variables.global_variables_initializer())
            with acd.AutomaticControlDependencies() as c:
                send_op = gen_sendrecv_ops.send(v, "x", "/", 0, "/")

            # Send must be in `ops_which_must_run`.
            self.assertIn(send_op, c.ops_which_must_run)
예제 #4
0
    def __init__(self, use_name_scope, function_scope_name, use_auto_deps):
        self.use_name_scope = use_name_scope
        if use_name_scope:
            self.name_scope = ops.name_scope(function_scope_name)

        self.use_auto_deps = use_auto_deps
        if use_auto_deps:
            self.autodeps_scope = auto_control_deps.AutomaticControlDependencies(
            )
            self._return_value_marked = False
예제 #5
0
 def testVariableReadsInOpsWithMustRun(self):
   with context.graph_mode(), self.cached_session():
     v = resource_variable_ops.ResourceVariable(1.0)
     self.evaluate(variables.global_variables_initializer())
     with acd.AutomaticControlDependencies() as c:
       read_op = gen_resource_variable_ops.read_variable_op(v.handle,
                                                            v.dtype).op
       # Read ops get added to control outputs only if they have consumers.
       c.mark_as_return(read_op.outputs[0])
     self.assertIn(read_op, c.ops_which_must_run)
예제 #6
0
 def testBasic(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             v.assign(v + 1)
             v.assign(2 * v)
             val = v.read_value()
             val = c.mark_as_return(val)
         self.assertAllEqual(val.eval(), 4.0)
예제 #7
0
 def testNoControlDepsBetweenVariableReads(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
예제 #8
0
 def testIdentityPassThrough(self):
   with context.graph_mode(), self.cached_session():
     v = resource_variable_ops.ResourceVariable(1.0)
     self.evaluate(variables.global_variables_initializer())
     with acd.AutomaticControlDependencies():
       gen_resource_variable_ops.assign_variable_op(v.handle, v + 1)
       identity_handle = gen_array_ops.identity(v.handle)
       assign_op2 = gen_resource_variable_ops.assign_variable_op(
           v.handle, v + 1)
       read_op = gen_resource_variable_ops.read_variable_op(
           identity_handle, v.dtype).op
     # Read should have a control dep from second last write even
     # with Identity applied to resource.
     self.assertIn(assign_op2, read_op.control_inputs)
예제 #9
0
  def testVariableMultipleReadsAndWrites(self):
    with context.graph_mode(), self.cached_session():
      v = resource_variable_ops.ResourceVariable(1.0)
      self.evaluate(variables.global_variables_initializer())
      with acd.AutomaticControlDependencies() as c:
        # 2 reads -> 2 writes -> 2 reads -> 2 writes.
        read_op1 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op2 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op1 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op2 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        read_op3 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        read_op4 = gen_resource_variable_ops.read_variable_op(
            v.handle, v.dtype).op
        assign_op3 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        assign_op4 = gen_resource_variable_ops.assign_variable_op(
            v.handle, v + 1)
        # Read ops get added to control outputs only if they have consumers.
        c.mark_as_return(read_op1.outputs[0])
        c.mark_as_return(read_op2.outputs[0])
        c.mark_as_return(read_op3.outputs[0])
        c.mark_as_return(read_op4.outputs[0])

      # Verify the control edges.
      self.assertIn(read_op1, assign_op1.control_inputs)
      self.assertIn(read_op2, assign_op1.control_inputs)
      self.assertIn(assign_op1, assign_op2.control_inputs)
      self.assertIn(assign_op2, read_op3.control_inputs)
      self.assertIn(assign_op2, read_op4.control_inputs)
      self.assertIn(read_op3, assign_op3.control_inputs)
      self.assertIn(read_op4, assign_op3.control_inputs)
      self.assertIn(assign_op3, assign_op4.control_inputs)

      # There should be no control deps between reads.
      read_ops = [read_op1, read_op2, read_op3, read_op4]
      for src_op, tgt_op in itertools.product(read_ops, read_ops):
        self.assertNotIn(src_op, tgt_op.control_inputs)

      # Reads must be in `ops_which_must_run`.
      self.assertIn(read_op1, c.ops_which_must_run)
      self.assertIn(read_op2, c.ops_which_must_run)
      self.assertIn(read_op3, c.ops_which_must_run)
      self.assertIn(read_op4, c.ops_which_must_run)
      # Last write must be in `ops_which_must_run`.
      self.assertIn(assign_op4, c.ops_which_must_run)
예제 #10
0
 def testVariableWriteThenRead(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
         # Reads should have a control dep from the last write.
         self.assertIn(assign_op, read_op1.control_inputs)
         self.assertIn(assign_op, read_op2.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
예제 #11
0
 def testVariableReadsNotInOpsWithMustRun(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies() as c:
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Reads must not be in `ops_which_must_run` since those get added to the
         # `control_outputs`.
         self.assertNotIn(read_op1, c.ops_which_must_run)
         self.assertNotIn(read_op2, c.ops_which_must_run)
         # Last write must be in `ops_which_must_run`.
         self.assertIn(assign_op, c.ops_which_must_run)
예제 #12
0
 def testVariableReadThenWrite(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps from "all" reads since last write
         # or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # There should be no control deps between reads.
         self.assertNotIn(read_op1, read_op2.control_inputs)
         self.assertNotIn(read_op2, read_op1.control_inputs)
예제 #13
0
  def __init__(self, function_name, scope_name, options):
    self.name = scope_name
    self.options = options

    if options.user_requested:
      self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
                                                   options)
    self.callopts = options.call_options()

    use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
    self.use_name_scope = use_name_scope
    if use_name_scope:
      self.name_scope = ops.name_scope(self._sanitize(function_name))

    use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
    self.use_auto_deps = use_auto_deps
    if use_auto_deps:
      self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
      self._return_value_marked = False
예제 #14
0
    def testCondOneBranch(self):
        with context.graph_mode(), self.cached_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            variables.global_variables_initializer().run()
            p = array_ops.placeholder(dtype=dtypes.bool)
            with acd.AutomaticControlDependencies() as c:

                def true_fn():
                    return 0.0

                def false_fn():
                    v.assign(v + 4)
                    return 1.0

                control_flow_ops.cond(p, true_fn, false_fn)
                val = v.read_value()
                val = c.mark_as_return(val)
            self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0)
            self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
예제 #15
0
 def testManualControlDepMonitoringAttrNotAdded(self):
     with context.graph_mode(), self.cached_session():
         v = resource_variable_ops.ResourceVariable(1.0)
         self.evaluate(variables.global_variables_initializer())
         with acd.AutomaticControlDependencies():
             read_op1 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             read_op2 = gen_resource_variable_ops.read_variable_op(
                 v.handle, v.dtype).op
             assign_op = gen_resource_variable_ops.assign_variable_op(
                 v.handle, v + 1)
         # Writes should have control deps automatically added from "all" reads
         # since last write or start of the code block.
         self.assertIn(read_op1, assign_op.control_inputs)
         self.assertIn(read_op2, assign_op.control_inputs)
         # But, we shouldn't add the monitoring attribute in this case.
         with self.assertRaises(ValueError):
             assign_op.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op1.get_attr("_has_manual_control_dependencies")
         with self.assertRaises(ValueError):
             read_op2.get_attr("_has_manual_control_dependencies")
예제 #16
0
    def testCondMustRunSeparateRead(self):
        with context.graph_mode(), self.cached_session():
            v = resource_variable_ops.ResourceVariable(1.0)
            variables.global_variables_initializer().run()
            p = array_ops.placeholder(dtype=dtypes.bool)
            with acd.AutomaticControlDependencies() as c:

                def true_fn():
                    v.assign(v + 1)
                    return 0.0

                def false_fn():
                    v.assign(v + 4)
                    return 1.0

                control_flow_ops.cond(p, true_fn, false_fn)
                one = constant_op.constant(1.0)
                one = c.mark_as_return(one)
            one.eval(feed_dict={p: False})
            self.assertAllEqual(v.read_value().eval(), 5.0)
            one.eval(feed_dict={p: True})
            self.assertAllEqual(v.read_value().eval(), 6.0)
예제 #17
0
    def inference(self, inputs, *args, **kwargs):

        call_context = base_layer_utils.call_context()
        input_list = nest.flatten(inputs)

        # We will attempt to build a TF graph if & only if all inputs are symbolic.
        # This is always the case in graph mode. It can also be the case in eager
        # mode when all inputs can be traced back to `keras.Input()` (when building
        # models using the functional API).
        build_graph = tf_utils.are_all_symbolic_tensors(input_list)

        # Accept NumPy and scalar inputs by converting to Tensors.
        if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
            def _convert_non_tensor(x):
                # Don't call `ops.convert_to_tensor` on all `inputs` because
                # `SparseTensors` can't be converted to `Tensor`.
                if isinstance(x, (np.ndarray, float, int)):
                    return ops.convert_to_tensor(x)
                return x
            inputs = nest.map_structure(_convert_non_tensor, inputs)
            input_list = nest.flatten(inputs)

        # Handle `mask` propagation from previous layer to current layer. Masks can
        # be propagated explicitly via the `mask` argument, or implicitly via
        # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
        # explicitly take priority.
        mask_arg_passed_by_framework = False
        input_masks = self._collect_input_masks(inputs, args, kwargs)
        if (self._expects_mask_arg and input_masks is not None and
                not self._call_arg_was_passed('mask', args, kwargs)):
            mask_arg_passed_by_framework = True
            kwargs['mask'] = input_masks

        # If `training` argument was not explicitly passed, propagate `training`
        # value from this layer's calling layer.
        training_arg_passed_by_framework = False
        # Priority 1: `training` was explicitly passed.
        if self._call_arg_was_passed('training', args, kwargs):
            training_value = self._get_call_arg_value('training', args, kwargs)
            if not self._expects_training_arg:
                kwargs.pop('training')
        else:
            training_value = None
            # Priority 2: `training` was passed to a parent layer.
            if call_context.training is not None:
                training_value = call_context.training
            # Priority 3a: `learning_phase()` has been set.
            elif backend.global_learning_phase_is_set():
                training_value = backend.learning_phase()
            # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
            elif build_graph:
                with backend.get_graph().as_default():
                    if base_layer_utils.is_in_keras_graph():
                        training_value = backend.learning_phase()

            if self._expects_training_arg and training_value is not None:
                # Force the training_value to be bool type which matches to the contract
                # for layer/model call args.
                if tensor_util.is_tensor(training_value):
                    training_value = math_ops.cast(training_value, dtypes.bool)
                else:
                    training_value = bool(training_value)
                kwargs['training'] = training_value
                training_arg_passed_by_framework = True

        # Only create Keras history if at least one tensor originates from a
        # `keras.Input`. Otherwise this Layer may be being used outside the Keras
        # framework.
        if build_graph and base_layer_utils.needs_keras_history(inputs):
            base_layer_utils.create_keras_history(inputs)

        # Clear eager losses on top level model call.
        # We are clearing the losses only on the top level model call and not on
        # every layer/model call because layer/model may be reused.
        if (base_layer_utils.is_in_eager_or_tf_function() and
                not call_context.in_call):
            self._clear_losses()

        with call_context.enter(self, inputs, build_graph, training_value):
            # Check input assumptions set after layer building, e.g. input shape.
            if build_graph:
                # Symbolic execution on symbolic tensors. We will attempt to build
                # the corresponding TF subgraph inside `backend.get_graph()`
                # TODO(reedwm): We should assert input compatibility after the inputs
                # are casted, not before.
                input_spec.assert_input_compatibility(self.input_spec, inputs,
                                                                                            self.name)
                if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list)
                        and self._supports_ragged_inputs is False):    # pylint: disable=g-bool-id-comparison
                    raise ValueError('Layer %s does not support RaggedTensors as input. '
                                                     'Inputs received: %s. You can try converting your '
                                                     'input to an uniform tensor.' % (self.name, inputs))

                graph = backend.get_graph()
                with graph.as_default(), backend.name_scope(self._name_scope()):
                    # Build layer if applicable (if the `build` method has been
                    # overridden).
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)

                    # Wrapping `call` function in autograph to allow for dynamic control
                    # flow and control dependencies in call. We are limiting this to
                    # subclassed layers as autograph is strictly needed only for
                    # subclassed layers and models.
                    # tf_convert will respect the value of autograph setting in the
                    # enclosing tf.function, if any.
                    if (base_layer_utils.is_subclassed(self) and
                            not base_layer_utils.from_saved_model(self)):
                        call_fn = autograph.tf_convert(
                                self._inference, ag_ctx.control_status_ctx())
                    else:
                        call_fn = self._inference

                    if not self.dynamic:
                        try:
                            with base_layer_utils.autocast_context_manager(
                                    self._compute_dtype):
                                # Add auto_control_deps in V2 when they are not already added by
                                # a `tf.function`.
                                if (ops.executing_eagerly_outside_functions() and
                                        not base_layer_utils.is_in_eager_or_tf_function()):
                                    with auto_control_deps.AutomaticControlDependencies() as acd:
                                        outputs = call_fn(cast_inputs, *args, **kwargs)
                                        # Wrap Tensors in `outputs` in `tf.identity` to avoid
                                        # circular dependencies.
                                        outputs = base_layer_utils.mark_as_return(outputs, acd)
                                else:
                                    outputs = call_fn(cast_inputs, *args, **kwargs)

                        except errors.OperatorNotAllowedInGraphError as e:
                            raise TypeError('You are attempting to use Python control '
                                                            'flow in a layer that was not declared to be '
                                                            'dynamic. Pass `dynamic=True` to the class '
                                                            'constructor.\nEncountered error:\n"""\n' +
                                                            str(e) + '\n"""')
                    else:
                        # We will use static shape inference to return symbolic tensors
                        # matching the specifications of the layer outputs.
                        # Since `self.dynamic` is True, we will never attempt to
                        # run the underlying TF graph (which is disconnected).
                        # TODO(fchollet): consider py_func as an alternative, which
                        # would enable us to run the underlying graph if needed.
                        outputs = self._symbolic_call(inputs)

                    if outputs is None:
                        raise ValueError('A layer\'s `call` method should return a '
                                                         'Tensor or a list of Tensors, not None '
                                                         '(layer: ' + self.name + ').')
                    if base_layer_utils.have_all_keras_metadata(inputs):
                        if training_arg_passed_by_framework:
                            kwargs.pop('training')
                        if mask_arg_passed_by_framework:
                            kwargs.pop('mask')
                        inputs, outputs = self._set_connectivity_metadata_(
                                inputs, outputs, args, kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)
                    if hasattr(self, '_set_inputs') and not self.inputs:
                        # Subclassed network: explicitly set metadata normally set by
                        # a call to self._set_inputs().
                        # TODO(b/120997007): This should be done in Eager as well, but
                        # causes garbage collection issues because of the placeholders
                        # created on the default Keras graph.
                        self._set_inputs(inputs, outputs)
            else:
                # Eager execution on data tensors.
                with backend.name_scope(self._name_scope()):
                    self._maybe_build(inputs)
                    cast_inputs = self._maybe_cast_inputs(inputs)
                    with base_layer_utils.autocast_context_manager(
                            self._compute_dtype):
                        outputs = self._inference(cast_inputs, *args, **kwargs)
                    self._handle_activity_regularization(inputs, outputs)
                    self._set_mask_metadata(inputs, outputs, input_masks)

        return outputs