def setUpClass(cls): cls.v2_symbols = {} cls.v1_symbols = {} if hasattr(tf.compat, "v2"): def symbol_collector(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: cls.v2_symbols["tf." + name] = attr visitor = public_api.PublicAPIVisitor(symbol_collector) traverse.traverse(tf.compat.v2, visitor) if hasattr(tf.compat, "v1"): def symbol_collector_v1(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) for name in api_names_v1: cls.v1_symbols["tf." + name] = attr visitor = public_api.PublicAPIVisitor(symbol_collector_v1) traverse.traverse(tf.compat.v1, visitor)
def testAllAPI(self): if not hasattr(tf.compat, "v2"): return v2_symbols = set([]) def symbol_collector(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = get_v2_names(attr) for name in api_names_v2: v2_symbols.add("tf." + name) visitor = public_api.PublicAPIVisitor(symbol_collector) traverse.traverse(tf.compat.v2, visitor) # Converts all symbols in the v1 namespace to the v2 namespace, raising # an error if the target of the conversion is not in the v2 namespace. def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and text not in v2_symbols): self.assertFalse( True, "Symbol %s generated from %s not in v2 API" % (text, name)) visitor = public_api.PublicAPIVisitor(conversion_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor)
def collect_function_arg_names(function_names): """Determines argument names for reordered function signatures. Args: function_names: Functions to collect arguments for. Returns: Dictionary mapping function name to its arguments. """ # Map from reordered function name to its arguments. function_to_args = {} def visit(unused_path, unused_parent, children): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any( name in function_names for name in api_names_v1) if matches_function_names: arg_list = tf_inspect.getargspec(attr)[0] for name in api_names_v1: function_to_args[name] = arg_list visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) return function_to_args
def testAPIBackwardsCompatibility(self): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map[''].append('contrib') traverse.traverse(tf, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. expression = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*')) golden_file_list = file_io.get_matching_files(expression) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens)
def checkBackwardsCompatibility(self, root, golden_file_pattern, api_version): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] traverse.traverse(root, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. golden_file_list = file_io.get_matching_files(golden_file_pattern) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens, api_version=api_version)
def testAllAPI(self): if not hasattr(tf.compat, "v2"): return # Converts all symbols in the v1 namespace to the v2 namespace, raising # an error if the target of the conversion is not in the v2 namespace. # Please regenerate the renames file or edit any manual renames if this # test fails. def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and text not in self.v2_symbols and # Builds currently install old version of estimator that doesn't # have some 2.0 symbols. not text.startswith("tf.estimator")): self.assertFalse( True, "Symbol %s generated from %s not in v2 API" % ( text, name)) visitor = public_api.PublicAPIVisitor(conversion_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor)
def collect_function_renames(): """Looks for functions/classes that need to be renamed in TF 2.0. Returns: Set of tuples of the form (current name, new name). """ # Set of rename lines to write to output file in the form: # 'tf.deprecated_name': 'tf.canonical_name' renames = set() def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v2 = tf_export.get_v2_names(attr) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name))) visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) # It is possible that a different function is exported with the # same name. For e.g. when creating a different function to # rename arguments. Exclude it from renames in this case. v2_names = get_all_v2_names() renames = set((name, new_name) for name, new_name in renames if name not in v2_names) return renames
def testV1KeywordArgNames(self): all_keyword_renames = ( tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames) # Visitor that verifies V1 argument names. def arg_test_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) names_v1 = tf_export.get_v1_names(attr) for name in names_v1: name = "tf.%s" % name if name not in all_keyword_renames: continue arg_names_v1 = tf_inspect.getargspec(attr)[0] keyword_renames = all_keyword_renames[name] self.assertEqual(type(keyword_renames), dict) # Assert that v1 function has valid v1 argument names. for from_name, _ in keyword_renames.items(): self.assertIn( from_name, arg_names_v1, "%s not found in %s arguments: %s" % (from_name, name, str(arg_names_v1))) visitor = public_api.PublicAPIVisitor(arg_test_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor)
def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor( _VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf.compat.v2, visitor)
def extract(): """Extract docs from tf namespace and write them to disk.""" visitor = doc_generator_visitor.DocGeneratorVisitor('tf') api_visitor = public_api.PublicAPIVisitor(visitor) # Access something in contrib so tf.contrib is properly loaded (it's hidden # behind lazy loading) _ = tf.contrib.__name__ # Exclude some libaries in contrib from the documentation altogether. # TODO(wicke): Shrink this list. api_visitor.do_not_descend_map.update({ 'contrib': [ 'compiler', 'factorization', 'grid_rnn', 'labeled_tensor', 'ndlstm', 'quantization', 'session_bundle', 'slim', 'solvers', 'specs', 'tensor_forest', 'tensorboard', 'testing', 'tfprof', ], 'contrib.bayesflow': [ 'entropy', 'monte_carlo', 'special_math', 'stochastic_gradient_estimators', 'stochastic_graph', 'stochastic_tensor', 'stochastic_variables', 'variational_inference' ], 'contrib.distributions': ['bijector'], 'contrib.ffmpeg': ['ffmpeg_ops'], 'contrib.graph_editor': ['edit', 'match', 'subgraph', 'transform', 'select', 'util'], 'contrib.layers': ['feature_column', 'summaries'], 'contrib.learn': [ 'datasets', 'head', 'graph_actions', 'io', 'models', 'monitors', 'ops', 'preprocessing', 'utils', ], 'contrib.util': ['loader'], }) traverse.traverse(tf, api_visitor) # tf_debug is not imported with tf, it's a separate module altogether visitor.set_root_name('tfdbg') traverse.traverse(tf_debug, api_visitor) return visitor
def test_private_child_removal(self): visitor = self.TestVisitor() children = [('name1', 'thing1'), ('_name2', 'thing2')] public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) # Make sure the private symbols are removed before the visitor is called. self.assertEqual([('name1', 'thing1')], visitor.last_children) self.assertEqual([('name1', 'thing1')], children)
def collect_function_renames(): """Looks for functions/classes that need to be renamed in TF 2.0. Returns: List of tuples of the form (current name, new name). """ # Set of rename lines to write to output file in the form: # 'tf.deprecated_name': 'tf.canonical_name' renames = set() def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) if not hasattr(attr, '__dict__'): continue api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, []) api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, []) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name))) visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) return renames
def testNoSubclassOfMessage(self): visitor = public_api.PublicAPIVisitor( _VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') # Skip compat.v1 and compat.v2 since they are validated in separate tests. visitor.private_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor)
def testAllAPIV1(self): collect = True v1_symbols = set([]) # Symbols which may be generated by the conversion script which do not exist # in TF 1.x. This should be a very short list of symbols which are # experimental in 1.x but stable for 2.x. whitelisted_v2_only_symbols = set(["tf.saved_model.save"]) # Converts all symbols in the v1 namespace to the v2 namespace, raising # an error if the target of the conversion is not in the v1 namespace. def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names = tf_export.get_v1_names(attr) for name in api_names: if collect: v1_symbols.add("tf." + name) else: _, _, _, text = self._upgrade("tf." + name) if (text and not text.startswith("tf.compat.v1") and not text.startswith("tf.estimator") and text not in v1_symbols and text not in whitelisted_v2_only_symbols): self.assertFalse( True, "Symbol %s generated from %s not in v1 API" % (text, name)) visitor = public_api.PublicAPIVisitor(conversion_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor) collect = False traverse.traverse(tf.compat.v1, visitor)
def _checkBackwardsCompatibility(self, root, golden_file_patterns, api_version, additional_private_map=None, omit_golden_symbols_map=None): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.private_map['tf'].append('contrib') if api_version == 2: public_api_visitor.private_map['tf'].append('enable_v2_behavior') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [ 'Experimental' ] # Do not descend into these numpy classes because their signatures may be # different between internal and OSS. public_api_visitor.do_not_descend_map['tf.experimental.numpy'] = [ 'bool_', 'complex_', 'complex128', 'complex64', 'float_', 'float16', 'float32', 'float64', 'inexact', 'int_', 'int16', 'int32', 'int64', 'int8', 'object_', 'string_', 'uint16', 'uint32', 'uint64', 'uint8', 'unicode_', 'iinfo' ] if FLAGS.only_test_core_api: public_api_visitor.do_not_descend_map['tf'].extend( _NON_CORE_PACKAGES) if additional_private_map: public_api_visitor.private_map.update(additional_private_map) traverse.traverse(root, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. golden_file_list = file_io.get_matching_files(golden_file_patterns) if FLAGS.only_test_core_api: golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map) # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals(golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens, api_version=api_version)
def test_call_forward(self): visitor = self.TestVisitor() children = [('name1', 'thing1'), ('name2', 'thing2')] public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) self.assertEqual(set(['test']), visitor.symbols) self.assertEqual('dummy', visitor.last_parent) self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')], visitor.last_children)
def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') if FLAGS.only_test_core_api: visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES) traverse.traverse(tf.compat.v2, visitor)
def testKeywordArgNames(self): if not hasattr(tf.compat, "v2"): return all_keyword_renames = ( tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames) v2_name_exceptions = {"verify_shape_is_now_always_true"} # Visitor that verifies V1 argument names, converts to V2 and checks # V2 argument names. def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) names_v1 = get_v1_names(attr) for name in names_v1: name = "tf.%s" % name if name not in all_keyword_renames: continue arg_names_v1 = tf_inspect.getargspec(attr)[0] keyword_renames = all_keyword_renames[name] self.assertEqual(type(keyword_renames), dict) # Assert that v1 function has valid v1 argument names. for from_name, _ in keyword_renames.items(): self.assertIn( from_name, arg_names_v1, "%s not found in %s arguments: %s" % (from_name, name, str(arg_names_v1))) # Assert that arg names after converting to v2 are present in # v2 function. # 1. First, create an input of the form: # tf.foo(arg1=val1, arg2=val2, ...) args = ",".join([ "%s=%d" % (from_name, from_index) for from_index, from_name in enumerate(keyword_renames.keys()) ]) text_input = "%s(%s)" % (name, args) # 2. Convert the input to V2. _, _, _, text = self._upgrade(text_input) new_function_name, new_args = get_func_and_args_from_str( text) # 3. Verify V2 function and arguments. # Note: If we rename arguments, new function must be available in 2.0. # We should not be using compat.v1 in this case. self.assertIn(new_function_name, self.v2_symbols) args_v2 = tf_inspect.getargspec( self.v2_symbols[new_function_name])[0] args_v2.extend(v2_name_exceptions) for new_arg in new_args: self.assertIn(new_arg, args_v2) visitor = public_api.PublicAPIVisitor(conversion_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor)
def test_no_descent_child_removal(self): visitor = self.TestVisitor() children = [('name1', 'thing1'), ('mock', 'thing2')] public_api.PublicAPIVisitor(visitor)('test', 'dummy', children) # Make sure not-to-be-descended-into symbols are removed after the visitor # is called. self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')], visitor.last_children) self.assertEqual([('name1', 'thing1')], children)
def testNewAPIBackwardsCompatibility(self): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [ 'Experimental' ] # TODO(annarev): Make slide_dataset available in API. public_api_visitor.private_map['tf'] = ['slide_dataset'] traverse.traverse(api, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. expression = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*')) golden_file_list = file_io.get_matching_files(expression) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } # user_ops is an empty module. It is currently available in TensorFlow API # but we don't keep empty modules in the new API. # We delete user_ops from golden_proto_dict to make sure assert passes # when diffing new API against goldens. # TODO(annarev): remove user_ops from goldens once we switch to new API. tf_module = golden_proto_dict['tensorflow'].tf_module for i in range(len(tf_module.member)): if tf_module.member[i].name == 'user_ops': del tf_module.member[i] break # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=False, additional_missing_object_message= 'Check if tf_export decorator/call is missing for this symbol.')
def extract(py_modules, do_not_descend_map): """Extract docs from tf namespace and write them to disk.""" # Traverse the first module. visitor = doc_generator_visitor.DocGeneratorVisitor(py_modules[0][0]) api_visitor = public_api.PublicAPIVisitor(visitor) add_dict_to_dict(do_not_descend_map, api_visitor.do_not_descend_map) traverse.traverse(py_modules[0][1], api_visitor) # Traverse all py_modules after the first: for module_name, module in py_modules[1:]: visitor.set_root_name(module_name) traverse.traverse(module, api_visitor) return visitor
def get_all_v2_names(): """Get a set of function/class names available in TensorFlow 2.0.""" v2_names = set() # All op names in TensorFlow 2.0 def visit(unused_path, unused_parent, children): """Visitor that collects TF 2.0 names.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = get_v2_names(attr) for name in api_names_v2: v2_names.add(name) visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf.compat.v2, visitor) return v2_names
def _checkBackwardsCompatibility( self, root, golden_file_patterns, api_version, additional_private_map=None, omit_golden_symbols_map=None, ): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor( default_path="tensorflow.keras") public_api_visitor = public_api.PublicAPIVisitor(visitor) if additional_private_map: public_api_visitor.private_map.update(additional_private_map) public_api_visitor.set_root_name("tf.keras") traverse.traverse(root, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. golden_file_list = tf.compat.v1.gfile.Glob(golden_file_patterns) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict, omit_golden_symbols_map) # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not # fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens, api_version=api_version, )
def testNoSubclassOfMessage(self): def Visit(path, parent, unused_children): """A Visitor that crashes on subclasses of generated proto classes.""" # If the traversed object is a proto Message class if not (isinstance(parent, type) and issubclass(parent, message.Message)): return if parent is message.Message: return # Check that it is a direct subclass of Message. if message.Message not in parent.__bases__: raise NotImplementedError( 'Object tf.%s is a subclass of a generated proto Message. ' 'They are not yet supported by the API tools.' % path) visitor = public_api.PublicAPIVisitor(Visit) visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf, visitor)
def testAPIBackwardsCompatibility(self): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] traverse.traverse(tf, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. expression = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*')) golden_file_list = file_io.get_matching_files(expression) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } # TODO(annarev): remove once we switch to using tf_export decorators. tf_module = golden_proto_dict['tensorflow'].tf_module for i in range(len(tf_module.member)): if tf_module.member[i].name == 'math': del tf_module.member[i] break # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens)
def collect_function_arg_names(function_names): """Determines argument names for reordered function signatures. Args: function_names: Functions to collect arguments for. Returns: Dictionary mapping function name to its arguments. """ # Map from reordered function name to its arguments. function_to_args = {} def visit(unused_path, unused_parent, children): """Visitor that collects arguments for reordered functions.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v1 = ['tf.%s' % name for name in api_names_v1] matches_function_names = any(name in function_names for name in api_names_v1) if matches_function_names: if tf_inspect.isclass(attr): # Get constructor arguments if attr is a class arg_list = tf_inspect.getargspec(getattr(attr, '__init__'))[0] arg_list = arg_list[1:] # skip 'self' argument else: # Get function arguments. # getargspec returns a tuple of (args, varargs, keywords, defaults) # we just look at args. arg_list = tf_inspect.getargspec(attr)[0] for name in api_names_v1: function_to_args[name] = arg_list visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) return function_to_args
def testNewAPIBackwardsCompatibility(self): # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] # TODO(annarev): Make slide_dataset available in API. public_api_visitor.private_map['tf'] = ['slide_dataset'] traverse.traverse(api, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. expression = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*')) golden_file_list = file_io.get_matching_files(expression) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" ret_val = api_objects_pb2.TFAPIObject() text_format.Merge(file_io.read_file_to_string(filename), ret_val) return ret_val golden_proto_dict = { _FileNameToKey(filename): _ReadFileToProto(filename) for filename in golden_file_list } # Diff them. Do not fail if called with update. # If the test is run to update goldens, only report diffs but do not fail. self._AssertProtoDictEquals( golden_proto_dict, proto_dict, verbose=FLAGS.verbose_diffs, update_goldens=False, additional_missing_object_message= 'Check if tf_export decorator/call is missing for this symbol.')
def collect_function_renames(): """Looks for functions/classes that need to be renamed in TF 2.0. Returns: List of tuples of the form (current name, new name). """ # Set of rename lines to write to output file in the form: # 'tf.deprecated_name': 'tf.canonical_name' renames = set() v2_names = set() # All op names in TensorFlow 2.0 def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) if not hasattr(attr, '__dict__'): continue api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, []) api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, []) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name))) for name in api_names_v2: v2_names.add(name) visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) # It is possible that a different function is exported with the # same name. For e.g. when creating a different function to # rename arguments. Exclude it from renames in this case. renames = { name: new_name for name, new_name in renames.items() if name not in v2_names } return renames
def collect_constant_renames(): """Looks for constants that need to be renamed in TF 2.0. Args: output_file_path: File path to write output to. Any existing contents would be replaced. """ # Set of rename lines to write to output file in the form: # 'tf.deprecated_name': 'tf.canonical_name' rename_line_set = set() # _tf_api_names attribute name tensorflow_api_attr = tf_export.API_ATTRS[ tf_export.TENSORFLOW_API_NAME].names def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) if not hasattr(attr, '__dict__'): continue api_names = attr.__dict__.get(tensorflow_api_attr, []) deprecated_api_names = attr.__dict__.get( '_tf_deprecated_api_names', []) canonical_name = tf_export.get_canonical_name( api_names, deprecated_api_names) for name in deprecated_api_names: rename_line_set.add(' \'tf.%s\': \'tf.%s\'' % (name, canonical_name)) visitor = public_api.PublicAPIVisitor(visit) visitor.do_not_descend_map['tf'].append('contrib') visitor.do_not_descend_map['tf.compat'] = ['v1', 'v2'] traverse.traverse(tf, visitor) renames_file_text = '%srenames = {\n%s\n}\n' % (_FILE_HEADER, ',\n'.join( sorted(rename_line_set))) file_io.write_string_to_file(output_file_path, renames_file_text)
def testV2KeywordArgNames(self): # This test converts a call of the form: # tf.foo(arg1=0, arg2=1, ...) # to 2.0. Then, checks that converted function has valid argument names. if not hasattr(tf.compat, "v2"): return v2_arg_exceptions = { "verify_shape_is_now_always_true", # These arguments should not be used, they just specify # that a function takes named arguments. "keyword_required", "_sentinel", } v1_name_exceptions = { "tf.print", # requires print_function import } function_warnings = (tf_upgrade_v2.TFAPIChangeSpec().function_warnings) function_handles = (tf_upgrade_v2.TFAPIChangeSpec().function_handle) keyword_renames = ( tf_upgrade_v2.TFAPIChangeSpec().function_keyword_renames) # Visitor that converts to V2 and checks V2 argument names. def conversion_visitor(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) if not tf_inspect.isfunction(attr): continue names_v1 = tf_export.get_v1_names(attr) arg_names_v1 = get_args(attr) for name in names_v1: tf_name = "tf.%s" % name if tf_name in function_warnings or tf_name in function_handles: continue # These require manual change if tf_name in v1_name_exceptions: continue # Assert that arg names after converting to v2 are present in # v2 function. # 1. First, create an input of the form: # tf.foo(arg1=val1, arg2=val2, ...) args = ",".join([ "%s=%d" % (from_name, from_index) for from_index, from_name in enumerate(arg_names_v1) ]) text_input = "%s(%s)" % (tf_name, args) # 2. Convert the input to V2. _, _, _, text = self._upgrade(text_input) new_function_name, new_args = get_func_and_args_from_str( text) if new_function_name == "tf.compat.v1.%s" % name: if tf_name in keyword_renames: # If we rename arguments, new function must be available in 2.0. # We should not be using compat.v1 in this case. self.assertFalse( "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" % (new_function_name, text_input, text)) continue # 3. Verify V2 function and arguments. args_v2 = get_args(self.v2_symbols[new_function_name]) args_v2.extend(v2_arg_exceptions) for new_arg in new_args: self.assertIn( new_arg, args_v2, "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n" "Supported arguments: %s" % (new_arg, text_input, text, str(args_v2))) visitor = public_api.PublicAPIVisitor(conversion_visitor) visitor.do_not_descend_map["tf"].append("contrib") visitor.private_map["tf.compat"] = ["v1", "v2"] traverse.traverse(tf.compat.v1, visitor)