def wrap_with_training_arg(*args, **kwargs): """Wrap the `wrapped_call` function, and set training argument.""" try: training = call_spec.get_arg_value("training", args, kwargs, inputs_in_args=True) except KeyError: training = None if training is None: training = (default_training_value or base_layer_utils.call_context().training or backend.learning_phase()) args = list(args) kwargs = kwargs.copy() def replace_training_and_call(training): new_args, new_kwargs = call_spec.set_arg_value("training", training, args, kwargs, inputs_in_args=True) return wrapped_call(*new_args, **new_kwargs) return control_flow_util.smart_cond( training, lambda: replace_training_and_call(True), lambda: replace_training_and_call(False), )
def _wrapped_model(*args, **kwargs): """A concrete tf.function that wraps the model's call function.""" ( args, kwargs, ) = model._call_spec.set_arg_value("training", False, args, kwargs, inputs_in_args=True) with base_layer_utils.call_context().enter(model, inputs=None, build_graph=False, training=False, saving=True): outputs = model(*args, **kwargs) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from keras.engine import compile_utils output_names = compile_utils.create_pseudo_output_names(outputs) outputs = tf.nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)}
def wrapper(*args, **kwargs): """Calls method within call context.""" layer = call_collection.layer training = None inputs = _filtered_inputs([args, kwargs]) if (args or kwargs) and call_collection.training_arg_was_passed( args, kwargs ): training = call_collection.get_training_arg_value(args, kwargs) original_losses = _reset_layer_losses(layer) with base_layer_utils.call_context().enter( layer, inputs=inputs, build_graph=False, training=training, saving=True, ): with autocast_variable.enable_auto_cast_variables( layer._compute_dtype_object ): ret = method(*args, **kwargs) _restore_layer_losses(original_losses) return ret
def _wrapped_model(*args, **kwargs): """A concrete tf.function that wraps the model's call function.""" kwargs['training'] = False with base_layer_utils.call_context().enter( model, inputs=None, build_graph=False, training=False, saving=True): outputs = model(*args, **kwargs) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from keras.engine import compile_utils # pylint: disable=g-import-not-at-top output_names = compile_utils.create_pseudo_output_names(outputs) outputs = tf.nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)}
def _wrapped_model(*args): """A concrete tf.function that wraps the model's call function.""" # When given a single input, Keras models will call the model on the tensor # rather than a list consisting of the single tensor. inputs = args[0] if len(input_signature) == 1 else list(args) with base_layer_utils.call_context().enter( model, inputs=inputs, build_graph=False, training=False, saving=True): outputs = model(inputs, training=False) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from keras.engine import compile_utils # pylint: disable=g-import-not-at-top output_names = compile_utils.create_pseudo_output_names(outputs) outputs = tf.nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)}
def test_Bidirectional_updates(self): if tf.executing_eagerly(): self.skipTest("layer.updates is only available in graph mode.") with self.cached_session(): x = keras.layers.Input(shape=(3, 2)) x_reachable_update = x * x layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) _ = layer(x) assert not layer.updates # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d. with base_layer_utils.call_context().enter(layer, x, True, None): layer.forward_layer.add_update(x_reachable_update) layer.forward_layer.add_update(1) layer.backward_layer.add_update(x_reachable_update) layer.backward_layer.add_update(1) assert len(layer.updates) == 4
def wrap_layer_functions(layer, serialization_cache): """Returns dict of wrapped layer call function and losses in tf.functions. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: A dictionary containing all keras tf.functions to serialize. See LayerAttributes and ModelAttributes for the list of all attributes. """ # Since Sequential models may be modified in place using model.add() or # model.pop(), don't use saved functions. if (isinstance(layer, keras_load.RevivedLayer) and not isinstance(layer, sequential_lib.Sequential)): return {fn_name: getattr(layer.keras_api, fn_name, None) for fn_name in serialized_attributes.LayerAttributes.all_functions} # Reset the losses of the layer and its children. The call function in each # child layer is replaced with tf.functions. original_fns = _replace_child_layer_functions(layer, serialization_cache) original_losses = _reset_layer_losses(layer) # Wrap all the layer call and activity regularizer functions. # Use LayerCallCollection to ensure that all layer call functions (__call__, # call with losses) are traced with the same inputs. call_collection = LayerCallCollection(layer) call_fn_with_losses = call_collection.add_function( _wrap_call_and_conditional_losses(layer), '{}_layer_call_and_return_conditional_losses'.format(layer.name)) call_fn = call_collection.add_function( _extract_outputs_from_fn(layer, call_fn_with_losses), '{}_layer_call_fn'.format(layer.name)) fns = {'call_and_return_conditional_losses': call_fn_with_losses, '__call__': call_fn} if layer._activity_regularizer is not None: # pylint: disable=protected-access fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) fns['call_and_return_all_conditional_losses'] = ( call_collection.add_function( _append_activity_regularizer_loss(layer, call_fn_with_losses, fns['activity_regularizer_fn']), '{}_layer_call_and_return_all_conditional_losses'.format(layer.name) )) else: fns['activity_regularizer_fn'] = None fns['call_and_return_all_conditional_losses'] = call_fn_with_losses # Manually trigger traces before restoring the overwritten functions. The # functions are traced within the layer call context to ensure that layer # functions (e.g. add_loss) behave as though running in graph mode. with tracing_scope(): call_collection.trace_with_input_signature() with base_layer_utils.call_context().enter( layer, inputs=None, build_graph=True, training=None, saving=True): for fn in fns.values(): if fn is not None and fn.input_signature is not None: if isinstance(fn, LayerCall): fn = fn.wrapped_call fn.get_concrete_function() # Restore overwritten functions and losses _restore_child_layer_functions(original_fns) _restore_layer_losses(original_losses) return fns
def wrap_layer_functions(layer, serialization_cache): """Returns dict of wrapped layer call function and losses in tf.functions. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: A dictionary containing all keras tf.functions to serialize. See LayerAttributes and ModelAttributes for the list of all attributes. """ # Since Sequential models may be modified in place using model.add() or # model.pop(), don't use saved functions. if isinstance(layer, keras_load.RevivedLayer) and not isinstance( layer, sequential_lib.Sequential): return { fn_name: getattr(layer.keras_api, fn_name, None) for fn_name in serialized_attributes.LayerAttributes.all_functions } # Reset the losses of the layer and its children. The call function in each # child layer is replaced with tf.functions. original_fns = _replace_child_layer_functions(layer, serialization_cache) original_losses = _reset_layer_losses(layer) # Wrap all the layer call and activity regularizer functions. # Use LayerCallCollection to ensure that all layer call functions (__call__, # call with losses) are traced with the same inputs. call_collection = LayerCallCollection(layer) call_fn_with_losses = call_collection.add_function( _wrap_call_and_conditional_losses(layer), "{}_layer_call_and_return_conditional_losses".format(layer.name), # If any of this layer's child layers use the training arg, the traced # call functions of this layer will have a training keyword argument. If # the original layer does not expect the training arg, then it will have # to be removed (by setting `match_layer_training_arg`). match_layer_training_arg=True, ) call_fn = call_collection.add_function( _extract_outputs_from_fn(layer, call_fn_with_losses), "{}_layer_call_fn".format(layer.name), # Since `call_fn` wraps call_fn_with_losses and not the original call # function, `match_layer_training_arg` should be set to False. match_layer_training_arg=False, ) fns = { "call_and_return_conditional_losses": call_fn_with_losses, "__call__": call_fn, } if (layer._activity_regularizer is not None): # pylint: disable=protected-access fns["activity_regularizer_fn"] = _wrap_activity_regularizer(layer) fns["call_and_return_all_conditional_losses"] = call_collection.add_function( _append_activity_regularizer_loss(layer, call_fn_with_losses, fns["activity_regularizer_fn"]), "{}_layer_call_and_return_all_conditional_losses".format( layer.name), match_layer_training_arg=False, ) else: fns["activity_regularizer_fn"] = None fns["call_and_return_all_conditional_losses"] = call_fn_with_losses # Manually trigger traces before restoring the overwritten functions. The # functions are traced within the layer call context to ensure that layer # functions (e.g. add_loss) behave as though running in graph mode. with tracing_scope(): call_collection.trace_with_input_signature() with base_layer_utils.call_context().enter(layer, inputs=None, build_graph=True, training=None, saving=True): for fn in fns.values(): if fn is not None and not isinstance(fn, LayerCall): fn.get_concrete_function() # Restore overwritten functions and losses _restore_child_layer_functions(original_fns) _restore_layer_losses(original_losses) return fns