def wrapper(self, *args, **kwargs): layer = self.call_collection.layer original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter(layer, None, True, None): ret = method(self, *args, **kwargs) _restore_layer_losses(original_losses) return ret
def _wrapped_model(*args): """A concrete tf.function that wraps the model's call function.""" # When given a single input, Keras models will call the model on the tensor # rather than a list consisting of the single tensor. # inputs = args[0] if len(input_signature) == 1 else list(args) inputs = args[0] if len(args) == 1 else list(args) with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): return model(*args, training=False)
def wrapper(self, *args, **kwargs): """Calls method within call context.""" layer = self.call_collection.layer training = None # pylint: disable=protected-access if (args or kwargs) and layer._call_arg_was_passed( 'training', args, kwargs, inputs_in_args=True): training = layer._get_call_arg_value( 'training', args, kwargs, inputs_in_args=True) # pylint: enable=protected-access original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter(layer, None, True, training): ret = method(self, *args, **kwargs) _restore_layer_losses(original_losses) return ret
def _wrapped_model(*args): """A concrete tf.function that wraps the model's call function.""" # When given a single input, Keras models will call the model on the tensor # rather than a list consisting of the single tensor. inputs = args[0] if len(input_signature) == 1 else list(args) with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): outputs_list = nest.flatten(model(inputs=inputs, training=False)) try: output_names = model.output_names except AttributeError: from tensorflow.python.keras.engine import training_utils # pylint: disable=g-import-not-at-top output_names = training_utils.generic_output_names(outputs_list) return {name: output for name, output in zip(output_names, outputs_list)}
def _wrapped_model(*args): """A concrete tf.function that wraps the model's call function.""" # When given a single input, Keras models will call the model on the tensor # rather than a list consisting of the single tensor. inputs = args[0] if len(input_signature) == 1 else list(args) with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): outputs = model(inputs, training=False) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top output_names = compile_utils.create_pseudo_output_names(outputs) outputs = nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)}
def test_Bidirectional_updates(self): if context.executing_eagerly(): self.skipTest('layer.updates is only available in graph mode.') with self.cached_session(): x = keras.layers.Input(shape=(3, 2)) x_reachable_update = x * x layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) _ = layer(x) assert not layer.updates # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d. with base_layer_utils.call_context().enter(layer, x, True, None): layer.forward_layer.add_update(x_reachable_update, inputs=x) layer.forward_layer.add_update(1, inputs=None) layer.backward_layer.add_update(x_reachable_update, inputs=x) layer.backward_layer.add_update(1, inputs=None) assert len(layer.updates) == 4
def wrapper(*args, **kwargs): """Calls method within call context.""" layer = call_collection.layer training = None inputs = call_collection.get_input_arg_value(args, kwargs) # pylint: disable=protected-access if (args or kwargs) and call_collection.training_arg_was_passed( args, kwargs): training = call_collection.get_training_arg_value(args, kwargs) # pylint: enable=protected-access original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter( layer, inputs=inputs, build_graph=False, training=training, saving=True): with base_layer_utils.autocast_context_manager(layer._compute_dtype): # pylint: disable=protected-access ret = method(*args, **kwargs) _restore_layer_losses(original_losses) return ret
def wrapper(*args, **kwargs): """Calls method within call context.""" layer = call_collection.layer training = None inputs = None # pylint: disable=protected-access if (args or kwargs) and call_collection.training_arg_was_passed( args, kwargs): inputs = args[0] training = call_collection.get_training_arg_value(args, kwargs) # pylint: enable=protected-access original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter(layer, inputs=inputs, build_graph=False, training=training): ret = method(*args, **kwargs) _restore_layer_losses(original_losses) return ret
def wrap_layer_functions(layer, serialization_cache): """Returns dict of wrapped layer call function and losses in tf.functions. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: A dictionary containing all keras tf.functions to serialize. See LayerAttributes and ModelAttributes for the list of all attributes. """ # Since Sequential models may be modified in place using model.add() or # model.pop(), don't use saved functions. if (isinstance(layer, keras_load.RevivedLayer) and not isinstance(layer, keras_load.RevivedSequential)): return { fn_name: getattr(layer.keras_api, fn_name, None) for fn_name in serialized_attributes.LayerAttributes.all_functions } # Reset the losses of the layer and its children. The call function in each # child layer is replaced with tf.functions. original_fns = _replace_child_layer_functions(layer, serialization_cache) original_losses = _reset_layer_losses(layer) # Wrap all the layer call and activity regularizer functions. # Use LayerCallCollection to ensure that all layer call functions (__call__, # call with losses) are traced with the same inputs. call_collection = LayerCallCollection(layer) call_fn_with_losses = call_collection.add_function( _wrap_call_and_conditional_losses(layer), '{}_layer_call_and_return_conditional_losses'.format(layer.name)) call_fn = call_collection.add_function( _extract_outputs_from_fn(layer, call_fn_with_losses), '{}_layer_call_fn'.format(layer.name)) fns = { 'call_and_return_conditional_losses': call_fn_with_losses, '__call__': call_fn } if layer.activity_regularizer is not None: fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) fns['call_and_return_all_conditional_losses'] = ( call_collection.add_function( _append_activity_regularizer_loss( layer, call_fn_with_losses, fns['activity_regularizer_fn']), '{}_layer_call_and_return_all_conditional_losses'.format( layer.name))) else: fns['activity_regularizer_fn'] = None fns['call_and_return_all_conditional_losses'] = call_fn_with_losses # Manually trigger traces before restoring the overwritten functions. The # functions are traced within the layer call context to ensure that layer # functions (e.g. add_loss) behave as though running in graph mode. with base_layer_utils.call_context().enter(layer, inputs=None, build_graph=True, training=None, saving=True): for fn in fns.values(): if fn is not None and fn.input_signature is not None: fn.get_concrete_function() # Restore overwritten functions and losses _restore_child_layer_functions(original_fns) _restore_layer_losses(original_losses) return fns
def _is_building_keras_layer(): return base_layer_utils.call_context().layer is not None
def inference(self, inputs, *args, **kwargs): call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) # We will attempt to build a TF graph if & only if all inputs are symbolic. # This is always the case in graph mode. It can also be the case in eager # mode when all inputs can be traced back to `keras.Input()` (when building # models using the functional API). build_graph = tf_utils.are_all_symbolic_tensors(input_list) # Accept NumPy and scalar inputs by converting to Tensors. if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): # Don't call `ops.convert_to_tensor` on all `inputs` because # `SparseTensors` can't be converted to `Tensor`. if isinstance(x, (np.ndarray, float, int)): return ops.convert_to_tensor(x) return x inputs = nest.map_structure(_convert_non_tensor, inputs) input_list = nest.flatten(inputs) # Handle `mask` propagation from previous layer to current layer. Masks can # be propagated explicitly via the `mask` argument, or implicitly via # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # explicitly take priority. mask_arg_passed_by_framework = False input_masks = self._collect_input_masks(inputs, args, kwargs) if (self._expects_mask_arg and input_masks is not None and not self._call_arg_was_passed('mask', args, kwargs)): mask_arg_passed_by_framework = True kwargs['mask'] = input_masks # If `training` argument was not explicitly passed, propagate `training` # value from this layer's calling layer. training_arg_passed_by_framework = False # Priority 1: `training` was explicitly passed. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: kwargs.pop('training') else: training_value = None # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training # Priority 3a: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. elif build_graph: with backend.get_graph().as_default(): if base_layer_utils.is_in_keras_graph(): training_value = backend.learning_phase() if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract # for layer/model call args. if tensor_util.is_tensor(training_value): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) kwargs['training'] = training_value training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras # framework. if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) # Clear eager losses on top level model call. # We are clearing the losses only on the top level model call and not on # every layer/model call because layer/model may be reused. if (base_layer_utils.is_in_eager_or_tf_function() and not call_context.in_call): self._clear_losses() with call_context.enter(self, inputs, build_graph, training_value): # Check input assumptions set after layer building, e.g. input shape. if build_graph: # Symbolic execution on symbolic tensors. We will attempt to build # the corresponding TF subgraph inside `backend.get_graph()` # TODO(reedwm): We should assert input compatibility after the inputs # are casted, not before. input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) if (any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list) and self._supports_ragged_inputs is False): # pylint: disable=g-bool-id-comparison raise ValueError('Layer %s does not support RaggedTensors as input. ' 'Inputs received: %s. You can try converting your ' 'input to an uniform tensor.' % (self.name, inputs)) graph = backend.get_graph() with graph.as_default(), backend.name_scope(self._name_scope()): # Build layer if applicable (if the `build` method has been # overridden). self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) # Wrapping `call` function in autograph to allow for dynamic control # flow and control dependencies in call. We are limiting this to # subclassed layers as autograph is strictly needed only for # subclassed layers and models. # tf_convert will respect the value of autograph setting in the # enclosing tf.function, if any. if (base_layer_utils.is_subclassed(self) and not base_layer_utils.from_saved_model(self)): call_fn = autograph.tf_convert( self._inference, ag_ctx.control_status_ctx()) else: call_fn = self._inference if not self.dynamic: try: with base_layer_utils.autocast_context_manager( self._compute_dtype): # Add auto_control_deps in V2 when they are not already added by # a `tf.function`. if (ops.executing_eagerly_outside_functions() and not base_layer_utils.is_in_eager_or_tf_function()): with auto_control_deps.AutomaticControlDependencies() as acd: outputs = call_fn(cast_inputs, *args, **kwargs) # Wrap Tensors in `outputs` in `tf.identity` to avoid # circular dependencies. outputs = base_layer_utils.mark_as_return(outputs, acd) else: outputs = call_fn(cast_inputs, *args, **kwargs) except errors.OperatorNotAllowedInGraphError as e: raise TypeError('You are attempting to use Python control ' 'flow in a layer that was not declared to be ' 'dynamic. Pass `dynamic=True` to the class ' 'constructor.\nEncountered error:\n"""\n' + str(e) + '\n"""') else: # We will use static shape inference to return symbolic tensors # matching the specifications of the layer outputs. # Since `self.dynamic` is True, we will never attempt to # run the underlying TF graph (which is disconnected). # TODO(fchollet): consider py_func as an alternative, which # would enable us to run the underlying graph if needed. outputs = self._symbolic_call(inputs) if outputs is None: raise ValueError('A layer\'s `call` method should return a ' 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): if training_arg_passed_by_framework: kwargs.pop('training') if mask_arg_passed_by_framework: kwargs.pop('mask') inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) if hasattr(self, '_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by # a call to self._set_inputs(). # TODO(b/120997007): This should be done in Eager as well, but # causes garbage collection issues because of the placeholders # created on the default Keras graph. self._set_inputs(inputs, outputs) else: # Eager execution on data tensors. with backend.name_scope(self._name_scope()): self._maybe_build(inputs) cast_inputs = self._maybe_cast_inputs(inputs) with base_layer_utils.autocast_context_manager( self._compute_dtype): outputs = self._inference(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) self._set_mask_metadata(inputs, outputs, input_masks) return outputs