def test_lr_layer(self):
        ''' LR layers. '''
        layer = self.layers['LR']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 4), (2, 2), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 2), (4, 1),
                                         PhyDim2(8, 4))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(5, 5), (1, 1), (1, 2),
                                         PhyDim2(2, 4))

        filter_nodes = frozenset()

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        true_nhops = self._true_unit_nhops(layer, nr, part, filter_nodes,
                                           ilayout, olayout)

        self.assertListEqual(nhops, true_nhops)
    def test_ofmap_local(self):
        ''' With locally stored ofmaps. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((4, 1), (1, 1), (1, 4), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(3, 3),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = DataLayout(frngs=(FmapRange(
            (0, ) * 4,
            (self.batch_size, layer.nofm, layer.hofm, layer.wofm)), ),
                             regions=(nr, ),
                             parts=(part, ))

        filter_nodes = frozenset([PhyDim2(3, -3)])

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        self.assertEqual(nhops[de.OFM], 0)
    def test_fc_layer(self):
        ''' FC layers. '''
        layer = self.layers['FC']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((8, 1), (1, 1), (1, 2), (1, 4)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, 10), (2, 4, 1, 2),
                                         PhyDim2(4, 4))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(1, 1), (2, 2, 1, 1),
                                         PhyDim2(2, 2))

        filter_nodes = [PhyDim2(0, 0), PhyDim2(0, 7)]

        nhops = partition.part_layer_unit_nhops(layer, self.batch_size, part,
                                                nr, filter_nodes, ilayout,
                                                olayout, self.options['BASE'])

        true_nhops = self._true_unit_nhops(layer, part, nr, filter_nodes,
                                           ilayout, olayout)

        self.assertListEqual(nhops, true_nhops)
    def test_origin(self):
        ''' Origin. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((1, 1), (1, 1), (1, 1), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(3, 3),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(3, 3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(3, -3)])

        nhops_1 = partition.unit_nhops_to_proc_region(layer, self.batch_size,
                                                      nr, part, filter_nodes,
                                                      ilayout, olayout,
                                                      self.options['BASE'])

        nr = NodeRegion(origin=PhyDim2(6, 6),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-6, -6), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(6, 6), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(6, -6)])

        nhops_2 = partition.unit_nhops_to_proc_region(layer, self.batch_size,
                                                      nr, part, filter_nodes,
                                                      ilayout, olayout,
                                                      self.options['BASE'])

        self.assertListEqual(nhops_2, [n * 2 for n in nhops_1])
