def test_local_definitions_filter(self): tf = types.ModuleType('tf') tf.keras = types.ModuleType('tf.keras') tf.keras.layers = types.ModuleType('tf.keras.layers') tf.keras.layers.Dense = lambda: None tf.keras.layers.Dense.__module__ = 'tf.keras.layers' tf.keras.Dense = tf.keras.layers.Dense tf.layers = types.ModuleType('tf.layers') tf.layers.Dense = tf.keras.layers.Dense def public_members(obj): members = inspect.getmembers(obj) return [(name, value) for name, value in members if not name.startswith('_')] filtered_children = public_api.local_definitions_filter( ('tf', 'keras', 'layers'), tf.keras.layers, public_members(tf.keras.layers)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual(['Dense'], filtered_names) filtered_children = public_api.local_definitions_filter( ('tf', 'keras'), tf.keras, public_members(tf.keras)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual(['layers', 'Dense'], filtered_names) filtered_children = public_api.local_definitions_filter( ('tf', 'layers'), tf.layers, public_members(tf.layers)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual([], filtered_names)
def test_local_definitions_filter(self): tf = types.ModuleType('tf') tf.keras = types.ModuleType('tf.keras') tf.keras.layers = types.ModuleType('tf.keras.layers') tf.keras.layers.Dense = lambda: None tf.keras.layers.Dense.__module__ = 'tf.keras.layers' tf.keras.Dense = tf.keras.layers.Dense tf.layers = types.ModuleType('tf.layers') tf.layers.Dense = tf.keras.layers.Dense def public_members(obj): members = inspect.getmembers(obj) return [ (name, value) for name, value in members if not name.startswith('_') ] filtered_children = public_api.local_definitions_filter( ('tf', 'keras', 'layers'), tf.keras.layers, public_members(tf.keras.layers)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual(['Dense'], filtered_names) filtered_children = public_api.local_definitions_filter( ('tf', 'keras'), tf.keras, public_members(tf.keras)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual(['layers', 'Dense'], filtered_names) filtered_children = public_api.local_definitions_filter( ('tf', 'layers'), tf.layers, public_members(tf.layers)) filtered_names = [name for name, _ in filtered_children] self.assertCountEqual([], filtered_names)
def local_definitions_filter(path, parent, children): """Filters local imports, except for the tfl.layers module.""" if path == ('tfl', 'layers'): return children return public_api.local_definitions_filter(path, parent, children)