def test_data_loops_all_lpe(self):
        ''' data_loops in constructor have all LoopEnum. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        bufshr = BufShrScheme(self.nr1, self.ps1, data_loops)

        self.assertTupleEqual(bufshr.dim(de.IFM), (1, 1))
        self.assertTrue(all(math.isinf(d) for d in bufshr.nbr_dists[de.IFM]))
    def test_bufshr_wide_fetch_example(self):
        ''' Example scheme using bufshr with wide fetch. '''

        # Make a PartitionScheme that allows bufshr for IFM.
        part = PartitionScheme(order=range(pe.NUM),
                               pdims=((2, 2), (1, 1), (2, 1), (1, 1)))
        bufshr = BufShrScheme(self.par_proc_region, part)
        self.assertEqual(
            bufshr.size(de.IFM), 4, 'test_bufshr_wide_fetch_example: '
            'made-up PartitionScheme is not expected: '
            '{}, bufshr size for {} {}.'.format(part, de.IFM,
                                                bufshr.size(de.IFM)))

        for t1, t2 in [((3, 3, 1), (1, 1, 2)), ((1, 3, 2), (3, 1, 1))]:
            # Make a LoopBlockingScheme that has wide fetch for IFM.
            p_nld = self._part_nld(part)
            bl_ts = (tuple(
                util.idivc(p_nld.loopcnt[lpe], t1[lpe] * t2[lpe])
                for lpe in range(le.NUM)), t1, t2)
            # At GBUF level, from inner to outer: le.BAT, le.IFM, le.OFM.
            bl_ords = (tuple(range(le.NUM)), (1, 2, 0))
            lbs = LoopBlockingScheme(p_nld, bl_ts, bl_ords,
                                     self.resource['PAR'], bufshr,
                                     self.options['BUFSHR'])
            self.assertTrue(lbs.is_valid())
            self.assertGreater(sum(lbs.get_noc_access()), 0)
            self.assertEqual(
                lbs.bufshr_subgrp_size[de.IFM], 4,
                'test_bufshr_wide_fetch_example: '
                'made-up LoopBlockingScheme is not expected: '
                '{}, bufshr subgrp size for {} {}.'.format(
                    (bl_ts, bl_ords), de.IFM, lbs.bufshr_subgrp_size[de.IFM]))
            self.assertGreater(
                lbs.bufshr_wide_fetch_width[de.IFM], 1,
                'test_bufshr_wide_fetch_example: '
                'made-up LoopBlockingScheme is not expected: '
                '{}, bufshr wide fetch width for {} {}.'.format(
                    (bl_ts, bl_ords), de.IFM,
                    lbs.bufshr_wide_fetch_width[de.IFM]))
            self.assertGreater(
                lbs.bufshr_rot_round_cnt[de.IFM], 0,
                'test_bufshr_wide_fetch_example: '
                'made-up LoopBlockingScheme is not expected: '
                '{}, bufshr rotation rounds for {} {}'.format(
                    (bl_ts, bl_ords), de.IFM,
                    lbs.bufshr_rot_round_cnt[de.IFM]))

            # Sim.
            dram_access, gbuf_access, bufshr_stats = \
                    self._sim_access_conv(lbs, get_bufshr=True)

            self._verify_bufshr_stats(dram_access, gbuf_access, bufshr_stats,
                                      lbs, bufshr,
                                      'test_bufshr_wide_fetch_example')
    def test_mismatch_node_region(self):
        ''' Mismatched node region and part in constructor. '''
        # Smaller node region. Invalid.
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*region.*'):
            _ = BufShrScheme(
                NodeRegion(origin=PhyDim2(0, 0),
                           dim=PhyDim2(1, 1),
                           type=NodeRegion.PROC), self.ps1)

        # Larger node region. Valid.
        bufshr = BufShrScheme(
            NodeRegion(origin=PhyDim2(0, 0),
                       dim=PhyDim2(100, 100),
                       type=NodeRegion.PROC), self.ps1)
        self.assertTupleEqual(bufshr.dim(de.IFM), self.ps1.dim(pe.OUTP))
    def test_data_loops(self):
        ''' data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for nr, ps in zip([self.nr1, self.nr2, self.nr3],
                          [self.ps1, self.ps2, self.ps3]):

            bufshr = BufShrScheme(nr, ps, data_loops)

            self.assertTupleEqual(bufshr.dim(de.IFM), bufshr.dim(de.OFM))
            self.assertTupleEqual(bufshr.nbr_dists[de.IFM],
                                  bufshr.nbr_dists[de.OFM])
    def test_bufshr_access_byp(self):
        ''' Access of scheme using bufshr with bypassing. '''

        for part in self._gen_all_partition():

            p_nld = self._part_nld(part)

            bufshr = BufShrScheme(self.par_proc_region, part)

            for lbs in loop_blocking.gen_loopblocking(
                    p_nld, self.resource['PAR'], part, self.none_cstr,
                    self.cost, self.options['BUFSHR-BYP']):
                if not lbs.is_valid():
                    continue

                # Skip those without bufshr.
                if all(sgs <= 1 for sgs in lbs.bufshr_subgrp_size):
                    continue
                # Skip those without bypassing.
                if all(lbs.stored_in_gbuf):
                    continue

                # Sim.
                dram_access, gbuf_access, bufshr_stats = \
                        self._sim_access_conv(lbs, get_bufshr=True)

                self._verify_bufshr_stats(dram_access, gbuf_access,
                                          bufshr_stats, lbs, bufshr,
                                          'test_bufshr_access')
    def test_default_data_loops(self):
        ''' Default data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for bufshr, nr, ps in zip([self.bufshr1, self.bufshr2, self.bufshr3],
                                  [self.nr1, self.nr2, self.nr3],
                                  [self.ps1, self.ps2, self.ps3]):

            bufshr_ = BufShrScheme(nr, ps, data_loops)

            for dce in range(de.NUM):
                self.assertTupleEqual(bufshr.dim(dce), bufshr_.dim(dce))
                self.assertTupleEqual(bufshr.nbr_dists[dce],
                                      bufshr_.nbr_dists[dce])
    def test_bufshr_multisubgrp_example(self):
        ''' Example scheme using bufshr with multiple subgroups in a group. '''

        # Make a PartitionScheme that allows bufshr for IFM.
        part = PartitionScheme(order=list(reversed(range(pe.NUM))),
                               pdims=((2, 2), (1, 1), (2, 1), (1, 1)))
        bufshr = BufShrScheme(self.par_proc_region, part)
        self.assertEqual(
            bufshr.size(de.IFM), 4, 'test_bufshr_multisubgrp_example: '
            'made-up PartitionScheme is not expected: '
            '{}, bufshr size for {} {}.'.format(part, de.IFM,
                                                bufshr.size(de.IFM)))

        # Make a LoopBlockingScheme that has multi subgroups per group for IFM.
        p_nld = self._part_nld(part)
        bl_ts = ((util.idivc(p_nld.loopcnt[le.IFM],
                             1), util.idivc(p_nld.loopcnt[le.OFM], 3),
                  util.idivc(p_nld.loopcnt[le.BAT], 2)), (1, 3, 2), (1, 1, 1))
        # At GBUF level, from inner to outer: le.BAT, le.OFM, le.IFM.
        bl_ords = (tuple(range(le.NUM)), (2, 1, 0))
        lbs = LoopBlockingScheme(p_nld, bl_ts, bl_ords, self.resource['PAR'],
                                 bufshr, self.options['BUFSHR'])
        self.assertTrue(lbs.is_valid())
        self.assertGreater(sum(lbs.get_noc_access()), 0)
        self.assertGreater(
            lbs.bufshr_grp_size[de.IFM], lbs.bufshr_subgrp_size[de.IFM],
            'test_bufshr_multisubgrp_example: '
            'made-up LoopBlockingScheme is not expected: '
            '{}, bufshr grp size {}, bufshr subgrp size {}'.format(
                (bl_ts, bl_ords), lbs.bufshr_grp_size, lbs.bufshr_subgrp_size))
        self.assertGreater(
            lbs.bufshr_rot_round_cnt[de.IFM], 0,
            'test_bufshr_multisubgrp_example: '
            'made-up LoopBlockingScheme is not expected: '
            '{}, bufshr rotation rounds for {} {}'.format(
                (bl_ts, bl_ords), de.IFM, lbs.bufshr_rot_round_cnt[de.IFM]))

        # Sim.
        dram_access, gbuf_access, bufshr_stats = \
                self._sim_access_conv(lbs, get_bufshr=True)

        self._verify_bufshr_stats(dram_access, gbuf_access, bufshr_stats, lbs,
                                  bufshr, 'test_bufshr_multisubgrp_example')
    def test_accfwd(self):
        ''' Scheme using accfwd. '''

        for part in self._gen_all_partition():

            p_nld = self._part_nld(part)

            filter_size, ifmap_size, ofmap_size = self._total_part_size(part)

            bufshr = BufShrScheme(self.par_proc_region, part)

            # Filter may still have redundant fetch.
            fil_fetch = part.size(pe.BATP, pe.OFMP) // bufshr.size(de.FIL)

            for lbs in loop_blocking.gen_loopblocking(p_nld,
                                                      self.resource['PAR'],
                                                      part, self.none_cstr,
                                                      self.cost,
                                                      self.options['ACCFWD']):
                if not lbs.is_valid():
                    continue

                # Ops.
                self.assertAlmostEqual(lbs.ops, self.total_ops)

                # Access forwarding reduction.
                accfwd_red = lbs.accfwd_reduction
                self.assertEqual(accfwd_red[de.FIL],
                                 part.size(pe.BATP, pe.OFMP) // fil_fetch)
                self.assertEqual(accfwd_red[de.OFM], part.size(pe.INPP))
                self.assertEqual(accfwd_red[de.IFM], part.size(pe.OUTP))

                # Top fetch and access.
                top_fetch = lbs.fetch[0]
                top_access = lbs.access[0]
                self.assertAlmostEqual(
                    top_access[de.FIL],
                    top_fetch[de.FIL] * filter_size * fil_fetch)
                self.assertAlmostEqual(top_access[de.OFM],
                                       top_fetch[de.OFM] * ofmap_size)
                self.assertGreaterEqual(top_access[de.IFM],
                                        top_fetch[de.IFM] * ifmap_size)
    def test_bufshr(self):
        ''' Scheme using bufshr. '''

        for part in self._gen_all_partition():

            p_nld = self._part_nld(part)

            bufshr = BufShrScheme(self.par_proc_region, part)

            # Filter may still have redundant fetch.
            fil_fetch = part.size(pe.BATP, pe.OFMP) // bufshr.size(de.FIL)

            for optkey in ['BUFSHR', 'BUFSHR-BYP']:

                for lbs in loop_blocking.gen_loopblocking(
                        p_nld, self.resource['PAR'], part, self.none_cstr,
                        self.cost, self.options[optkey]):
                    if not lbs.is_valid():
                        continue

                    # Ops.
                    self.assertAlmostEqual(lbs.ops, self.total_ops)

                    # Buffer sharing uses access forwarding reduction.
                    accfwd_red = lbs.accfwd_reduction
                    self.assertEqual(accfwd_red[de.FIL],
                                     part.size(pe.BATP, pe.OFMP) // fil_fetch)
                    self.assertEqual(accfwd_red[de.OFM], part.size(pe.INPP))
                    self.assertEqual(accfwd_red[de.IFM], part.size(pe.OUTP))

                    # Buffer sharing group size.
                    bufshr_grp_size = lbs.bufshr_grp_size
                    self.assertSequenceEqual(bufshr_grp_size, accfwd_red)

                    # Buffer sharing subgroup size.
                    bufshr_subgrp_size = lbs.bufshr_subgrp_size
                    self.assertTrue(
                        all(subgrp <= grp for subgrp, grp in zip(
                            bufshr_subgrp_size, bufshr_grp_size)))
    def test_bufshr_rotation_example(self):
        ''' Example scheme using bufshr with rotation. '''

        # Make a PartitionScheme that allows bufshr for all data categories.
        part = PartitionScheme(order=range(pe.NUM),
                               pdims=((2, 1), (1, 2), (1, 1), (2, 1)))
        bufshr = BufShrScheme(self.par_proc_region, part)
        self.assertTrue(
            all(bufshr.size(dce) > 1 for dce in range(de.NUM)),
            'test_bufshr_rotation_example: '
            'made-up PartitionScheme is not expected: '
            '{}, bufshr size {}'.format(
                part, [bufshr.size(dce) for dce in range(de.NUM)]))

        # Make a LoopBlockingScheme that uses bufshr for all data categories.
        p_nld = self._part_nld(part)
        bl_ts = ((util.idivc(p_nld.loopcnt[le.IFM],
                             6), util.idivc(p_nld.loopcnt[le.OFM], 9),
                  util.idivc(p_nld.loopcnt[le.BAT], 2)), (3, 3, 2), (2, 3, 1))
        bl_ords = (tuple(range(le.NUM)), tuple(range(le.NUM)))
        lbs = LoopBlockingScheme(p_nld, bl_ts, bl_ords, self.resource['PAR'],
                                 bufshr, self.options['BUFSHR'])
        self.assertTrue(lbs.is_valid())
        self.assertGreater(sum(lbs.get_noc_access()), 0)
        self.assertTrue(
            all(sgs > 1 for sgs in lbs.bufshr_subgrp_size)
            and all(t > 1 for t in bl_ts[0]), 'test_bufshr_rotation_example: '
            'made-up LoopBlockingScheme is not expected: '
            '{}, top factors {}, bufshr subgrp size {}'.format(
                (bl_ts, bl_ords), bl_ts[0], lbs.bufshr_subgrp_size))

        # Sim.
        dram_access, gbuf_access, bufshr_stats = \
                self._sim_access_conv(lbs, get_bufshr=True)

        self._verify_bufshr_stats(dram_access, gbuf_access, bufshr_stats, lbs,
                                  bufshr, 'test_bufshr_rotation_example')
    def test_dim_fil(self):
        ''' Accessor dim with different partitioning for FIL. '''
        # Adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.INPP, pe.OUTP, pe.BATP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)
        # Adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.INPP, pe.OFMP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)

        # Not adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
        # Not adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)

        # Only BATP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (1, 1), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)
        # Only OFMP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (1, 1), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
Beispiel #12
0
    def setUp(self):
        self.ps1 = PartitionScheme(order=[pe.BATP, pe.OUTP, pe.OFMP, pe.INPP],
                                   pdims=[(2, 3), (3, 1), (1, 5), (5, 2)])
        self.ps2 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(2, 2), (5, 5), (3, 3), (1, 1)])
        self.ps3 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(1, 6), (1, 2), (4, 1), (3, 5)])

        self.nr1 = NodeRegion(origin=PhyDim2(0, 0), dim=self.ps1.dim(),
                              type=NodeRegion.PROC)
        self.nr2 = NodeRegion(origin=PhyDim2(0, 0), dim=self.ps2.dim(),
                              type=NodeRegion.PROC)
        self.nr3 = NodeRegion(origin=PhyDim2(0, 0), dim=self.ps3.dim(),
                              type=NodeRegion.PROC)

        self.bufshr1 = BufShrScheme(self.nr1, self.ps1)
        self.bufshr2 = BufShrScheme(self.nr2, self.ps2)
        self.bufshr3 = BufShrScheme(self.nr3, self.ps3)
    def setUp(self):

        # Workload.
        self.layer = {}
        self.layer['BASE'] = ConvLayer(12, 10, 28, 3)
        self.layer['LGFIL'] = ConvLayer(2, 4, 28, 20)
        self.layer['POOL'] = PoolingLayer(32, 28, 2)
        self.layer['PAR'] = ConvLayer(24, 36, 56, 3)
        self.batch_size = 4

        # Resource.
        self.resource = {}
        dim_array = PhyDim2(16, 16)
        proc_region = NodeRegion(origin=PhyDim2(0, 0),
                                 dim=PhyDim2(1, 1),
                                 type=NodeRegion.PROC)
        data_region = NodeRegion(origin=PhyDim2(0, 0),
                                 dim=PhyDim2(1, 1),
                                 type=NodeRegion.DRAM)
        # Typical resource.
        self.resource['BASE'] = Resource(proc_region=proc_region,
                                         dram_region=data_region,
                                         src_data_region=data_region,
                                         dst_data_region=data_region,
                                         dim_array=dim_array,
                                         size_gbuf=65536,
                                         size_regf=64,
                                         array_bus_width=float('inf'),
                                         dram_bandwidth=float('inf'),
                                         no_time_mux=False)
        # Larger resource with sufficient capacity, to make all schemes valid.
        self.resource['LG'] = Resource(proc_region=proc_region,
                                       dram_region=data_region,
                                       src_data_region=data_region,
                                       dst_data_region=data_region,
                                       dim_array=dim_array,
                                       size_gbuf=1024**3,
                                       size_regf=1024**3,
                                       array_bus_width=float('inf'),
                                       dram_bandwidth=float('inf'),
                                       no_time_mux=False)
        # Small resource.
        self.resource['SM'] = Resource(proc_region=proc_region,
                                       dram_region=data_region,
                                       src_data_region=data_region,
                                       dst_data_region=data_region,
                                       dim_array=dim_array,
                                       size_gbuf=4096,
                                       size_regf=16,
                                       array_bus_width=float('inf'),
                                       dram_bandwidth=float('inf'),
                                       no_time_mux=False)
        # Multi-node parallel resource.
        self.resource['PAR'] = Resource(proc_region=NodeRegion(
            origin=PhyDim2(0, 0), dim=PhyDim2(4, 2), type=NodeRegion.PROC),
                                        dram_region=data_region,
                                        src_data_region=data_region,
                                        dst_data_region=data_region,
                                        dim_array=dim_array,
                                        size_gbuf=25000,
                                        size_regf=64,
                                        array_bus_width=float('inf'),
                                        dram_bandwidth=float('inf'),
                                        no_time_mux=False)
        # Resource with no data regions.
        proc_data_region = NodeRegion(origin=PhyDim2(1, 1),
                                      dim=PhyDim2(1, 1),
                                      type=NodeRegion.PROC)
        self.resource['SRCNOTDATA'] = Resource(
            proc_region=proc_region,
            dram_region=data_region,
            src_data_region=proc_data_region,
            dst_data_region=data_region,
            dim_array=dim_array,
            size_gbuf=1024**3,
            size_regf=1024**3,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)
        self.resource['DSTNOTDATA'] = Resource(
            proc_region=proc_region,
            dram_region=data_region,
            src_data_region=data_region,
            dst_data_region=proc_data_region,
            dim_array=dim_array,
            size_gbuf=1024**3,
            size_regf=1024**3,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)
        self.resource['DATALOCAL'] = Resource(proc_region=proc_region,
                                              dram_region=data_region,
                                              src_data_region=proc_region,
                                              dst_data_region=proc_region,
                                              dim_array=dim_array,
                                              size_gbuf=1024**3,
                                              size_regf=1024**3,
                                              array_bus_width=float('inf'),
                                              dram_bandwidth=float('inf'),
                                              no_time_mux=False)
        # Filter pinning.
        self.resource['FILPIN'] = Resource(proc_region=proc_region,
                                           dram_region=data_region,
                                           src_data_region=data_region,
                                           dst_data_region=data_region,
                                           dim_array=dim_array,
                                           size_gbuf=1024**3,
                                           size_regf=1024**3,
                                           array_bus_width=float('inf'),
                                           dram_bandwidth=float('inf'),
                                           no_time_mux=True)

        # Nested loop description after mapping.
        self.nld = {}
        self.nld['BASE'] = next(
            MapStrategyEyeriss(self.layer['BASE'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        self.nld['LGFIL'] = next(
            MapStrategyEyeriss(self.layer['LGFIL'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        self.nld['POOL'] = next(
            MapStrategyEyeriss(self.layer['POOL'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        # Fake nested loop, with zero filter size.
        self.nld['ZERO_FIL'] = NestedLoopDesc(
            loopcnt=(12, 10, 4),
            usize_gbuf=(0, 1000, 800),
            usize_regf=(0, 3, 1),
            unit_access=((0, 1000, 800), (0, 1000, 800), (3, 9, 7), (1, 1, 1)),
            data_loops=(DataDimLoops(le.IFM,
                                     le.OFM), DataDimLoops(le.IFM, le.BAT),
                        DataDimLoops(le.OFM, le.BAT)),
            unit_ops=1,
            unit_time=1)
        # Fake nested loop, with zero ifmap size.
        self.nld['ZERO_IFM'] = NestedLoopDesc(
            loopcnt=(12, 10, 4),
            usize_gbuf=(9, 0, 800),
            usize_regf=(3, 0, 1),
            unit_access=((9, 0, 800), (9, 0, 800), (3, 9, 7), (1, 1, 1)),
            data_loops=(DataDimLoops(le.IFM,
                                     le.OFM), DataDimLoops(le.IFM, le.BAT),
                        DataDimLoops(le.OFM, le.BAT)),
            unit_ops=1,
            unit_time=1)

        # Fake partition scheme.
        self.part = PartitionScheme(range(pe.NUM), ((1, 1), ) * pe.NUM)

        # Fake buffer sharing scheme.
        self.bufshr = BufShrScheme(proc_region, self.part)

        # Options.
        self.options = {}
        # Basic.
        self.options['BASE'] = Option(ntops=2**30)
        # Multiprocessing.
        self.options['MP'] = Option(ntops=2**30, nprocesses=8)
        # Limited top schemes.
        self.options['NTOPS'] = Option(ntops=10)
        # Bypass.
        self.options['BYP'] = Option(sw_gbuf_bypass=(True, ) * 3, ntops=2**30)
        # Bypass solver.
        self.options['BYPSOL'] = Option(sw_gbuf_bypass=(True, ) * 3,
                                        sw_solve_loopblocking=True,
                                        ntops=2**30)
        # Access forwarding.
        self.options['ACCFWD'] = Option(hw_access_forwarding=True, ntops=2**30)
        # Buffer sharing.
        self.options['BUFSHR'] = Option(hw_gbuf_sharing=True, ntops=2**30)
        # Buffer sharing with bypassing.
        self.options['BUFSHR-BYP'] = Option(sw_gbuf_bypass=(True, ) * 3,
                                            hw_gbuf_sharing=True,
                                            ntops=2**30)

        # Constraint.
        self.none_cstr = SchedulingConstraint()
        self.cstr = SchedulingConstraint(topifm=1, topbat=1)

        # Cost.
        self.cost = Cost(mac_op=1,
                         mem_hier=(200, 6, 2, 1),
                         noc_hop=50,
                         idl_unit=50)
class TestBufShrScheme(unittest.TestCase):
    ''' Tests for BufShrScheme. '''
    def setUp(self):
        self.ps1 = PartitionScheme(order=[pe.BATP, pe.OUTP, pe.OFMP, pe.INPP],
                                   pdims=[(2, 3), (3, 1), (1, 5), (5, 2)])
        self.ps2 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(2, 2), (5, 5), (3, 3), (1, 1)])
        self.ps3 = PartitionScheme(order=range(pe.NUM),
                                   pdims=[(1, 6), (1, 2), (4, 1), (3, 5)])

        self.nr1 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps1.dim(),
                              type=NodeRegion.PROC)
        self.nr2 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps2.dim(),
                              type=NodeRegion.PROC)
        self.nr3 = NodeRegion(origin=PhyDim2(0, 0),
                              dim=self.ps3.dim(),
                              type=NodeRegion.PROC)

        self.bufshr1 = BufShrScheme(self.nr1, self.ps1)
        self.bufshr2 = BufShrScheme(self.nr2, self.ps2)
        self.bufshr3 = BufShrScheme(self.nr3, self.ps3)

    def test_dim(self):
        ''' Accessor dim. '''
        for bufshr, ps in zip([self.bufshr1, self.bufshr2, self.bufshr3],
                              [self.ps1, self.ps2, self.ps3]):
            self.assertTupleEqual(bufshr.dim(de.IFM), ps.dim(pe.OUTP))
            self.assertTupleEqual(bufshr.dim(de.OFM), ps.dim(pe.INPP))

        self.assertTupleEqual(self.bufshr1.dim(de.FIL), self.ps1.dim(pe.OFMP))
        self.assertTupleEqual(self.bufshr2.dim(de.FIL),
                              self.ps2.dim(pe.OFMP, pe.BATP))
        self.assertTupleEqual(self.bufshr3.dim(de.FIL),
                              self.ps3.dim(pe.OFMP, pe.BATP))

    def test_size(self):
        ''' Get size. '''
        for bufshr in [self.bufshr1, self.bufshr2, self.bufshr3]:
            for dce in range(de.NUM):
                self.assertEqual(bufshr.dim(dce).size(), bufshr.size(dce))

    def test_dim_fil(self):
        ''' Accessor dim with different partitioning for FIL. '''
        # Adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.INPP, pe.OUTP, pe.BATP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)
        # Adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.INPP, pe.OFMP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (15, ) * 2)

        # Not adjacent, BATP upon OFMP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)
        # Not adjacent, OFMP upon BATP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)

        # Only BATP.
        ps = PartitionScheme(order=[pe.OUTP, pe.BATP, pe.INPP, pe.OFMP],
                             pdims=[(2, 2), (1, 1), (3, 3), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (3, ) * 2)
        # Only OFMP.
        ps = PartitionScheme(order=[pe.OFMP, pe.INPP, pe.BATP, pe.OUTP],
                             pdims=[(2, 2), (5, 5), (1, 1), (7, 7)])
        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=ps.dim(),
                        type=NodeRegion.PROC)
        self.assertTupleEqual(BufShrScheme(nr, ps).dim(de.FIL), (5, ) * 2)

    def test_dim_invalid_index(self):
        ''' Accessor dim invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.bufshr1.dim(de.NUM)

    def test_size_invalid_index(self):
        ''' Get size invalid index. '''
        with self.assertRaises(IndexError):
            _ = self.bufshr1.size(de.NUM)

    def test_nbr_dists(self):
        ''' Accessor nbr_dists. '''
        inf = float('inf')

        self.assertTupleEqual(self.bufshr1.nbr_dists[de.FIL], (5, inf))
        self.assertTupleEqual(self.bufshr1.nbr_dists[de.IFM], (15, 2))
        self.assertTupleEqual(self.bufshr1.nbr_dists[de.OFM], (1, 1))

        self.assertTupleEqual(self.bufshr2.nbr_dists[de.FIL], (1, 1))
        self.assertTupleEqual(self.bufshr2.nbr_dists[de.IFM], (15, 15))
        self.assertTupleEqual(self.bufshr2.nbr_dists[de.OFM], (inf, inf))

        self.assertTupleEqual(self.bufshr3.nbr_dists[de.FIL], (3, 5))
        self.assertTupleEqual(self.bufshr3.nbr_dists[de.IFM], (inf, 10))
        self.assertTupleEqual(self.bufshr3.nbr_dists[de.OFM], (1, 1))

    def test_default_data_loops(self):
        ''' Default data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for bufshr, nr, ps in zip([self.bufshr1, self.bufshr2, self.bufshr3],
                                  [self.nr1, self.nr2, self.nr3],
                                  [self.ps1, self.ps2, self.ps3]):

            bufshr_ = BufShrScheme(nr, ps, data_loops)

            for dce in range(de.NUM):
                self.assertTupleEqual(bufshr.dim(dce), bufshr_.dim(dce))
                self.assertTupleEqual(bufshr.nbr_dists[dce],
                                      bufshr_.nbr_dists[dce])

    def test_data_loops(self):
        ''' data_loops in constructor. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        for nr, ps in zip([self.nr1, self.nr2, self.nr3],
                          [self.ps1, self.ps2, self.ps3]):

            bufshr = BufShrScheme(nr, ps, data_loops)

            self.assertTupleEqual(bufshr.dim(de.IFM), bufshr.dim(de.OFM))
            self.assertTupleEqual(bufshr.nbr_dists[de.IFM],
                                  bufshr.nbr_dists[de.OFM])

    def test_data_loops_all_lpe(self):
        ''' data_loops in constructor have all LoopEnum. '''
        data_loops = [None] * de.NUM
        data_loops[de.FIL] = DataDimLoops(le.IFM, le.OFM)
        data_loops[de.IFM] = DataDimLoops(le.IFM, le.OFM, le.BAT)
        data_loops[de.OFM] = DataDimLoops(le.OFM, le.BAT)

        bufshr = BufShrScheme(self.nr1, self.ps1, data_loops)

        self.assertTupleEqual(bufshr.dim(de.IFM), (1, 1))
        self.assertTrue(all(math.isinf(d) for d in bufshr.nbr_dists[de.IFM]))

    def test_mismatch_node_region(self):
        ''' Mismatched node region and part in constructor. '''
        # Smaller node region. Invalid.
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*region.*'):
            _ = BufShrScheme(
                NodeRegion(origin=PhyDim2(0, 0),
                           dim=PhyDim2(1, 1),
                           type=NodeRegion.PROC), self.ps1)

        # Larger node region. Valid.
        bufshr = BufShrScheme(
            NodeRegion(origin=PhyDim2(0, 0),
                       dim=PhyDim2(100, 100),
                       type=NodeRegion.PROC), self.ps1)
        self.assertTupleEqual(bufshr.dim(de.IFM), self.ps1.dim(pe.OUTP))

    def test_nhops_rotate_all(self):
        ''' Get nhops_rotate_all. '''
        # With `self.bufshr3` and FIL, the dimension is 4 by 2, with neighbor
        # distances 3 and 5.
        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))
        self.assertTupleEqual(bufshr.nbr_dists[dce], (3, 5))

        # Subgroup as 4 by 2. The whole circle is six hops of 3 and two hops of
        # 5, but only 7 of 8 steps.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 8),
                               (3 * 6 + 5 * 2) * 7 / 8.)
        # Subgroup as 4 by 1. One node does three hops of 3, and other three
        # nodes do two hops of 3 and one hop of 9 (looping back).
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 4),
                               ((3 * 3) + (3 * 2 + 9) * 3) / 4. * 2)
        # Subgroup as 2 by 1. All nodes do one hop of 3.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 2),
                               (3 + 3) / 2. * 4)
        # Subgroup as 1. No rotation.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 1), 0)

        # Subgroup as 4 by 1. One node does two hops of 3 and two do one hop of
        # 3 and 6 each. The 3rd node also sends to the 4th one with two hops of
        # 3.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 3),
                               ((3 * 2) + (3 + 6) * 2 + (3 * 2)) / 3. * 2)
        # Subgroup as 4 by 2. The 1st node does three hops of 3 and one hop of
        # 5. The 2nd, 3rd, and 4th nodes do two hops of 3, and one hop of 5,
        # and one looping back from the 5th node to the 1st node. The 5th node
        # does one looping back and three hops of 3. Finally, the 5th node also
        # sends to the 6th to 8th nodes.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 5),
                               ((3 * 3 + 5) + (3 * 2 + 5 + (3 * 3 + 5)) * 3 +
                                ((3 * 3 + 5) + 3 * 3) + 3 * 3 * 4) / 5.)
        # The others are similar.
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 6),
                               ((3 * 4 + 5) + (3 * 3 + 5 + (3 * 2 + 5)) * 4 +
                                ((3 * 2 + 5) + 3 * 4) + 3 * 2 * 5) / 6.)
        self.assertAlmostEqual(bufshr.nhops_rotate_all(dce, 7),
                               ((3 * 5 + 5) + (3 * 4 + 5 + (3 * 1 + 5)) * 5 +
                                ((3 * 1 + 5) + 3 * 5) + 3 * 1 * 6) / 7.)

    def test_nhops_rotate_all_invalid(self):
        ''' Get nhops_rotate_all with invalid args. '''
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*subgroup.*'):
            _ = self.bufshr3.nhops_rotate_all(de.FIL,
                                              self.bufshr3.size(de.FIL) + 1)

    def test_nhops_rotate_all_rot_unit(self):
        ''' Get nhops_rotate_all with rotation unit count. '''

        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))

        for subgrp_size in range(1, bufshr.size(dce)):

            nhops = bufshr.nhops_rotate_all(dce, subgrp_size)

            for rotation_unit_cnt in range(subgrp_size, 32):
                self.assertEqual(
                    bufshr.nhops_rotate_all(dce, subgrp_size,
                                            rotation_unit_cnt), nhops)

            for rotation_unit_cnt in range(1, subgrp_size):
                self.assertLess(
                    bufshr.nhops_rotate_all(dce, subgrp_size,
                                            rotation_unit_cnt), nhops)

    def test_nhops_rotate_all_cache(self):
        ''' Get nhops_rotate_all using cache. '''

        bufshr = self.bufshr3
        dce = de.FIL

        self.assertFalse(bufshr.nhops_cache)

        nhops_8 = bufshr.nhops_rotate_all(dce, 8)
        nhops_4 = bufshr.nhops_rotate_all(dce, 4)
        nhops_1 = bufshr.nhops_rotate_all(dce, 1)
        self.assertEqual(len(bufshr.nhops_cache), 3)
        self.assertEqual(nhops_8, bufshr.nhops_rotate_all(dce, 8))
        self.assertEqual(nhops_4, bufshr.nhops_rotate_all(dce, 4))
        self.assertEqual(nhops_1, bufshr.nhops_rotate_all(dce, 1))
        self.assertEqual(len(bufshr.nhops_cache), 3)

        dce = de.IFM

        nhops_3 = bufshr.nhops_rotate_all(dce, 3)
        nhops_2 = bufshr.nhops_rotate_all(dce, 2)
        self.assertEqual(len(bufshr.nhops_cache), 5)
        self.assertEqual(nhops_3, bufshr.nhops_rotate_all(dce, 3))
        self.assertEqual(nhops_2, bufshr.nhops_rotate_all(dce, 2))
        self.assertEqual(len(bufshr.nhops_cache), 5)

        nhops_rot_unit = bufshr.nhops_rotate_all(dce, 3, 2)

        self.assertEqual(len(bufshr.nhops_cache), 6)
        self.assertEqual(nhops_rot_unit, bufshr.nhops_rotate_all(dce, 3, 2))
        self.assertEqual(len(bufshr.nhops_cache), 6)

    def test_nhops_wide_fetch_once(self):
        ''' Get nhops_wide_fetch_once. '''
        # With `self.bufshr3` and FIL, the dimension is 4 by 2, with neighbor
        # distances 3 and 5.
        bufshr = self.bufshr3
        dce = de.FIL
        self.assertTupleEqual(bufshr.dim(dce), (4, 2))
        self.assertTupleEqual(bufshr.nbr_dists[dce], (3, 5))

        for subgrp_size in range(bufshr.size(dce)):
            self.assertAlmostEqual(
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 1), 0)

        # Three nodes fetch one hop of 3, and the last node fetches one hop of
        # 9 (looping back).
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 4, 2) * 2, (3 * 3 + 9) / 4. * 2)
        # Two nodes fetch one hop of 3, and the 3rd node fetches one hop of 6
        # (looping back). The last node fetches one hop of 3 from the 3rd.
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 3, 2) * 2,
            (3 * 2 + 6 + 3) / 3. * 2)
        # All nodes do one hop of 3.
        self.assertAlmostEqual(
            bufshr.nhops_wide_fetch_once(dce, 2, 2) * 2, (3 + 3) / 2. * 4)

        for subgrp_size in range(2, bufshr.size(dce)):
            self.assertAlmostEqual(
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 1.5) * 1.5,
                bufshr.nhops_wide_fetch_once(dce, subgrp_size, 2) * 2. / 2.)

    def test_nhops_wide_fetch_once_inv(self):
        ''' Get nhops_wide_fetch_once with invalid args. '''
        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*subgroup.*'):
            _ = self.bufshr3.nhops_wide_fetch_once(
                de.FIL,
                self.bufshr3.size(de.FIL) + 1, 2)

        with self.assertRaisesRegexp(ValueError, 'BufShrScheme: .*width.*'):
            _ = self.bufshr3.nhops_wide_fetch_once(
                de.FIL,
                self.bufshr3.size(de.FIL) / 2,
                self.bufshr3.size(de.FIL) / 2 + 1)

    def test_repr(self):
        ''' __repr__. '''
        self.assertIn(repr(self.ps1), repr(self.bufshr1))
        self.assertIn(repr(self.ps2), repr(self.bufshr2))
        self.assertIn(repr(self.ps3), repr(self.bufshr3))