Пример #1
0
    def test_docs_for_function_with_kwargs(self):
        index = {
            'test_function_with_args_kwargs': test_function_with_args_kwargs
        }

        visitor = DummyVisitor(index=index, duplicate_of={})

        reference_resolver = parser.ReferenceResolver.from_visitor(
            visitor=visitor, doc_index={}, py_module_names=['tf'])

        tree = {'': ['test_function_with_args_kwargs']}
        parser_config = parser.ParserConfig(
            reference_resolver=reference_resolver,
            duplicates={},
            duplicate_of={},
            tree=tree,
            index=index,
            reverse_index={},
            guide_index={},
            base_dir='/')

        page_info = parser.docs_for_object(
            full_name='test_function_with_args_kwargs',
            py_object=test_function_with_args_kwargs,
            parser_config=parser_config)

        # Make sure the brief docstring is present
        self.assertEqual(
            tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
            page_info.doc.brief)

        # Make sure the extracted signature is good.
        self.assertEqual(['unused_arg', '*unused_args', '**unused_kwargs'],
                         page_info.signature)
Пример #2
0
    def _generic_test(self,
                      f_raw,
                      examples,
                      input_signature=None,
                      skip_modes=None):
        """Test a function `f_raw` against all tests `examples`.

    Args:
      f_raw: a callable.
      examples: A list of `Example` named tuples.
      input_signature: Input signature to tf.function.
      skip_modes: A list of `RunMode` enums to entirely skip testing in the
        specified `RunMode`s. This is necessary when things fail in a certain
        `RunMode` even before executing the function (e.g. during saving or
        loading in `RunMode.SAVED` mode).
    """
        f_tf = None
        if not skip_modes:
            skip_modes = []

        if tf_inspect.isfunction(f_raw):
            self.recordProperty('f', tf_inspect.getsource(f_raw))
        else:
            self.recordProperty('f', tf_inspect.getdoc(f_raw))

        for arg, out, failure, bugs in examples:
            del out
            self.recordProperty('Input "{}"'.format(arg), {
                'not-working': failure,
                'bugs': bugs
            })

        # Run the function without tf.function
        if RunMode.RAW not in skip_modes:
            self._run_and_check(f_raw, RunMode.RAW, examples)

        # TF Function
        if RunMode.FUNCTION not in skip_modes:
            f_tf = tf.function(f_raw, input_signature=input_signature)
            self._run_and_check(f_tf, RunMode.FUNCTION, examples)

        # XLA Function
        if RunMode.XLA not in skip_modes:
            f_xla = tf.function(f_raw,
                                input_signature=input_signature,
                                experimental_compile=True)
            self._run_and_check(f_xla, RunMode.XLA, examples)

        # Write a saved model and try to run it
        if RunMode.SAVED not in skip_modes:
            module = tf.Module()
            if f_tf:
                module.f = f_tf
            else:
                module.f = tf.function(f_raw, input_signature=input_signature)

            saved_model_dir = tempfile.gettempdir()
            tf.saved_model.save(module, saved_model_dir)
            module_loaded = tf.saved_model.load(saved_model_dir)
            self._run_and_check(module_loaded.f, RunMode.SAVED, examples)
Пример #3
0
    def test_docs_for_function(self):
        index = {'test_function': test_function}

        visitor = DummyVisitor(index=index, duplicate_of={})

        reference_resolver = parser.ReferenceResolver.from_visitor(
            visitor=visitor, doc_index={}, py_module_names=['tf'])

        tree = {'': ['test_function']}
        parser_config = parser.ParserConfig(
            reference_resolver=reference_resolver,
            duplicates={},
            duplicate_of={},
            tree=tree,
            index=index,
            reverse_index={},
            guide_index={},
            base_dir='/')

        page_info = parser.docs_for_object(full_name='test_function',
                                           py_object=test_function,
                                           parser_config=parser_config)

        # Make sure the brief docstring is present
        self.assertEqual(
            tf_inspect.getdoc(test_function).split('\n')[0],
            page_info.doc.brief)

        # Make sure the extracted signature is good.
        self.assertEqual(['unused_arg', "unused_kwarg='default'"],
                         page_info.signature)

        # Make sure this file is contained as the definition location.
        self.assertEqual(os.path.relpath(__file__, '/'),
                         page_info.defined_in.path)
