Пример #1
0
    def testFlattenAndPack(self):
        structure = ((3, 4), 5, (6, 7, (9, 10), 8))
        flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
        self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
        self.assertEqual(nest.pack_sequence_as(structure, flat),
                         (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
        structure = (NestTest.PointXY(x=4, y=2), ((NestTest.PointXY(x=1,
                                                                    y=0), ), ))
        flat = [4, 2, 1, 0]
        self.assertEqual(nest.flatten(structure), flat)
        restructured_from_flat = nest.pack_sequence_as(structure, flat)
        self.assertEqual(restructured_from_flat, structure)
        self.assertEqual(restructured_from_flat[0].x, 4)
        self.assertEqual(restructured_from_flat[0].y, 2)
        self.assertEqual(restructured_from_flat[1][0][0].x, 1)
        self.assertEqual(restructured_from_flat[1][0][0].y, 0)

        self.assertEqual([5], nest.flatten(5))
        self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

        self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
        self.assertEqual(np.array([5]),
                         nest.pack_sequence_as("scalar", [np.array([5])]))

        with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
            nest.pack_sequence_as("scalar", [4, 5])

        with self.assertRaisesRegexp(TypeError, "flat_sequence"):
            nest.pack_sequence_as([4, 5], "bad_sequence")

        with self.assertRaises(ValueError):
            nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Пример #2
0
 def testFlattenDictOrder(self, mapping_type):
     """`flatten` orders dicts by key, including OrderedDicts."""
     ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
     plain = {"d": 3, "b": 1, "a": 0, "c": 2}
     ordered_flat = nest.flatten(ordered)
     plain_flat = nest.flatten(plain)
     self.assertEqual([0, 1, 2, 3], ordered_flat)
     self.assertEqual([0, 1, 2, 3], plain_flat)
Пример #3
0
    def testAttrsFlattenAndPack(self):
        if attr is None:
            self.skipTest("attr module is unavailable.")

        field_values = [1, 2]
        sample_attr = NestTest.SampleAttr(*field_values)
        self.assertFalse(nest._is_attrs(field_values))
        self.assertTrue(nest._is_attrs(sample_attr))
        flat = nest.flatten(sample_attr)
        self.assertEqual(field_values, flat)
        restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
        self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
        self.assertEqual(restructured_from_flat, sample_attr)

        # Check that flatten fails if attributes are not iterable
        with self.assertRaisesRegexp(TypeError, "object is not iterable"):
            flat = nest.flatten(NestTest.BadAttr())
Пример #4
0
  def testFlattenAndPack(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(
        nest.pack_sequence_as(structure, flat),
        (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
    structure = (NestTest.PointXY(x=4, y=2), ((NestTest.PointXY(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(nest.flatten(structure), flat)
    restructured_from_flat = nest.pack_sequence_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], nest.flatten(5))
    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

    # NOTE(taylorrobie): The second pattern is for version compatibility.
    with self.assertRaisesRegex(
        ValueError,
        "(nest cannot guarantee that it is safe to map one to the other.)|"
        "(Structure is a scalar)"):
      nest.pack_sequence_as("scalar", [4, 5])

    # NOTE(taylorrobie): The second pattern is for version compatibility.
    with self.assertRaisesRegex(
        TypeError,
        "(Attempted to pack value:\n  bad_sequence\ninto a sequence, but found "
        "incompatible type `<(type|class) 'str'>` instead.)|"
        "(flat_sequence must be a sequence)"):
      nest.pack_sequence_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Пример #5
0
  def testFlattenAndPack_withDicts(self):
    # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
    mess = [
        "z",
        NestTest.Abc(3, 4), {
            "d": _CustomMapping({
                41: 4
            }),
            "c": [
                1,
                collections.OrderedDict([
                    ("b", 3),
                    ("a", 2),
                ]),
            ],
            "b": 5
        }, 17
    ]

    flattened = nest.flatten(mess)
    self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])

    structure_of_mess = [
        14,
        NestTest.Abc("a", True),
        {
            "d": _CustomMapping({
                41: 42
            }),
            "c": [
                0,
                collections.OrderedDict([
                    ("b", 9),
                    ("a", 8),
                ]),
            ],
            "b": 3
        },
        "hi everybody",
    ]

    unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
    self.assertEqual(unflattened, mess)

    # Check also that the OrderedDict was created, with the correct key order.
    unflattened_ordered_dict = unflattened[2]["c"][1]
    self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
    self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])

    unflattened_custom_mapping = unflattened[2]["d"]
    self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
    self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
Пример #6
0
    def testFlattenWithTuplePathsUpTo(self):
        def get_paths_and_values(shallow_tree,
                                 input_tree,
                                 check_subtrees_length=True):
            path_value_pairs = nest.flatten_with_tuple_paths_up_to(
                shallow_tree,
                input_tree,
                check_subtrees_length=check_subtrees_length)
            paths = [p for p, _ in path_value_pairs]
            values = [v for _, v in path_value_pairs]
            return paths, values

        # Shallow tree ends at scalar.
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [(0, 0), (0, 1), (1, 0),
                                                      (1, 1)])
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree_paths, [(0, 0), (0, 1), (1, 0),
                                                        (1, 1)])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Shallow tree ends at string.
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        input_tree_flattened_paths = [
            p for p, _ in nest.flatten_with_tuple_paths(input_tree)
        ]
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])

        self.assertEqual(input_tree_flattened_paths, [(0, 0, 0), (0, 0, 1),
                                                      (0, 1, 0, 0),
                                                      (0, 1, 0, 1),
                                                      (0, 1, 1, 0, 0),
                                                      (0, 1, 1, 0, 1),
                                                      (0, 1, 1, 1, 0, 0),
                                                      (0, 1, 1, 1, 0, 1)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        # Make sure dicts are correctly flattened, yielding values, not keys.
        input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
        shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", ), ("b", ), ("d", 0), ("d", 1)])
        self.assertEqual(input_tree_flattened_as_shallow_tree,
                         [1, {
                             "c": 2
                         }, 3, (4, 5)])

        # Namedtuples.
        ab_tuple = collections.namedtuple("ab_tuple", "a, b")
        input_tree = ab_tuple(a=[0, 1], b=2)
        shallow_tree = ab_tuple(a=0, b=1)
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ),
                                                                      ("b", )])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Nested dicts, OrderedDicts and namedtuples.
        input_tree = collections.OrderedDict([
            ("a", ab_tuple(a=[0, {
                "b": 1
            }], b=2)), ("c", {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            })
        ])
        shallow_tree = input_tree
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", "a", 0), ("a", "a", 1, "b"), ("a", "b"),
                          ("c", "d"), ("c", "e", "f")])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
        shallow_tree = collections.OrderedDict([("a", 0),
                                                ("c", {
                                                    "d": 3,
                                                    "e": 1
                                                })])
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
                         [("a", ), ("c", "d"), ("c", "e")])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), 3,
            collections.OrderedDict([("f", 4)])
        ])
        shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
        (input_tree_flattened_as_shallow_tree_paths,
         input_tree_flattened_as_shallow_tree) = get_paths_and_values(
             shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", ),
                                                                      ("c", )])
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            }
        ])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Test case where len(shallow_tree) < len(input_tree)
        input_tree = {"a": "A", "b": "B", "c": "C"}
        shallow_tree = {"a": 1, "c": 2}

        with self.assertRaisesWithLiteralMatch(  # pylint: disable=g-error-prone-assert-raises
                ValueError,
                nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
                    input_length=len(input_tree),
                    shallow_length=len(shallow_tree))):
            get_paths_and_values(shallow_tree, input_tree)

        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(
             shallow_tree, input_tree, check_subtrees_length=False)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [("a", ), ("c", )])
        self.assertEqual(flattened_input_tree, ["A", "C"])
        self.assertEqual(flattened_shallow_tree_paths, [("a", ), ("c", )])
        self.assertEqual(flattened_shallow_tree, [1, 2])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        (flattened_input_tree_paths,
         flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree_paths, [()])
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree_paths, [()])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesWithLiteralMatch(
                TypeError,
                nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
                    type(input_tree))):
            (flattened_input_tree_paths,
             flattened_input_tree) = get_paths_and_values(
                 shallow_tree, input_tree)
        (flattened_shallow_tree_paths,
         flattened_shallow_tree) = get_paths_and_values(
             shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree_paths, [(0, ), (1, )])
        self.assertEqual(flattened_shallow_tree, shallow_tree)
