예제 #1
0
    def test_union_2x_1d(self):
        """Test union 2-way for 1d fibers"""

        ans = Fiber([0, 2, 3, 4, 6], [('AB', Payload(1), Payload(2)),
                                      ('AB', Payload(3), Payload(4)),
                                      ('B', Payload(0), Payload(5)),
                                      ('A', Payload(5), Payload(0)),
                                      ('A', Payload(7), Payload(0))])

        a_m = self.input["a1_m"]
        b_m = self.input["b1_m"]

        z_m1 = a_m | b_m
        z_m2 = Fiber.union(a_m, b_m)

        for test, z_m in enumerate([z_m1, z_m2]):
            with self.subTest(test=test):
                # Check for right answer
                self.assertEqual(z_m, ans)

                # Check that payloads are of correct type
                self.assertIsInstance(z_m[0].payload.value[1], Payload)
                self.assertIsInstance(z_m[2].payload.value[1], Payload)
                self.assertIsInstance(z_m[3].payload.value[2], Payload)

                # Check that default was set properly
                z_m_default = z_m.getDefault()
                self.assertEqual(z_m_default, Payload(('', 0, 0)))
                self.assertIsInstance(z_m_default, Payload)

                # Check final shape is correct
                z_m_shape = z_m.getShape()
                self.assertEqual(z_m_shape, [7])
예제 #2
0
    def test_union_2x_1d2d(self):
        """Test union 2-way for 1d/2d fibers"""

        ans = Fiber([0, 2, 4, 6],
                    [('AB', 1, Fiber([0, 2, 3], [2, 4, 5])),
                     ('AB', 3, Fiber([0, 1, 2], [3, 4, 6])),
                     ('AB', 5, Fiber([0, 1, 2, 3], [1, 2, 3, 4])),
                     ('A', 7, Fiber([], []))])

        a_m = self.input["a1_m"]
        b_m = self.input["b2_m"]

        z_m1 = a_m | b_m
        z_m2 = Fiber.union(a_m, b_m)

        for test, z_m in enumerate([z_m1, z_m2]):
            with self.subTest(test=test):
                # Check for right answer
                self.assertEqual(z_m, ans)

                # Check that payloads are of correct type
                self.assertIsInstance(z_m[0].payload.value[1], Payload)
                self.assertIsInstance(z_m[0].payload.value[2], Fiber)
                self.assertIsInstance(z_m[2].payload.value[1], Payload)
                self.assertIsInstance(z_m[3].payload.value[2], Fiber)

                # Check that default was set properly
                z_m_default = z_m.getDefault()
                self.assertEqual(z_m_default, Payload(('', 0, Fiber)))
                self.assertIsInstance(z_m_default, Payload)

                # Check final shape is correct (note it is 1-D)
                z_m_shape = z_m.getShape()
                self.assertEqual(z_m_shape, [7])