예제 #1
0
    def testGetTraverseShallowStructure(self):
        scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7, )}, []]
        scalar_traverse_r = nest.get_traverse_shallow_structure(
            lambda s: not isinstance(s, tuple), scalar_traverse_input)
        self.assertEqual(scalar_traverse_r,
                         [True, True, False, [True, True], {
                             "a": False
                         }, []])
        nest.assert_shallow_structure(scalar_traverse_r, scalar_traverse_input)

        structure_traverse_input = [(1, [2]), ([1], 2)]
        fn = lambda s: (True, False) if isinstance(s, tuple) else True
        structure_traverse_r = nest.get_traverse_shallow_structure(
            fn, structure_traverse_input)
        self.assertEqual(structure_traverse_r, [(True, False),
                                                ([True], False)])
        nest.assert_shallow_structure(structure_traverse_r,
                                      structure_traverse_input)

        with self.assertRaisesRegexp(TypeError, "returned structure"):
            nest.get_traverse_shallow_structure(lambda _: [True], 0)

        with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
            nest.get_traverse_shallow_structure(lambda _: 1, [1])

        with self.assertRaisesRegexp(
                TypeError, "didn't return a depth=1 structure of bools"):
            nest.get_traverse_shallow_structure(lambda _: [1], [1])
예제 #2
0
    def testFlattenUpTo(self):
        # Shallow tree ends at scalar.
        input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
        shallow_tree = [[True, True], [False, True]]
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree,
                         [[2, 2], [3, 3], [4, 9], [5, 5]])
        self.assertEqual(flattened_shallow_tree, [True, True, False, True])

        # Shallow tree ends at string.
        input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
        shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        input_tree_flattened = nest.flatten(input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1),
                                                                ("b", 2),
                                                                ("c", 3),
                                                                ("d", 4)])
        self.assertEqual(input_tree_flattened,
                         ["a", 1, "b", 2, "c", 3, "d", 4])

        # Make sure dicts are correctly flattened, yielding values, not keys.
        input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
        shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree,
                         [1, {
                             "c": 2
                         }, 3, (4, 5)])

        # Namedtuples.
        ab_tuple = NestTest.ABTuple
        input_tree = ab_tuple(a=[0, 1], b=2)
        shallow_tree = ab_tuple(a=0, b=1)
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2])

        # Nested dicts, OrderedDicts and namedtuples.
        input_tree = collections.OrderedDict([
            ("a", ab_tuple(a=[0, {
                "b": 1
            }], b=2)), ("c", {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            })
        ])
        shallow_tree = input_tree
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
        shallow_tree = collections.OrderedDict([("a", 0),
                                                ("c", {
                                                    "d": 3,
                                                    "e": 1
                                                })])
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), 3,
            collections.OrderedDict([("f", 4)])
        ])
        shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
        input_tree_flattened_as_shallow_tree = nest.flatten_up_to(
            shallow_tree, input_tree)
        self.assertEqual(input_tree_flattened_as_shallow_tree, [
            ab_tuple(a=[0, {
                "b": 1
            }], b=2), {
                "d": 3,
                "e": collections.OrderedDict([("f", 4)])
            }
        ])

        ## Shallow non-list edge-case.
        # Using iterable elements.
        input_tree = ["input_tree"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = ["input_tree_0", "input_tree_1"]
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = [0]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        input_tree = [0, 1]
        shallow_tree = 9
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Both non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = "shallow_tree"
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = 0
        flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_input_tree, [input_tree])
        self.assertEqual(flattened_shallow_tree, [shallow_tree])

        ## Input non-list edge-case.
        # Using iterable elements.
        input_tree = "input_tree"
        shallow_tree = ["shallow_tree"]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'str'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = "input_tree"
        shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        # Using non-iterable elements.
        input_tree = 0
        shallow_tree = [9]
        expected_message = (
            "If shallow structure is a sequence, input must also "
            "be a sequence. Input has type: <(type|class) 'int'>.")
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = 0
        shallow_tree = [9, 8]
        with self.assertRaisesRegexp(TypeError, expected_message):
            flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
        flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
        self.assertEqual(flattened_shallow_tree, shallow_tree)

        input_tree = [(1, ), (2, ), 3]
        shallow_tree = [(1, ), (2, )]
        expected_message = nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
            input_length=len(input_tree), shallow_length=len(shallow_tree))
        with self.assertRaisesRegexp(ValueError, expected_message):  # pylint: disable=g-error-prone-assert-raises
            nest.assert_shallow_structure(shallow_tree, input_tree)
예제 #3
0
    def testAssertShallowStructure(self):
        inp_ab = ["a", "b"]
        inp_abc = ["a", "b", "c"]
        with self.assertRaisesWithLiteralMatch(  # pylint: disable=g-error-prone-assert-raises
                ValueError,
                nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
                    input_length=len(inp_ab), shallow_length=len(inp_abc))):
            nest.assert_shallow_structure(inp_abc, inp_ab)

        inp_ab1 = [(1, 1), (2, 2)]
        inp_ab2 = [[1, 1], [2, 2]]
        with self.assertRaisesWithLiteralMatch(  # pylint: disable=g-error-prone-assert-raises
                TypeError,
                nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
                    shallow_type=type(inp_ab2[0]),
                    input_type=type(inp_ab1[0]))):
            nest.assert_shallow_structure(inp_ab2, inp_ab1)
        nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)

        inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
        inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
        with self.assertRaisesWithLiteralMatch(  # pylint: disable=g-error-prone-assert-raises
                ValueError, nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
            nest.assert_shallow_structure(inp_ab2, inp_ab1)

        inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
        inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
        nest.assert_shallow_structure(inp_ab, inp_ba)

        # This assertion is expected to pass: two namedtuples with the same
        # name and field names are considered to be identical.
        inp_shallow = NestTest.SameNameab(1, 2)
        inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
        nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
        nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)