示例#1
0
def collect_function_renames():
  """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    Set of tuples of the form (current name, new name).
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  renames = set()

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v2 = tf_export.get_v2_names(attr)
      deprecated_api_names = set(api_names_v1) - set(api_names_v2)
      for name in deprecated_api_names:
        renames.add((name, get_canonical_name(api_names_v2, name)))

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  # It is possible that a different function is exported with the
  # same name. For e.g. when creating a different function to
  # rename arguments. Exclude it from renames in this case.
  v2_names = get_all_v2_names()
  renames = set((name, new_name) for name, new_name in renames
                if name not in v2_names)
  return renames
  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map[''].append('contrib')
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
def collect_function_arg_names(function_names):
  """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
  # Map from reordered function name to its arguments.
  function_to_args = {}

  def visit(unused_path, unused_parent, children):
    """Visitor that collects arguments for reordered functions."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v1 = ['tf.%s' % name for name in api_names_v1]
      matches_function_names = any(
          name in function_names for name in api_names_v1)
      if matches_function_names:
        arg_list = tf_inspect.getargspec(attr)[0]
        for name in api_names_v1:
          function_to_args[name] = arg_list

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  return function_to_args
示例#5
0
  def testAllAPI(self):
    if not hasattr(tf.compat, "v2"):
      return

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v2 namespace.
    # Please regenerate the renames file or edit any manual renames if this
    # test fails.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = get_v1_names(attr)
        for name in api_names:
          _, _, _, text = self._upgrade("tf." + name)
          if (text and
              not text.startswith("tf.compat.v1") and
              text not in self.v2_symbols):
            self.assertFalse(
                True, "Symbol %s generated from %s not in v2 API" % (
                    text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
示例#6
0
def extract():
    """Extract docs from tf namespace and write them to disk."""
    visitor = doc_generator_visitor.DocGeneratorVisitor('tf')
    api_visitor = public_api.PublicAPIVisitor(visitor)

    # Access something in contrib so tf.contrib is properly loaded (it's hidden
    # behind lazy loading)
    _ = tf.contrib.__name__

    # Exclude some libaries in contrib from the documentation altogether.
    # TODO(wicke): Shrink this list.
    api_visitor.do_not_descend_map.update({
        'contrib': [
            'compiler',
            'factorization',
            'grid_rnn',
            'labeled_tensor',
            'ndlstm',
            'quantization',
            'session_bundle',
            'slim',
            'solvers',
            'specs',
            'tensor_forest',
            'tensorboard',
            'testing',
            'tfprof',
        ],
        'contrib.bayesflow': [
            'entropy', 'monte_carlo', 'special_math',
            'stochastic_gradient_estimators', 'stochastic_graph',
            'stochastic_tensor', 'stochastic_variables',
            'variational_inference'
        ],
        'contrib.distributions': ['bijector'],
        'contrib.ffmpeg': ['ffmpeg_ops'],
        'contrib.graph_editor':
        ['edit', 'match', 'subgraph', 'transform', 'select', 'util'],
        'contrib.layers': ['feature_column', 'summaries'],
        'contrib.learn': [
            'datasets',
            'head',
            'graph_actions',
            'io',
            'models',
            'monitors',
            'ops',
            'preprocessing',
            'utils',
        ],
        'contrib.util': ['loader'],
    })

    traverse.traverse(tf, api_visitor)

    # tf_debug is not imported with tf, it's a separate module altogether
    visitor.set_root_name('tfdbg')
    traverse.traverse(tf_debug, api_visitor)

    return visitor
 def testNoSubclassOfMessage(self):
     visitor = public_api.PublicAPIVisitor(
         _VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
     # Skip compat.v1 and compat.v2 since they are validated in separate tests.
     visitor.private_map['tf.compat'] = ['v1', 'v2']
     traverse.traverse(tf, visitor)
示例#8
0
def collect_function_renames():
    """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    List of tuples of the form (current name, new name).
  """
    # Set of rename lines to write to output file in the form:
    #   'tf.deprecated_name': 'tf.canonical_name'
    renames = set()

    def visit(unused_path, unused_parent, children):
        """Visitor that collects rename strings to add to rename_line_set."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            if not hasattr(attr, '__dict__'):
                continue
            api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, [])
            api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
            deprecated_api_names = set(api_names_v1) - set(api_names_v2)
            for name in deprecated_api_names:
                renames.add((name, get_canonical_name(api_names_v2, name)))

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)
    return renames
