def test_valid_args(self):
        ''' Valid arguments. '''
        cstr = SchedulingConstraint(topbat=2, topifm=1, topofm=4)
        self.assertEqual(cstr.topbat, 2)
        self.assertEqual(cstr.topifm, 1)
        self.assertEqual(cstr.topofm, 4)
        self.assertDictEqual(cstr.update_dict, {})

        cstr = SchedulingConstraint(topbat=2, topofm=4)
        self.assertEqual(cstr.topbat, 2)
        self.assertEqual(cstr.topifm, 0)
        self.assertEqual(cstr.topofm, 4)
        self.assertDictEqual(cstr.update_dict, {})

        cstr = SchedulingConstraint(
            topofm=4,
            update_dict={
                'l1': lambda s, _: setattr(s, 'topbat', 1),
                'l2': lambda s, r: setattr(s, 'topifm', r.topifm),
            })
        self.assertEqual(cstr.topbat, 0)
        self.assertEqual(cstr.topifm, 0)
        self.assertEqual(cstr.topofm, 4)
        self.assertEqual(len(cstr.update_dict), 2)
        self.assertIn('l1', cstr.update_dict)
        self.assertIn('l2', cstr.update_dict)

        cstr = SchedulingConstraint()
        self.assertEqual(cstr.topbat, 0)
        self.assertEqual(cstr.topifm, 0)
        self.assertEqual(cstr.topofm, 0)
        self.assertDictEqual(cstr.update_dict, {})
    def setUp(self):

        self.layers = {}
        self.layers['BASE'] = ConvLayer(8, 16, 28, 3)
        self.layers['POOL'] = PoolingLayer(16, 28, 2)
        self.layers['LR'] = LocalRegionLayer(16, 28, nreg=3, sreg=1)

        self.batch_size = 4

        self.cost = Cost(mac_op=1,
                         mem_hier=(200, 6, 2, 1),
                         noc_hop=50,
                         idl_unit=50)

        self.none_cstr = SchedulingConstraint()
        self.cstr = SchedulingConstraint(topofm=1, topbat=self.batch_size)

        self.resource = Resource(
            proc_region=NodeRegion(origin=PhyDim2(0, 0),
                                   dim=PhyDim2(4, 4),
                                   type=NodeRegion.PROC),
            dram_region=NodeRegion(origin=PhyDim2(0, 0),
                                   dim=PhyDim2(4, 1),
                                   type=NodeRegion.DRAM),
            src_data_region=NodeRegion(origin=PhyDim2(0, 0),
                                       dim=PhyDim2(4, 1),
                                       type=NodeRegion.DRAM),
            dst_data_region=NodeRegion(origin=PhyDim2(0, 0),
                                       dim=PhyDim2(4, 1),
                                       type=NodeRegion.DRAM),
            dim_array=PhyDim2(16, 16),
            size_gbuf=65536,
            size_regf=64,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)

        self.options = Option(partition_hybrid=True,
                              partition_batch=True,
                              partition_ifmaps=True,
                              ntops=10)

        self.ifmap_layouts = {}
        part = PartitionScheme(order=(pe.INPP, pe.BATP, pe.OUTP, pe.OFMP),
                               pdims=((1, 2), (2, 1), (1, 2), (2, 1)))
        for wlkey in self.layers:
            input_layer = self.layers[wlkey].input_layer()
            self.ifmap_layouts[wlkey] = DataLayout(
                frngs=(FmapRange((0, 0, 0, 0),
                                 FmapPosition(b=self.batch_size,
                                              n=input_layer.nofm,
                                              h=input_layer.hofm,
                                              w=input_layer.wofm)), ),
                regions=(self.resource.src_data_region, ),
                parts=(part.projection(self.resource.src_data_region,
                                       appl2frng=True), ))

        self.sched_seq = (2, 0, 1)
    def test_null_constraint(self):
        ''' Null constraint. '''
        cstr = SchedulingConstraint()

        self.assertTrue(cstr.is_valid_top_bl((1, 1, 2), (0, 1, 2)))
        self.assertTrue(cstr.is_valid_top_bl((3, 4, 5), (2, 1, 0)))
        self.assertTrue(cstr.is_valid_top_bl((1, 1, 1), (1, 2, 0)))

        self.assertTrue(
            cstr.is_valid_part(
                PartitionScheme(order=range(pe.NUM), pdims=[(2, 2)] * pe.NUM)))
    def test_invalid_update_dict(self):
        ''' Invalid argument update_dict. '''
        with self.assertRaisesRegexp(
                TypeError, 'SchedulingConstraint: '
                '.*update_dict.*'):
            _ = SchedulingConstraint(update_dict=['l1'])

        with self.assertRaisesRegexp(
                TypeError, 'SchedulingConstraint: '
                '.*update_dict.*'):
            _ = SchedulingConstraint(update_dict={'l1': 1})
    def setUp(self):

        self.resource = Resource(
            proc_region=NodeRegion(origin=PhyDim2(0, 0),
                                   dim=PhyDim2(1, 1),
                                   type=NodeRegion.PROC),
            dram_region=NodeRegion(origin=PhyDim2(0, 0),
                                   dim=PhyDim2(1, 1),
                                   type=NodeRegion.DRAM),
            src_data_region=NodeRegion(origin=PhyDim2(0, 0),
                                       dim=PhyDim2(1, 1),
                                       type=NodeRegion.DRAM),
            dst_data_region=NodeRegion(origin=PhyDim2(0, 0),
                                       dim=PhyDim2(1, 1),
                                       type=NodeRegion.DRAM),
            dim_array=PhyDim2(16, 16),
            size_gbuf=65536,
            size_regf=64,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)

        self.none_cstr = SchedulingConstraint()

        part = PartitionScheme(order=range(pe.NUM), pdims=[(1, 1)] * pe.NUM)
        self.ifmap_layout = DataLayout(
            frngs=(FmapRange((0, 0, 0, 0), (2, 4, 16, 16)), ),
            regions=(self.resource.src_data_region, ),
            parts=(part, ))

        self.sched_seq = (2, 0, 0)
    def test_repr(self):
        ''' __repr__. '''
        cstr = SchedulingConstraint(topbat=2)
        self.assertIn('SchedulingConstraint(', repr(cstr))
        self.assertIn('topbat=2', repr(cstr))
        self.assertIn('topifm=0', repr(cstr))
        self.assertIn('topofm=0', repr(cstr))

        cstr = SchedulingConstraint(
            update_dict={
                'l1': lambda s, _: setattr(s, 'topbat', 1),
                'l2': lambda s, r: setattr(s, 'topifm', r.topifm),
            })
        self.assertIn('update_dict=', repr(cstr))
        self.assertIn('l1', repr(cstr))
        self.assertIn('l2', repr(cstr))
    def test_is_valid_top_bl(self):
        ''' Whether is_valid_top_bl. '''
        cstr = SchedulingConstraint(topbat=2, topofm=4)
        for bl_t, bl_ord in self._gen_bl():
            valid = (bl_t[le.BAT] == 2 and bl_t[le.OFM] == 4)
            self.assertEqual(cstr.is_valid_top_bl(bl_t, bl_ord), valid)

        cstr = SchedulingConstraint(topifm=4)
        for bl_t, bl_ord in self._gen_bl():
            valid = (bl_t[le.IFM] == 4)
            self.assertEqual(cstr.is_valid_top_bl(bl_t, bl_ord), valid)

        cstr = SchedulingConstraint()
        for bl_t, bl_ord in self._gen_bl():
            self.assertTrue(cstr.is_valid_top_bl(bl_t, bl_ord))
    def test_filter_gen_ts(self):
        ''' Get filter_gen_ts. '''
        gen_tifm = util.factorize(36, 3)
        gen_tofm = util.factorize(20, 3)
        gen_tbat = util.factorize(16, 3)

        cstr = SchedulingConstraint(topbat=2, topofm=4)

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgifm), set(gifm0))
        set_fgofm = set(fgofm)
        set_fgbat = set(fgbat)
        self.assertTrue(set_fgofm.issubset(set(gofm0)))
        self.assertTrue(set_fgbat.issubset(set(gbat0)))
        self.assertSetEqual(set_fgofm,
                            set([(4, ) + tpl for tpl in util.factorize(5, 2)]))
        self.assertSetEqual(set_fgbat,
                            set([(2, ) + tpl for tpl in util.factorize(8, 2)]))

        cstr = SchedulingConstraint(topifm=4)

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgofm), set(gofm0))
        self.assertSetEqual(set(fgbat), set(gbat0))
        set_fgifm = set(fgifm)
        self.assertTrue(set_fgifm.issubset(set(gifm0)))
        self.assertSetEqual(set_fgifm,
                            set([(4, ) + tpl for tpl in util.factorize(9, 2)]))

        cstr = SchedulingConstraint()

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgifm), set(gifm0))
        self.assertSetEqual(set(fgofm), set(gofm0))
        self.assertSetEqual(set(fgbat), set(gbat0))
    def test_content_hash(self):
        ''' Content-based hash. '''
        cstr1 = SchedulingConstraint(topbat=2)
        cstr2 = SchedulingConstraint(topbat=2)
        self.assertNotEqual(id(cstr1), id(cstr2))
        self.assertEqual(hash(cstr1), hash(cstr2))
        self.assertEqual(cstr1, cstr2)

        cstr3 = SchedulingConstraint(
            topbat=2,
            update_dict={
                'l1': lambda s, _: setattr(s, 'topbat', 1),
                'l2': lambda s, r: setattr(s, 'topifm', r.topifm),
            })
        r = SchedulingConstraint(topifm=2)
        cstr3.update_by_prev({'l1': None, 'l2': r})
        cstr4 = SchedulingConstraint(topifm=2, topbat=1)
        self.assertNotEqual(id(cstr3), id(cstr4))
        self.assertEqual(hash(cstr3), hash(cstr4))
        self.assertEqual(cstr3, cstr4)
    def test_update_by_prev(self):
        ''' Modifier update_by_prev. '''
        cstr = SchedulingConstraint(
            topofm=4,
            update_dict={
                'l1': lambda s, _: setattr(s, 'topbat', 1),
                'l2': lambda s, r: setattr(s, 'topifm', r.topifm),
            })
        self.assertEqual(cstr.topbat, 0)
        self.assertEqual(cstr.topifm, 0)
        self.assertEqual(cstr.topofm, 4)

        r = SchedulingConstraint(topifm=2)
        cstr.update_by_prev({'l1': None, 'l2': r})

        self.assertEqual(cstr.topbat, 1)
        self.assertEqual(cstr.topifm, 2)
        self.assertEqual(cstr.topofm, 4)

        self.assertFalse(cstr.is_valid_top_bl([1, 4, 1], range(le.NUM)))
        self.assertTrue(cstr.is_valid_top_bl([2, 4, 1], range(le.NUM)))
    def test_is_valid_before_update(self):
        ''' is_valid_top_bl and is_valid_part called before update. '''
        cstr = SchedulingConstraint(
            topofm=4,
            update_dict={
                'l1': lambda s, _: setattr(s, 'topbat', 1),
                'l2': lambda s, r: setattr(s, 'topifm', r.topifm),
            })

        with self.assertRaisesRegexp(
                ValueError, 'SchedulingConstraint: '
                '.*update_dict.*'):
            cstr.is_valid_top_bl([1] * le.NUM, range(le.NUM))

        with self.assertRaisesRegexp(
                ValueError, 'SchedulingConstraint: '
                '.*update_dict.*'):
            cstr.is_valid_part(
                PartitionScheme(order=range(pe.NUM), pdims=[(2, 2)] * pe.NUM))
    def setUp(self):

        # Workload.
        self.layer = {}
        self.layer['BASE'] = ConvLayer(12, 10, 28, 3)
        self.layer['LGFIL'] = ConvLayer(2, 4, 28, 20)
        self.layer['POOL'] = PoolingLayer(32, 28, 2)
        self.layer['PAR'] = ConvLayer(24, 36, 56, 3)
        self.batch_size = 4

        # Resource.
        self.resource = {}
        dim_array = PhyDim2(16, 16)
        proc_region = NodeRegion(origin=PhyDim2(0, 0),
                                 dim=PhyDim2(1, 1),
                                 type=NodeRegion.PROC)
        data_region = NodeRegion(origin=PhyDim2(0, 0),
                                 dim=PhyDim2(1, 1),
                                 type=NodeRegion.DRAM)
        # Typical resource.
        self.resource['BASE'] = Resource(proc_region=proc_region,
                                         dram_region=data_region,
                                         src_data_region=data_region,
                                         dst_data_region=data_region,
                                         dim_array=dim_array,
                                         size_gbuf=65536,
                                         size_regf=64,
                                         array_bus_width=float('inf'),
                                         dram_bandwidth=float('inf'),
                                         no_time_mux=False)
        # Larger resource with sufficient capacity, to make all schemes valid.
        self.resource['LG'] = Resource(proc_region=proc_region,
                                       dram_region=data_region,
                                       src_data_region=data_region,
                                       dst_data_region=data_region,
                                       dim_array=dim_array,
                                       size_gbuf=1024**3,
                                       size_regf=1024**3,
                                       array_bus_width=float('inf'),
                                       dram_bandwidth=float('inf'),
                                       no_time_mux=False)
        # Small resource.
        self.resource['SM'] = Resource(proc_region=proc_region,
                                       dram_region=data_region,
                                       src_data_region=data_region,
                                       dst_data_region=data_region,
                                       dim_array=dim_array,
                                       size_gbuf=4096,
                                       size_regf=16,
                                       array_bus_width=float('inf'),
                                       dram_bandwidth=float('inf'),
                                       no_time_mux=False)
        # Multi-node parallel resource.
        self.resource['PAR'] = Resource(proc_region=NodeRegion(
            origin=PhyDim2(0, 0), dim=PhyDim2(4, 2), type=NodeRegion.PROC),
                                        dram_region=data_region,
                                        src_data_region=data_region,
                                        dst_data_region=data_region,
                                        dim_array=dim_array,
                                        size_gbuf=25000,
                                        size_regf=64,
                                        array_bus_width=float('inf'),
                                        dram_bandwidth=float('inf'),
                                        no_time_mux=False)
        # Resource with no data regions.
        proc_data_region = NodeRegion(origin=PhyDim2(1, 1),
                                      dim=PhyDim2(1, 1),
                                      type=NodeRegion.PROC)
        self.resource['SRCNOTDATA'] = Resource(
            proc_region=proc_region,
            dram_region=data_region,
            src_data_region=proc_data_region,
            dst_data_region=data_region,
            dim_array=dim_array,
            size_gbuf=1024**3,
            size_regf=1024**3,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)
        self.resource['DSTNOTDATA'] = Resource(
            proc_region=proc_region,
            dram_region=data_region,
            src_data_region=data_region,
            dst_data_region=proc_data_region,
            dim_array=dim_array,
            size_gbuf=1024**3,
            size_regf=1024**3,
            array_bus_width=float('inf'),
            dram_bandwidth=float('inf'),
            no_time_mux=False)
        self.resource['DATALOCAL'] = Resource(proc_region=proc_region,
                                              dram_region=data_region,
                                              src_data_region=proc_region,
                                              dst_data_region=proc_region,
                                              dim_array=dim_array,
                                              size_gbuf=1024**3,
                                              size_regf=1024**3,
                                              array_bus_width=float('inf'),
                                              dram_bandwidth=float('inf'),
                                              no_time_mux=False)
        # Filter pinning.
        self.resource['FILPIN'] = Resource(proc_region=proc_region,
                                           dram_region=data_region,
                                           src_data_region=data_region,
                                           dst_data_region=data_region,
                                           dim_array=dim_array,
                                           size_gbuf=1024**3,
                                           size_regf=1024**3,
                                           array_bus_width=float('inf'),
                                           dram_bandwidth=float('inf'),
                                           no_time_mux=True)

        # Nested loop description after mapping.
        self.nld = {}
        self.nld['BASE'] = next(
            MapStrategyEyeriss(self.layer['BASE'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        self.nld['LGFIL'] = next(
            MapStrategyEyeriss(self.layer['LGFIL'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        self.nld['POOL'] = next(
            MapStrategyEyeriss(self.layer['POOL'], self.batch_size, 1,
                               dim_array).gen_nested_loop_desc())
        # Fake nested loop, with zero filter size.
        self.nld['ZERO_FIL'] = NestedLoopDesc(
            loopcnt=(12, 10, 4),
            usize_gbuf=(0, 1000, 800),
            usize_regf=(0, 3, 1),
            unit_access=((0, 1000, 800), (0, 1000, 800), (3, 9, 7), (1, 1, 1)),
            data_loops=(DataDimLoops(le.IFM,
                                     le.OFM), DataDimLoops(le.IFM, le.BAT),
                        DataDimLoops(le.OFM, le.BAT)),
            unit_ops=1,
            unit_time=1)
        # Fake nested loop, with zero ifmap size.
        self.nld['ZERO_IFM'] = NestedLoopDesc(
            loopcnt=(12, 10, 4),
            usize_gbuf=(9, 0, 800),
            usize_regf=(3, 0, 1),
            unit_access=((9, 0, 800), (9, 0, 800), (3, 9, 7), (1, 1, 1)),
            data_loops=(DataDimLoops(le.IFM,
                                     le.OFM), DataDimLoops(le.IFM, le.BAT),
                        DataDimLoops(le.OFM, le.BAT)),
            unit_ops=1,
            unit_time=1)

        # Fake partition scheme.
        self.part = PartitionScheme(range(pe.NUM), ((1, 1), ) * pe.NUM)

        # Fake buffer sharing scheme.
        self.bufshr = BufShrScheme(proc_region, self.part)

        # Options.
        self.options = {}
        # Basic.
        self.options['BASE'] = Option(ntops=2**30)
        # Multiprocessing.
        self.options['MP'] = Option(ntops=2**30, nprocesses=8)
        # Limited top schemes.
        self.options['NTOPS'] = Option(ntops=10)
        # Bypass.
        self.options['BYP'] = Option(sw_gbuf_bypass=(True, ) * 3, ntops=2**30)
        # Bypass solver.
        self.options['BYPSOL'] = Option(sw_gbuf_bypass=(True, ) * 3,
                                        sw_solve_loopblocking=True,
                                        ntops=2**30)
        # Access forwarding.
        self.options['ACCFWD'] = Option(hw_access_forwarding=True, ntops=2**30)
        # Buffer sharing.
        self.options['BUFSHR'] = Option(hw_gbuf_sharing=True, ntops=2**30)
        # Buffer sharing with bypassing.
        self.options['BUFSHR-BYP'] = Option(sw_gbuf_bypass=(True, ) * 3,
                                            hw_gbuf_sharing=True,
                                            ntops=2**30)

        # Constraint.
        self.none_cstr = SchedulingConstraint()
        self.cstr = SchedulingConstraint(topifm=1, topbat=1)

        # Cost.
        self.cost = Cost(mac_op=1,
                         mem_hier=(200, 6, 2, 1),
                         noc_hop=50,
                         idl_unit=50)
 def test_invalid_args(self):
     ''' Invalid arguments. '''
     with self.assertRaisesRegexp(
             ValueError, 'SchedulingConstraint: '
             '.*positive integers.*'):
         _ = SchedulingConstraint(topbat=-1, topofm=2.)