Пример #4
0
    def test_docs_for_class(self):

        index = {
            'TestClass': TestClass,
            'TestClass.a_method': TestClass.a_method,
            'TestClass.a_property': TestClass.a_property,
            'TestClass.ChildClass': TestClass.ChildClass,
            'TestClass.CLASS_MEMBER': TestClass.CLASS_MEMBER
        }

        visitor = DummyVisitor(index=index, duplicate_of={})

        reference_resolver = parser.ReferenceResolver.from_visitor(
            visitor=visitor, doc_index={}, py_module_names=['tf'])

        tree = {
            'TestClass':
            ['a_method', 'a_property', 'ChildClass', 'CLASS_MEMBER']
        }
        parser_config = parser.ParserConfig(
            reference_resolver=reference_resolver,
            duplicates={},
            duplicate_of={},
            tree=tree,
            index=index,
            reverse_index={},
            guide_index={},
            base_dir='/')

        page_info = parser.docs_for_object(full_name='TestClass',
                                           py_object=TestClass,
                                           parser_config=parser_config)

        # Make sure the brief docstring is present
        self.assertEqual(
            six.ensure_str(tf_inspect.getdoc(TestClass)).split('\n')[0],
            page_info.doc.brief)

        # Make sure the method is present
        self.assertEqual(TestClass.a_method, page_info.methods[0].obj)

        # Make sure that the signature is extracted properly and omits self.
        self.assertEqual(["arg='default'"], page_info.methods[0].signature)

        # Make sure the property is present
        self.assertIs(TestClass.a_property, page_info.properties[0].obj)

        # Make sure there is a link to the child class and it points the right way.
        self.assertIs(TestClass.ChildClass, page_info.classes[0].obj)

        # Make sure this file is contained as the definition location.
        self.assertEqual(os.path.relpath(__file__, '/'),
                         page_info.defined_in.path)
Пример #5
0
  def test_docs_for_module(self):

    index = {
        'TestModule':
            test_module,
        'TestModule.test_function':
            test_function,
        'TestModule.test_function_with_args_kwargs':
            test_function_with_args_kwargs,
        'TestModule.TestClass':
            TestClass,
    }

    visitor = DummyVisitor(index=index, duplicate_of={})

    reference_resolver = parser.ReferenceResolver.from_visitor(
        visitor=visitor, doc_index={}, py_module_names=['tf'])

    tree = {
        'TestModule': ['TestClass', 'test_function',
                       'test_function_with_args_kwargs']
    }
    parser_config = parser.ParserConfig(
        reference_resolver=reference_resolver,
        duplicates={},
        duplicate_of={},
        tree=tree,
        index=index,
        reverse_index={},
        guide_index={},
        base_dir='/')

    page_info = parser.docs_for_object(
        full_name='TestModule',
        py_object=test_module,
        parser_config=parser_config)

    # Make sure the brief docstring is present
    self.assertEqual(
        tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)

    # Make sure that the members are there
    funcs = {f_info.obj for f_info in page_info.functions}
    self.assertEqual({test_function, test_function_with_args_kwargs}, funcs)

    classes = {cls_info.obj for cls_info in page_info.classes}
    self.assertEqual({TestClass}, classes)

    # Make sure the module's file is contained as the definition location.
    self.assertEqual(
        os.path.relpath(test_module.__file__.rstrip('c'), '/'),
        page_info.defined_in.path)
