def visit(unused_path, unused_parent, children): """Visitor that collects TF 2.0 names.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: v2_names.add(name)
def _op_is_in_tf_version(op, version): if version == 1: return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS) elif version == 2: return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) else: raise ValueError('Expected version 1 or 2.')
def _op_is_in_tf_version(op, version): if version == 1: return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS) elif version == 2: return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) else: raise ValueError('Expected version 1 or 2.')
def visit(unused_path, unused_parent, children): """Visitor that collects rename strings to add to rename_line_set.""" for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v1 = tf_export.get_v1_names(attr) api_names_v2 = tf_export.get_v2_names(attr) deprecated_api_names = set(api_names_v1) - set(api_names_v2) for name in deprecated_api_names: renames.add((name, get_canonical_name(api_names_v2, name)))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function)), decorated_function)
def testExportSingleFunctionV1Only(self): export_decorator = tf_export.tf_export(v1=['nameA', 'nameB']) decorated_function = export_decorator(_test_function) self.assertEqual(decorated_function, _test_function) self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1) self.assertAllEqual(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEqual([], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol( decorated_function, add_prefix_to_v1_names=True)), decorated_function)
def symbol_collector(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: cls.v2_symbols["tf." + name] = attr
def symbol_collector(unused_path, unused_parent, children): for child in children: _, attr = tf_decorator.unwrap(child[1]) api_names_v2 = tf_export.get_v2_names(attr) for name in api_names_v2: cls.v2_symbols["tf." + name] = attr