예제 #1
0
  def setUpClass(cls):
    cls.v2_symbols = {}
    cls.v1_symbols = {}
    if hasattr(tf.compat, "v2"):

      def symbol_collector(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v2 = tf_export.get_v2_names(attr)
          for name in api_names_v2:
            cls.v2_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector)
      traverse.traverse(tf.compat.v2, visitor)

    if hasattr(tf.compat, "v1"):

      def symbol_collector_v1(unused_path, unused_parent, children):
        for child in children:
          _, attr = tf_decorator.unwrap(child[1])
          api_names_v1 = tf_export.get_v1_names(attr)
          for name in api_names_v1:
            cls.v1_symbols["tf." + name] = attr

      visitor = public_api.PublicAPIVisitor(symbol_collector_v1)
      traverse.traverse(tf.compat.v1, visitor)
예제 #2
0
    def testAllAPI(self):
        if not hasattr(tf.compat, "v2"):
            return

        v2_symbols = set([])

        def symbol_collector(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                api_names_v2 = get_v2_names(attr)
                for name in api_names_v2:
                    v2_symbols.add("tf." + name)

        visitor = public_api.PublicAPIVisitor(symbol_collector)
        traverse.traverse(tf.compat.v2, visitor)

        # Converts all symbols in the v1 namespace to the v2 namespace, raising
        # an error if the target of the conversion is not in the v2 namespace.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                api_names = get_v1_names(attr)
                for name in api_names:
                    _, _, _, text = self._upgrade("tf." + name)
                    if (text and not text.startswith("tf.compat.v1")
                            and text not in v2_symbols):
                        self.assertFalse(
                            True, "Symbol %s generated from %s not in v2 API" %
                            (text, name))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
def collect_function_arg_names(function_names):
  """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
  # Map from reordered function name to its arguments.
  function_to_args = {}

  def visit(unused_path, unused_parent, children):
    """Visitor that collects arguments for reordered functions."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v1 = ['tf.%s' % name for name in api_names_v1]
      matches_function_names = any(
          name in function_names for name in api_names_v1)
      if matches_function_names:
        arg_list = tf_inspect.getargspec(attr)[0]
        for name in api_names_v1:
          function_to_args[name] = arg_list

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  return function_to_args
  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map[''].append('contrib')
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
예제 #5
0
  def checkBackwardsCompatibility(self, root, golden_file_pattern, api_version):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(root, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    golden_file_list = file_io.get_matching_files(golden_file_pattern)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens,
        api_version=api_version)
예제 #6
0
  def testAllAPI(self):
    if not hasattr(tf.compat, "v2"):
      return

    # Converts all symbols in the v1 namespace to the v2 namespace, raising
    # an error if the target of the conversion is not in the v2 namespace.
    # Please regenerate the renames file or edit any manual renames if this
    # test fails.
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        api_names = tf_export.get_v1_names(attr)
        for name in api_names:
          _, _, _, text = self._upgrade("tf." + name)
          if (text and
              not text.startswith("tf.compat.v1") and
              text not in self.v2_symbols and
              # Builds currently install old version of estimator that doesn't
              # have some 2.0 symbols.
              not text.startswith("tf.estimator")):
            self.assertFalse(
                True, "Symbol %s generated from %s not in v2 API" % (
                    text, name))

    visitor = public_api.PublicAPIVisitor(conversion_visitor)
    visitor.do_not_descend_map["tf"].append("contrib")
    visitor.private_map["tf.compat"] = ["v1", "v2"]
    traverse.traverse(tf.compat.v1, visitor)
예제 #7
0
def collect_function_renames():
  """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    Set of tuples of the form (current name, new name).
  """
  # Set of rename lines to write to output file in the form:
  #   'tf.deprecated_name': 'tf.canonical_name'
  renames = set()

  def visit(unused_path, unused_parent, children):
    """Visitor that collects rename strings to add to rename_line_set."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v1 = tf_export.get_v1_names(attr)
      api_names_v2 = tf_export.get_v2_names(attr)
      deprecated_api_names = set(api_names_v1) - set(api_names_v2)
      for name in deprecated_api_names:
        renames.add((name, get_canonical_name(api_names_v2, name)))

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
  traverse.traverse(tf, visitor)

  # It is possible that a different function is exported with the
  # same name. For e.g. when creating a different function to
  # rename arguments. Exclude it from renames in this case.
  v2_names = get_all_v2_names()
  renames = set((name, new_name) for name, new_name in renames
                if name not in v2_names)
  return renames
예제 #8
0
    def testV1KeywordArgNames(self):
        all_keyword_renames = (
            tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

        # Visitor that verifies V1 argument names.
        def arg_test_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                names_v1 = tf_export.get_v1_names(attr)

                for name in names_v1:
                    name = "tf.%s" % name
                    if name not in all_keyword_renames:
                        continue
                    arg_names_v1 = tf_inspect.getargspec(attr)[0]
                    keyword_renames = all_keyword_renames[name]
                    self.assertEqual(type(keyword_renames), dict)

                    # Assert that v1 function has valid v1 argument names.
                    for from_name, _ in keyword_renames.items():
                        self.assertIn(
                            from_name, arg_names_v1,
                            "%s not found in %s arguments: %s" %
                            (from_name, name, str(arg_names_v1)))

        visitor = public_api.PublicAPIVisitor(arg_test_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
예제 #9
0
 def testNoSubclassOfMessageV2(self):
     if not hasattr(tf.compat, 'v2'):
         return
     visitor = public_api.PublicAPIVisitor(
         _VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
     traverse.traverse(tf.compat.v2, visitor)
예제 #10
0
def extract():
    """Extract docs from tf namespace and write them to disk."""
    visitor = doc_generator_visitor.DocGeneratorVisitor('tf')
    api_visitor = public_api.PublicAPIVisitor(visitor)

    # Access something in contrib so tf.contrib is properly loaded (it's hidden
    # behind lazy loading)
    _ = tf.contrib.__name__

    # Exclude some libaries in contrib from the documentation altogether.
    # TODO(wicke): Shrink this list.
    api_visitor.do_not_descend_map.update({
        'contrib': [
            'compiler',
            'factorization',
            'grid_rnn',
            'labeled_tensor',
            'ndlstm',
            'quantization',
            'session_bundle',
            'slim',
            'solvers',
            'specs',
            'tensor_forest',
            'tensorboard',
            'testing',
            'tfprof',
        ],
        'contrib.bayesflow': [
            'entropy', 'monte_carlo', 'special_math',
            'stochastic_gradient_estimators', 'stochastic_graph',
            'stochastic_tensor', 'stochastic_variables',
            'variational_inference'
        ],
        'contrib.distributions': ['bijector'],
        'contrib.ffmpeg': ['ffmpeg_ops'],
        'contrib.graph_editor':
        ['edit', 'match', 'subgraph', 'transform', 'select', 'util'],
        'contrib.layers': ['feature_column', 'summaries'],
        'contrib.learn': [
            'datasets',
            'head',
            'graph_actions',
            'io',
            'models',
            'monitors',
            'ops',
            'preprocessing',
            'utils',
        ],
        'contrib.util': ['loader'],
    })

    traverse.traverse(tf, api_visitor)

    # tf_debug is not imported with tf, it's a separate module altogether
    visitor.set_root_name('tfdbg')
    traverse.traverse(tf_debug, api_visitor)

    return visitor
예제 #11
0
 def test_private_child_removal(self):
     visitor = self.TestVisitor()
     children = [('name1', 'thing1'), ('_name2', 'thing2')]
     public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
     # Make sure the private symbols are removed before the visitor is called.
     self.assertEqual([('name1', 'thing1')], visitor.last_children)
     self.assertEqual([('name1', 'thing1')], children)
예제 #12
0
def collect_function_renames():
    """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    List of tuples of the form (current name, new name).
  """
    # Set of rename lines to write to output file in the form:
    #   'tf.deprecated_name': 'tf.canonical_name'
    renames = set()

    def visit(unused_path, unused_parent, children):
        """Visitor that collects rename strings to add to rename_line_set."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            if not hasattr(attr, '__dict__'):
                continue
            api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, [])
            api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
            deprecated_api_names = set(api_names_v1) - set(api_names_v2)
            for name in deprecated_api_names:
                renames.add((name, get_canonical_name(api_names_v2, name)))

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)
    return renames