示例#9
0
  def testAllAPI(self):
    if not hasattr(tf.compat, "v2"):
      return

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v2 namespace.
    # Please regenerate the renames file or edit any manual renames if this
    # test fails.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = tf_export.get_v1_names(attr)
        for name in api_names:
          _, _, _, text = self._upgrade("tf." + name)
          if (text and
              not text.startswith("tf.compat.v1") and
              text not in self.v2_symbols):
            self.assertFalse(
                True, "Symbol %s generated from %s not in v2 API" % (
                    text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
示例#10
0
  def testAllAPIV1(self):
    collect = True
    v1_symbols = set([])

    # Symbols which may be generated by the conversion script which do not exist
    # in TF 1.x. This should be a very short list of symbols which are
    # experimental in 1.x but stable for 2.x.
    whitelisted_v2_only_symbols = set(["tf.saved_model.save"])

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v1 namespace.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = tf_export.get_v1_names(attr)
        for name in api_names:
          if collect:
            v1_symbols.add("tf." + name)
          else:
            _, _, _, text = self._upgrade("tf." + name)
            if (text and
                not text.startswith("tf.compat.v1") and
                not text.startswith("tf.estimator") and
                text not in v1_symbols and
                text not in whitelisted_v2_only_symbols):
              self.assertFalse(
                  True, "Symbol %s generated from %s not in v1 API" % (
                      text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
    collect = False
    traverse.traverse(tf.compat.v1, visitor)
示例#11
0
  def testV1KeywordArgNames(self):
    all_keyword_renames = (
        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

    # Visitor that verifies V1 argument names.
    def arg_test_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = tf_export.get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))

    visitor = public_api.PublicAPIVisitor(arg_test_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
def update_renames_v2(output_file_path):
  """Writes a Python dictionary mapping deprecated to canonical API names.

  Args:
    output_file_path: File path to write output to. Any existing contents
      would be replaced.
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  rename_line_set = set()
  # _tf_api_names attribute name
  tensorflow_api_attr = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      if not hasattr(attr, '__dict__'):
        continue
      api_names = attr.__dict__.get(tensorflow_api_attr, [])
      deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
      canonical_name = tf_export.get_canonical_name(
          api_names, deprecated_api_names)
      for name in deprecated_api_names:
        rename_line_set.add('    \'tf.%s\': \'tf.%s\'' % (name, canonical_name))

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  traverse.traverse(tf, visitor)

  renames_file_text = '%srenames = {\n%s\n}\n' % (
      _FILE_HEADER, ',\n'.join(sorted(rename_line_set)))
  file_io.write_string_to_file(output_file_path, renames_file_text)
  def checkBackwardsCompatibility(self, root, golden_file_pattern, api_version):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(root, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    golden_file_list = file_io.get_matching_files(golden_file_pattern)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens,
        api_version=api_version)
示例#14
0
  def setUpClass(cls):
    cls.v2_symbols = {}
    cls.v1_symbols = {}
    if hasattr(tf.compat, "v2"):

      def symbol_collector(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v2 = tf_export.get_v2_names(attr)
          for name in api_names_v2:
            cls.v2_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector)
      traverse.traverse(tf.compat.v2, visitor)

    if hasattr(tf.compat, "v1"):

      def symbol_collector_v1(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v1 = tf_export.get_v1_names(attr)
          for name in api_names_v1:
            cls.v1_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
      traverse.traverse(tf.compat.v1, visitor)
示例#15
0
    def testV1KeywordArgNames(self):
        all_keyword_renames = (
            tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

        # Visitor that verifies V1 argument names.
        def arg_test_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                names_v1 = tf_export.get_v1_names(attr)

                for name in names_v1:
                    name = "tf.%s" % name
                    if name not in all_keyword_renames:
                        continue
                    arg_names_v1 = tf_inspect.getargspec(attr)[0]
                    keyword_renames = all_keyword_renames[name]
                    self.assertEqual(type(keyword_renames), dict)

                    # Assert that v1 function has valid v1 argument names.
                    for from_name, _ in keyword_renames.items():
                        self.assertIn(
                            from_name, arg_names_v1,
                            "%s not found in %s arguments: %s" %
                            (from_name, name, str(arg_names_v1)))

        visitor = public_api.PublicAPIVisitor(arg_test_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
示例#16
0
    def testAllAPIV1(self):
        collect = True
        v1_symbols = set([])

        # Symbols which may be generated by the conversion script which do not exist
        # in TF 1.x. This should be a very short list of symbols which are
        # experimental in 1.x but stable for 2.x.
        whitelisted_v2_only_symbols = set(["tf.saved_model.save"])

        # Converts all symbols in the v1 namespace to the v2 namespace, raising
        # an error if the target of the conversion is not in the v1 namespace.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                api_names = tf_export.get_v1_names(attr)
                for name in api_names:
                    if collect:
                        v1_symbols.add("tf." + name)
                    else:
                        _, _, _, text = self._upgrade("tf." + name)
                        if (text and not text.startswith("tf.compat.v1")
                                and not text.startswith("tf.estimator")
                                and text not in v1_symbols
                                and text not in whitelisted_v2_only_symbols):
                            self.assertFalse(
                                True,
                                "Symbol %s generated from %s not in v1 API" %
                                (text, name))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
        collect = False
        traverse.traverse(tf.compat.v1, visitor)
 def testNoSubclassOfMessageV2(self):
     if not hasattr(tf.compat, 'v2'):
         return
     visitor = public_api.PublicAPIVisitor(
         _VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
     traverse.traverse(tf.compat.v2, visitor)
示例#18
0
    def testAllAPIV1(self):
        collect = True
        v1_symbols = set([])

        # Converts all symbols in the v1 namespace to the v2 namespace, raising
        # an error if the target of the conversion is not in the v1 namespace.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                api_names = get_v1_names(attr)
                for name in api_names:
                    if collect:
                        v1_symbols.add("tf." + name)
                    else:
                        _, _, _, text = self._upgrade("tf." + name)
                        if (text and not text.startswith("tf.compat.v1")
                                and text not in v1_symbols):
                            self.assertFalse(
                                True,
                                "Symbol %s generated from %s not in v1 API" %
                                (text, name))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
        collect = False
        traverse.traverse(tf.compat.v1, visitor)
示例#19
0
    def _checkBackwardsCompatibility(self,
                                     root,
                                     golden_file_patterns,
                                     api_version,
                                     additional_private_map=None,
                                     omit_golden_symbols_map=None):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        public_api_visitor.private_map['tf'].append('contrib')
        if api_version == 2:
            public_api_visitor.private_map['tf'].append('enable_v2_behavior')

        public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
            'Experimental'
        ]
        # Do not descend into these numpy classes because their signatures may be
        # different between internal and OSS.
        public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [
            'bool_', 'complex_', 'complex128', 'complex64', 'float_',
            'float16', 'float32', 'float64', 'inexact', 'int_', 'int16',
            'int32', 'int64', 'int8', 'object_', 'string_', 'uint16', 'uint32',
            'uint64', 'uint8', 'unicode_', 'iinfo'
        ]
        if FLAGS.only_test_core_api:
            public_api_visitor.do_not_descend_map['tf'].extend(
                _NON_CORE_PACKAGES)
        if additional_private_map:
            public_api_visitor.private_map.update(additional_private_map)

        traverse.traverse(root, public_api_visitor)
        proto_dict = visitor.GetProtos()

        # Read all golden files.
        golden_file_list = file_io.get_matching_files(golden_file_patterns)
        if FLAGS.only_test_core_api:
            golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }
        golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                                   omit_golden_symbols_map)

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not fail.
        self._AssertProtoDictEquals(golden_proto_dict,
                                    proto_dict,
                                    verbose=FLAGS.verbose_diffs,
                                    update_goldens=FLAGS.update_goldens,
                                    api_version=api_version)
示例#20
0
 def test_class(self):
     visitor = TestVisitor()
     traverse.traverse(TestVisitor, visitor)
     self.assertEqual(TestVisitor, visitor.call_log[0][1])
     # There are a bunch of other members, but make sure that the ones we know
     # about are there.
     self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]])
     self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]])
示例#21
0
    def test_module(self):
        visitor = TestVisitor()
        traverse.traverse(test_module1, visitor)

        called = [parent for _, parent, _ in visitor.call_log]

        self.assertIn(test_module1.ModuleClass1, called)
        self.assertIn(test_module2.ModuleClass2, called)
示例#22
0
    def test_cycle(self):
        class Cyclist(object):
            pass

        Cyclist.cycle = Cyclist

        visitor = TestVisitor()
        traverse.traverse(Cyclist, visitor)
 def testNoSubclassOfMessageV2(self):
   if not hasattr(tf.compat, 'v2'):
     return
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   if FLAGS.only_test_core_api:
     visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
   traverse.traverse(tf.compat.v2, visitor)
 def testNoSubclassOfMessageV2(self):
   if not hasattr(tf.compat, 'v2'):
     return
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   if FLAGS.only_test_core_api:
     visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
   traverse.traverse(tf_v2, visitor)
示例#25
0
  def test_cycle(self):

    class Cyclist(object):
      pass
    Cyclist.cycle = Cyclist

    visitor = TestVisitor()
    traverse.traverse(Cyclist, visitor)
示例#26
0
  def test_module(self):
    visitor = TestVisitor()
    traverse.traverse(test_module1, visitor)

    called = [parent for _, parent, _ in visitor.call_log]

    self.assertIn(test_module1.ModuleClass1, called)
    self.assertIn(test_module2.ModuleClass2, called)
示例#27
0
  def test_module(self):
    visitor = TestVisitor()
    traverse.traverse(sys.modules[__name__], visitor)

    called = [parent for _, parent, _ in visitor.call_log]

    self.assertIn(TestVisitor, called)
    self.assertIn(TraverseTest, called)
    self.assertIn(traverse, called)
示例#28
0
 def test_class(self):
   visitor = TestVisitor()
   traverse.traverse(TestVisitor, visitor)
   self.assertEqual(TestVisitor,
                    visitor.call_log[0][1])
   # There are a bunch of other members, but make sure that the ones we know
   # about are there.
   self.assertIn('__init__', [name for name, _ in visitor.call_log[0][2]])
   self.assertIn('__call__', [name for name, _ in visitor.call_log[0][2]])
示例#29
0
    def testKeywordArgNames(self):
        if not hasattr(tf.compat, "v2"):
            return

        all_keyword_renames = (
            tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
        v2_name_exceptions = {"verify_shape_is_now_always_true"}

        # Visitor that verifies V1 argument names, converts to V2 and checks
        # V2 argument names.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                names_v1 = get_v1_names(attr)

                for name in names_v1:
                    name = "tf.%s" % name
                    if name not in all_keyword_renames:
                        continue
                    arg_names_v1 = tf_inspect.getargspec(attr)[0]
                    keyword_renames = all_keyword_renames[name]
                    self.assertEqual(type(keyword_renames), dict)

                    # Assert that v1 function has valid v1 argument names.
                    for from_name, _ in keyword_renames.items():
                        self.assertIn(
                            from_name, arg_names_v1,
                            "%s not found in %s arguments: %s" %
                            (from_name, name, str(arg_names_v1)))

                    # Assert that arg names after converting to v2 are present in
                    # v2 function.
                    # 1. First, create an input of the form:
                    #    tf.foo(arg1=val1, arg2=val2, ...)
                    args = ",".join([
                        "%s=%d" % (from_name, from_index) for from_index,
                        from_name in enumerate(keyword_renames.keys())
                    ])
                    text_input = "%s(%s)" % (name, args)
                    # 2. Convert the input to V2.
                    _, _, _, text = self._upgrade(text_input)
                    new_function_name, new_args = get_func_and_args_from_str(
                        text)
                    # 3. Verify V2 function and arguments.
                    # Note: If we rename arguments, new function must be available in 2.0.
                    # We should not be using compat.v1 in this case.
                    self.assertIn(new_function_name, self.v2_symbols)
                    args_v2 = tf_inspect.getargspec(
                        self.v2_symbols[new_function_name])[0]
                    args_v2.extend(v2_name_exceptions)
                    for new_arg in new_args:
                        self.assertIn(new_arg, args_v2)

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
示例#30
0
  def testKeywordArgNames(self):
    if not hasattr(tf.compat, "v2"):
      return

    all_keyword_renames = (
        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
    v2_name_exceptions = {"verify_shape_is_now_always_true"}

    # Visitor that verifies V1 argument names, converts to V2 and checks
    # V2 argument names.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))

          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(keyword_renames.keys())])
          text_input = "%s(%s)" % (name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          # 3. Verify V2 function and arguments.
          # Note: If we rename arguments, new function must be available in 2.0.
          # We should not be using compat.v1 in this case.
          self.assertIn(new_function_name, self.v2_symbols)
          args_v2 = tf_inspect.getargspec(self.v2_symbols[new_function_name])[0]
          args_v2.extend(v2_name_exceptions)
          for new_arg in new_args:
            self.assertIn(new_arg, args_v2)

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
    def testNewAPIBackwardsCompatibility(self):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        public_api_visitor.do_not_descend_map['tf'].append('contrib')
        public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
            'Experimental'
        ]
        # TODO(annarev): Make slide_dataset available in API.
        public_api_visitor.private_map['tf'] = ['slide_dataset']
        traverse.traverse(api, public_api_visitor)

        proto_dict = visitor.GetProtos()

        # Read all golden files.
        expression = os.path.join(
            resource_loader.get_root_dir_with_all_resources(),
            _KeyToFilePath('*'))
        golden_file_list = file_io.get_matching_files(expression)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }

        # user_ops is an empty module. It is currently available in TensorFlow API
        # but we don't keep empty modules in the new API.
        # We delete user_ops from golden_proto_dict to make sure assert passes
        # when diffing new API against goldens.
        # TODO(annarev): remove user_ops from goldens once we switch to new API.
        tf_module = golden_proto_dict['tensorflow'].tf_module
        for i in range(len(tf_module.member)):
            if tf_module.member[i].name == 'user_ops':
                del tf_module.member[i]
                break

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not fail.
        self._AssertProtoDictEquals(
            golden_proto_dict,
            proto_dict,
            verbose=FLAGS.verbose_diffs,
            update_goldens=False,
            additional_missing_object_message=
            'Check if tf_export decorator/call is missing for this symbol.')
  def testNewAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    # TODO(annarev): Make slide_dataset available in API.
    public_api_visitor.private_map['tf'] = ['slide_dataset']
    traverse.traverse(api, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # user_ops is an empty module. It is currently available in TensorFlow API
    # but we don't keep empty modules in the new API.
    # We delete user_ops from golden_proto_dict to make sure assert passes
    # when diffing new API against goldens.
    # TODO(annarev): remove user_ops from goldens once we switch to new API.
    tf_module = golden_proto_dict['tensorflow'].tf_module
    for i in range(len(tf_module.member)):
      if tf_module.member[i].name == 'user_ops':
        del tf_module.member[i]
        break

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=False,
        additional_missing_object_message=
        'Check if tf_export decorator/call is missing for this symbol.')
  def _checkBackwardsCompatibility(self,
                                   root,
                                   golden_file_pattern,
                                   api_version,
                                   additional_private_map=None,
                                   omit_golden_symbols_map=None):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.private_map['tf'] = ['contrib']
    if api_version == 2:
      public_api_visitor.private_map['tf'].append('enable_v2_behavior')

    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    if FLAGS.only_test_core_api:
      public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
    if additional_private_map:
      public_api_visitor.private_map.update(additional_private_map)

    traverse.traverse(root, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Read all golden files.
    golden_file_list = file_io.get_matching_files(golden_file_pattern)
    if FLAGS.only_test_core_api:
      golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }
    golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                               omit_golden_symbols_map)

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens,
        api_version=api_version)
示例#34
0
def extract(py_modules, do_not_descend_map):
  """Extract docs from tf namespace and write them to disk."""
  # Traverse the first module.
  visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0])
  api_visitor = public_api.PublicAPIVisitor(visitor)
  add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)

  traverse.traverse(py_modules[0][1], api_visitor)

  # Traverse all py_modules after the first:
  for module_name, module in py_modules[1:]:
    visitor.set_root_name(module_name)
    traverse.traverse(module, api_visitor)

  return visitor
示例#35
0
def extract(py_modules, do_not_descend_map):
    """Extract docs from tf namespace and write them to disk."""
    # Traverse the first module.
    visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0])
    api_visitor = public_api.PublicAPIVisitor(visitor)
    add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)

    traverse.traverse(py_modules[0][1], api_visitor)

    # Traverse all py_modules after the first:
    for module_name, module in py_modules[1:]:
        visitor.set_root_name(module_name)
        traverse.traverse(module, api_visitor)

    return visitor
def get_all_v2_names():
  """Get a set of function/class names available in TensorFlow 2.0."""
  v2_names = set()  # All op names in TensorFlow 2.0

  def visit(unused_path, unused_parent, children):
    """Visitor that collects TF 2.0 names."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v2 = get_v2_names(attr)
      for name in api_names_v2:
        v2_names.add(name)

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  traverse.traverse(tf.compat.v2, visitor)
  return v2_names
示例#37
0
    def _checkBackwardsCompatibility(
        self,
        root,
        golden_file_patterns,
        api_version,
        additional_private_map=None,
        omit_golden_symbols_map=None,
    ):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor(
            default_path="tensorflow.keras")

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        if additional_private_map:
            public_api_visitor.private_map.update(additional_private_map)
        public_api_visitor.set_root_name("tf.keras")

        traverse.traverse(root, public_api_visitor)
        proto_dict = visitor.GetProtos()

        # Read all golden files.
        golden_file_list = tf.compat.v1.gfile.Glob(golden_file_patterns)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }
        golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                                   omit_golden_symbols_map)

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not
        # fail.
        self._AssertProtoDictEquals(
            golden_proto_dict,
            proto_dict,
            verbose=FLAGS.verbose_diffs,
            update_goldens=FLAGS.update_goldens,
            api_version=api_version,
        )
  def testNoSubclassOfMessage(self):

    def Visit(path, parent, unused_children):
      """A Visitor that crashes on subclasses of generated proto classes."""
      # If the traversed object is a proto Message class
      if not (isinstance(parent, type) and
              issubclass(parent, message.Message)):
        return
      if parent is message.Message:
        return
      # Check that it is a direct subclass of Message.
      if message.Message not in parent.__bases__:
        raise NotImplementedError(
            'Object tf.%s is a subclass of a generated proto Message. '
            'They are not yet supported by the API tools.' % path)
    visitor = public_api.PublicAPIVisitor(Visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, visitor)
  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # TODO(annarev): remove once we switch to using tf_export decorators.
    tf_module = golden_proto_dict['tensorflow'].tf_module
    for i in range(len(tf_module.member)):
      if tf_module.member[i].name == 'math':
        del tf_module.member[i]
        break

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
    def testAllAPI(self):
        if not hasattr(tf.compat, "v2"):
            return

        v2_symbols = set([])
        attr_v2 = tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names

        def symbol_collector(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                if not hasattr(attr, "__dict__"):
                    continue
                api_names_v2 = attr.__dict__.get(attr_v2, [])
                for name in api_names_v2:
                    v2_symbols.add("tf." + name)

        visitor = public_api.PublicAPIVisitor(symbol_collector)
        traverse.traverse(tf.compat.v2, visitor)

        attr_v1 = (tf_export.API_ATTRS_V1[tf_export.TENSORFLOW_API_NAME].names)

        # Converts all symbols in the v1 namespace to the v2 namespace, raising
        # an error if the target of the conversion is not in the v2 namespace.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                if not hasattr(attr, "__dict__"):
                    continue
                api_names = attr.__dict__.get(attr_v1, [])
                for name in api_names:
                    _, _, _, text = self._upgrade("tf." + name)
                    if (text and not text.startswith("tf.compat.v1")
                            and text not in v2_symbols):
                        self.assertFalse(
                            True, "Symbol %s generated from %s not in v2 API" %
                            (text, name))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
示例#41
0
def extract(py_modules,
            private_map,
            do_not_descend_map,
            visitor_cls=doc_generator_visitor.DocGeneratorVisitor):
  """Extract docs from tf namespace and write them to disk."""
  # Traverse the first module.
  visitor = visitor_cls(py_modules[0][0])
  api_visitor = DocControlsAwareCrawler(visitor)
  api_visitor.set_root_name(py_modules[0][0])
  add_dict_to_dict(private_map, api_visitor.private_map)
  add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)

  traverse.traverse(py_modules[0][1], api_visitor)

  # Traverse all py_modules after the first:
  for module_name, module in py_modules[1:]:
    visitor.set_root_name(module_name)
    api_visitor.set_root_name(module_name)
    traverse.traverse(module, api_visitor)

  return visitor
def extract(py_modules,
            private_map,
            do_not_descend_map,
            visitor_cls=doc_generator_visitor.DocGeneratorVisitor):
    """Extract docs from tf namespace and write them to disk."""
    # Traverse the first module.
    visitor = visitor_cls(py_modules[0][0])
    api_visitor = DocControlsAwareCrawler(visitor)
    api_visitor.set_root_name(py_modules[0][0])
    add_dict_to_dict(private_map, api_visitor.private_map)
    add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)

    traverse.traverse(py_modules[0][1], api_visitor)

    # Traverse all py_modules after the first:
    for module_name, module in py_modules[1:]:
        visitor.set_root_name(module_name)
        api_visitor.set_root_name(module_name)
        traverse.traverse(module, api_visitor)

    return visitor
def collect_function_arg_names(function_names):
  """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
  # Map from reordered function name to its arguments.
  function_to_args = {}

  def visit(unused_path, unused_parent, children):
    """Visitor that collects arguments for reordered functions."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v1 = ['tf.%s' % name for name in api_names_v1]
      matches_function_names = any(
          name in function_names for name in api_names_v1)
      if matches_function_names:
        if tf_inspect.isclass(attr):
          # Get constructor arguments if attr is a class
          arg_list = tf_inspect.getargspec(
              getattr(attr, '__init__'))[0]
          arg_list = arg_list[1:]  # skip 'self' argument
        else:
          # Get function arguments.
          # getargspec returns a tuple of (args, varargs, keywords, defaults)
          # we just look at args.
          arg_list = tf_inspect.getargspec(attr)[0]
        for name in api_names_v1:
          function_to_args[name] = arg_list

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  return function_to_args
示例#44
0
def collect_function_arg_names(function_names):
    """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
    # Map from reordered function name to its arguments.
    function_to_args = {}

    def visit(unused_path, unused_parent, children):
        """Visitor that collects arguments for reordered functions."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            api_names_v1 = tf_export.get_v1_names(attr)
            api_names_v1 = ['tf.%s' % name for name in api_names_v1]
            matches_function_names = any(name in function_names
                                         for name in api_names_v1)
            if matches_function_names:
                if tf_inspect.isclass(attr):
                    # Get constructor arguments if attr is a class
                    arg_list = tf_inspect.getargspec(getattr(attr,
                                                             '__init__'))[0]
                    arg_list = arg_list[1:]  # skip 'self' argument
                else:
                    # Get function arguments.
                    # getargspec returns a tuple of (args, varargs, keywords, defaults)
                    # we just look at args.
                    arg_list = tf_inspect.getargspec(attr)[0]
                for name in api_names_v1:
                    function_to_args[name] = arg_list

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)

    return function_to_args
  def testNewAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    # TODO(annarev): Make slide_dataset available in API.
    public_api_visitor.private_map['tf'] = ['slide_dataset']
    traverse.traverse(api, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=False,
        additional_missing_object_message=
        'Check if tf_export decorator/call is missing for this symbol.')
