def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object.""" if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: return cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) with CustomObjectScope(custom_objects): return cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): return cls(**cls_config) elif isinstance(identifier, six.string_types): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def __init__(self, x, y, image_data_generator, batch_size=32, shuffle=False, sample_weight=None, seed=None, data_format=None, save_to_dir=None, save_prefix='', save_format='png', subset=None, dtype=None): if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec( image.NumpyArrayIterator.__init__)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype super(NumpyArrayIterator, self).__init__( x, y, image_data_generator, batch_size=batch_size, shuffle=shuffle, sample_weight=sample_weight, seed=seed, data_format=data_format, save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format, subset=subset, **kwargs)
def _recreate_layer_from_config(self, layer, go_backwards=False): # When recreating the layer from its config, it is possible that the # layer is a RNN layer that contains custom cells. In this case we # inspect the layer and pass the custom cell class as part of the # `custom_objects` argument when calling `from_config`. See # https://github.com/tensorflow/tensorflow/issues/26581 for more detail. config = layer.get_config() if go_backwards: config["go_backwards"] = not config["go_backwards"] if ( "custom_objects" in tf_inspect.getfullargspec(layer.__class__.from_config).args ): custom_objects = {} cell = getattr(layer, "cell", None) if cell is not None: custom_objects[cell.__class__.__name__] = cell.__class__ # For StackedRNNCells stacked_cells = getattr(cell, "cells", []) for c in stacked_cells: custom_objects[c.__class__.__name__] = c.__class__ return layer.__class__.from_config( config, custom_objects=custom_objects ) else: return layer.__class__.from_config(config)
def __init__(self, cell, *args, **kwargs): super().__init__(*args, **kwargs) self.cell = cell cell_call_spec = tf_inspect.getfullargspec(cell.call) self._call_spec.expects_training_arg = ( "training" in cell_call_spec.args ) or (cell_call_spec.varkw is not None)
def __init__(self, layer): self.layer = layer self.layer_call_method = _get_layer_call_method(layer) self._expects_training_arg = utils.layer_uses_training_bool(layer) self._training_arg_index = utils.get_training_arg_index( self.layer_call_method) # If the layer call function has kwargs, then the traced function cannot # have an input signature. arg_spec = tf_inspect.getfullargspec(self.layer_call_method) self._has_kwargs = bool(self._expects_training_arg or arg_spec.defaults or arg_spec.kwonlyargs or arg_spec.varkw) self._input_signature = self._generate_input_signature(layer) self._functions = weakref.WeakValueDictionary() # Bool indicating whether this object is currently tracing the layer call # functions. self.tracing = False # Get the input argument name from the args. args = arg_spec.args if tf_inspect.ismethod(self.layer_call_method): args = args[1:] self._input_arg_name = args[0] if args else 'inputs'
def get_training_arg_index(call_fn): """Returns the index of 'training' in the layer call function arguments. Args: call_fn: Call function. Returns: - n: index of 'training' in the call function arguments. - -1: if 'training' is not found in the arguments, but layer.call accepts variable keyword arguments - None: if layer doesn't expect a training argument. """ argspec = tf_inspect.getfullargspec(call_fn) if argspec.varargs: # When there are variable args, training must be a keyword arg. if 'training' in argspec.kwonlyargs or argspec.varkw: return -1 return None else: # Try to find 'training' in the list of args or kwargs. arg_list = argspec.args if call_is_method(call_fn): arg_list = arg_list[1:] if 'training' in arg_list: return arg_list.index('training') elif 'training' in argspec.kwonlyargs or argspec.varkw: return -1 return None
def fn_args(fn): """Get argument names for function-like object. Args: fn: Function, or function-like object (e.g., result of `functools.partial`). Returns: `tuple` of string argument names. Raises: ValueError: if partial function has positionally bound arguments """ if isinstance(fn, functools.partial): args = fn_args(fn.func) args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] else: if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__): fn = fn.__call__ args = tf_inspect.getfullargspec(fn).args if is_bound_method(fn) and args: # If it's a bound method, it may or may not have a self/cls first # argument; for example, self could be captured in *args. # If it does have a positional argument, it is self/cls. args.pop(0) return tuple(args)
def __init__(self, featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-6, rotation_range=0, width_shift_range=0., height_shift_range=0., brightness_range=None, shear_range=0., zoom_range=0., channel_shift_range=0., fill_mode='nearest', cval=0., horizontal_flip=False, vertical_flip=False, rescale=None, preprocessing_function=None, data_format=None, validation_split=0.0, dtype=None): if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec( image.ImageDataGenerator.__init__)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype super(ImageDataGenerator, self).__init__( featurewise_center=featurewise_center, samplewise_center=samplewise_center, featurewise_std_normalization=featurewise_std_normalization, samplewise_std_normalization=samplewise_std_normalization, zca_whitening=zca_whitening, zca_epsilon=zca_epsilon, rotation_range=rotation_range, width_shift_range=width_shift_range, height_shift_range=height_shift_range, brightness_range=brightness_range, shear_range=shear_range, zoom_range=zoom_range, channel_shift_range=channel_shift_range, fill_mode=fill_mode, cval=cval, horizontal_flip=horizontal_flip, vertical_flip=vertical_flip, rescale=rescale, preprocessing_function=preprocessing_function, data_format=data_format, validation_split=validation_split, **kwargs)
def testTrainingDefaults(self): def assert_training_default(fn, default_value): arg_spec = tf_inspect.getfullargspec(fn) index = len(arg_spec.args) - arg_spec.args.index('training') self.assertEqual(arg_spec.defaults[-index], default_value) class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer): def call(self, inputs, training): return control_flow_util.smart_cond(training, lambda: inputs * 0, lambda: tf.identity(inputs)) class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer): def call(self, inputs, training=True): return control_flow_util.smart_cond(training, lambda: inputs * 0, lambda: tf.identity(inputs)) class Model(keras.models.Model): def __init__(self): super(Model, self).__init__() self.layer_with_training_default_none = LayerWithLearningPhase() self.layer_with_training_default_true = LayerWithTrainingDefaultTrue() self.layer_with_required_training_arg = LayerWithTrainingRequiredArg() def call(self, inputs): x = self.layer_with_training_default_none(inputs) x += self.layer_with_training_default_true(inputs) x += self.layer_with_required_training_arg(inputs, False) return x model = Model() # Build and set model inputs model.predict(np.ones([1, 3]).astype('float32')) saved_model_dir = self._save_model_dir() model.save(saved_model_dir, save_format='tf') load = tf.saved_model.load(saved_model_dir) # Ensure that the Keras loader is able to load and build the model. _ = keras_load.load(saved_model_dir) assert_training_default(load.__call__, False) assert_training_default( load.layer_with_training_default_none.__call__, False) assert_training_default( load.layer_with_training_default_true.__call__, True) # Assert that there are no defaults for layer with required training arg arg_spec = tf_inspect.getfullargspec( load.layer_with_required_training_arg.__call__) self.assertFalse(arg_spec.defaults) # defaults is None or empty
def has_arg(fn, name, accept_all=False): """Checks if a callable accepts a given keyword argument. Args: fn: Callable to inspect. name: Check if `fn` can be called with `name` as a keyword argument. accept_all: What to return if there is no parameter called `name` but the function accepts a `**kwargs` argument. Returns: bool, whether `fn` accepts a `name` keyword argument. """ arg_spec = tf_inspect.getfullargspec(fn) if accept_all and arg_spec.varkw is not None: return True return name in arg_spec.args or name in arg_spec.kwonlyargs
def __init__(self, layer): self.layer = layer self.layer_call_method = _get_layer_call_method(layer) self._expects_training_arg = utils.layer_uses_training_bool(layer) self._training_arg_index = utils.get_training_arg_index( self.layer_call_method) self._layer_inputs = self._get_layer_inputs(layer) self._functions = weakref.WeakValueDictionary() # Get the input argument name from the args. arg_spec = tf_inspect.getfullargspec(self.layer_call_method) args = arg_spec.args if tf_inspect.ismethod(self.layer_call_method): args = args[1:] self._input_arg_name = args[0] if args else 'inputs'
def get_training_arg_index(call_fn): """Returns the index of 'training' in the layer call function arguments. Args: call_fn: Call function. Returns: - n: index of 'training' in the call function arguments. - -1: if 'training' is not found in the arguments, but layer.call accepts variable keyword arguments - None: if layer doesn't expect a training argument. """ arg_list = tf_inspect.getfullargspec(call_fn).args if tf_inspect.ismethod(call_fn): arg_list = arg_list[1:] if 'training' in arg_list: return arg_list.index('training') else: return -1
def __init__( self, function, output_shape=None, mask=None, arguments=None, **kwargs ): super().__init__(**kwargs) self.arguments = arguments or {} self.function = function if mask is not None: self.supports_masking = True self.mask = mask self._output_shape = output_shape # Warning on every invocation will be quite irksome in Eager mode. self._already_warned = False function_args = tf_inspect.getfullargspec(function).args self._fn_expects_training_arg = "training" in function_args self._fn_expects_mask_arg = "mask" in function_args
def __init__(self, directory, image_data_generator, target_size=(256, 256), color_mode='rgb', classes=None, class_mode='categorical', batch_size=32, shuffle=True, seed=None, data_format=None, save_to_dir=None, save_prefix='', save_format='png', follow_links=False, subset=None, interpolation='nearest', dtype=None): if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec( image.ImageDataGenerator.__init__)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype super(DirectoryIterator, self).__init__( directory, image_data_generator, target_size=target_size, color_mode=color_mode, classes=classes, class_mode=class_mode, batch_size=batch_size, shuffle=shuffle, seed=seed, data_format=data_format, save_to_dir=save_to_dir, save_prefix=save_prefix, save_format=save_format, follow_links=follow_links, subset=subset, interpolation=interpolation, **kwargs)
def _has_kwargs(fn): """Returns whether the passed callable has **kwargs in its signature. Args: fn: Function, or function-like object (e.g., result of `functools.partial`). Returns: `bool`: if `fn` has **kwargs in its signature. Raises: `TypeError`: If fn is not a Function, or function-like object. """ if isinstance(fn, functools.partial): fn = fn.func elif _is_callable_object(fn): fn = fn.__call__ elif not callable(fn): raise TypeError( "fn should be a function-like object, but is of type {}.".format( type(fn))) return tf_inspect.getfullargspec(fn).varkw is not None
def array_to_img(x, data_format=None, scale=True, dtype=None): """Converts a 3D Numpy array to a PIL Image instance. Usage: ```python from PIL import Image img = np.random.random(size=(100, 100, 3)) pil_img = tf.keras.preprocessing.image.array_to_img(img) ``` Arguments: x: Input Numpy array. data_format: Image data format, can be either "channels_first" or "channels_last". Defaults to `None`, in which case the global setting `tf.keras.backend.image_data_format()` is used (unless you changed it, it defaults to "channels_last"). scale: Whether to rescale image values to be within `[0, 255]`. Defaults to `True`. dtype: Dtype to use. Default to `None`, in which case the global setting `tf.keras.backend.floatx()` is used (unless you changed it, it defaults to "float32") Returns: A PIL Image instance. Raises: ImportError: if PIL is not available. ValueError: if invalid `x` or `data_format` is passed. """ if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
def _maybe_wrap_with_training_arg(self, call_fn, match_layer_training_arg): """Wraps call function with added training argument if necessary.""" if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access # Add training arg to wrapper function. arg_spec = tf_inspect.getfullargspec(call_fn) args = arg_spec.args + ['training'] defaults = list(arg_spec.defaults or []) defaults.append(False) new_arg_spec = tf_inspect.FullArgSpec( args=args, varargs=arg_spec.varargs, varkw=arg_spec.varkw, defaults=defaults, kwonlyargs=arg_spec.kwonlyargs, kwonlydefaults=arg_spec.kwonlydefaults, annotations=arg_spec.annotations) # Set new training arg index self._training_arg_index = len(args) - 1 if tf_inspect.ismethod(call_fn): self._training_arg_index -= 1 def wrap_with_training_arg(*args, **kwargs): if match_layer_training_arg: # Remove the training value, since the original call_fn does not # expect a training arg. Instead, the training value will be # propagated using the call context created in LayerCall. args = list(args) kwargs = kwargs.copy() utils.remove_training_arg(self._training_arg_index, args, kwargs) return call_fn(*args, **kwargs) return tf.__internal__.decorator.make_decorator( target=call_fn, decorator_func=wrap_with_training_arg, decorator_argspec=new_arg_spec) return call_fn
def img_to_array(img, data_format=None, dtype=None): """Converts a PIL Image instance to a Numpy array. Usage: ```python from PIL import Image img_data = np.random.random(size=(100, 100, 3)) img = tf.keras.preprocessing.image.array_to_img(img_data) array = tf.keras.preprocessing.image.img_to_array(img) ``` Arguments: img: Input PIL Image instance. data_format: Image data format, can be either "channels_first" or "channels_last". Defaults to `None`, in which case the global setting `tf.keras.backend.image_data_format()` is used (unless you changed it, it defaults to "channels_last"). dtype: Dtype to use. Default to `None`, in which case the global setting `tf.keras.backend.floatx()` is used (unless you changed it, it defaults to "float32") Returns: A 3D Numpy array. Raises: ValueError: if invalid `img` or `data_format` is passed. """ if data_format is None: data_format = backend.image_data_format() kwargs = {} if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]: if dtype is None: dtype = backend.floatx() kwargs['dtype'] = dtype return image.img_to_array(img, data_format=data_format, **kwargs)
def add(self, layer): """Adds a layer instance on top of the layer stack. Args: layer: layer instance. Raises: TypeError: If `layer` is not a layer instance. ValueError: In case the `layer` argument does not know its input shape. ValueError: In case the `layer` argument has multiple output tensors, or is already connected somewhere else (forbidden in `Sequential` models). """ # If we are passed a Keras tensor created by keras.Input(), we can extract # the input layer from its keras history and use that without any loss of # generality. if hasattr(layer, '_keras_history'): origin_layer = layer._keras_history[0] if isinstance(origin_layer, input_layer.InputLayer): layer = origin_layer logging.warning( 'Please add `keras.layers.InputLayer` instead of `keras.Input` to ' 'Sequential model. `keras.Input` is intended to be used by ' 'Functional model.') if isinstance(layer, tf.Module): if not isinstance(layer, base_layer.Layer): layer = functional.ModuleWrapper(layer) else: raise TypeError('The added layer must be ' 'an instance of class Layer. ' 'Found: ' + str(layer)) tf_utils.assert_no_legacy_layers([layer]) if not self._is_layer_name_unique(layer): raise ValueError('All layers added to a Sequential model ' 'should have unique names. Name "%s" is already the name' ' of a layer in this model. Update the `name` argument ' 'to pass a unique name.' % (layer.name,)) self.built = False set_inputs = False self._maybe_create_attribute('_self_tracked_trackables', []) if not self._self_tracked_trackables: if isinstance(layer, input_layer.InputLayer): # Case where the user passes an Input or InputLayer layer via `add`. set_inputs = True else: batch_shape, dtype = training_utils.get_input_shape_and_dtype(layer) if batch_shape: # Instantiate an input layer. x = input_layer.Input( batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') # This will build the current layer # and create the node connecting the current layer # to the input layer we just created. layer(x) set_inputs = True if set_inputs: outputs = tf.nest.flatten(layer._inbound_nodes[-1].outputs) if len(outputs) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) self.outputs = outputs self.inputs = layer_utils.get_source_inputs(self.outputs[0]) self.built = True self._has_explicit_input_shape = True elif self.outputs: # If the model is being built continuously on top of an input layer: # refresh its output. output_tensor = layer(self.outputs[0]) if len(tf.nest.flatten(output_tensor)) != 1: raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) self.outputs = [output_tensor] self.built = True if set_inputs or self._graph_initialized: self._init_graph_network(self.inputs, self.outputs) self._graph_initialized = True else: self._self_tracked_trackables.append(layer) self._handle_deferred_layer_dependencies([layer]) self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): """Turns the serialized form of a Keras object back into an actual object. This function is for mid-level library implementers rather than end users. Importantly, this utility requires you to provide the dict of `module_objects` to use for looking up the object config; this is not populated by default. If you need a deserialization utility that has preexisting knowledge of built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, etc. Calling `deserialize_keras_object` while underneath the `SharedObjectLoadingScope` context manager will cause any already-seen shared objects to be returned as-is rather than creating a new object. Args: identifier: the serialized form of the object. module_objects: A dictionary of built-in objects to look the name up in. Generally, `module_objects` is provided by midlevel library implementers. custom_objects: A dictionary of custom objects to look the name up in. Generally, `custom_objects` is provided by the end user. printable_module_name: A human-readable string representing the type of the object. Printed in case of exception. Returns: The deserialized object. Example: A mid-level library implementer might want to implement a utility for retrieving an object from its config, as such: ```python def deserialize(config, custom_objects=None): return deserialize_keras_object( identifier, module_objects=globals(), custom_objects=custom_objects, name="MyObjectType", ) ``` This is how e.g. `keras.layers.deserialize()` is implemented. """ if identifier is None: return None if isinstance(identifier, dict): # In this case we are dealing with a Keras config dictionary. config = identifier (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) # If this object has already been loaded (i.e. it's shared between multiple # objects), return the already-loaded object. shared_object_id = config.get(SHARED_OBJECT_KEY) shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none if shared_object is not None: return shared_object if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: deserialized_obj = cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) else: with CustomObjectScope(custom_objects): deserialized_obj = cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): deserialized_obj = cls(**cls_config) # Add object to shared objects, in case we find it referenced again. _shared_object_loading_scope().set(shared_object_id, deserialized_obj) return deserialized_obj elif isinstance(identifier, str): object_name = identifier if custom_objects and object_name in custom_objects: obj = custom_objects.get(object_name) elif object_name in _GLOBAL_CUSTOM_OBJECTS: obj = _GLOBAL_CUSTOM_OBJECTS[object_name] else: obj = module_objects.get(object_name) if obj is None: raise ValueError( 'Unknown {}: {}. Please ensure this object is ' 'passed to the `custom_objects` argument. See ' 'https://www.tensorflow.org/guide/keras/save_and_serialize' '#registering_the_custom_object for details.'.format( printable_module_name, object_name)) # Classes passed by name are instantiated with no args, functions are # returned as-is. if tf_inspect.isclass(obj): return obj() return obj elif tf_inspect.isfunction(identifier): # If a function has already been deserialized, return as is. return identifier else: raise ValueError('Could not interpret serialized %s: %s' % (printable_module_name, identifier))
def assert_training_default(fn, default_value): arg_spec = tf_inspect.getfullargspec(fn) index = len(arg_spec.args) - arg_spec.args.index('training') self.assertEqual(arg_spec.defaults[-index], default_value)
def _call_full_argspec(self): # Argspec inspection is expensive and the call spec is used often, so it # makes sense to cache the result. return tf_inspect.getfullargspec(self.forward_pass)
def maybe_add_training_arg(original_call, wrapped_call, expects_training_arg, default_training_value): """Decorate call and optionally adds training argument. If a layer expects a training argument, this function ensures that 'training' is present in the layer args or kwonly args, with the default training value. Args: original_call: Original call function. wrapped_call: Wrapped call function. expects_training_arg: Whether to include 'training' argument. default_training_value: Default value of the training kwarg to include in the arg spec. If `None`, the default is `tf.keras.backend.learning_phase()`. Returns: Tuple of ( function that calls `wrapped_call` and sets the training arg, Argspec of returned function or `None` if the argspec is unchanged) """ if not expects_training_arg: return wrapped_call, None def wrap_with_training_arg(*args, **kwargs): """Wrap the `wrapped_call` function, and set training argument.""" training_arg_index = get_training_arg_index(original_call) training = get_training_arg(training_arg_index, args, kwargs) if training is None: training = default_training_value or backend.learning_phase() args = list(args) kwargs = kwargs.copy() def replace_training_and_call(training): set_training_arg(training, training_arg_index, args, kwargs) return wrapped_call(*args, **kwargs) return control_flow_util.smart_cond( training, lambda: replace_training_and_call(True), lambda: replace_training_and_call(False)) # Create arg spec for decorated function. If 'training' is not defined in the # args of the original arg spec, then add it to kwonlyargs. arg_spec = tf_inspect.getfullargspec(original_call) defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else [] kwonlyargs = arg_spec.kwonlyargs kwonlydefaults = arg_spec.kwonlydefaults or {} # Add training arg if it does not exist, or set the default training value. if 'training' not in arg_spec.args: kwonlyargs.append('training') kwonlydefaults['training'] = default_training_value else: index = arg_spec.args.index('training') training_default_index = len(arg_spec.args) - index if (arg_spec.defaults and len(arg_spec.defaults) >= training_default_index and defaults[-training_default_index] is None): defaults[-training_default_index] = default_training_value decorator_argspec = tf_inspect.FullArgSpec( args=arg_spec.args, varargs=arg_spec.varargs, varkw=arg_spec.varkw, defaults=defaults, kwonlyargs=kwonlyargs, kwonlydefaults=kwonlydefaults, annotations=arg_spec.annotations) return wrap_with_training_arg, decorator_argspec