예제 #13
0
 def testNoSubclassOfMessage(self):
     visitor = public_api.PublicAPIVisitor(
         _VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
     # Skip compat.v1 and compat.v2 since they are validated in separate tests.
     visitor.private_map['tf.compat'] = ['v1', 'v2']
     traverse.traverse(tf, visitor)
예제 #14
0
    def testAllAPIV1(self):
        collect = True
        v1_symbols = set([])

        # Symbols which may be generated by the conversion script which do not exist
        # in TF 1.x. This should be a very short list of symbols which are
        # experimental in 1.x but stable for 2.x.
        whitelisted_v2_only_symbols = set(["tf.saved_model.save"])

        # Converts all symbols in the v1 namespace to the v2 namespace, raising
        # an error if the target of the conversion is not in the v1 namespace.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                api_names = tf_export.get_v1_names(attr)
                for name in api_names:
                    if collect:
                        v1_symbols.add("tf." + name)
                    else:
                        _, _, _, text = self._upgrade("tf." + name)
                        if (text and not text.startswith("tf.compat.v1")
                                and not text.startswith("tf.estimator")
                                and text not in v1_symbols
                                and text not in whitelisted_v2_only_symbols):
                            self.assertFalse(
                                True,
                                "Symbol %s generated from %s not in v1 API" %
                                (text, name))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
        collect = False
        traverse.traverse(tf.compat.v1, visitor)
예제 #15
0
    def _checkBackwardsCompatibility(self,
                                     root,
                                     golden_file_patterns,
                                     api_version,
                                     additional_private_map=None,
                                     omit_golden_symbols_map=None):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        public_api_visitor.private_map['tf'].append('contrib')
        if api_version == 2:
            public_api_visitor.private_map['tf'].append('enable_v2_behavior')

        public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
            'Experimental'
        ]
        # Do not descend into these numpy classes because their signatures may be
        # different between internal and OSS.
        public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [
            'bool_', 'complex_', 'complex128', 'complex64', 'float_',
            'float16', 'float32', 'float64', 'inexact', 'int_', 'int16',
            'int32', 'int64', 'int8', 'object_', 'string_', 'uint16', 'uint32',
            'uint64', 'uint8', 'unicode_', 'iinfo'
        ]
        if FLAGS.only_test_core_api:
            public_api_visitor.do_not_descend_map['tf'].extend(
                _NON_CORE_PACKAGES)
        if additional_private_map:
            public_api_visitor.private_map.update(additional_private_map)

        traverse.traverse(root, public_api_visitor)
        proto_dict = visitor.GetProtos()

        # Read all golden files.
        golden_file_list = file_io.get_matching_files(golden_file_patterns)
        if FLAGS.only_test_core_api:
            golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }
        golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                                   omit_golden_symbols_map)

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not fail.
        self._AssertProtoDictEquals(golden_proto_dict,
                                    proto_dict,
                                    verbose=FLAGS.verbose_diffs,
                                    update_goldens=FLAGS.update_goldens,
                                    api_version=api_version)