Пример #6
0
    def test_docs_for_module(self):
        # Get the current module.
        module = sys.modules[__name__]

        index = {
            'TestModule': module,
            'TestModule.test_function': test_function,
            'TestModule.test_function_with_args_kwargs':
            test_function_with_args_kwargs,
            'TestModule.TestClass': TestClass,
        }

        visitor = DummyVisitor(index=index, duplicate_of={})

        reference_resolver = parser.ReferenceResolver.from_visitor(
            visitor=visitor, doc_index={}, py_module_names=['tf'])

        tree = {
            'TestModule':
            ['TestClass', 'test_function', 'test_function_with_args_kwargs']
        }
        parser_config = parser.ParserConfig(
            reference_resolver=reference_resolver,
            duplicates={},
            duplicate_of={},
            tree=tree,
            index=index,
            reverse_index={},
            guide_index={},
            base_dir='/')

        page_info = parser.docs_for_object(full_name='TestModule',
                                           py_object=module,
                                           parser_config=parser_config)

        # Make sure the brief docstring is present
        self.assertEqual(
            tf_inspect.getdoc(module).split('\n')[0], page_info.doc.brief)

        # Make sure that the members are there
        funcs = {f_info.obj for f_info in page_info.functions}
        self.assertEqual({test_function, test_function_with_args_kwargs},
                         funcs)

        classes = {cls_info.obj for cls_info in page_info.classes}
        self.assertEqual({TestClass}, classes)

        # Make sure this file is contained as the definition location.
        self.assertEqual(os.path.relpath(__file__, '/'),
                         page_info.defined_in.path)
Пример #7
0
  def test_docs_for_class(self):

    index = {
        'TestClass': TestClass,
        'TestClass.a_method': TestClass.a_method,
        'TestClass.a_property': TestClass.a_property,
        'TestClass.ChildClass': TestClass.ChildClass,
        'TestClass.CLASS_MEMBER': TestClass.CLASS_MEMBER
    }

    visitor = DummyVisitor(index=index, duplicate_of={})

    reference_resolver = parser.ReferenceResolver.from_visitor(
        visitor=visitor, doc_index={}, py_module_names=['tf'])

    tree = {
        'TestClass': ['a_method', 'a_property', 'ChildClass', 'CLASS_MEMBER']
    }
    parser_config = parser.ParserConfig(
        reference_resolver=reference_resolver,
        duplicates={},
        duplicate_of={},
        tree=tree,
        index=index,
        reverse_index={},
        guide_index={},
        base_dir='/')

    page_info = parser.docs_for_object(
        full_name='TestClass', py_object=TestClass, parser_config=parser_config)

    # Make sure the brief docstring is present
    self.assertEqual(
        tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)

    # Make sure the method is present
    self.assertEqual(TestClass.a_method, page_info.methods[0].obj)

    # Make sure that the signature is extracted properly and omits self.
    self.assertEqual(["arg='default'"], page_info.methods[0].signature)

    # Make sure the property is present
    self.assertIs(TestClass.a_property, page_info.properties[0].obj)

    # Make sure there is a link to the child class and it points the right way.
    self.assertIs(TestClass.ChildClass, page_info.classes[0].obj)

    # Make sure this file is contained as the definition location.
    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
Пример #8
0
def _get_raw_docstring(py_object):
  """Get the docs for a given python object.

  Args:
    py_object: A python object to retrieve the docs for (class, function/method,
      or module).

  Returns:
    The docstring, or the empty string if no docstring was found.
  """
  # For object instances, tf_inspect.getdoc does give us the docstring of their
  # type, which is not what we want. Only return the docstring if it is useful.
  if (tf_inspect.isclass(py_object) or tf_inspect.ismethod(py_object) or
      tf_inspect.isfunction(py_object) or tf_inspect.ismodule(py_object) or
      isinstance(py_object, property)):
    return tf_inspect.getdoc(py_object) or ''
  else:
    return ''
