def test_o_full_layer(self):
        ''' ofmap full layer. '''
        for wlkey in self.layers:
            layer = self.layers[wlkey]

            for dnkey in self.dim_nodes:

                for part in self._gen_partition(wlkey=wlkey,
                                                dnkey=dnkey,
                                                optkey='NOINPP'):

                    frmap = FmapRangeMap()

                    for pidx in part.gen_pidx():

                        fr = partition.part_layer_ofmap_range(
                            layer, self.batch_size, part, pidx)
                        frmap.add(fr, 0)

                    self.assertTrue(frmap.is_complete())

                    cfr = frmap.complete_fmap_range()
                    self.assertEqual(cfr.size('b'), self.batch_size)
                    self.assertEqual(cfr.size('n'), layer.nofm)
                    self.assertEqual(cfr.size('h'), layer.hofm)
                    self.assertEqual(cfr.size('w'), layer.wofm)
    def test_complete_fmap_range(self):
        ''' Get complete_fmap_range. '''
        self.assertTrue(self.frm.is_complete(), 'is_complete')
        self.assertEqual(self.frm.complete_fmap_range().size(), 4 * 8 * 16 * 16,
                         'complete_fmap_range')

        fr = FmapRange((0, 0, 0, 0), (3, 5, 7, 9))
        frm = FmapRangeMap()
        frm.add(fr, 3.4)
        self.assertTrue(frm.is_complete(), 'is_complete')
        self.assertEqual(frm.complete_fmap_range(), fr, 'complete_fmap_range')
    def test_is_complete_incomplete(self):
        ''' Get complete_fmap_range incomplete. '''
        self.frm.add(FmapRange((4, 8, 16, 16), (5, 9, 17, 17)), 10)
        self.assertFalse(self.frm.is_complete(), 'is_complete: incomplete')
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*complete.*'):
            _ = self.frm.complete_fmap_range()

        fr = FmapRange((1, 0, 0, 0), (3, 5, 7, 9))
        frm = FmapRangeMap()
        frm.add(fr, 3.4)
        self.assertFalse(frm.is_complete(), 'is_complete: incomplete')
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*complete.*'):
            _ = frm.complete_fmap_range()
    def test_io_full_layer(self):
        ''' i/ofmap full layer. '''
        for wlkey in ['SM', 'POOL']:
            layer = self.layers[wlkey]

            for dnkey in self.dim_nodes:

                for part in self._gen_partition(wlkey=wlkey,
                                                dnkey=dnkey,
                                                optkey='NOINPP'):

                    # Remove ifmap point from full set.
                    ifp_set = set(
                        FmapRange(fp_beg=(0, 0, 0, 0),
                                  fp_end=(self.batch_size, layer.nifm,
                                          layer.hifm, layer.wifm)).range())
                    # Add ofmap ranges to a map.
                    ofrmap = FmapRangeMap()

                    for pidx in part.gen_pidx():

                        _, ifrng, ofrng = partition.proc_data_range(
                            layer, self.batch_size, part, pidx)

                        for ifp in ifrng.range():
                            ifp_set.discard(ifp)

                        ofrmap.add(ofrng, 0)

                    # Ifmap point set should be empty now.
                    self.assertFalse(ifp_set)

                    # Ofmap range map should be full now.
                    self.assertTrue(ofrmap.is_complete())
                    cfr = ofrmap.complete_fmap_range()
                    self.assertEqual(cfr.size('b'), self.batch_size)
                    self.assertEqual(cfr.size('n'), layer.nofm)
                    self.assertEqual(cfr.size('h'), layer.hofm)
                    self.assertEqual(cfr.size('w'), layer.wofm)