예제 #16
0
 def test_call_forward(self):
     visitor = self.TestVisitor()
     children = [('name1', 'thing1'), ('name2', 'thing2')]
     public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
     self.assertEqual(set(['test']), visitor.symbols)
     self.assertEqual('dummy', visitor.last_parent)
     self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')],
                      visitor.last_children)
 def testNoSubclassOfMessageV2(self):
   if not hasattr(tf.compat, 'v2'):
     return
   visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
   visitor.do_not_descend_map['tf'].append('contrib')
   if FLAGS.only_test_core_api:
     visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
   traverse.traverse(tf.compat.v2, visitor)
예제 #18
0
    def testKeywordArgNames(self):
        if not hasattr(tf.compat, "v2"):
            return

        all_keyword_renames = (
            tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)
        v2_name_exceptions = {"verify_shape_is_now_always_true"}

        # Visitor that verifies V1 argument names, converts to V2 and checks
        # V2 argument names.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                names_v1 = get_v1_names(attr)

                for name in names_v1:
                    name = "tf.%s" % name
                    if name not in all_keyword_renames:
                        continue
                    arg_names_v1 = tf_inspect.getargspec(attr)[0]
                    keyword_renames = all_keyword_renames[name]
                    self.assertEqual(type(keyword_renames), dict)

                    # Assert that v1 function has valid v1 argument names.
                    for from_name, _ in keyword_renames.items():
                        self.assertIn(
                            from_name, arg_names_v1,
                            "%s not found in %s arguments: %s" %
                            (from_name, name, str(arg_names_v1)))

                    # Assert that arg names after converting to v2 are present in
                    # v2 function.
                    # 1. First, create an input of the form:
                    #    tf.foo(arg1=val1, arg2=val2, ...)
                    args = ",".join([
                        "%s=%d" % (from_name, from_index) for from_index,
                        from_name in enumerate(keyword_renames.keys())
                    ])
                    text_input = "%s(%s)" % (name, args)
                    # 2. Convert the input to V2.
                    _, _, _, text = self._upgrade(text_input)
                    new_function_name, new_args = get_func_and_args_from_str(
                        text)
                    # 3. Verify V2 function and arguments.
                    # Note: If we rename arguments, new function must be available in 2.0.
                    # We should not be using compat.v1 in this case.
                    self.assertIn(new_function_name, self.v2_symbols)
                    args_v2 = tf_inspect.getargspec(
                        self.v2_symbols[new_function_name])[0]
                    args_v2.extend(v2_name_exceptions)
                    for new_arg in new_args:
                        self.assertIn(new_arg, args_v2)

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)
예제 #19
0
 def test_no_descent_child_removal(self):
     visitor = self.TestVisitor()
     children = [('name1', 'thing1'), ('mock', 'thing2')]
     public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
     # Make sure not-to-be-descended-into symbols are removed after the visitor
     # is called.
     self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')],
                      visitor.last_children)
     self.assertEqual([('name1', 'thing1')], children)