Пример #9
0
def _get_raw_docstring(py_object):
    """Get the docs for a given python object.

  Args:
    py_object: A python object to retrieve the docs for (class, function/method,
      or module).

  Returns:
    The docstring, or the empty string if no docstring was found.
  """
    # For object instances, tf_inspect.getdoc does give us the docstring of their
    # type, which is not what we want. Only return the docstring if it is useful.
    if (tf_inspect.isclass(py_object) or tf_inspect.ismethod(py_object)
            or tf_inspect.isfunction(py_object)
            or tf_inspect.ismodule(py_object)
            or isinstance(py_object, property)):
        return tf_inspect.getdoc(py_object) or ''
    else:
        return ''
Пример #10
0
  def test_docs_for_function(self):
    index = {
        'test_function': test_function
    }

    visitor = DummyVisitor(index=index, duplicate_of={})

    reference_resolver = parser.ReferenceResolver.from_visitor(
        visitor=visitor, doc_index={}, py_module_names=['tf'])

    tree = {
        '': ['test_function']
    }
    parser_config = parser.ParserConfig(
        reference_resolver=reference_resolver,
        duplicates={},
        duplicate_of={},
        tree=tree,
        index=index,
        reverse_index={},
        guide_index={},
        base_dir='/')

    page_info = parser.docs_for_object(
        full_name='test_function',
        py_object=test_function,
        parser_config=parser_config)

    # Make sure the brief docstring is present
    self.assertEqual(
        tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)

    # Make sure the extracted signature is good.
    self.assertEqual(['unused_arg', "unused_kwarg='default'"],
                     page_info.signature)

    # Make sure this file is contained as the definition location.
    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
Пример #11
0
  def test_docs_for_function_with_kwargs(self):
    index = {
        'test_function_with_args_kwargs': test_function_with_args_kwargs
    }

    visitor = DummyVisitor(index=index, duplicate_of={})

    reference_resolver = parser.ReferenceResolver.from_visitor(
        visitor=visitor, doc_index={}, py_module_names=['tf'])

    tree = {
        '': ['test_function_with_args_kwargs']
    }
    parser_config = parser.ParserConfig(
        reference_resolver=reference_resolver,
        duplicates={},
        duplicate_of={},
        tree=tree,
        index=index,
        reverse_index={},
        guide_index={},
        base_dir='/')

    page_info = parser.docs_for_object(
        full_name='test_function_with_args_kwargs',
        py_object=test_function_with_args_kwargs,
        parser_config=parser_config)

    # Make sure the brief docstring is present
    self.assertEqual(
        tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
        page_info.doc.brief)

    # Make sure the extracted signature is good.
    self.assertEqual(['unused_arg', '*unused_args', '**unused_kwargs'],
                     page_info.signature)
Пример #12
0
    def __new__(mcs, classname, baseclasses, attrs):
        """Control the creation of subclasses of the Distribution class.

    The main purpose of this method is to properly propagate docstrings
    from private Distribution methods, like `_log_prob`, into their
    public wrappers as inherited by the Distribution base class
    (e.g. `log_prob`).

    Args:
      classname: The name of the subclass being created.
      baseclasses: A tuple of parent classes.
      attrs: A dict mapping new attributes to their values.

    Returns:
      The class object.

    Raises:
      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
        the new class is derived via multiple inheritance and the first
        parent class is not a subclass of `BaseDistribution`.
      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
      ValueError:  If a `Distribution` public method lacks a docstring.
    """
        if not baseclasses:  # Nothing to be done for Distribution
            raise TypeError("Expected non-empty baseclass. Does Distribution "
                            "not subclass _BaseDistribution?")
        which_base = [
            base for base in baseclasses
            if base == _BaseDistribution or issubclass(base, Distribution)
        ]
        base = which_base[0]
        if base == _BaseDistribution:  # Nothing to be done for Distribution
            return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
        if not issubclass(base, Distribution):
            raise TypeError("First parent class declared for %s must be "
                            "Distribution, but saw '%s'" %
                            (classname, base.__name__))
        for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
            special_attr = "_%s" % attr
            class_attr_value = attrs.get(attr, None)
            if attr in attrs:
                # The method is being overridden, do not update its docstring
                continue
            base_attr_value = getattr(base, attr, None)
            if not base_attr_value:
                raise AttributeError(
                    "Internal error: expected base class '%s' to implement method '%s'"
                    % (base.__name__, attr))
            class_special_attr_value = attrs.get(special_attr, None)
            if class_special_attr_value is None:
                # No _special method available, no need to update the docstring.
                continue
            class_special_attr_docstring = tf_inspect.getdoc(
                class_special_attr_value)
            if not class_special_attr_docstring:
                # No docstring to append.
                continue
            class_attr_value = _copy_fn(base_attr_value)
            class_attr_docstring = tf_inspect.getdoc(base_attr_value)
            if class_attr_docstring is None:
                raise ValueError(
                    "Expected base class fn to contain a docstring: %s.%s" %
                    (base.__name__, attr))
            class_attr_value.__doc__ = _update_docstring(
                class_attr_value.__doc__,
                ("Additional documentation from `%s`:\n\n%s" %
                 (classname, class_special_attr_docstring)))
            attrs[attr] = class_attr_value

        return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
