Exemple #1
0
    def test_prod(self):
        ''' Check prod. '''
        for fs in util.factorize(24, 3):
            self.assertEqual(util.prod(fs), 24)

        for fs in util.factorize(1024, 3):
            self.assertEqual(util.prod(fs), 1024)
Exemple #2
0
    def test_len(self):
        ''' Length. '''
        # Use 4 prime factors, 2, 3, 5, 7.
        val = 2 * 3 * 5 * 7
        self.assertEqual(len(list(util.factorize(val, 2))), 2**4)
        self.assertEqual(len(list(util.factorize(val, 3))), 3**4)

        for val in [24, 1024, (2**4) * (3**5) * (5**2)]:
            fs = list(util.factorize(val, 2))
            self.assertEqual(len(fs), len(set(fs)))
Exemple #3
0
    def test_factors(self):
        ''' Factors. '''
        factors2 = set()
        for fs in util.factorize(24, 2):
            factors2.update(fs)
        self.assertSetEqual(factors2, set([1, 2, 3, 4, 6, 8, 12, 24]))

        factors3 = set()
        for fs in util.factorize(24, 3):
            factors3.update(fs)
        self.assertSetEqual(factors2, factors3)
 def _gen_loopblocking_all(self, wlkey='BASE'):
     ''' Generate all combinations of loop blocking factors and orders. '''
     for ti, to, tb, orders in itertools.product(
             util.factorize(self.nld[wlkey].loopcnt[le.IFM], 3),
             util.factorize(self.nld[wlkey].loopcnt[le.OFM], 3),
             util.factorize(self.nld[wlkey].loopcnt[le.BAT], 3),
             itertools.product(itertools.permutations(range(le.NUM)),
                               itertools.permutations(range(le.NUM)))):
         lp_ts = [None] * le.NUM
         lp_ts[le.IFM] = ti
         lp_ts[le.OFM] = to
         lp_ts[le.BAT] = tb
         yield tuple(zip(*lp_ts)), orders
Exemple #5
0
    def _gen_partition_full(self, wlkey='BASE', dnkey='BASE'):
        ''' Generate all PartitionScheme regardless of equivalence. '''

        layer = self.layers[wlkey]
        dim_nodes = self.dim_nodes[dnkey]

        for ph, pw in itertools.product(util.factorize(dim_nodes.h, pe.NUM),
                                        util.factorize(dim_nodes.w, pe.NUM)):

            pdims = [PhyDim2(h, w) for h, w in zip(ph, pw)]

            # BATP.
            if self.batch_size % pdims[pe.BATP].size() != 0:
                continue

            # OUTP.
            if not util.approx_dividable(layer.nofm, pdims[pe.OUTP].size()):
                continue

            # OFMP.
            if not util.approx_dividable(layer.hofm, pdims[pe.OFMP].h) \
                    or not util.approx_dividable(layer.wofm, pdims[pe.OFMP].w):
                continue

            # INPP.
            if isinstance(layer, ConvLayer):
                if not util.approx_dividable(layer.nifm,
                                             pdims[pe.INPP].size()):
                    continue
            elif isinstance(layer, LocalRegionLayer):
                if pdims[pe.INPP].size() > 1:
                    continue

            # Fully utilize one dimension.
            pdims_no_ofmp = pdims[:pe.OFMP] + pdims[pe.OFMP + 1:]
            if any(pd.h != 1 and pd.h != dim_nodes.h and pd.w != 1
                   and pd.w != dim_nodes.w for pd in pdims_no_ofmp):
                continue

            for order in itertools.permutations(range(pe.NUM)):

                # Batch parallelism should be at the top.
                filtered_order = [
                    pae for pae in order if pdims[pae].size() > 1
                ]
                if pe.BATP in filtered_order and filtered_order[0] != pe.BATP:
                    continue

                yield PartitionScheme(order=order, pdims=pdims)
Exemple #6
0
    def test_perm(self):
        ''' Permutations. '''
        fs_ord = set()
        fs_unord = set()
        for fs in util.factorize(512, 3):
            fs_ord.add(fs)
            fs_unord.add(frozenset(fs))

        cnt = 0
        for fs in fs_unord:
            if len(fs) == 3:
                # Permutations.
                cnt += math.factorial(3)
            elif len(fs) == 2:
                # Permutations of a, a, b.
                cnt += 3
            else:
                # Pattern a, a, a.
                cnt += 1
        self.assertEqual(len(fs_ord), cnt)
    def test_filter_gen_ts(self):
        ''' Get filter_gen_ts. '''
        gen_tifm = util.factorize(36, 3)
        gen_tofm = util.factorize(20, 3)
        gen_tbat = util.factorize(16, 3)

        cstr = SchedulingConstraint(topbat=2, topofm=4)

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgifm), set(gifm0))
        set_fgofm = set(fgofm)
        set_fgbat = set(fgbat)
        self.assertTrue(set_fgofm.issubset(set(gofm0)))
        self.assertTrue(set_fgbat.issubset(set(gbat0)))
        self.assertSetEqual(set_fgofm,
                            set([(4, ) + tpl for tpl in util.factorize(5, 2)]))
        self.assertSetEqual(set_fgbat,
                            set([(2, ) + tpl for tpl in util.factorize(8, 2)]))

        cstr = SchedulingConstraint(topifm=4)

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgofm), set(gofm0))
        self.assertSetEqual(set(fgbat), set(gbat0))
        set_fgifm = set(fgifm)
        self.assertTrue(set_fgifm.issubset(set(gifm0)))
        self.assertSetEqual(set_fgifm,
                            set([(4, ) + tpl for tpl in util.factorize(9, 2)]))

        cstr = SchedulingConstraint()

        gifm, gifm0, gen_tifm = itertools.tee(gen_tifm, 3)
        gofm, gofm0, gen_tofm = itertools.tee(gen_tofm, 3)
        gbat, gbat0, gen_tbat = itertools.tee(gen_tbat, 3)
        fgifm, fgofm, fgbat = cstr.filter_gen_ts(gifm, gofm, gbat)

        self.assertSetEqual(set(fgifm), set(gifm0))
        self.assertSetEqual(set(fgofm), set(gofm0))
        self.assertSetEqual(set(fgbat), set(gbat0))
Exemple #8
0
 def test_limits(self):
     ''' Check limits. '''
     for fs in util.factorize(1024, 3, limits=(10, 20)):
         self.assertLessEqual(fs[0], 10)
         self.assertLessEqual(fs[1], 20)
         self.assertEqual(util.prod(fs), 1024)