Exemplo n.º 5
0
    def test_dim_fil(self):
        ''' Accessor dim with different partitioning for FIL. '''
        # Adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.INPP, pe.OUTP, pe.BATP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)
        # Adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.INPP, pe.OFMP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)

        # Not adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
        # Not adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)

        # Only BATP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (1, 1), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)
        # Only OFMP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (1, 1), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
    def test_use_fwd(self):
        ''' Use access forwarding. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 4), (1, 2), (2, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        far_dist = 1000

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-far_dist, 0), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(0, -far_dist), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(far_dist, 0), PhyDim2(0, far_dist)])

        nhops_base = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['BASE'])
        nhops_accfwd = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['ACCFWD'])
        nhops_bufshr = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['BUFSHR'])

        for dce in range(de.NUM):
            self.assertEqual(nhops_accfwd[dce], nhops_bufshr[dce])

        # In the basic access scheme, FIL and IFM are independently fetched,
        # resulting in repeating remote fetch. OFM are merged locally and only
        # stored back remotely once.
        self.assertGreater(
            nhops_base[de.FIL],
            layer.total_filter_size() * far_dist * part.size(pe.BATP) *
            part.size(pe.OFMP) * 0.8)
        self.assertGreater(
            nhops_base[de.IFM],
            layer.total_ifmap_size(self.batch_size) * far_dist *
            part.size(pe.OUTP) * 0.8)

        p_layer, p_batch_size, _ = part.part_layer(layer, self.batch_size)
        # With forwarding, everyone is only remotely fetched once.
        self.assertLess(
            nhops_accfwd[de.FIL],
            p_layer.total_filter_size() * part.size(pe.INPP, pe.OUTP) *
            (far_dist + nr.dim.size()))
        self.assertLess(
            nhops_accfwd[de.IFM],
            p_layer.total_ifmap_size(p_batch_size) *
            part.size(pe.INPP, pe.OFMP, pe.BATP) * (far_dist + nr.dim.size()))
        self.assertLess(
            nhops_accfwd[de.OFM],
            p_layer.total_ofmap_size(p_batch_size) *
            part.size(pe.OUTP, pe.OFMP, pe.BATP) * (far_dist + nr.dim.size()))
    def test_small(self):
        ''' Small case with hand calculation. '''
        layer = ConvLayer(6, 8, 16, 3)
        assert self.batch_size == 8

        # i (0, 0), (2, 0): (0, 0, 0, 0) -- (4, 6, 10, 10)
        #   (0, 1), (2, 1): (0, 0, 0, 8) -- (4, 6, 10, 18)
        #   (0, 2), (2, 2): (4, 0, 0, 0) -- (8, 6, 10, 10)
        #   (0, 3), (2, 3): (4, 0, 0, 8) -- (8, 6, 10, 18)
        #   (1, 0), (3, 0): (0, 0, 8, 0) -- (4, 6, 18, 10)
        #   (1, 1), (3, 1): (0, 0, 8, 8) -- (4, 6, 18, 18)
        #   (1, 2), (3, 2): (4, 0, 8, 0) -- (8, 6, 18, 10)
        #   (1, 3), (3, 3): (4, 0, 8, 8) -- (8, 6, 18, 18)
        # o (0, 0): (0, 0, 0, 0) -- (4, 4, 8, 8)
        #   (0, 1): (0, 0, 0, 8) -- (4, 4, 8, 16)
        #   (0, 2): (4, 0, 0, 0) -- (8, 4, 8, 8)
        #   (0, 3): (4, 0, 0, 8) -- (8, 4, 8, 16)
        #   (1, 0): (0, 0, 8, 0) -- (4, 4, 16, 8)
        #   (1, 1): (0, 0, 8, 8) -- (4, 4, 16, 16)
        #   (1, 2): (4, 0, 8, 0) -- (8, 4, 16, 8)
        #   (1, 3): (4, 0, 8, 8) -- (8, 4, 16, 16)
        #   (2, 0): (0, 4, 0, 0) -- (4, 8, 8, 8)
        #   (2, 1): (0, 4, 0, 8) -- (4, 8, 8, 16)
        #   (2, 2): (4, 4, 0, 0) -- (8, 8, 8, 8)
        #   (2, 3): (4, 4, 0, 8) -- (8, 8, 8, 16)
        #   (3, 0): (0, 4, 8, 0) -- (4, 8, 16, 8)
        #   (3, 1): (0, 4, 8, 8) -- (4, 8, 16, 16)
        #   (3, 2): (4, 4, 8, 0) -- (8, 8, 16, 8)
        #   (3, 3): (4, 4, 8, 8) -- (8, 8, 16, 16)
        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 2), (1, 2), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        # (0, 0, 0, 0) -- (4, 6, 18, 9): (-2, -2)
        # (0, 0, 0, 9) -- (4, 6, 18, 18): (-2, -1)
        # (4, 0, 0, 0) -- (8, 6, 18, 9): (-1, -2)
        # (4, 0, 0, 9) -- (8, 6, 18, 18): (-1, -1)
        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-2, -2), (2, 1), (1, 1),
                                         PhyDim2(2, 2))

        # (0, 0, 0, 0) -- (8, 4, 16, 8): (2, 2)
        # (0, 0, 0, 8) -- (8, 4, 16, 16): (2, 3)
        # (0, 4, 0, 0) -- (8, 8, 16, 8): (3, 2)
        # (0, 4, 0, 8) -- (8, 8, 16, 16): (3, 3)
        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(2, 2), (1, 1), (2, 1),
                                         PhyDim2(2, 2))

        filter_nodes = frozenset([PhyDim2(0, 0)])

        # filter: (0, 0) -> all, 6 * 4 * 3 * 3

        # ifmap: (-2, -2) -> (0, 0), (2, 0): 4 * 6 * 10 * 9
        #                 -> (0, 1), (2, 1): 4 * 6 * 10 * (9 - 8)
        #                 -> (1, 0), (3, 0): 4 * 6 * (18 - 8) * 9
        #                 -> (1, 1), (3, 1): 4 * 6 * (18 - 8) * (9 - 8)
        #        (-2, -1) -> (0, 0), (2, 0): 4 * 6 * 10 * (10 - 9)
        #                 -> (0, 1), (2, 1): 4 * 6 * 10 * (18 - 9)
        #                 -> (1, 0), (3, 0): 4 * 6 * (18 - 8) * (10 - 9)
        #                 -> (1, 1), (3, 1): 4 * 6 * (18 - 8) * (18 - 9)
        #        (-1, -2) -> (0, 2), (2, 2): (8 - 4) * 6 * 10 * 9
        #                 -> (0, 3), (2, 3): (8 - 4) * 6 * 10 * (9 - 8)
        #                 -> (1, 2), (3, 2): (8 - 4) * 6 * (18 - 8) * 9
        #                 -> (1, 3), (3, 3): (8 - 4) * 6 * (18 - 8) * (9 - 8)
        #        (-1, -1) -> (0, 2), (2, 2): (8 - 4) * 6 * 10 * (10 - 9)
        #                 -> (0, 3), (2, 3): (8 - 4) * 6 * 10 * (18 - 9)
        #                 -> (1, 2), (3, 2): (8 - 4) * 6 * (18 - 8) * (10 - 9)
        #                 -> (1, 3), (3, 3): (8 - 4) * 6 * (18 - 8) * (18 - 9)

        # ofmap: (2, 2) -> (0, 0):
        #               -> (0, 2):
        #               -> (1, 0):
        #               -> (1, 2): 4 * 4 * 8 * 8
        #        (2, 3) -> (0/1, 1/3)
        #        (3, 2) -> (2/3, 0/2)
        #        (3, 3) -> (2/3, 1/3)

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        self.assertEqual(
            nhops[de.FIL],
            6 * 4 * 3 * 3 * sum(h + w for h in range(4) for w in range(4)))
        self.assertEqual(
            nhops[de.IFM], 4 * 6 * 10 *
            ((4 + 6) * 9 + (5 + 7) * 1 + (5 + 7) * 9 + (6 + 8) * 1 +
             (3 + 5) * 1 + (4 + 6) * 9 + (4 + 6) * 1 + (5 + 7) * 9 +
             (5 + 7) * 9 + (6 + 8) * 1 + (6 + 8) * 9 + (7 + 9) * 1 +
             (4 + 6) * 1 + (5 + 7) * 9 + (5 + 7) * 1 + (6 + 8) * 9))
        self.assertEqual(
            nhops[de.OFM], 4 * 4 * 8 * 8 * ((4 + 2 + 3 + 1) + (4 + 2 + 3 + 1) +
                                            (3 + 1 + 2 + 0) + (3 + 1 + 2 + 0)))
Exemplo n.º 8
0
class TestPartitionScheme(unittest.TestCase):
    ''' Tests for PartitionScheme. '''

    def setUp(self):
        self.ps1 = PartitionScheme(order=[pe.BATP, pe.OUTP, pe.OFMP, pe.INPP],
                                   pdims=[(2, 3), (3, 1), (1, 5), (5, 2)])
        self.ps2 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(2, 2), (5, 5), (3, 3), (1, 1)])

    def test_invalid_order(self):
        ''' Invalid order. '''
        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=range(pe.NUM - 1),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=[0] + range(2, pe.NUM),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=[1] + range(pe.NUM),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=range(4, 4 + pe.NUM),
                                pdims=[(1, 1)] * pe.NUM)

    def test_invalid_pdims(self):
        ''' Invalid pdims. '''
        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*pdims.*'):
            _ = PartitionScheme(order=range(pe.NUM),
                                pdims=[(1, 1)] * (pe.NUM - 1))

        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*pdims.*'):
            _ = PartitionScheme(order=range(pe.NUM),
                                pdims=[(1, 1), (1, 1), (2, 1, 1), (1, 1)])

    def test_dim(self):
        ''' Get dim. '''
        self.assertEqual(self.ps1.dim(0), PhyDim2(2, 3))
        self.assertEqual(self.ps1.dim(1), PhyDim2(3, 1))
        self.assertEqual(self.ps1.dim(2), PhyDim2(1, 5))
        self.assertEqual(self.ps1.dim(3), PhyDim2(5, 2))

        self.assertEqual(self.ps2.dim(0), PhyDim2(2, 2))
        self.assertEqual(self.ps2.dim(1), PhyDim2(5, 5))
        self.assertEqual(self.ps2.dim(2), PhyDim2(3, 3))
        self.assertEqual(self.ps2.dim(3), PhyDim2(1, 1))

        self.assertEqual(self.ps1.dim(0, 1, 2),
                         PhyDim2(2, 3) * PhyDim2(3, 1) * PhyDim2(1, 5))
        self.assertEqual(self.ps1.dim(),
                         PhyDim2(2, 3) * PhyDim2(3, 1)
                         * PhyDim2(1, 5) * PhyDim2(5, 2))

        self.assertEqual(self.ps1.dim(0, 1, 2), self.ps1.dim(1, 2, 0))

    def test_dim_invalid_index(self):
        ''' Get dim invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.ps1.dim(pe.NUM + 1)

        with self.assertRaises(IndexError):
            _ = self.ps1.dim(0, 1, pe.NUM)

    def test_size(self):
        ''' Get size. '''
        for l in range(1, pe.NUM):
            for args in itertools.combinations(range(pe.NUM), l):
                self.assertEqual(self.ps1.dim(*args).size(), self.ps1.size(*args))

    def test_size_invalid_index(self):
        ''' Get size invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.ps1.size(pe.NUM + 1)

        with self.assertRaises(IndexError):
            _ = self.ps1.size(0, 1, pe.NUM)

    def test_gen_pidx(self):
        ''' Generate pidx. '''
        for ps in [self.ps1, self.ps2]:

            pidx_list = list(ps.gen_pidx())

            # Num. of pidx == size.
            self.assertEqual(len(pidx_list), ps.size())
            self.assertEqual(len(set(pidx_list)), ps.size())

            for i, idx_list in enumerate(zip(*pidx_list)):
                cnt = collections.Counter(idx_list)
                # Num. of different pidx == size.
                self.assertEqual(len(cnt), ps.size(i))
                # Num. of repeated pidx == other sizes.
                for c in cnt.values():
                    self.assertEqual(c, ps.size() // ps.size(i))

    def test_coordinate(self):
        ''' Get coordinate. '''
        nr1 = NodeRegion(origin=PhyDim2(0, 0), dim=self.ps1.dim(),
                         type=NodeRegion.PROC)
        nr2 = NodeRegion(origin=PhyDim2(0, 0), dim=self.ps2.dim(),
                         type=NodeRegion.PROC)

        for ps, nr in zip([self.ps1, self.ps2], [nr1, nr2]):

            coord_list = [ps.coordinate(nr, pidx) for pidx in ps.gen_pidx()]

            self.assertEqual(len(coord_list), ps.size())
            self.assertEqual(len(set(coord_list)), ps.size())

            for coord in coord_list:
                self.assertGreaterEqual(coord.h, 0)
                self.assertGreaterEqual(coord.w, 0)
                self.assertLess(coord.h, ps.dim().h)
                self.assertLess(coord.w, ps.dim().w)

        pidx = [PhyDim2(0, 0)] * pe.NUM
        pidx[pe.OUTP] = PhyDim2(1, 1)

        self.assertEqual(self.ps1.coordinate(nr1, pidx),
                         self.ps1.dim(pe.OFMP, pe.INPP)
                         * PhyDim2(1, 1))

        self.assertEqual(self.ps2.coordinate(nr2, pidx),
                         self.ps2.dim(pe.OFMP, pe.BATP, pe.INPP)
                         * PhyDim2(1, 1))

    def test_part_layer(self):
        ''' Get part_layer. '''
        batch_size = 16

        layer = ConvLayer(32, 128, 28, 3)
        p_layer, p_batch_size, p_occ = self.ps1.part_layer(layer, batch_size)
        self.assertGreaterEqual(p_layer.hofm * self.ps1.dim(pe.OFMP).h,
                                layer.hofm, 'part_layer: Conv: hofm')
        self.assertGreaterEqual(p_layer.wofm * self.ps1.dim(pe.OFMP).w,
                                layer.wofm, 'part_layer: Conv: wofm')
        self.assertGreaterEqual(p_layer.nofm * self.ps1.size(pe.OUTP),
                                layer.nofm, 'part_layer: Conv: nofm')
        self.assertGreaterEqual(p_layer.nifm * self.ps1.size(pe.INPP),
                                layer.nifm, 'part_layer: Conv: nifm')
        self.assertGreaterEqual(p_batch_size * self.ps1.size(pe.BATP),
                                16, 'part_layer: Conv: batch_size')
        self.assertAlmostEqual(p_occ, 1. * (32 * 128 * 28 * 28 * 16)
                               / (4 * 22 * 10 * 28 * 4 * self.ps1.size()))

        layer = PoolingLayer(128, 112, 2)
        p_layer, p_batch_size, p_occ = self.ps2.part_layer(layer, batch_size)
        self.assertGreaterEqual(p_layer.hofm * self.ps2.dim(pe.OFMP).h,
                                layer.hofm, 'part_layer: Pooling: hofm')
        self.assertGreaterEqual(p_layer.wofm * self.ps2.dim(pe.OFMP).w,
                                layer.wofm, 'part_layer: Pooling: wofm')
        self.assertGreaterEqual(p_layer.nofm * self.ps2.size(pe.OUTP),
                                layer.nofm, 'part_layer: Pooling: nofm')
        self.assertGreaterEqual(p_layer.nifm, p_layer.nofm,
                                'part_layer: Pooling: nifm')
        self.assertGreaterEqual(p_batch_size * self.ps2.size(pe.BATP),
                                16, 'part_layer: Pooling: batch_size')
        self.assertAlmostEqual(p_occ, 1. * (128 * 112 * 112 * 16)
                               / (32 * 23 * 23 * 2 * self.ps2.size()))

    def test_part_layer_invalid_inpart(self):
        ''' Get part_layer invalid INPP. '''
        with self.assertRaisesRegexp(ValueError, 'PartitionScheme: .*input.*'):
            _ = self.ps1.part_layer(PoolingLayer(self.ps1.size(pe.OUTP),
                                                 self.ps1.size(pe.OFMP), 2),
                                    self.ps1.size(pe.BATP))

    def test_part_layer_invalid_type(self):
        ''' Get part_layer invalid type. '''
        class _Layer(Layer):
            def input_layer(self):
                return self
            def ops_per_neuron(self):
                return 0

        with self.assertRaisesRegexp(TypeError, 'PartitionScheme: .*layer.*'):
            _ = self.ps1.part_layer(_Layer(self.ps1.size(pe.OUTP),
                                           self.ps1.size(pe.OFMP)),
                                    self.ps1.size(pe.BATP))

    def test_repr(self):
        ''' __repr__. '''
        # pylint: disable=eval-used
        self.assertEqual(eval(repr(self.ps1)), self.ps1)
        self.assertEqual(eval(repr(self.ps2)), self.ps2)
Exemplo n.º 9
0
class TestPartitionScheme(unittest.TestCase):
    ''' Tests for PartitionScheme. '''
    def setUp(self):
        self.ps1 = PartitionScheme(order=[pe.BATP, pe.OUTP, pe.OFMP, pe.INPP],
                                   pdims=[(2, 3), (3, 1), (1, 5), (5, 2)])
        self.ps2 = PartitionScheme(order=list(range(pe.NUM)),
                                   pdims=[(2, 2), (5, 5), (3, 3), (1, 1)])

        self.nr1 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps1.dim(),
                              type=NodeRegion.PROC)
        self.nr2 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps2.dim(),
                              type=NodeRegion.PROC)

    def test_invalid_order(self):
        ''' Invalid order. '''
        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=list(range(pe.NUM - 1)),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=[0] + list(range(2, pe.NUM)),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=[1] + list(range(pe.NUM)),
                                pdims=[(1, 1)] * pe.NUM)

        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*order.*'):
            _ = PartitionScheme(order=list(range(4, 4 + pe.NUM)),
                                pdims=[(1, 1)] * pe.NUM)

    def test_invalid_pdims(self):
        ''' Invalid pdims. '''
        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*pdims.*'):
            _ = PartitionScheme(order=list(range(pe.NUM)),
                                pdims=[(1, 1)] * (pe.NUM - 1))

        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*pdims.*'):
            _ = PartitionScheme(order=list(range(pe.NUM)),
                                pdims=[(1, 1), (1, 1), (2, 1, 1), (1, 1)])

    def test_dim(self):
        ''' Get dim. '''
        self.assertEqual(self.ps1.dim(0), PhyDim2(2, 3))
        self.assertEqual(self.ps1.dim(1), PhyDim2(3, 1))
        self.assertEqual(self.ps1.dim(2), PhyDim2(1, 5))
        self.assertEqual(self.ps1.dim(3), PhyDim2(5, 2))

        self.assertEqual(self.ps2.dim(0), PhyDim2(2, 2))
        self.assertEqual(self.ps2.dim(1), PhyDim2(5, 5))
        self.assertEqual(self.ps2.dim(2), PhyDim2(3, 3))
        self.assertEqual(self.ps2.dim(3), PhyDim2(1, 1))

        self.assertEqual(self.ps1.dim(0, 1, 2),
                         PhyDim2(2, 3) * PhyDim2(3, 1) * PhyDim2(1, 5))
        self.assertEqual(
            self.ps1.dim(),
            PhyDim2(2, 3) * PhyDim2(3, 1) * PhyDim2(1, 5) * PhyDim2(5, 2))

        self.assertEqual(self.ps1.dim(0, 1, 2), self.ps1.dim(1, 2, 0))

    def test_dim_invalid_index(self):
        ''' Get dim invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.ps1.dim(pe.NUM + 1)

        with self.assertRaises(IndexError):
            _ = self.ps1.dim(0, 1, pe.NUM)

    def test_size(self):
        ''' Get size. '''
        for l in range(1, pe.NUM):
            for args in itertools.combinations(range(pe.NUM), l):
                self.assertEqual(
                    self.ps1.dim(*args).size(), self.ps1.size(*args))

    def test_size_invalid_index(self):
        ''' Get size invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.ps1.size(pe.NUM + 1)

        with self.assertRaises(IndexError):
            _ = self.ps1.size(0, 1, pe.NUM)

    def test_gen_pidx(self):
        ''' Generate pidx. '''
        for ps in [self.ps1, self.ps2]:

            pidx_list = list(ps.gen_pidx())

            # Num. of pidx == size.
            self.assertEqual(len(pidx_list), ps.size())
            self.assertEqual(len(set(pidx_list)), ps.size())

            for i, idx_list in enumerate(zip(*pidx_list)):
                cnt = collections.Counter(idx_list)
                # Num. of different pidx == size.
                self.assertEqual(len(cnt), ps.size(i))
                # Num. of repeated pidx == other sizes.
                for c in cnt.values():
                    self.assertEqual(c, ps.size() // ps.size(i))

    def test_coordinate(self):
        ''' Get coordinate. '''
        for ps, nr in zip([self.ps1, self.ps2], [self.nr1, self.nr2]):

            coord_list = [ps.coordinate(nr, pidx) for pidx in ps.gen_pidx()]

            self.assertEqual(len(coord_list), ps.size())
            self.assertEqual(len(set(coord_list)), ps.size())

            for coord in coord_list:
                self.assertGreaterEqual(coord.h, 0)
                self.assertGreaterEqual(coord.w, 0)
                self.assertLess(coord.h, ps.dim().h)
                self.assertLess(coord.w, ps.dim().w)

        pidx = [PhyDim2(0, 0)] * pe.NUM
        pidx[pe.OUTP] = PhyDim2(1, 1)

        self.assertEqual(self.ps1.coordinate(self.nr1, pidx),
                         self.ps1.dim(pe.OFMP, pe.INPP) * PhyDim2(1, 1))

        self.assertEqual(
            self.ps2.coordinate(self.nr2, pidx),
            self.ps2.dim(pe.OFMP, pe.BATP, pe.INPP) * PhyDim2(1, 1))

    def test_fmap_range(self):
        ''' Get fmap_range. '''
        fr1 = FmapRange(FmapPosition(b=0, n=0, h=0, w=0),
                        FmapPosition(b=8, n=64, h=28, w=28))
        # Small ranges.
        fr2 = FmapRange(FmapPosition(b=0, n=0, h=0, w=0),
                        FmapPosition(b=1, n=1, h=1, w=1))
        # Irregular values.
        fr3 = FmapRange(FmapPosition(b=2, n=4, h=2, w=6),
                        FmapPosition(b=5, n=11, h=13, w=13))

        ps = self.ps2

        # No overlap.
        for fr in [fr1, fr2, fr3]:
            pfr_list = [ps.fmap_range(fr, pidx) for pidx in ps.gen_pidx()]
            for idx, pfr in enumerate(pfr_list):
                for jdx in range(idx):
                    self.assertEqual(pfr_list[jdx].overlap_size(pfr), 0)

        pidx = (PhyDim2(1, 0), PhyDim2(4, 3), PhyDim2(0, 2), PhyDim2(0, 0))

        self.assertEqual(
            ps.fmap_range(fr1, pidx),
            FmapRange(FmapPosition(b=1, n=32, h=22, w=16),
                      FmapPosition(b=2, n=48, h=28, w=22)))
        self.assertEqual(
            ps.fmap_range(fr2, pidx),
            FmapRange(FmapPosition(b=0, n=0, h=0, w=0),
                      FmapPosition(b=0, n=0, h=1, w=0)))
        self.assertEqual(
            ps.fmap_range(fr3, pidx),
            FmapRange(FmapPosition(b=2, n=7, h=10, w=10),
                      FmapPosition(b=3, n=9, h=13, w=11)))

    def test_is_appl2frng(self):
        ''' Get is_applicable_to_fmap_range. '''
        self.assertFalse(self.ps1.is_applicable_to_fmap_range())
        self.assertTrue(self.ps2.is_applicable_to_fmap_range())

    def test_part_layer(self):
        ''' Get part_layer. '''
        batch_size = 16

        layer = ConvLayer(32, 128, 28, 3)
        p_layer, p_batch_size, p_occ = self.ps1.part_layer(layer, batch_size)
        self.assertGreaterEqual(p_layer.hofm * self.ps1.dim(pe.OFMP).h,
                                layer.hofm, 'part_layer: Conv: hofm')
        self.assertGreaterEqual(p_layer.wofm * self.ps1.dim(pe.OFMP).w,
                                layer.wofm, 'part_layer: Conv: wofm')
        self.assertGreaterEqual(p_layer.nofm * self.ps1.size(pe.OUTP),
                                layer.nofm, 'part_layer: Conv: nofm')
        self.assertGreaterEqual(p_layer.nifm * self.ps1.size(pe.INPP),
                                layer.nifm, 'part_layer: Conv: nifm')
        self.assertGreaterEqual(p_batch_size * self.ps1.size(pe.BATP), 16,
                                'part_layer: Conv: batch_size')
        self.assertAlmostEqual(
            p_occ, 1. * (32 * 128 * 28 * 28 * 16) /
            (4 * 22 * 10 * 28 * 4 * self.ps1.size()))

        layer = PoolingLayer(128, 112, 2)
        p_layer, p_batch_size, p_occ = self.ps2.part_layer(layer, batch_size)
        self.assertGreaterEqual(p_layer.hofm * self.ps2.dim(pe.OFMP).h,
                                layer.hofm, 'part_layer: Pooling: hofm')
        self.assertGreaterEqual(p_layer.wofm * self.ps2.dim(pe.OFMP).w,
                                layer.wofm, 'part_layer: Pooling: wofm')
        self.assertGreaterEqual(p_layer.nofm * self.ps2.size(pe.OUTP),
                                layer.nofm, 'part_layer: Pooling: nofm')
        self.assertGreaterEqual(p_layer.nifm, p_layer.nofm,
                                'part_layer: Pooling: nifm')
        self.assertGreaterEqual(p_batch_size * self.ps2.size(pe.BATP), 16,
                                'part_layer: Pooling: batch_size')
        self.assertAlmostEqual(
            p_occ,
            1. * (128 * 112 * 112 * 16) / (32 * 23 * 23 * 2 * self.ps2.size()))

    def test_part_layer_invalid_inpart(self):
        ''' Get part_layer invalid INPP. '''
        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*input.*'):
            _ = self.ps1.part_layer(
                PoolingLayer(self.ps1.size(pe.OUTP), self.ps1.size(pe.OFMP),
                             2), self.ps1.size(pe.BATP))

    def test_part_layer_invalid_type(self):
        ''' Get part_layer invalid type. '''
        class _Layer(Layer):
            def input_layer(self):
                return self

            def ops_per_neuron(self):
                return 0

            @staticmethod
            def data_loops():
                return None

        layer = _Layer(self.ps1.size(pe.OUTP), self.ps1.size(pe.OFMP))
        self.assertEqual(layer.total_ops(), 0)
        self.assertIsNone(_Layer.data_loops())

        with self.assertRaisesRegex(TypeError, 'PartitionScheme: .*layer.*'):
            _ = self.ps1.part_layer(layer, self.ps1.size(pe.BATP))

    def test_part_neighbor_dist(self):
        ''' Get part_neighbor_dist. '''
        for ps, nr in zip([self.ps1, self.ps2], [self.nr1, self.nr2]):

            for idx in range(pe.NUM):
                nbr_dist = ps.part_neighbor_dist(nr, ps.order[idx])
                dim_below = ps.dim(*ps.order[idx + 1:]) if idx + 1 < pe.NUM \
                        else PhyDim2(1, 1)
                dim_cur = ps.dim(ps.order[idx])

                if dim_cur.h == 1:
                    self.assertTrue(math.isinf(nbr_dist.h))
                else:
                    self.assertEqual(nbr_dist.h, dim_below.h)

                if dim_cur.w == 1:
                    self.assertTrue(math.isinf(nbr_dist.w))
                else:
                    self.assertEqual(nbr_dist.w, dim_below.w)

    def test_part_neighbor_dist_inv(self):
        ''' Get part_neighbor_dist invalid arg. '''
        dist = self.ps1.part_neighbor_dist(self.nr1, pe.NUM)
        self.assertTrue(all(math.isnan(d) for d in dist))

    def test_projection(self):
        ''' Get projection. '''
        def _make_region(dim):
            return NodeRegion(origin=PhyDim2(0, 0),
                              dim=PhyDim2(*dim),
                              type=NodeRegion.DRAM)

        # Shrink.
        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 3), (1, 5), (4, 4), (1, 1)))
        proj_part = part.projection(_make_region((4, 30)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Shrink multiple.
        proj_part = part.projection(_make_region((2, 2)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 1))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 2))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (1, 1))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Shrink non-dividable.
        proj_part = part.projection(_make_region((3, 54)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (2, 45))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        # For height, 3 // 2 = 1.
        # For width, 54 // 5 = 10, 10 // 3 = 3.
        self.assertTupleEqual(proj_part.dim(pe.BATP), (1, 3))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Shrink with INPP.
        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 3), (1, 5), (4, 4), (4, 4)))
        proj_part = part.projection(_make_region((4, 30)), appl2frng=True)
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))
        proj_part = part.projection(_make_region((4, 30)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (1, 1))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (2, 2))

        # Shrink all.
        proj_part = part.projection(_make_region((1, 1)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (1, 1))

        # Extend.
        part = PartitionScheme(order=(pe.INPP, pe.BATP, pe.OUTP, pe.OFMP),
                               pdims=((2, 3), (1, 5), (1, 1), (1, 1)))
        proj_part = part.projection(_make_region((4, 30)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Extend non-dividable.
        proj_part = part.projection(_make_region((5, 40)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        # For height, 5 // 2 = 2.
        # For width, 40 // (3 * 5) == 2.
        self.assertTupleEqual(proj_part.dim(pe.BATP), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Extend with INPP.
        part = PartitionScheme(order=(pe.INPP, pe.BATP, pe.OUTP, pe.OFMP),
                               pdims=((2, 3), (1, 5), (1, 1), (4, 4)))
        proj_part = part.projection(_make_region((4, 30)), appl2frng=True)
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (4, 30))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (2, 2))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

        # Both shrink and extend.
        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 3), (1, 5), (4, 4), (1, 1)))
        proj_part = part.projection(_make_region((16, 16)))
        self.assertTupleEqual(proj_part.order, part.order)
        self.assertTupleEqual(proj_part.dim(), (16, 15))
        self.assertTupleEqual(proj_part.dim(pe.OUTP), (2, 3))
        self.assertTupleEqual(proj_part.dim(pe.OFMP), (1, 5))
        self.assertTupleEqual(proj_part.dim(pe.BATP), (8, 1))
        self.assertTupleEqual(proj_part.dim(pe.INPP), (1, 1))

    def test_projection_empty_region(self):
        ''' Get projection with empty region. '''
        with self.assertRaisesRegex(ValueError, 'PartitionScheme: .*region.*'):
            _ = self.ps1.projection(
                NodeRegion(origin=PhyDim2(0, 0),
                           dim=PhyDim2(0, 0),
                           type=NodeRegion.DRAM))

    def test_repr(self):
        ''' __repr__. '''
        # pylint: disable=eval-used
        self.assertEqual(eval(repr(self.ps1)), self.ps1)
        self.assertEqual(eval(repr(self.ps2)), self.ps2)
Exemplo n.º 10
0
class TestBufShrScheme(unittest.TestCase):
    ''' Tests for BufShrScheme. '''
    def setUp(self):
        self.ps1 = PartitionScheme(order=[pe.BATP, pe.OUTP, pe.OFMP, pe.INPP],
                                   pdims=[(2, 3), (3, 1), (1, 5), (5, 2)])
        self.ps2 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(2, 2), (5, 5), (3, 3), (1, 1)])
        self.ps3 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(1, 6), (1, 2), (4, 1), (3, 5)])

        self.nr1 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps1.dim(),
                              type=NodeRegion.PROC)
        self.nr2 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps2.dim(),
                              type=NodeRegion.PROC)
        self.nr3 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps3.dim(),
                              type=NodeRegion.PROC)

        self.bufshr1 = BufShrScheme(self.nr1, self.ps1)
        self.bufshr2 = BufShrScheme(self.nr2, self.ps2)
        self.bufshr3 = BufShrScheme(self.nr3, self.ps3)

    def test_dim(self):
        ''' Accessor dim. '''
        for bufshr, ps in zip([self.bufshr1, self.bufshr2, self.bufshr3],
                              [self.ps1, self.ps2, self.ps3]):
            self.assertTupleEqual(bufshr.dim(de.IFM), ps.dim(pe.OUTP))
            self.assertTupleEqual(bufshr.dim(de.OFM), ps.dim(pe.INPP))

        self.assertTupleEqual(self.bufshr1.dim(de.FIL), self.ps1.dim(pe.OFMP))
        self.assertTupleEqual(self.bufshr2.dim(de.FIL),
                              self.ps2.dim(pe.OFMP, pe.BATP))
        self.assertTupleEqual(self.bufshr3.dim(de.FIL),
                              self.ps3.dim(pe.OFMP, pe.BATP))

    def test_size(self):
        ''' Get size. '''
        for bufshr in [self.bufshr1, self.bufshr2, self.bufshr3]:
            for dce in range(de.NUM):
                self.assertEqual(bufshr.dim(dce).size(), bufshr.size(dce))

    def test_dim_fil(self):
        ''' Accessor dim with different partitioning for FIL. '''
        # Adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.INPP, pe.OUTP, pe.BATP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)
        # Adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.INPP, pe.OFMP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)

        # Not adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
        # Not adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)

        # Only BATP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (1, 1), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)
        # Only OFMP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (1, 1), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)

    def test_dim_invalid_index(self):
        ''' Accessor dim invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.bufshr1.dim(de.NUM)

    def test_size_invalid_index(self):
        ''' Get size invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.bufshr1.size(de.NUM)

    def test_nbr_dists(self):
        ''' Accessor nbr_dists. '''
        inf = float('inf')

        self.assertTupleEqual(self.bufshr1.nbr_dists[de.FIL], (5, inf))
        self.assertTupleEqual(self.bufshr1.nbr_dists[de.IFM], (15, 2))
        self.assertTupleEqual(self.bufshr1.nbr_dists[de.OFM], (1, 1))

        self.assertTupleEqual(self.bufshr2.nbr_dists[de.FIL], (1, 1))
        self.assertTupleEqual(self.bufshr2.nbr_dists[de.IFM], (15, 15))
        self.assertTupleEqual(self.bufshr2.nbr_dists[de.OFM], (inf, inf))

        self.assertTupleEqual(self.bufshr3.nbr_dists[de.FIL], (3, 5))
        self.assertTupleEqual(self.bufshr3.nbr_dists[de.IFM], (inf, 10))
        self.assertTupleEqual(self.bufshr3.nbr_dists[de.OFM], (1, 1))

    def test_default_data_loops(self):
        ''' Default data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for bufshr, nr, ps in zip([self.bufshr1, self.bufshr2, self.bufshr3],
                                  [self.nr1, self.nr2, self.nr3],
                                  [self.ps1, self.ps2, self.ps3]):

            bufshr_ = BufShrScheme(nr, ps, data_loops)

            for dce in range(de.NUM):
                self.assertTupleEqual(bufshr.dim(dce), bufshr_.dim(dce))
                self.assertTupleEqual(bufshr.nbr_dists[dce],
                                      bufshr_.nbr_dists[dce])

    def test_data_loops(self):
        ''' data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for nr, ps in zip([self.nr1, self.nr2, self.nr3],
                          [self.ps1, self.ps2, self.ps3]):

            bufshr = BufShrScheme(nr, ps, data_loops)

            self.assertTupleEqual(bufshr.dim(de.IFM), bufshr.dim(de.OFM))
            self.assertTupleEqual(bufshr.nbr_dists[de.IFM],
                                  bufshr.nbr_dists[de.OFM])

    def test_data_loops_all_lpe(self):
        ''' data_loops in constructor have all LoopEnum. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        bufshr = BufShrScheme(self.nr1, self.ps1, data_loops)

        self.assertTupleEqual(bufshr.dim(de.IFM), (1, 1))
        self.assertTrue(all(math.isinf(d) for d in bufshr.nbr_dists[de.IFM]))

    def test_mismatch_node_region(self):
        ''' Mismatched node region and part in constructor. '''
        # Smaller node region. Invalid.
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*region.*'):
            _ = BufShrScheme(
                NodeRegion(origin=PhyDim2(0, 0),
                           dim=PhyDim2(1, 1),
                           type=NodeRegion.PROC), self.ps1)

        # Larger node region. Valid.
        bufshr = BufShrScheme(
            NodeRegion(origin=PhyDim2(0, 0),
                       dim=PhyDim2(100, 100),
                       type=NodeRegion.PROC), self.ps1)
        self.assertTupleEqual(bufshr.dim(de.IFM), self.ps1.dim(pe.OUTP))

    def test_nhops_rotate_all(self):
        ''' Get nhops_rotate_all. '''
        # With `self.bufshr3` and FIL, the dimension is 4 by 2, with neighbor
        # distances 3 and 5.
        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))
        self.assertTupleEqual(bufshr.nbr_dists[dce], (3, 5))

        # Subgroup as 4 by 2. The whole circle is six hops of 3 and two hops of
        # 5, but only 7 of 8 steps.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 8),
                               (3 * 6 + 5 * 2) * 7 / 8.)
        # Subgroup as 4 by 1. One node does three hops of 3, and other three
        # nodes do two hops of 3 and one hop of 9 (looping back).
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 4),
                               ((3 * 3) + (3 * 2 + 9) * 3) / 4. * 2)
        # Subgroup as 2 by 1. All nodes do one hop of 3.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 2),
                               (3 + 3) / 2. * 4)
        # Subgroup as 1. No rotation.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 1), 0)

        # Subgroup as 4 by 1. One node does two hops of 3 and two do one hop of
        # 3 and 6 each. The 3rd node also sends to the 4th one with two hops of
        # 3.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 3),
                               ((3 * 2) + (3 + 6) * 2 + (3 * 2)) / 3. * 2)
        # Subgroup as 4 by 2. The 1st node does three hops of 3 and one hop of
        # 5. The 2nd, 3rd, and 4th nodes do two hops of 3, and one hop of 5,
        # and one looping back from the 5th node to the 1st node. The 5th node
        # does one looping back and three hops of 3. Finally, the 5th node also
        # sends to the 6th to 8th nodes.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 5),
                               ((3 * 3 + 5) + (3 * 2 + 5 + (3 * 3 + 5)) * 3 +
                                ((3 * 3 + 5) + 3 * 3) + 3 * 3 * 4) / 5.)
        # The others are similar.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 6),
                               ((3 * 4 + 5) + (3 * 3 + 5 + (3 * 2 + 5)) * 4 +
                                ((3 * 2 + 5) + 3 * 4) + 3 * 2 * 5) / 6.)
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 7),
                               ((3 * 5 + 5) + (3 * 4 + 5 + (3 * 1 + 5)) * 5 +
                                ((3 * 1 + 5) + 3 * 5) + 3 * 1 * 6) / 7.)

    def test_nhops_rotate_all_invalid(self):
        ''' Get nhops_rotate_all with invalid args. '''
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*subgroup.*'):
            _ = self.bufshr3.nhops_rotate_all(de.FIL,
                                              self.bufshr3.size(de.FIL) + 1)

    def test_nhops_rotate_all_rot_unit(self):
        ''' Get nhops_rotate_all with rotation unit count. '''

        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))

        for subgrp_size in range(1, bufshr.size(dce)):

            nhops = bufshr.nhops_rotate_all(dce, subgrp_size)

            for rotation_unit_cnt in range(subgrp_size, 32):
                self.assertEqual(
                    bufshr.nhops_rotate_all(dce, subgrp_size,
                                            rotation_unit_cnt), nhops)

            for rotation_unit_cnt in range(1, subgrp_size):
                self.assertLess(
                    bufshr.nhops_rotate_all(dce, subgrp_size,
                                            rotation_unit_cnt), nhops)

    def test_nhops_rotate_all_cache(self):
        ''' Get nhops_rotate_all using cache. '''

        bufshr = self.bufshr3
        dce = de.FIL

        self.assertFalse(bufshr.nhops_cache)

        nhops_8 = bufshr.nhops_rotate_all(dce, 8)
        nhops_4 = bufshr.nhops_rotate_all(dce, 4)
        nhops_1 = bufshr.nhops_rotate_all(dce, 1)
        self.assertEqual(len(bufshr.nhops_cache), 3)
        self.assertEqual(nhops_8, bufshr.nhops_rotate_all(dce, 8))
        self.assertEqual(nhops_4, bufshr.nhops_rotate_all(dce, 4))
        self.assertEqual(nhops_1, bufshr.nhops_rotate_all(dce, 1))
        self.assertEqual(len(bufshr.nhops_cache), 3)

        dce = de.IFM

        nhops_3 = bufshr.nhops_rotate_all(dce, 3)
        nhops_2 = bufshr.nhops_rotate_all(dce, 2)
        self.assertEqual(len(bufshr.nhops_cache), 5)
        self.assertEqual(nhops_3, bufshr.nhops_rotate_all(dce, 3))
        self.assertEqual(nhops_2, bufshr.nhops_rotate_all(dce, 2))
        self.assertEqual(len(bufshr.nhops_cache), 5)

        nhops_rot_unit = bufshr.nhops_rotate_all(dce, 3, 2)

        self.assertEqual(len(bufshr.nhops_cache), 6)
        self.assertEqual(nhops_rot_unit, bufshr.nhops_rotate_all(dce, 3, 2))
        self.assertEqual(len(bufshr.nhops_cache), 6)

    def test_nhops_wide_fetch_once(self):
        ''' Get nhops_wide_fetch_once. '''
        # With `self.bufshr3` and FIL, the dimension is 4 by 2, with neighbor
        # distances 3 and 5.
        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))
        self.assertTupleEqual(bufshr.nbr_dists[dce], (3, 5))

        for subgrp_size in range(bufshr.size(dce)):
            self.assertAlmostEqual(
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 1), 0)

        # Three nodes fetch one hop of 3, and the last node fetches one hop of
        # 9 (looping back).
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 4, 2) * 2, (3 * 3 + 9) / 4. * 2)
        # Two nodes fetch one hop of 3, and the 3rd node fetches one hop of 6
        # (looping back). The last node fetches one hop of 3 from the 3rd.
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 3, 2) * 2,
            (3 * 2 + 6 + 3) / 3. * 2)
        # All nodes do one hop of 3.
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 2, 2) * 2, (3 + 3) / 2. * 4)

        for subgrp_size in range(2, bufshr.size(dce)):
            self.assertAlmostEqual(
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 1.5) * 1.5,
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 2) * 2. / 2.)

    def test_nhops_wide_fetch_once_inv(self):
        ''' Get nhops_wide_fetch_once with invalid args. '''
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*subgroup.*'):
            _ = self.bufshr3.nhops_wide_fetch_once(
                de.FIL,
                self.bufshr3.size(de.FIL) + 1, 2)

        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*width.*'):
            _ = self.bufshr3.nhops_wide_fetch_once(
                de.FIL,
                self.bufshr3.size(de.FIL) / 2,
                self.bufshr3.size(de.FIL) / 2 + 1)

    def test_repr(self):
        ''' __repr__. '''
        self.assertIn(repr(self.ps1), repr(self.bufshr1))
        self.assertIn(repr(self.ps2), repr(self.bufshr2))
        self.assertIn(repr(self.ps3), repr(self.bufshr3))