Пример #13
0
  def __new__(mcs, classname, baseclasses, attrs):
    """Control the creation of subclasses of the Distribution class.

    The main purpose of this method is to properly propagate docstrings
    from private Distribution methods, like `_log_prob`, into their
    public wrappers as inherited by the Distribution base class
    (e.g. `log_prob`).

    Args:
      classname: The name of the subclass being created.
      baseclasses: A tuple of parent classes.
      attrs: A dict mapping new attributes to their values.

    Returns:
      The class object.

    Raises:
      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
        the new class is derived via multiple inheritance and the first
        parent class is not a subclass of `BaseDistribution`.
      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
      ValueError:  If a `Distribution` public method lacks a docstring.
    """
    if not baseclasses:  # Nothing to be done for Distribution
      raise TypeError("Expected non-empty baseclass. Does Distribution "
                      "not subclass _BaseDistribution?")
    which_base = [
        base for base in baseclasses
        if base == _BaseDistribution or issubclass(base, Distribution)]
    base = which_base[0]
    if base == _BaseDistribution:  # Nothing to be done for Distribution
      return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
    if not issubclass(base, Distribution):
      raise TypeError("First parent class declared for %s must be "
                      "Distribution, but saw '%s'" % (classname, base.__name__))
    for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
      special_attr = "_%s" % attr
      class_attr_value = attrs.get(attr, None)
      if attr in attrs:
        # The method is being overridden, do not update its docstring
        continue
      base_attr_value = getattr(base, attr, None)
      if not base_attr_value:
        raise AttributeError(
            "Internal error: expected base class '%s' to implement method '%s'"
            % (base.__name__, attr))
      class_special_attr_value = attrs.get(special_attr, None)
      if class_special_attr_value is None:
        # No _special method available, no need to update the docstring.
        continue
      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
      if not class_special_attr_docstring:
        # No docstring to append.
        continue
      class_attr_value = _copy_fn(base_attr_value)
      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
      if class_attr_docstring is None:
        raise ValueError(
            "Expected base class fn to contain a docstring: %s.%s"
            % (base.__name__, attr))
      class_attr_value.__doc__ = _update_docstring(
          class_attr_value.__doc__,
          ("Additional documentation from `%s`:\n\n%s"
           % (classname, class_special_attr_docstring)))
      attrs[attr] = class_attr_value

    return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
Пример #14
0
 def testGetDoc(self):
     self.assertEqual(
         'Test Decorated Function With Defaults Docstring.',
         tf_inspect.getdoc(test_decorated_function_with_defaults))
