def test_gen_loopblocking_all(self): ''' gen_loopblocking cover all. ''' exp_cnt = 0 for bl_ts, bl_ords in self._gen_loopblocking_all(): exp_cnt += 1 if not loop_blocking.skip_conv(bl_ts, bl_ords) else 0 cnt = 0 for _ in self._gen_loopblocking(rsrckey='LG'): cnt += 1 self.assertEqual(cnt, exp_cnt)
def test_skip_not_reg(self): ''' skip non-regularized. ''' for sch in self._gen_loopblocking_all(): skip = loop_blocking.skip_conv(*sch) reg_sch = self._regularized_scheme(*sch) if not skip: self.assertEqual( reg_sch, sch, 'test_skip_not_reg: non-skipped {} should be ' 'regularized to {}'.format(sch, reg_sch)) continue lbs = self._lbs(*sch, rsrckey='LG') reg_lbs = self._lbs(*reg_sch, rsrckey='LG') self.assertFalse( loop_blocking.skip_conv(*reg_sch), 'test_skip_not_reg: regularized {} is skipped.'.format( reg_sch)) self.assertAlmostEqual(lbs.get_access_cost(self.cost), reg_lbs.get_access_cost(self.cost), msg=('test_skip_not_reg: cost mismatch. ' 'orig {}, reg {}.'.format( sch, reg_sch))) self.assertListEqual(lbs.get_access(), reg_lbs.get_access(), msg=('test_skip_not_reg: access mismatch. ' 'orig {}, reg {}.'.format(sch, reg_sch))) size = self._get_lbs_size(lbs) reg_size = self._get_lbs_size(reg_lbs) self.assertTrue( all( all(ss1 >= ss2 for ss1, ss2 in zip(s1, s2)) for s1, s2 in zip(size, reg_size)), 'test_skip_not_reg: reg size is larger than eqv.\n' 'org {} has size {}\nreg {} has size {}'.format( sch, size, reg_sch, reg_size))
def test_skip_ratio(self): ''' skip ratio. ''' cnts = [0, 0] for bl_ts, bl_ords in self._gen_loopblocking_all(): skip = loop_blocking.skip_conv(bl_ts, bl_ords) cnts[skip] += 1 skip_ratio = 1. * cnts[True] / sum(cnts) self.assertGreater( skip_ratio, 0.95, 'test_skip_ratio: skip ratio {} too low.'.format(skip_ratio))
def test_gen_loopblocking_all(self): ''' gen_loopblocking cover all. ''' exp_cnt = 0 for bl_ts, bl_ords in self._gen_loopblocking_all(): exp_cnt += 1 if not loop_blocking.skip_conv(bl_ts, bl_ords) else 0 cnt = 0 for _ in loop_blocking.gen_loopblocking(self.nld['BASE'], self.resource['LG'], self.cost, 1, self.options['BASE']): cnt += 1 self.assertEqual(cnt, exp_cnt)