def testCondNested(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) q = array_ops.placeholder(dtype=dtypes.bool) with acd.AutomaticControlDependencies() as c: def true_fn(): v.assign(v + 1, name='true') return 1.0 def false_fn(): def inner_true_fn(): v.assign(v * 2, name='false_true') return 2.0 def inner_false_fn(): v.assign(v * 3, name='false_false') return 3.0 control_flow_ops.cond(q, inner_true_fn, inner_false_fn) return 1.0 control_flow_ops.cond(p, true_fn, false_fn) with ops.name_scope('final'): val = v.read_value() val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
def testVariableReadsInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: read_op = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op self.assertIn(read_op, c.ops_which_must_run)
def testSendInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: send_op = gen_sendrecv_ops.send(v, "x", "/", 0, "/") # Send must be in `ops_which_must_run`. self.assertIn(send_op, c.ops_which_must_run)
def __init__(self, use_name_scope, function_scope_name, use_auto_deps): self.use_name_scope = use_name_scope if use_name_scope: self.name_scope = ops.name_scope(function_scope_name) self.use_auto_deps = use_auto_deps if use_auto_deps: self.autodeps_scope = auto_control_deps.AutomaticControlDependencies( ) self._return_value_marked = False
def testVariableReadsInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: read_op = gen_resource_variable_ops.read_variable_op(v.handle, v.dtype).op # Read ops get added to control outputs only if they have consumers. c.mark_as_return(read_op.outputs[0]) self.assertIn(read_op, c.ops_which_must_run)
def testBasic(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: v.assign(v + 1) v.assign(2 * v) val = v.read_value() val = c.mark_as_return(val) self.assertAllEqual(val.eval(), 4.0)
def testNoControlDepsBetweenVariableReads(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies(): read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) self.assertNotIn(read_op1, read_op2.control_inputs) self.assertNotIn(read_op2, read_op1.control_inputs)
def testIdentityPassThrough(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies(): gen_resource_variable_ops.assign_variable_op(v.handle, v + 1) identity_handle = gen_array_ops.identity(v.handle) assign_op2 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) read_op = gen_resource_variable_ops.read_variable_op( identity_handle, v.dtype).op # Read should have a control dep from second last write even # with Identity applied to resource. self.assertIn(assign_op2, read_op.control_inputs)
def testVariableMultipleReadsAndWrites(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: # 2 reads -> 2 writes -> 2 reads -> 2 writes. read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op assign_op1 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) assign_op2 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) read_op3 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op4 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op assign_op3 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) assign_op4 = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) # Read ops get added to control outputs only if they have consumers. c.mark_as_return(read_op1.outputs[0]) c.mark_as_return(read_op2.outputs[0]) c.mark_as_return(read_op3.outputs[0]) c.mark_as_return(read_op4.outputs[0]) # Verify the control edges. self.assertIn(read_op1, assign_op1.control_inputs) self.assertIn(read_op2, assign_op1.control_inputs) self.assertIn(assign_op1, assign_op2.control_inputs) self.assertIn(assign_op2, read_op3.control_inputs) self.assertIn(assign_op2, read_op4.control_inputs) self.assertIn(read_op3, assign_op3.control_inputs) self.assertIn(read_op4, assign_op3.control_inputs) self.assertIn(assign_op3, assign_op4.control_inputs) # There should be no control deps between reads. read_ops = [read_op1, read_op2, read_op3, read_op4] for src_op, tgt_op in itertools.product(read_ops, read_ops): self.assertNotIn(src_op, tgt_op.control_inputs) # Reads must be in `ops_which_must_run`. self.assertIn(read_op1, c.ops_which_must_run) self.assertIn(read_op2, c.ops_which_must_run) self.assertIn(read_op3, c.ops_which_must_run) self.assertIn(read_op4, c.ops_which_must_run) # Last write must be in `ops_which_must_run`. self.assertIn(assign_op4, c.ops_which_must_run)
def testVariableWriteThenRead(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies(): assign_op = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op # Reads should have a control dep from the last write. self.assertIn(assign_op, read_op1.control_inputs) self.assertIn(assign_op, read_op2.control_inputs) # There should be no control deps between reads. self.assertNotIn(read_op1, read_op2.control_inputs) self.assertNotIn(read_op2, read_op1.control_inputs)
def testVariableReadsNotInOpsWithMustRun(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies() as c: read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op assign_op = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) # Reads must not be in `ops_which_must_run` since those get added to the # `control_outputs`. self.assertNotIn(read_op1, c.ops_which_must_run) self.assertNotIn(read_op2, c.ops_which_must_run) # Last write must be in `ops_which_must_run`. self.assertIn(assign_op, c.ops_which_must_run)
def testVariableReadThenWrite(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies(): read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op assign_op = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) # Writes should have control deps from "all" reads since last write # or start of the code block. self.assertIn(read_op1, assign_op.control_inputs) self.assertIn(read_op2, assign_op.control_inputs) # There should be no control deps between reads. self.assertNotIn(read_op1, read_op2.control_inputs) self.assertNotIn(read_op2, read_op1.control_inputs)
def __init__(self, function_name, scope_name, options): self.name = scope_name self.options = options if options.user_requested: self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED, options) self.callopts = options.call_options() use_name_scope = options.uses(converter.Feature.NAME_SCOPES) self.use_name_scope = use_name_scope if use_name_scope: self.name_scope = ops.name_scope(self._sanitize(function_name)) use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS) self.use_auto_deps = use_auto_deps if use_auto_deps: self.autodeps_scope = auto_control_deps.AutomaticControlDependencies() self._return_value_marked = False
def testCondOneBranch(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) with acd.AutomaticControlDependencies() as c: def true_fn(): return 0.0 def false_fn(): v.assign(v + 4) return 1.0 control_flow_ops.cond(p, true_fn, false_fn) val = v.read_value() val = c.mark_as_return(val) self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
def testManualControlDepMonitoringAttrNotAdded(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) self.evaluate(variables.global_variables_initializer()) with acd.AutomaticControlDependencies(): read_op1 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op read_op2 = gen_resource_variable_ops.read_variable_op( v.handle, v.dtype).op assign_op = gen_resource_variable_ops.assign_variable_op( v.handle, v + 1) # Writes should have control deps automatically added from "all" reads # since last write or start of the code block. self.assertIn(read_op1, assign_op.control_inputs) self.assertIn(read_op2, assign_op.control_inputs) # But, we shouldn't add the monitoring attribute in this case. with self.assertRaises(ValueError): assign_op.get_attr("_has_manual_control_dependencies") with self.assertRaises(ValueError): read_op1.get_attr("_has_manual_control_dependencies") with self.assertRaises(ValueError): read_op2.get_attr("_has_manual_control_dependencies")
def testCondMustRunSeparateRead(self): with context.graph_mode(), self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() p = array_ops.placeholder(dtype=dtypes.bool) with acd.AutomaticControlDependencies() as c: def true_fn(): v.assign(v + 1) return 0.0 def false_fn(): v.assign(v + 4) return 1.0 control_flow_ops.cond(p, true_fn, false_fn) one = constant_op.constant(1.0) one = c.mark_as_return(one) one.eval(feed_dict={p: False}) self.assertAllEqual(v.read_value().eval(), 5.0) one.eval(feed_dict={p: True}) self.assertAllEqual(v.read_value().eval(), 6.0)
def inference(self, inputs, *args, **kwargs): call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) # We will attempt to build a TF graph if & only if all inputs are symbolic. # This is always the case in graph mode. It can also be the case in eager # mode when all inputs can be traced back to `keras.Input()` (when building # models using the functional API). build_graph = tf_utils.are_all_symbolic_tensors(input_list) # Accept NumPy and scalar inputs by converting to Tensors. if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): # Don't call `ops.convert_to_tensor` on all `inputs` because # `SparseTensors` can't be converted to `Tensor`. if isinstance(x, (np.ndarray, float, int)): return ops.convert_to_tensor(x) return x inputs = nest.map_structure(_convert_non_tensor, inputs) input_list = nest.flatten(inputs) # Handle `mask` propagation from previous layer to current layer. Masks can # be propagated explicitly via the `mask` argument, or implicitly via # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # explicitly take priority. mask_arg_passed_by_framework = False input_masks = self._collect_input_masks(inputs, args, kwargs) if (self._expects_mask_arg and input_masks is not None and not self._call_arg_was_passed('mask', args, kwargs)): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks # If `training` argument was not explicitly passed, propagate `training` # value from this layer's calling layer. training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') else: training_value = None # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training # Priority 3a: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. elif build_graph: with backend.get_graph().as_default(): if base_layer_utils.is_in_keras_graph(): training_value = backend.learning_phase() if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract # for layer/model call args. if tensor_util.is_tensor(training_value): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) kwargs['training'] = training_value training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras # framework. if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) # Clear eager losses on top level model call. # We are clearing the losses only on the top level model call and not on # every layer/model call because layer/model may be reused. if (base_layer_utils.is_in_eager_or_tf_function() and not call_context.in_call): self._clear_losses() with call_context.enter(self, inputs, build_graph, training_value): # Check input assumptions set after layer building, e.g. input shape. if build_graph: # Symbolic execution on symbolic tensors. We will attempt to build # the corresponding TF subgraph inside `backend.get_graph()` # TODO(reedwm): We should assert input compatibility after the inputs # are casted, not before. input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list) and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison raise ValueError('Layer %s does not support RaggedTensors as input. ' 'Inputs received: %s. You can try converting your ' 'input to an uniform tensor.' % (self.name, inputs)) graph = backend.get_graph() with graph.as_default(), backend.name_scope(self._name_scope()): # Build layer if applicable (if the `build` method has been # overridden). self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) # Wrapping `call` function in autograph to allow for dynamic control # flow and control dependencies in call. We are limiting this to # subclassed layers as autograph is strictly needed only for # subclassed layers and models. # tf_convert will respect the value of autograph setting in the # enclosing tf.function, if any. if (base_layer_utils.is_subclassed(self) and not base_layer_utils.from_saved_model(self)): call_fn = autograph.tf_convert( self._inference, ag_ctx.control_status_ctx()) else: call_fn = self._inference if not self.dynamic: try: with base_layer_utils.autocast_context_manager( self._compute_dtype): # Add auto_control_deps in V2 when they are not already added by # a `tf.function`. if (ops.executing_eagerly_outside_functions() and not base_layer_utils.is_in_eager_or_tf_function()): with auto_control_deps.AutomaticControlDependencies() as acd: outputs = call_fn(cast_inputs, *args, **kwargs) # Wrap Tensors in `outputs` in `tf.identity` to avoid # circular dependencies. outputs = base_layer_utils.mark_as_return(outputs, acd) else: outputs = call_fn(cast_inputs, *args, **kwargs) except errors.OperatorNotAllowedInGraphError as e: raise TypeError('You are attempting to use Python control ' 'flow in a layer that was not declared to be ' 'dynamic. Pass `dynamic=True` to the class ' 'constructor.\nEncountered error:\n"""\n' + str(e) + '\n"""') else: # We will use static shape inference to return symbolic tensors # matching the specifications of the layer outputs. # Since `self.dynamic` is True, we will never attempt to # run the underlying TF graph (which is disconnected). # TODO(fchollet): consider py_func as an alternative, which # would enable us to run the underlying graph if needed. outputs = self._symbolic_call(inputs) if outputs is None: raise ValueError('A layer\'s `call` method should return a ' 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: kwargs.pop('training') if mask_arg_passed_by_framework: kwargs.pop('mask') inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by # a call to self._set_inputs(). # TODO(b/120997007): This should be done in Eager as well, but # causes garbage collection issues because of the placeholders # created on the default Keras graph. self._set_inputs(inputs, outputs) else: # Eager execution on data tensors. with backend.name_scope(self._name_scope()): self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) with base_layer_utils.autocast_context_manager( self._compute_dtype): outputs = self._inference(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) return outputs