Пример #15
0
    def __new__(mcs, classname, baseclasses, attrs):
        """Control the creation of subclasses of the Distribution class.

    The main purpose of this method is to properly propagate docstrings
    from private Distribution methods, like `_log_prob`, into their
    public wrappers as inherited by the Distribution base class
    (e.g. `log_prob`).

    Args:
      classname: The name of the subclass being created.
      baseclasses: A tuple of parent classes.
      attrs: A dict mapping new attributes to their values.

    Returns:
      The class object.

    Raises:
      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
        the new class is derived via multiple inheritance and the first
        parent class is not a subclass of `BaseDistribution`.
      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
      ValueError:  If a `Distribution` public method lacks a docstring.
    """
        if not baseclasses:  # Nothing to be done for Distribution
            raise TypeError("Expected non-empty baseclass. Does Distribution "
                            "not subclass _BaseDistribution?")
        which_base = [
            base for base in baseclasses
            if base == _BaseDistribution or issubclass(base, Distribution)
        ]
        base = which_base[0] if which_base else None
        if base is None or base == _BaseDistribution:
            # Nothing to be done for Distribution or unrelated subclass.
            return super(_DistributionMeta,
                         mcs).__new__(mcs, classname, baseclasses, attrs)
        if not issubclass(base, Distribution):
            raise TypeError("First parent class declared for {} must be "
                            "Distribution, but saw '{}'".format(
                                classname, base.__name__))
        for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
            if attr in attrs:
                # The method is being overridden, do not update its docstring.
                continue
            special_attr = "_{}".format(attr)
            class_attr_value = attrs.get(attr, None)
            base_attr_value = getattr(base, attr, None)
            if not base_attr_value:
                raise AttributeError(
                    "Internal error: expected base class '{}' to "
                    "implement method '{}'".format(base.__name__, attr))
            class_special_attr_value = attrs.get(special_attr, None)
            if class_special_attr_value is None:
                # No _special method available, no need to update the docstring.
                continue
            class_special_attr_docstring = tf_inspect.getdoc(
                class_special_attr_value)
            if not class_special_attr_docstring:
                # No docstring to append.
                continue
            class_attr_value = _copy_fn(base_attr_value)
            class_attr_docstring = tf_inspect.getdoc(base_attr_value)
            if class_attr_docstring is None:
                raise ValueError(
                    "Expected base class fn to contain a docstring: {}.{}".
                    format(base.__name__, attr))
            class_attr_value.__doc__ = _update_docstring(
                class_attr_value.__doc__,
                "Additional documentation from `{}`:\n\n{}".format(
                    classname, class_special_attr_docstring))
            attrs[attr] = class_attr_value

        # Now we'll intercept the default __init__ if it exists.
        default_init = attrs.get("__init__", None)
        if default_init is None:
            # The class has no __init__ because its abstract. (And we won't add one.)
            return super(_DistributionMeta,
                         mcs).__new__(mcs, classname, baseclasses, attrs)

        # pylint: disable=protected-access
        @functools.wraps(default_init)
        def wrapped_init(self_, *args, **kwargs):
            """A "master `__init__`" which is always called."""
            # Note: if we ever want to have things set in `self` before `__init__` is
            # called, here is the place to do it.
            self_._parameters = None
            default_init(self_, *args, **kwargs)
            # Note: if we ever want to override things set in `self` by subclass
            # `__init__`, here is the place to do it.
            if self_._parameters is None:
                # We prefer subclasses will set `parameters = dict(locals())` because
                # this has nearly zero overhead. However, failing to do this, we will
                # resolve the input arguments dynamically and only when needed.
                dummy_self = tuple()
                self_._parameters = lambda: (  # pylint: disable=g-long-lambda
                    _remove_dict_keys_with_value(
                        inspect.getcallargs(default_init, dummy_self, *args, **
                                            kwargs), dummy_self))
            elif hasattr(self_._parameters, "pop"):
                self_._parameters = _remove_dict_keys_with_value(
                    self_._parameters, self_)

        # pylint: enable=protected-access

        attrs["__init__"] = wrapped_init
        return super(_DistributionMeta, mcs).__new__(mcs, classname,
                                                     baseclasses, attrs)
Пример #16
0
 def testGetDoc(self):
   self.assertEqual('Test Decorated Function With Defaults Docstring.',
                    tf_inspect.getdoc(test_decorated_function_with_defaults))