Пример #7
0
    def testFlattenUpTo(self):
        # Shallow tree ends at scalar.
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Shallow tree ends at string.
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        # Make sure dicts are correctly flattened, yielding values, not keys.
        input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
        shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree,
                         [1, {
                             "c": 2
                         }, 3, (4, 5)])

        # Namedtuples.
        ab_tuple = NestTest.ABTuple
        input_tree = ab_tuple(a=[0, 1], b=2)
        shallow_tree = ab_tuple(a=0, b=1)
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Nested dicts, OrderedDicts and namedtuples.
        input_tree = collections.OrderedDict([
            ("a", ab_tuple(a=[0, {
                "b": 1
            }], b=2)), ("c", {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            })
        ])
        shallow_tree = input_tree
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
        shallow_tree = collections.OrderedDict([("a", 0),
                                                ("c", {
                                                    "d": 3,
                                                    "e": 1
                                                })])
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), 3,
            collections.OrderedDict([("f", 4)])
        ])
        shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            }
        ])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'str'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'int'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = [(1, ), (2, ), 3]
        shallow_tree = [(1, ), (2, )]
        expected_message = nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
            input_length=len(input_tree), shallow_length=len(shallow_tree))
        with self.assertRaisesRegexp(ValueError, expected_message):  # pylint: disable=g-error-prone-assert-raises
            nest.assert_shallow_structure(shallow_tree, input_tree)
Пример #8
0
 def testFlatten_stringIsNotFlattened(self):
     structure = "lots of letters"
     flattened = nest.flatten(structure)
     self.assertLen(flattened, 1)
     unflattened = nest.pack_sequence_as("goodbye", flattened)
     self.assertEqual(structure, unflattened)
Пример #9
0
 def testFlatten_numpyIsNotFlattened(self):
     structure = np.array([1, 2, 3])
     flattened = nest.flatten(structure)
     self.assertLen(flattened, 1)