예제 #20
0
    def testNewAPIBackwardsCompatibility(self):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        public_api_visitor.do_not_descend_map['tf'].append('contrib')
        public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
            'Experimental'
        ]
        # TODO(annarev): Make slide_dataset available in API.
        public_api_visitor.private_map['tf'] = ['slide_dataset']
        traverse.traverse(api, public_api_visitor)

        proto_dict = visitor.GetProtos()

        # Read all golden files.
        expression = os.path.join(
            resource_loader.get_root_dir_with_all_resources(),
            _KeyToFilePath('*'))
        golden_file_list = file_io.get_matching_files(expression)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }

        # user_ops is an empty module. It is currently available in TensorFlow API
        # but we don't keep empty modules in the new API.
        # We delete user_ops from golden_proto_dict to make sure assert passes
        # when diffing new API against goldens.
        # TODO(annarev): remove user_ops from goldens once we switch to new API.
        tf_module = golden_proto_dict['tensorflow'].tf_module
        for i in range(len(tf_module.member)):
            if tf_module.member[i].name == 'user_ops':
                del tf_module.member[i]
                break

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not fail.
        self._AssertProtoDictEquals(
            golden_proto_dict,
            proto_dict,
            verbose=FLAGS.verbose_diffs,
            update_goldens=False,
            additional_missing_object_message=
            'Check if tf_export decorator/call is missing for this symbol.')
예제 #21
0
def extract(py_modules, do_not_descend_map):
    """Extract docs from tf namespace and write them to disk."""
    # Traverse the first module.
    visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0])
    api_visitor = public_api.PublicAPIVisitor(visitor)
    add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map)

    traverse.traverse(py_modules[0][1], api_visitor)

    # Traverse all py_modules after the first:
    for module_name, module in py_modules[1:]:
        visitor.set_root_name(module_name)
        traverse.traverse(module, api_visitor)

    return visitor
def get_all_v2_names():
  """Get a set of function/class names available in TensorFlow 2.0."""
  v2_names = set()  # All op names in TensorFlow 2.0

  def visit(unused_path, unused_parent, children):
    """Visitor that collects TF 2.0 names."""
    for child in children:
      _, attr = tf_decorator.unwrap(child[1])
      api_names_v2 = get_v2_names(attr)
      for name in api_names_v2:
        v2_names.add(name)

  visitor = public_api.PublicAPIVisitor(visit)
  visitor.do_not_descend_map['tf'].append('contrib')
  traverse.traverse(tf.compat.v2, visitor)
  return v2_names
