def testGetTraverseShallowStructure(self): scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7, )}, []] scalar_traverse_r = nest.get_traverse_shallow_structure( lambda s: not isinstance(s, tuple), scalar_traverse_input) self.assertEqual(scalar_traverse_r, [True, True, False, [True, True], { "a": False }, []]) nest.assert_shallow_structure(scalar_traverse_r, scalar_traverse_input) structure_traverse_input = [(1, [2]), ([1], 2)] fn = lambda s: (True, False) if isinstance(s, tuple) else True structure_traverse_r = nest.get_traverse_shallow_structure( fn, structure_traverse_input) self.assertEqual(structure_traverse_r, [(True, False), ([True], False)]) nest.assert_shallow_structure(structure_traverse_r, structure_traverse_input) with self.assertRaisesRegexp(TypeError, "returned structure"): nest.get_traverse_shallow_structure(lambda _: [True], 0) with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): nest.get_traverse_shallow_structure(lambda _: 1, [1]) with self.assertRaisesRegexp( TypeError, "didn't return a depth=1 structure of bools"): nest.get_traverse_shallow_structure(lambda _: [1], [1])
def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, { "c": 2 }, 3, (4, 5)]) # Namedtuples. ab_tuple = NestTest.ABTuple input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict([ ("a", ab_tuple(a=[0, { "b": 1 }], b=2)), ("c", { "d": 3, "e": collections.OrderedDict([("f", 4)]) }) ]) shallow_tree = input_tree input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", { "d": 3, "e": 1 })]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), 3, collections.OrderedDict([("f", 4)]) ]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), { "d": 3, "e": collections.OrderedDict([("f", 4)]) } ]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'str'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'int'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = [(1, ), (2, ), 3] shallow_tree = [(1, ), (2, )] expected_message = nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree)) with self.assertRaisesRegexp(ValueError, expected_message): # pylint: disable=g-error-prone-assert-raises nest.assert_shallow_structure(shallow_tree, input_tree)
def testAssertShallowStructure(self): inp_ab = ["a", "b"] inp_abc = ["a", "b", "c"] with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises ValueError, nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(inp_ab), shallow_length=len(inp_abc))): nest.assert_shallow_structure(inp_abc, inp_ab) inp_ab1 = [(1, 1), (2, 2)] inp_ab2 = [[1, 1], [2, 2]] with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises TypeError, nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format( shallow_type=type(inp_ab2[0]), input_type=type(inp_ab1[0]))): nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises ValueError, nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])): nest.assert_shallow_structure(inp_ab2, inp_ab1) inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) nest.assert_shallow_structure(inp_ab, inp_ba) # This assertion is expected to pass: two namedtuples with the same # name and field names are considered to be identical. inp_shallow = NestTest.SameNameab(1, 2) inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)