def collect_function_renames():
  """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    List of tuples of the form (current name, new name).
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  renames = set()
  v2_names = set()  # All op names in TensorFlow 2.0

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      if not hasattr(attr, '__dict__'):
        continue
      api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, [])
      api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
      deprecated_api_names = set(api_names_v1) - set(api_names_v2)
      for name in deprecated_api_names:
        renames.add((name, get_canonical_name(api_names_v2, name)))
      for name in api_names_v2:
        v2_names.add(name)

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  # It is possible that a different function is exported with the
  # same name. For e.g. when creating a different function to
  # rename arguments. Exclude it from renames in this case.
  renames = {name: new_name for name, new_name in renames.items()
             if name not in v2_names}
  return renames
  def testAPIDefCompatibility(self):
    # Get base ApiDef
    name_to_base_api_def = self._GetBaseApiMap()
    # Extract Python API
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Map from first character of op name to Python ApiDefs.
    api_def_map = defaultdict(api_def_pb2.ApiDefs)
    # We need to override all endpoints even if 1 endpoint differs from base
    # ApiDef. So, we first create a map from an op to all its endpoints.
    op_to_endpoint_name = defaultdict(list)

    # Generate map from generated python op to endpoint names.
    for public_module, value in proto_dict.items():
      module_obj = _GetSymbol(public_module)
      for sym in value.tf_module.member_method:
        obj = getattr(module_obj, sym.name)

        # Check if object is defined in gen_* module. That is,
        # the object has been generated from OpDef.
        if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
          if obj.__name__ not in name_to_base_api_def:
            # Symbol might be defined only in Python and not generated from
            # C++ api.
            continue
          relative_public_module = public_module[len('tensorflow.'):]
          full_name = (relative_public_module + '.' + sym.name
                       if relative_public_module else sym.name)
          op_to_endpoint_name[obj].append(full_name)

    # Generate Python ApiDef overrides.
    for op, endpoint_names in op_to_endpoint_name.items():
      api_def = self._CreatePythonApiDef(
          name_to_base_api_def[op.__name__], endpoint_names)
      if api_def:
        api_defs = api_def_map[op.__name__[0].upper()]
        api_defs.op.extend([api_def])

    for key in _ALPHABET:
      # Get new ApiDef for the given key.
      new_api_defs_str = ''
      if key in api_def_map:
        new_api_defs = api_def_map[key]
        new_api_defs.op.sort(key=attrgetter('graph_op_name'))
        new_api_defs_str = str(new_api_defs)

      # Get current ApiDef for the given key.
      api_defs_file_path = os.path.join(
          _PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
      old_api_defs_str = ''
      if file_io.file_exists(api_defs_file_path):
        old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)

      if old_api_defs_str == new_api_defs_str:
        continue

      if FLAGS.update_goldens:
        if not new_api_defs_str:
          logging.info('Deleting %s...' % api_defs_file_path)
          file_io.delete_file(api_defs_file_path)
        else:
          logging.info('Updating %s...' % api_defs_file_path)
          file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
      else:
        self.assertMultiLineEqual(
            old_api_defs_str, new_api_defs_str,
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.')
示例#48
0
 def test_non_class(self):
   integer = 5
   visitor = TestVisitor()
   traverse.traverse(integer, visitor)
   self.assertEqual([], visitor.call_log)
  def testAPIDefCompatibility(self):
    # Get base ApiDef
    name_to_base_api_def = self._GetBaseApiMap()
    snake_to_camel_graph_op_names = {
        self._GenerateLowerCaseOpName(name): name
        for name in name_to_base_api_def.keys()}
    # Extract Python API
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Map from file path to Python ApiDefs.
    new_api_defs_map = defaultdict(api_def_pb2.ApiDefs)
    # We need to override all endpoints even if 1 endpoint differs from base
    # ApiDef. So, we first create a map from an op to all its endpoints.
    op_to_endpoint_name = defaultdict(list)

    # Generate map from generated python op to endpoint names.
    for public_module, value in proto_dict.items():
      module_obj = _GetSymbol(public_module)
      for sym in value.tf_module.member_method:
        obj = getattr(module_obj, sym.name)

        # Check if object is defined in gen_* module. That is,
        # the object has been generated from OpDef.
        if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
          if obj.__name__ not in snake_to_camel_graph_op_names:
            # Symbol might be defined only in Python and not generated from
            # C++ api.
            continue
          relative_public_module = public_module[len('tensorflow.'):]
          full_name = (relative_public_module + '.' + sym.name
                       if relative_public_module else sym.name)
          op_to_endpoint_name[obj].append(full_name)

    # Generate Python ApiDef overrides.
    for op, endpoint_names in op_to_endpoint_name.items():
      graph_op_name = snake_to_camel_graph_op_names[op.__name__]
      api_def = self._CreatePythonApiDef(
          name_to_base_api_def[graph_op_name], endpoint_names)

      if api_def:
        file_path = _GetApiDefFilePath(graph_op_name)
        api_defs = new_api_defs_map[file_path]
        api_defs.op.extend([api_def])

    self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map)

    old_api_defs_map = _GetGoldenApiDefs()
    for file_path, new_api_defs in new_api_defs_map.items():
      # Get new ApiDef string.
      new_api_defs_str = str(new_api_defs)

      # Get current ApiDef for the given file.
      old_api_defs_str = (
          old_api_defs_map[file_path] if file_path in old_api_defs_map else '')

      if old_api_defs_str == new_api_defs_str:
        continue

      if FLAGS.update_goldens:
        logging.info('Updating %s...' % file_path)
        file_io.write_string_to_file(file_path, new_api_defs_str)
      else:
        self.assertMultiLineEqual(
            old_api_defs_str, new_api_defs_str,
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.')

    for file_path in set(old_api_defs_map) - set(new_api_defs_map):
      if FLAGS.update_goldens:
        logging.info('Deleting %s...' % file_path)
        file_io.delete_file(file_path)
      else:
        self.fail(
            '%s file is no longer needed and should be removed.'
            'To update golden API files, run api_compatibility_test locally '
            'with --update_goldens=True flag.' % file_path)
示例#50
0
  def testV2KeywordArgNames(self):
    # This test converts a call of the form:
    # tf.foo(arg1=0, arg2=1, ...)
    # to 2.0. Then, checks that converted function has valid argument names.
    if not hasattr(tf.compat, "v2"):
      return
    v2_arg_exceptions = {
        "verify_shape_is_now_always_true",
        # These arguments should not be used, they just specify
        # that a function takes named arguments.
        "keyword_required",
        "_sentinel",
    }
    v1_name_exceptions = {
        "tf.print",  # requires print_function import
    }
    function_warnings = (
        tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
    function_transformers = (
        tf_upgrade_v2.TFAPIChangeSpec().function_transformers)
    keyword_renames = (
        tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

    # Visitor that converts to V2 and checks V2 argument names.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        if not tf_inspect.isfunction(attr):
          continue
        names_v1 = tf_export.get_v1_names(attr)
        arg_names_v1 = get_args(attr)

        for name in names_v1:
          tf_name = "tf.%s" % name
          if tf_name in function_warnings or tf_name in function_transformers:
            continue  # These require manual change
          if tf_name in v1_name_exceptions:
            continue
          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(arg_names_v1)])
          text_input = "%s(%s)" % (tf_name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          if new_function_name == "tf.compat.v1.%s" % name:
            if tf_name in keyword_renames:
              # If we rename arguments, new function must be available in 2.0.
              # We should not be using compat.v1 in this case.
              self.assertFalse(
                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
                  (new_function_name, text_input, text))
            continue
          # 3. Verify V2 function and arguments.
          args_v2 = get_args(self.v2_symbols[new_function_name])
          args_v2.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v2,
                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v2)))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
示例#51
0
def extract():
  """Extract docs from tf namespace and write them to disk."""
  visitor = doc_generator_visitor.DocGeneratorVisitor('tf')
  api_visitor = public_api.PublicAPIVisitor(visitor)

  # Access something in contrib so tf.contrib is properly loaded (it's hidden
  # behind lazy loading)
  _ = tf.contrib.__name__

  # Exclude some libaries in contrib from the documentation altogether.
  # TODO(wicke): Shrink this list.
  api_visitor.do_not_descend_map.update({
      'contrib': [
          'compiler',
          'factorization',
          'grid_rnn',
          'labeled_tensor',
          'ndlstm',
          'quantization',
          'session_bundle',
          'slim',
          'solvers',
          'specs',
          'tensor_forest',
          'tensorboard',
          'testing',
          'tfprof',
          'training',
      ],
      'contrib.bayesflow': [
          'entropy', 'monte_carlo',
          'special_math', 'stochastic_gradient_estimators',
          'stochastic_graph', 'stochastic_tensor',
          'stochastic_variables', 'variational_inference'
      ],
      'contrib.distributions': ['bijector'],
      'contrib.graph_editor': [
          'edit',
          'match',
          'reroute',
          'subgraph',
          'transform',
          'select',
          'util'
      ],
      'contrib.layers': [
          'feature_column',
          'summaries'
      ],
      'contrib.learn': [
          'datasets',
          'head',
          'graph_actions',
          'io',
          'models',
          'monitors',
          'ops',
          'preprocessing',
          'utils',
      ],
      'contrib.util': ['loader'],
  })

  traverse.traverse(tf, api_visitor)

  return visitor
示例#52
0
 def testNoSubclassOfMessageV2(self):
   if not hasattr(tf.compat, 'v2'):
     return
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   traverse.traverse(tf_v2, visitor)
 def testNoSubclassOfMessage(self):
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   # Skip compat.v1 and compat.v2 since they are validated in separate tests.
   visitor.private_map['tf.compat'] = ['v1', 'v2']
   traverse.traverse(tf, visitor)