예제 #23
0
    def _checkBackwardsCompatibility(
        self,
        root,
        golden_file_patterns,
        api_version,
        additional_private_map=None,
        omit_golden_symbols_map=None,
    ):
        # Extract all API stuff.
        visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor(
            default_path="tensorflow.keras")

        public_api_visitor = public_api.PublicAPIVisitor(visitor)
        if additional_private_map:
            public_api_visitor.private_map.update(additional_private_map)
        public_api_visitor.set_root_name("tf.keras")

        traverse.traverse(root, public_api_visitor)
        proto_dict = visitor.GetProtos()

        # Read all golden files.
        golden_file_list = tf.compat.v1.gfile.Glob(golden_file_patterns)

        def _ReadFileToProto(filename):
            """Read a filename, create a protobuf from its contents."""
            ret_val = api_objects_pb2.TFAPIObject()
            text_format.Merge(file_io.read_file_to_string(filename), ret_val)
            return ret_val

        golden_proto_dict = {
            _FileNameToKey(filename): _ReadFileToProto(filename)
            for filename in golden_file_list
        }
        golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                                   omit_golden_symbols_map)

        # Diff them. Do not fail if called with update.
        # If the test is run to update goldens, only report diffs but do not
        # fail.
        self._AssertProtoDictEquals(
            golden_proto_dict,
            proto_dict,
            verbose=FLAGS.verbose_diffs,
            update_goldens=FLAGS.update_goldens,
            api_version=api_version,
        )
  def testNoSubclassOfMessage(self):

    def Visit(path, parent, unused_children):
      """A Visitor that crashes on subclasses of generated proto classes."""
      # If the traversed object is a proto Message class
      if not (isinstance(parent, type) and
              issubclass(parent, message.Message)):
        return
      if parent is message.Message:
        return
      # Check that it is a direct subclass of Message.
      if message.Message not in parent.__bases__:
        raise NotImplementedError(
            'Object tf.%s is a subclass of a generated proto Message. '
            'They are not yet supported by the API tools.' % path)
    visitor = public_api.PublicAPIVisitor(Visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    traverse.traverse(tf, visitor)
  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # TODO(annarev): remove once we switch to using tf_export decorators.
    tf_module = golden_proto_dict['tensorflow'].tf_module
    for i in range(len(tf_module.member)):
      if tf_module.member[i].name == 'math':
        del tf_module.member[i]
        break

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
예제 #26
0
def collect_function_arg_names(function_names):
    """Determines argument names for reordered function signatures.

  Args:
    function_names: Functions to collect arguments for.

  Returns:
    Dictionary mapping function name to its arguments.
  """
    # Map from reordered function name to its arguments.
    function_to_args = {}

    def visit(unused_path, unused_parent, children):
        """Visitor that collects arguments for reordered functions."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            api_names_v1 = tf_export.get_v1_names(attr)
            api_names_v1 = ['tf.%s' % name for name in api_names_v1]
            matches_function_names = any(name in function_names
                                         for name in api_names_v1)
            if matches_function_names:
                if tf_inspect.isclass(attr):
                    # Get constructor arguments if attr is a class
                    arg_list = tf_inspect.getargspec(getattr(attr,
                                                             '__init__'))[0]
                    arg_list = arg_list[1:]  # skip 'self' argument
                else:
                    # Get function arguments.
                    # getargspec returns a tuple of (args, varargs, keywords, defaults)
                    # we just look at args.
                    arg_list = tf_inspect.getargspec(attr)[0]
                for name in api_names_v1:
                    function_to_args[name] = arg_list

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)

    return function_to_args
  def testNewAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    # TODO(annarev): Make slide_dataset available in API.
    public_api_visitor.private_map['tf'] = ['slide_dataset']
    traverse.traverse(api, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=False,
        additional_missing_object_message=
        'Check if tf_export decorator/call is missing for this symbol.')
def collect_function_renames():
    """Looks for functions/classes that need to be renamed in TF 2.0.

  Returns:
    List of tuples of the form (current name, new name).
  """
    # Set of rename lines to write to output file in the form:
    #   'tf.deprecated_name': 'tf.canonical_name'
    renames = set()
    v2_names = set()  # All op names in TensorFlow 2.0

    def visit(unused_path, unused_parent, children):
        """Visitor that collects rename strings to add to rename_line_set."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            if not hasattr(attr, '__dict__'):
                continue
            api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, [])
            api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
            deprecated_api_names = set(api_names_v1) - set(api_names_v2)
            for name in deprecated_api_names:
                renames.add((name, get_canonical_name(api_names_v2, name)))
            for name in api_names_v2:
                v2_names.add(name)

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)

    # It is possible that a different function is exported with the
    # same name. For e.g. when creating a different function to
    # rename arguments. Exclude it from renames in this case.
    renames = {
        name: new_name
        for name, new_name in renames.items() if name not in v2_names
    }
    return renames
