コード例 #1
0
ファイル: tf_export_test.py プロジェクト: Wajih-O/tensorflow
  def testExportClasses(self):
    export_decorator_a = tf_export.tf_export('TestClassA1')
    export_decorator_a(TestClassA)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)

    export_decorator_b = tf_export.tf_export('TestClassB1')
    export_decorator_b(TestClassB)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
    self.assertEquals(('TestClassB1',), TestClassB._tf_api_names)
    self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
    self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
コード例 #2
0
  def testExportClasses(self):
    export_decorator_a = tf_export.tf_export('TestClassA1')
    export_decorator_a(TestClassA)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)

    export_decorator_b = tf_export.tf_export('TestClassB1')
    export_decorator_b(TestClassB)
    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
    self.assertEquals(('TestClassB1',), TestClassB._tf_api_names)
    self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA))
    self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
コード例 #3
0
    def testReorderFileNeedsUpdate(self):
        reordered_function_names = (
            tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
        function_reorders = (tf_upgrade_v2.TFAPIChangeSpec().function_reorders)

        added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
        removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
        self.assertTrue(reordered_function_names.issubset(function_reorders),
                        added_names_message)
        # function_reorders should contain reordered_function_names
        # and their TensorFlow V1 aliases.
        for name in function_reorders:
            # get other names for this function
            attr = get_symbol_for_name(tf.compat.v1, name)
            _, attr = tf_decorator.unwrap(attr)
            v1_names = tf_export.get_v1_names(attr)
            self.assertTrue(v1_names)
            v1_names = ["tf.%s" % n for n in v1_names]
            # check if any other name is in
            self.assertTrue(
                any(n in reordered_function_names for n in v1_names),
                removed_names_message % name)
コード例 #4
0
  def testReorderFileNeedsUpdate(self):
    reordered_function_names = (
        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
    function_reorders = (
        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)

    added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    self.assertTrue(
        reordered_function_names.issubset(function_reorders),
        added_names_message)
    # function_reorders should contain reordered_function_names
    # and their TensorFlow V1 aliases.
    for name in function_reorders:
      # get other names for this function
      attr = get_symbol_for_name(tf.compat.v1, name)
      _, attr = tf_decorator.unwrap(attr)
      v1_names = tf_export.get_v1_names(attr)
      self.assertTrue(v1_names)
      v1_names = ["tf.%s" % n for n in v1_names]
      # check if any other name is in
      self.assertTrue(
          any(n in reordered_function_names for n in v1_names),
          removed_names_message % name)
コード例 #5
0
ファイル: ragged_dispatch.py プロジェクト: Wajih-O/tensorflow
def _op_is_in_tf_version(op, version):
  if version == 1:
    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
            op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
  elif version == 2:
    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
  else:
    raise ValueError('Expected version 1 or 2.')
コード例 #6
0
def _op_is_in_tf_version(op, version):
    if version == 1:
        return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1])
                or op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS)
    elif version == 2:
        return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
    else:
        raise ValueError('Expected version 1 or 2.')
コード例 #7
0
ファイル: tf_export_test.py プロジェクト: Wajih-O/tensorflow
 def testExportSingleFunction(self):
   export_decorator = tf_export.tf_export('nameA', 'nameB')
   decorated_function = export_decorator(_test_function)
   self.assertEquals(decorated_function, _test_function)
   self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v1_names(decorated_function))
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v2_names(decorated_function))
コード例 #8
0
 def testExportSingleFunction(self):
     export_decorator = tf_export.tf_export('nameA', 'nameB')
     decorated_function = export_decorator(_test_function)
     self.assertEquals(decorated_function, _test_function)
     self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
     self.assertEquals(['nameA', 'nameB'],
                       tf_export.get_v1_names(decorated_function))
     self.assertEquals(['nameA', 'nameB'],
                       tf_export.get_v2_names(decorated_function))
コード例 #9
0
 def visit(unused_path, unused_parent, children):
   """Visitor that collects rename strings to add to rename_line_set."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     api_names_v2 = tf_export.get_v2_names(attr)
     deprecated_api_names = set(api_names_v1) - set(api_names_v2)
     for name in deprecated_api_names:
       renames.add((name, get_canonical_name(api_names_v2, name)))
コード例 #10
0
 def conversion_visitor(unused_path, unused_parent, children):
     for child in children:
         _, attr = tf_decorator.unwrap(child[1])
         api_names = tf_export.get_v1_names(attr)
         for name in api_names:
             _, _, _, text = self._upgrade("tf." + name)
             if (text and not text.startswith("tf.compat.v1")
                     and text not in self.v2_symbols):
                 self.assertFalse(
                     True, "Symbol %s generated from %s not in v2 API" %
                     (text, name))
コード例 #11
0
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        if not tf_inspect.isfunction(attr):
          continue
        names_v1 = tf_export.get_v1_names(attr)
        arg_names_v1 = get_args(attr)

        for name in names_v1:
          tf_name = "tf.%s" % name
          if tf_name in function_warnings or tf_name in function_transformers:
            continue  # These require manual change
          if tf_name in v1_name_exceptions:
            continue
          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(arg_names_v1)])
          text_input = "%s(%s)" % (tf_name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          if new_function_name == "tf.compat.v1.%s" % name:
            if tf_name in keyword_renames:
              # If we rename arguments, new function must be available in 2.0.
              # We should not be using compat.v1 in this case.
              self.assertFalse(
                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
                  (new_function_name, text_input, text))
            continue
          # 3. Verify V2 function and arguments.
          args_v2 = get_args(self.v2_symbols[new_function_name])
          args_v2.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v2,
                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v2)))
          # 4. Verify that the argument exists in v1 as well.
          if new_function_name in set(["tf.nn.ctc_loss",
                                       "tf.saved_model.save"]):
            continue
          args_v1 = get_args(self.v1_symbols[new_function_name])
          args_v1.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v1,
                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v1)))
コード例 #12
0
 def conversion_visitor(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names = tf_export.get_v1_names(attr)
     for name in api_names:
       _, _, _, text = self._upgrade("tf." + name)
       if (text and
           not text.startswith("tf.compat.v1") and
           text not in self.v2_symbols):
         self.assertFalse(
             True, "Symbol %s generated from %s not in v2 API" % (
                 text, name))
コード例 #13
0
 def visit(unused_path, unused_parent, children):
   """Visitor that collects arguments for reordered functions."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     api_names_v1 = ['tf.%s' % name for name in api_names_v1]
     matches_function_names = any(
         name in function_names for name in api_names_v1)
     if matches_function_names:
       arg_list = tf_inspect.getargspec(attr)[0]
       for name in api_names_v1:
         function_to_args[name] = arg_list