class TestFmapRangeMap(unittest.TestCase):
    ''' Tests for FmapRangeMap. '''

    def setUp(self):
        self.frm = FmapRangeMap()
        self.frm.add(FmapRange((0, 0, 0, 0), (2, 4, 8, 16)), 0)
        self.frm.add(FmapRange((0, 0, 8, 0), (2, 4, 16, 16)), 1)
        self.frm.add(FmapRange((0, 4, 0, 0), (2, 8, 8, 16)), 2)
        self.frm.add(FmapRange((0, 4, 8, 0), (2, 8, 16, 16)), 3)
        self.frm.add(FmapRange((2, 0, 0, 0), (4, 4, 8, 16)), 4)
        self.frm.add(FmapRange((2, 0, 8, 0), (4, 4, 16, 16)), 5)
        self.frm.add(FmapRange((2, 4, 0, 0), (4, 8, 8, 16)), 6)
        self.frm.add(FmapRange((2, 4, 8, 0), (4, 8, 16, 16)), 7)

    def test_add(self):
        ''' Modifier add. '''
        self.frm.add(FmapRange((4, 8, 16, 16), (5, 9, 17, 17)), 10)
        self.assertEqual(self.frm.get(FmapPosition(4, 8, 16, 16)), 10, 'add')
        self.frm.add(FmapRange((10, 10, 20, 20), (15, 19, 27, 27)), 11)
        self.assertEqual(self.frm.get(FmapPosition(14, 15, 22, 24)), 11, 'add')

    def test_add_zero_fr(self):
        ''' Modifier add zero FmapRange. '''
        num_items = len(list(self.frm.items()))
        self.frm.add(FmapRange((5, 9, 17, 17), (5, 9, 17, 17)), 10)
        self.assertEqual(len(list(self.frm.items())), num_items)

    def test_add_overlap_fr(self):
        ''' Modifier add overlapping FmapRange. '''
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*overlap.*'):
            self.frm.add(FmapRange((3, 7, 15, 15), (5, 9, 17, 17)), 10)

    def test_get(self):
        ''' Get. '''
        self.assertEqual(self.frm.get(FmapPosition(3, 5, 7, 9)), 6, 'get')
        self.assertEqual(self.frm.get(FmapPosition(0, 0, 0, 0)), 0, 'get')
        self.assertEqual(self.frm.get(FmapPosition(2, 1, 1, 12)), 4, 'get')
        self.assertEqual(self.frm.get(FmapPosition(3, 7, 15, 15)), 7, 'get')

    def test_get_not_in(self):
        ''' Get not in. '''
        with self.assertRaisesRegex(KeyError, 'FmapRangeMap: .*key.*'):
            _ = self.frm.get(FmapPosition(4, 8, 16, 16))

    def test_complete_fmap_range(self):
        ''' Get complete_fmap_range. '''
        self.assertTrue(self.frm.is_complete(), 'is_complete')
        self.assertEqual(self.frm.complete_fmap_range().size(), 4 * 8 * 16 * 16,
                         'complete_fmap_range')

        fr = FmapRange((0, 0, 0, 0), (3, 5, 7, 9))
        frm = FmapRangeMap()
        frm.add(fr, 3.4)
        self.assertTrue(frm.is_complete(), 'is_complete')
        self.assertEqual(frm.complete_fmap_range(), fr, 'complete_fmap_range')

    def test_is_complete_incomplete(self):
        ''' Get complete_fmap_range incomplete. '''
        self.frm.add(FmapRange((4, 8, 16, 16), (5, 9, 17, 17)), 10)
        self.assertFalse(self.frm.is_complete(), 'is_complete: incomplete')
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*complete.*'):
            _ = self.frm.complete_fmap_range()

        fr = FmapRange((1, 0, 0, 0), (3, 5, 7, 9))
        frm = FmapRangeMap()
        frm.add(fr, 3.4)
        self.assertFalse(frm.is_complete(), 'is_complete: incomplete')
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*complete.*'):
            _ = frm.complete_fmap_range()

    def test_items(self):
        ''' Accessor items. '''
        size = 0
        for k, v in self.frm.items():
            size += k.size()
            self.assertEqual(self.frm.get(k.fp_beg), v, 'items: keyval')
        self.assertEqual(size, 4 * 8 * 16 * 16, 'items: size')

    def test_copy(self):
        ''' Copy. '''
        frm = self.frm.copy()
        self.assertListEqual(list(frm.items()), list(self.frm.items()),
                             'copy: equal')

        fr1 = FmapRange((10, 10, 10, 10), (11, 11, 11, 11))
        frm.add(fr1, 10)
        self.assertEqual(frm.get(fr1.fp_beg), 10, 'copy: in')
        with self.assertRaisesRegex(KeyError, 'FmapRangeMap: .*key.*'):
            _ = self.frm.get(fr1.fp_beg)

        fr2 = FmapRange((20, 20, 20, 20), (21, 21, 21, 21))
        self.frm.add(fr2, 20)
        self.assertEqual(self.frm.get(fr2.fp_beg), 20, 'copy: in')
        with self.assertRaisesRegex(KeyError, 'FmapRangeMap: .*key.*'):
            _ = frm.get(fr2.fp_beg)

    def test_rget_counter(self):
        ''' Get rget_counter. '''
        fr = FmapRange((1, 3, 9, 11), (3, 5, 13, 15))
        counters = self.frm.rget_counter(fr)
        self.assertEqual(sum(counters.values()), fr.size(), 'rget_counter')

        fr = FmapRange((0, 0, 0, 0), (0, 0, 0, 0))
        counters = self.frm.rget_counter(fr)
        self.assertEqual(sum(counters.values()), 0, 'rget_counter')

        fr = FmapRange((1, 3, 9, 11), (3, 5, 13, 17))
        counters = self.frm.rget_counter(fr)
        self.assertLess(sum(counters.values()), fr.size(), 'rget_counter')
        self.assertEqual(sum(counters.values()),
                         self.frm.complete_fmap_range().overlap(fr).size(),
                         'rget_counter')

    def test_rget_counter_same_vals(self):
        ''' Get rget_counter when there are same values in FmapRangeMap. '''
        self.frm.add(FmapRange((0, 0, 0, 16), (4, 8, 16, 32)), 2)
        fr = FmapRange((1, 3, 9, 11), (3, 5, 13, 17))
        counters = self.frm.rget_counter(fr)
        self.assertEqual(sum(counters.values()), fr.size())

    def test_rget_single(self):
        ''' Get rget_single. '''
        for k, v in self.frm.items():
            self.assertEqual(self.frm.rget_single(k), v, 'rget_single')

        val = self.frm.rget_single(FmapRange((0, 0, 0, 0), (1, 1, 1, 1)))
        self.assertEqual(val, 0, 'rget_single')

        val = self.frm.rget_single(FmapRange((3, 1, 10, 3), (4, 3, 13, 7)))
        self.assertEqual(val, 5, 'rget_single')

    def test_rget_single_multi(self):
        ''' Get rget_single with . '''
        with self.assertRaisesRegex(ValueError, 'FmapRangeMap: .*single.*'):
            _ = self.frm.rget_single(FmapRange((3, 1, 10, 3), (4, 6, 13, 7)))

    def test_str(self):
        ''' Get string. '''
        string = str(self.frm)
        for k, v in self.frm.items():
            self.assertIn(str(k), string)
            self.assertIn(str(v), string)