예제 #29
0
def collect_constant_renames():
    """Looks for constants that need to be renamed in TF 2.0.

  Args:
    output_file_path: File path to write output to. Any existing contents
      would be replaced.
  """
    # Set of rename lines to write to output file in the form:
    #   'tf.deprecated_name': 'tf.canonical_name'
    rename_line_set = set()
    # _tf_api_names attribute name
    tensorflow_api_attr = tf_export.API_ATTRS[
        tf_export.TENSORFLOW_API_NAME].names

    def visit(unused_path, unused_parent, children):
        """Visitor that collects rename strings to add to rename_line_set."""
        for child in children:
            _, attr = tf_decorator.unwrap(child[1])
            if not hasattr(attr, '__dict__'):
                continue
            api_names = attr.__dict__.get(tensorflow_api_attr, [])
            deprecated_api_names = attr.__dict__.get(
                '_tf_deprecated_api_names', [])
            canonical_name = tf_export.get_canonical_name(
                api_names, deprecated_api_names)
            for name in deprecated_api_names:
                rename_line_set.add('    \'tf.%s\': \'tf.%s\'' %
                                    (name, canonical_name))

    visitor = public_api.PublicAPIVisitor(visit)
    visitor.do_not_descend_map['tf'].append('contrib')
    visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2']
    traverse.traverse(tf, visitor)

    renames_file_text = '%srenames = {\n%s\n}\n' % (_FILE_HEADER, ',\n'.join(
        sorted(rename_line_set)))
    file_io.write_string_to_file(output_file_path, renames_file_text)
예제 #30
0
    def testV2KeywordArgNames(self):
        # This test converts a call of the form:
        # tf.foo(arg1=0, arg2=1, ...)
        # to 2.0. Then, checks that converted function has valid argument names.
        if not hasattr(tf.compat, "v2"):
            return
        v2_arg_exceptions = {
            "verify_shape_is_now_always_true",
            # These arguments should not be used, they just specify
            # that a function takes named arguments.
            "keyword_required",
            "_sentinel",
        }
        v1_name_exceptions = {
            "tf.print",  # requires print_function import
        }
        function_warnings = (tf_upgrade_v2.TFAPIChangeSpec().function_warnings)
        function_handles = (tf_upgrade_v2.TFAPIChangeSpec().function_handle)
        keyword_renames = (
            tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames)

        # Visitor that converts to V2 and checks V2 argument names.
        def conversion_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                if not tf_inspect.isfunction(attr):
                    continue
                names_v1 = tf_export.get_v1_names(attr)
                arg_names_v1 = get_args(attr)

                for name in names_v1:
                    tf_name = "tf.%s" % name
                    if tf_name in function_warnings or tf_name in function_handles:
                        continue  # These require manual change
                    if tf_name in v1_name_exceptions:
                        continue
                    # Assert that arg names after converting to v2 are present in
                    # v2 function.
                    # 1. First, create an input of the form:
                    #    tf.foo(arg1=val1, arg2=val2, ...)
                    args = ",".join([
                        "%s=%d" % (from_name, from_index)
                        for from_index, from_name in enumerate(arg_names_v1)
                    ])
                    text_input = "%s(%s)" % (tf_name, args)
                    # 2. Convert the input to V2.
                    _, _, _, text = self._upgrade(text_input)
                    new_function_name, new_args = get_func_and_args_from_str(
                        text)
                    if new_function_name == "tf.compat.v1.%s" % name:
                        if tf_name in keyword_renames:
                            # If we rename arguments, new function must be available in 2.0.
                            # We should not be using compat.v1 in this case.
                            self.assertFalse(
                                "Function '%s' is not in 2.0 when converting\n%s\nto\n%s"
                                % (new_function_name, text_input, text))
                        continue
                    # 3. Verify V2 function and arguments.
                    args_v2 = get_args(self.v2_symbols[new_function_name])
                    args_v2.extend(v2_arg_exceptions)
                    for new_arg in new_args:
                        self.assertIn(
                            new_arg, args_v2,
                            "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                            "Supported arguments: %s" %
                            (new_arg, text_input, text, str(args_v2)))

        visitor = public_api.PublicAPIVisitor(conversion_visitor)
        visitor.do_not_descend_map["tf"].append("contrib")
        visitor.private_map["tf.compat"] = ["v1", "v2"]
        traverse.traverse(tf.compat.v1, visitor)