def from_config(cls, config, custom_objects=None): config = config.copy() globs = globals() if custom_objects: globs = dict(list(globs.items()) + list(custom_objects.items())) function_type = config.pop('function_type') if function_type == 'function': # Simple lookup in custom objects function = deserialize_keras_object( config['function'], custom_objects=custom_objects, printable_module_name='function in Lambda layer') elif function_type == 'lambda': # Unsafe deserialization from bytecode function = func_load(config['function'], globs=globs) else: raise TypeError('Unknown function type:', function_type) # If arguments were numpy array, they have been saved as # list. We need to recover the ndarray if 'arguments' in config: for key in config['arguments']: if isinstance(config['arguments'][key], dict): arg_dict = config['arguments'][key] if 'type' in arg_dict and arg_dict['type'] == 'ndarray': # Overwrite the argument with its numpy translation config['arguments'][key] = np.array(arg_dict['value']) config['function'] = function return cls(**config)
def deserialize_function(serial, function_type): """Deserializes the Keras-serialized function. (De)serializing Python functions from/to bytecode is unsafe. Therefore we also use the function's type as an anonymous function ('lambda') or named function in the Python environment ('function'). In the latter case, this lets us use the Python scope to obtain the function rather than reload it from bytecode. (Note that both cases are brittle!) Keras-deserialized functions do not perform lexical scoping. Any modules that the function requires must be imported within the function itself. This serialization mimicks the implementation in `tf.keras.layers.Lambda`. Args: serial: Serialized Keras object: typically a dict, string, or bytecode. function_type: Python string denoting 'function' or 'lambda'. Returns: function: Function the serialized Keras object represents. #### Examples ```python serial, function_type = serialize_function(lambda x: x) function = deserialize_function(serial, function_type) assert function(2.3) == 2.3 # function is identity ``` """ if function_type == 'function': # Simple lookup in custom objects function = tf.keras.utils.deserialize_keras_object(serial) elif function_type == 'lambda': # Unsafe deserialization from bytecode function = generic_utils.func_load(serial) else: raise TypeError('Unknown function type:', function_type) return function