def test_local_definitions_filter(self):
        tf = types.ModuleType('tf')
        tf.keras = types.ModuleType('tf.keras')
        tf.keras.layers = types.ModuleType('tf.keras.layers')
        tf.keras.layers.Dense = lambda: None
        tf.keras.layers.Dense.__module__ = 'tf.keras.layers'

        tf.keras.Dense = tf.keras.layers.Dense

        tf.layers = types.ModuleType('tf.layers')
        tf.layers.Dense = tf.keras.layers.Dense

        def public_members(obj):
            members = inspect.getmembers(obj)
            return [(name, value) for name, value in members
                    if not name.startswith('_')]

        filtered_children = public_api.local_definitions_filter(
            ('tf', 'keras', 'layers'), tf.keras.layers,
            public_members(tf.keras.layers))
        filtered_names = [name for name, _ in filtered_children]

        self.assertCountEqual(['Dense'], filtered_names)

        filtered_children = public_api.local_definitions_filter(
            ('tf', 'keras'), tf.keras, public_members(tf.keras))
        filtered_names = [name for name, _ in filtered_children]

        self.assertCountEqual(['layers', 'Dense'], filtered_names)

        filtered_children = public_api.local_definitions_filter(
            ('tf', 'layers'), tf.layers, public_members(tf.layers))
        filtered_names = [name for name, _ in filtered_children]

        self.assertCountEqual([], filtered_names)
示例#2
0
  def test_local_definitions_filter(self):
    tf = types.ModuleType('tf')
    tf.keras = types.ModuleType('tf.keras')
    tf.keras.layers = types.ModuleType('tf.keras.layers')
    tf.keras.layers.Dense = lambda: None
    tf.keras.layers.Dense.__module__ = 'tf.keras.layers'

    tf.keras.Dense = tf.keras.layers.Dense

    tf.layers = types.ModuleType('tf.layers')
    tf.layers.Dense = tf.keras.layers.Dense

    def public_members(obj):
      members = inspect.getmembers(obj)
      return [
          (name, value) for name, value in members if not name.startswith('_')
      ]

    filtered_children = public_api.local_definitions_filter(
        ('tf', 'keras', 'layers'), tf.keras.layers,
        public_members(tf.keras.layers))
    filtered_names = [name for name, _ in filtered_children]

    self.assertCountEqual(['Dense'], filtered_names)

    filtered_children = public_api.local_definitions_filter(
        ('tf', 'keras'), tf.keras, public_members(tf.keras))
    filtered_names = [name for name, _ in filtered_children]

    self.assertCountEqual(['layers', 'Dense'], filtered_names)

    filtered_children = public_api.local_definitions_filter(
        ('tf', 'layers'), tf.layers, public_members(tf.layers))
    filtered_names = [name for name, _ in filtered_children]

    self.assertCountEqual([], filtered_names)
示例#3
0
def local_definitions_filter(path, parent, children):
    """Filters local imports, except for the tfl.layers module."""
    if path == ('tfl', 'layers'):
        return children
    return public_api.local_definitions_filter(path, parent, children)