def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual(nest.pack_sequence_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) structure = (NestTest.PointXY(x=4, y=2), ((NestTest.PointXY(x=1, y=0), ), )) flat = [4, 2, 1, 0] self.assertEqual(nest.flatten(structure), flat) restructured_from_flat = nest.pack_sequence_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], nest.flatten(5)) self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) self.assertEqual(np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): nest.pack_sequence_as("scalar", [4, 5]) with self.assertRaisesRegexp(TypeError, "flat_sequence"): nest.pack_sequence_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenDictOrder(self, mapping_type): """`flatten` orders dicts by key, including OrderedDicts.""" ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = nest.flatten(ordered) plain_flat = nest.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat)
def testAttrsFlattenAndPack(self): if attr is None: self.skipTest("attr module is unavailable.") field_values = [1, 2] sample_attr = NestTest.SampleAttr(*field_values) self.assertFalse(nest._is_attrs(field_values)) self.assertTrue(nest._is_attrs(sample_attr)) flat = nest.flatten(sample_attr) self.assertEqual(field_values, flat) restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) self.assertEqual(restructured_from_flat, sample_attr) # Check that flatten fails if attributes are not iterable with self.assertRaisesRegexp(TypeError, "object is not iterable"): flat = nest.flatten(NestTest.BadAttr())
def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual( nest.pack_sequence_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) structure = (NestTest.PointXY(x=4, y=2), ((NestTest.PointXY(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(nest.flatten(structure), flat) restructured_from_flat = nest.pack_sequence_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], nest.flatten(5)) self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) self.assertEqual( np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) # NOTE(taylorrobie): The second pattern is for version compatibility. with self.assertRaisesRegex( ValueError, "(nest cannot guarantee that it is safe to map one to the other.)|" "(Structure is a scalar)"): nest.pack_sequence_as("scalar", [4, 5]) # NOTE(taylorrobie): The second pattern is for version compatibility. with self.assertRaisesRegex( TypeError, "(Attempted to pack value:\n bad_sequence\ninto a sequence, but found " "incompatible type `<(type|class) 'str'>` instead.)|" "(flat_sequence must be a sequence)"): nest.pack_sequence_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenAndPack_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. mess = [ "z", NestTest.Abc(3, 4), { "d": _CustomMapping({ 41: 4 }), "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = nest.flatten(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) structure_of_mess = [ 14, NestTest.Abc("a", True), { "d": _CustomMapping({ 41: 42 }), "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] unflattened = nest.pack_sequence_as(structure_of_mess, flattened) self.assertEqual(unflattened, mess) # Check also that the OrderedDict was created, with the correct key order. unflattened_ordered_dict = unflattened[2]["c"][1] self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) unflattened_custom_mapping = unflattened[2]["d"] self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
def testFlattenWithTuplePathsUpTo(self): def get_paths_and_values(shallow_tree, input_tree, check_subtrees_length=True): path_value_pairs = nest.flatten_with_tuple_paths_up_to( shallow_tree, input_tree, check_subtrees_length=check_subtrees_length) paths = [p for p, _ in path_value_pairs] values = [v for _, v in path_value_pairs] return paths, values # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) input_tree_flattened_paths = [ p for p, _ in nest.flatten_with_tuple_paths(input_tree) ] input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened_paths, [(0, 0, 0), (0, 0, 1), (0, 1, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("b", ), ("d", 0), ("d", 1)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, { "c": 2 }, 3, (4, 5)]) # Namedtuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("b", )]) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict([ ("a", ab_tuple(a=[0, { "b": 1 }], b=2)), ("c", { "d": 3, "e": collections.OrderedDict([("f", 4)]) }) ]) shallow_tree = input_tree (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", "a", 0), ("a", "a", 1, "b"), ("a", "b"), ("c", "d"), ("c", "e", "f")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", { "d": 3, "e": 1 })]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("c", "d"), ("c", "e")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), 3, collections.OrderedDict([("f", 4)]) ]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ), ("c", )]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), { "d": 3, "e": collections.OrderedDict([("f", 4)]) } ]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Test case where len(shallow_tree) < len(input_tree) input_tree = {"a": "A", "b": "B", "c": "C"} shallow_tree = {"a": 1, "c": 2} with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises ValueError, nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree))): get_paths_and_values(shallow_tree, input_tree) (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree, check_subtrees_length=False) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [("a", ), ("c", )]) self.assertEqual(flattened_input_tree, ["A", "C"]) self.assertEqual(flattened_shallow_tree_paths, [("a", ), ("c", )]) self.assertEqual(flattened_shallow_tree, [1, 2]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] with self.assertRaisesWithLiteralMatch( TypeError, nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesWithLiteralMatch( TypeError, nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] with self.assertRaisesWithLiteralMatch( TypeError, nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, )]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesWithLiteralMatch( TypeError, nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values( shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values( shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )]) self.assertEqual(flattened_shallow_tree, shallow_tree)
def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) input_tree_flattened = nest.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, { "c": 2 }, 3, (4, 5)]) # Namedtuples. ab_tuple = NestTest.ABTuple input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict([ ("a", ab_tuple(a=[0, { "b": 1 }], b=2)), ("c", { "d": 3, "e": collections.OrderedDict([("f", 4)]) }) ]) shallow_tree = input_tree input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", { "d": 3, "e": 1 })]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), 3, collections.OrderedDict([("f", 4)]) ]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to( shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ ab_tuple(a=[0, { "b": 1 }], b=2), { "d": 3, "e": collections.OrderedDict([("f", 4)]) } ]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'str'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] expected_message = ( "If shallow structure is a sequence, input must also " "be a sequence. Input has type: <(type|class) 'int'>.") with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesRegexp(TypeError, expected_message): flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = [(1, ), (2, ), 3] shallow_tree = [(1, ), (2, )] expected_message = nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(input_tree), shallow_length=len(shallow_tree)) with self.assertRaisesRegexp(ValueError, expected_message): # pylint: disable=g-error-prone-assert-raises nest.assert_shallow_structure(shallow_tree, input_tree)
def testFlatten_stringIsNotFlattened(self): structure = "lots of letters" flattened = nest.flatten(structure) self.assertLen(flattened, 1) unflattened = nest.pack_sequence_as("goodbye", flattened) self.assertEqual(structure, unflattened)
def testFlatten_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = nest.flatten(structure) self.assertLen(flattened, 1)