Exemplo n.º 1
0
def is_whitelisted(o,
                   check_call_override=True,
                   allow_namedtuple_subclass=False):
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.
    allow_namedtuple_subclass: Reserved for internal use. When `True`,
      namedtuple subclasses are not whitelisted.

  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    # Examples of callables that lack a __module__ property include builtins.
    if hasattr(m, '__name__'):
        for rule in config.CONVERSION_RULES:
            action = rule.get_action(m)
            if action == config.Action.CONVERT:
                logging.log(2, 'Not whitelisted: %s: %s', o, rule)
                return False
            elif action == config.Action.DO_NOT_CONVERT:
                logging.log(2, 'Whitelisted: %s: %s', o, rule)
                return True

    if tf_inspect.isgeneratorfunction(o):
        logging.warn(
            'Entity %s appears to be a generator function. It will not be converted'
            ' by AutoGraph.', o)
        logging.log(2,
                    'Whitelisted: %s: generator functions are not converted',
                    o)
        return True

    if (check_call_override and not tf_inspect.isclass(o)
            and hasattr(o, '__call__')):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is function.TfMethodTarget:
            owner_class = o.__self__.target_class
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted(owner_class,
                              check_call_override=False,
                              allow_namedtuple_subclass=True):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if allow_namedtuple_subclass:
            if not any(
                    inspect_utils.isnamedtuple(base) for base in o.__bases__):
                logging.log(2, 'Whitelisted: %s: named tuple', o)
                return True
        else:
            logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
            return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
Exemplo n.º 2
0
def is_whitelisted_for_graph(o, check_call_override=True):
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.

  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    if hasattr(m, '__name__'):
        # Builtins typically have unnamed modules.
        for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
            if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
                logging.log(2, 'Whitelisted: %s: name starts with "%s"', o,
                            prefix)
                return True

    if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
        logging.log(2, 'Whitelisted: %s: already converted', o)
        return True

    if tf_inspect.isgeneratorfunction(o):
        logging.warn(
            'Entity {} appears to be a generator function. It will not be converted'
            ' by AutoGraph.'.format(o), 1)
        logging.log(2,
                    'Whitelisted: %s: generator functions are not converted',
                    o)
        return True

    if check_call_override and hasattr(o, '__call__'):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(
                o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            is_call_override = (o.__name__ == '__call__')
            if is_whitelisted_for_graph(
                    owner_class, check_call_override=not is_call_override):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.warn(
                'Entity {} looks like a namedtuple subclass. Its constructor will'
                ' not be converted by AutoGraph, but if it has any custom methods,'
                ' those will be.'.format(o), 1)
        logging.log(2, 'Whitelisted: %s: named tuple', o)
        return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
Exemplo n.º 3
0
def is_whitelisted_for_graph(o):
  """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.

  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)

  if hasattr(m, '__name__'):
    # Builtins typically have unnamed modules.
    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
      if m.__name__.startswith(prefix):
        logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
        return True

    # Temporary -- whitelist tensorboard modules.
    # TODO(b/122731813): Remove.
    if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
      logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
      return True

  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
    logging.log(2, 'Whitelisted: %s: already converted', o)
    return True

  if tf_inspect.isgeneratorfunction(o):
    logging.warn(
        'Entity {} appears to be a generator function. It will not be converted'
        ' by AutoGraph.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
    return True

  if hasattr(o, '__call__'):
    # Callable objects: whitelisted if their __call__ method is.
    # The type check avoids infinite recursion around the __call__ method
    # of function objects.
    if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__):  # pylint: disable=unidiomatic-typecheck
      logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
      return True

  owner_class = None
  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      if issubclass(owner_class, unittest.TestCase):
        logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
        return True

      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                    owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.warn(
          'Entity {} looks like a namedtuple subclass. Its constructor will'
          ' not be converted by AutoGraph, but if it has any custom methods,'
          ' those will be.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: named tuple', o)
    return True

  logging.log(2, 'Not whitelisted: %s: default rule', o)
  return False