def testExportClasses(self): export_decorator_a = tf_export.tf_export('TestClassA1') export_decorator_a(TestClassA) self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) self.assertTrue('_tf_api_names' not in TestClassB.__dict__) export_decorator_b = tf_export.tf_export('TestClassB1') export_decorator_b(TestClassB) self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) self.assertEquals(('TestClassB1',), TestClassB._tf_api_names) self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA)) self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
def testExportClasses(self): export_decorator_a = tf_export.tf_export('TestClassA1') export_decorator_a(TestClassA) self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) self.assertTrue('_tf_api_names' not in TestClassB.__dict__) export_decorator_b = tf_export.tf_export('TestClassB1') export_decorator_b(TestClassB) self.assertEquals(('TestClassA1',), TestClassA._tf_api_names) self.assertEquals(('TestClassB1',), TestClassB._tf_api_names) self.assertEquals(['TestClassA1'], tf_export.get_v1_names(TestClassA)) self.assertEquals(['TestClassB1'], tf_export.get_v1_names(TestClassB))
def testReorderFileNeedsUpdate(self): reordered_function_names = ( tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names) function_reorders = (tf_upgrade_v2.TFAPIChangeSpec().function_reorders) added_names_message = """Some function names in self.reordered_function_names are not in reorders_v2.py. Please run the following commands to update reorders_v2.py: bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map """ removed_names_message = """%s in self.reorders_v2 does not match any name in self.reordered_function_names. Please run the following commands to update reorders_v2.py: bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map """ self.assertTrue(reordered_function_names.issubset(function_reorders), added_names_message) # function_reorders should contain reordered_function_names # and their TensorFlow V1 aliases. for name in function_reorders: # get other names for this function attr = get_symbol_for_name(tf.compat.v1, name) _, attr = tf_decorator.unwrap(attr) v1_names = tf_export.get_v1_names(attr) self.assertTrue(v1_names) v1_names = ["tf.%s" % n for n in v1_names] # check if any other name is in self.assertTrue( any(n in reordered_function_names for n in v1_names), removed_names_message % name)
def testReorderFileNeedsUpdate(self): reordered_function_names = ( tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names) function_reorders = ( tf_upgrade_v2.TFAPIChangeSpec().function_reorders) added_names_message = """Some function names in self.reordered_function_names are not in reorders_v2.py. Please run the following commands to update reorders_v2.py: bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map """ removed_names_message = """%s in self.reorders_v2 does not match any name in self.reordered_function_names. Please run the following commands to update reorders_v2.py: bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map """ self.assertTrue( reordered_function_names.issubset(function_reorders), added_names_message) # function_reorders should contain reordered_function_names # and their TensorFlow V1 aliases. for name in function_reorders: # get other names for this function attr = get_symbol_for_name(tf.compat.v1, name) _, attr = tf_decorator.unwrap(attr) v1_names = tf_export.get_v1_names(attr) self.assertTrue(v1_names) v1_names = ["tf.%s" % n for n in v1_names] # check if any other name is in self.assertTrue( any(n in reordered_function_names for n in v1_names), removed_names_message % name)
def _op_is_in_tf_version(op, version): if version == 1: return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS) elif version == 2: return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) else: raise ValueError('Expected version 1 or 2.')
def _op_is_in_tf_version(op, version): if version == 1: return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS) elif version == 2: return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) else: raise ValueError('Expected version 1 or 2.')
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function))
def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v2 = tf_export.get_v2_names(attr) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name)))
def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and text not in self.v2_symbols): self.assertFalse( True, "Symbol %s generated from %s not in v2 API" % (text, name))
def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) if not tf_inspect.isfunction(attr): continue names_v1 = tf_export.get_v1_names(attr) arg_names_v1 = get_args(attr) for name in names_v1: tf_name = "tf.%s" % name if tf_name in function_warnings or tf_name in function_transformers: continue # These require manual change if tf_name in v1_name_exceptions: continue # Assert that arg names after converting to v2 are present in # v2 function. # 1. First, create an input of the form: # tf.foo(arg1=val1, arg2=val2, ...) args = ",".join( ["%s=%d" % (from_name, from_index) for from_index, from_name in enumerate(arg_names_v1)]) text_input = "%s(%s)" % (tf_name, args) # 2. Convert the input to V2. _, _, _, text = self._upgrade(text_input) new_function_name, new_args = get_func_and_args_from_str(text) if new_function_name == "tf.compat.v1.%s" % name: if tf_name in keyword_renames: # If we rename arguments, new function must be available in 2.0. # We should not be using compat.v1 in this case. self.assertFalse( "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" % (new_function_name, text_input, text)) continue # 3. Verify V2 function and arguments. args_v2 = get_args(self.v2_symbols[new_function_name]) args_v2.extend(v2_arg_exceptions) for new_arg in new_args: self.assertIn( new_arg, args_v2, "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n" "Supported arguments: %s" % ( new_arg, text_input, text, str(args_v2))) # 4. Verify that the argument exists in v1 as well. if new_function_name in set(["tf.nn.ctc_loss", "tf.saved_model.save"]): continue args_v1 = get_args(self.v1_symbols[new_function_name]) args_v1.extend(v2_arg_exceptions) for new_arg in new_args: self.assertIn( new_arg, args_v1, "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n" "Supported arguments: %s" % ( new_arg, text_input, text, str(args_v1)))
def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and text not in self.v2_symbols): self.assertFalse( True, "Symbol %s generated from %s not in v2 API" % ( text, name))
def visit(unused_path, unused_parent, children): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any( name in function_names for name in api_names_v1) if matches_function_names: arg_list = tf_inspect.getargspec(attr)[0] for name in api_names_v1: function_to_args[name] = arg_list
def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and text not in self.v2_symbols and # Builds currently install old version of estimator that doesn't # have some 2.0 symbols. not text.startswith("tf.estimator")): self.assertFalse( True, "Symbol %s generated from %s not in v2 API" % ( text, name))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function)), decorated_function)
def testExportSingleFunctionV1Only(self): export_decorator = tf_export.tf_export(v1=['nameA', 'nameB']) decorated_function = export_decorator(_test_function) self.assertEqual(decorated_function, _test_function) self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1) self.assertAllEqual(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEqual([], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol( decorated_function, add_prefix_to_v1_names=True)), decorated_function)
def arg_test_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) names_v1 = tf_export.get_v1_names(attr) for name in names_v1: name = "tf.%s" % name if name not in all_keyword_renames: continue arg_names_v1 = tf_inspect.getargspec(attr)[0] keyword_renames = all_keyword_renames[name] self.assertEqual(type(keyword_renames), dict) # Assert that v1 function has valid v1 argument names. for from_name, _ in keyword_renames.items(): self.assertIn( from_name, arg_names_v1, "%s not found in %s arguments: %s" % (from_name, name, str(arg_names_v1)))
def arg_test_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) names_v1 = tf_export.get_v1_names(attr) for name in names_v1: name = "tf.%s" % name if name not in all_keyword_renames: continue arg_names_v1 = tf_inspect.getargspec(attr)[0] keyword_renames = all_keyword_renames[name] self.assertEqual(type(keyword_renames), dict) # Assert that v1 function has valid v1 argument names. for from_name, _ in keyword_renames.items(): self.assertIn( from_name, arg_names_v1, "%s not found in %s arguments: %s" % (from_name, name, str(arg_names_v1)))
def visit(unused_path, unused_parent, children): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any( name in function_names for name in api_names_v1) if matches_function_names: if tf_inspect.isclass(attr): # Get constructor arguments if attr is a class arg_list = tf_inspect.getargspec( getattr(attr, '__init__'))[0] arg_list = arg_list[1:] # skip 'self' argument else: # Get function arguments. # getargspec returns a tuple of (args, varargs, keywords, defaults) # we just look at args. arg_list = tf_inspect.getargspec(attr)[0] for name in api_names_v1: function_to_args[name] = arg_list
def visit(unused_path, unused_parent, children): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any(name in function_names for name in api_names_v1) if matches_function_names: if tf_inspect.isclass(attr): # Get constructor arguments if attr is a class arg_list = tf_inspect.getargspec(getattr(attr, '__init__'))[0] arg_list = arg_list[1:] # skip 'self' argument else: # Get function arguments. # getargspec returns a tuple of (args, varargs, keywords, defaults) # we just look at args. arg_list = tf_inspect.getargspec(attr)[0] for name in api_names_v1: function_to_args[name] = arg_list
def symbol_collector_v1(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) for name in api_names_v1: cls.v1_symbols["tf." + name] = attr