コード例 #14
0
 def conversion_visitor(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names = tf_export.get_v1_names(attr)
     for name in api_names:
       _, _, _, text = self._upgrade("tf." + name)
       if (text and
           not text.startswith("tf.compat.v1") and
           text not in self.v2_symbols and
           # Builds currently install old version of estimator that doesn't
           # have some 2.0 symbols.
           not text.startswith("tf.estimator")):
         self.assertFalse(
             True, "Symbol %s generated from %s not in v2 API" % (
                 text, name))
コード例 #15
0
 def testExportSingleFunction(self):
   export_decorator = tf_export.tf_export('nameA', 'nameB')
   decorated_function = export_decorator(_test_function)
   self.assertEquals(decorated_function, _test_function)
   self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v1_names(decorated_function))
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(decorated_function)),
       decorated_function)
コード例 #16
0
 def testExportSingleFunctionV1Only(self):
   export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
   decorated_function = export_decorator(_test_function)
   self.assertEqual(decorated_function, _test_function)
   self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
   self.assertAllEqual(['nameA', 'nameB'],
                       tf_export.get_v1_names(decorated_function))
   self.assertEqual([],
                    tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(
               decorated_function, add_prefix_to_v1_names=True)),
       decorated_function)
コード例 #17
0
    def arg_test_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        names_v1 = tf_export.get_v1_names(attr)

        for name in names_v1:
          name = "tf.%s" % name
          if name not in all_keyword_renames:
            continue
          arg_names_v1 = tf_inspect.getargspec(attr)[0]
          keyword_renames = all_keyword_renames[name]
          self.assertEqual(type(keyword_renames), dict)

          # Assert that v1 function has valid v1 argument names.
          for from_name, _ in keyword_renames.items():
            self.assertIn(
                from_name, arg_names_v1,
                "%s not found in %s arguments: %s" %
                (from_name, name, str(arg_names_v1)))
コード例 #18
0
        def arg_test_visitor(unused_path, unused_parent, children):
            for child in children:
                _, attr = tf_decorator.unwrap(child[1])
                names_v1 = tf_export.get_v1_names(attr)

                for name in names_v1:
                    name = "tf.%s" % name
                    if name not in all_keyword_renames:
                        continue
                    arg_names_v1 = tf_inspect.getargspec(attr)[0]
                    keyword_renames = all_keyword_renames[name]
                    self.assertEqual(type(keyword_renames), dict)

                    # Assert that v1 function has valid v1 argument names.
                    for from_name, _ in keyword_renames.items():
                        self.assertIn(
                            from_name, arg_names_v1,
                            "%s not found in %s arguments: %s" %
                            (from_name, name, str(arg_names_v1)))
コード例 #19
0
 def visit(unused_path, unused_parent, children):
   """Visitor that collects arguments for reordered functions."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     api_names_v1 = ['tf.%s' % name for name in api_names_v1]
     matches_function_names = any(
         name in function_names for name in api_names_v1)
     if matches_function_names:
       if tf_inspect.isclass(attr):
         # Get constructor arguments if attr is a class
         arg_list = tf_inspect.getargspec(
             getattr(attr, '__init__'))[0]
         arg_list = arg_list[1:]  # skip 'self' argument
       else:
         # Get function arguments.
         # getargspec returns a tuple of (args, varargs, keywords, defaults)
         # we just look at args.
         arg_list = tf_inspect.getargspec(attr)[0]
       for name in api_names_v1:
         function_to_args[name] = arg_list
コード例 #20
0
 def visit(unused_path, unused_parent, children):
     """Visitor that collects arguments for reordered functions."""
     for child in children:
         _, attr = tf_decorator.unwrap(child[1])
         api_names_v1 = tf_export.get_v1_names(attr)
         api_names_v1 = ['tf.%s' % name for name in api_names_v1]
         matches_function_names = any(name in function_names
                                      for name in api_names_v1)
         if matches_function_names:
             if tf_inspect.isclass(attr):
                 # Get constructor arguments if attr is a class
                 arg_list = tf_inspect.getargspec(getattr(attr,
                                                          '__init__'))[0]
                 arg_list = arg_list[1:]  # skip 'self' argument
             else:
                 # Get function arguments.
                 # getargspec returns a tuple of (args, varargs, keywords, defaults)
                 # we just look at args.
                 arg_list = tf_inspect.getargspec(attr)[0]
             for name in api_names_v1:
                 function_to_args[name] = arg_list
コード例 #21
0
 def symbol_collector_v1(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     for name in api_names_v1:
       cls.v1_symbols["tf." + name] = attr