def test_ofmap_local(self):
        ''' With locally stored ofmaps. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((4, 1), (1, 1), (1, 4), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(3, 3),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = DataLayout(frngs=(FmapRange(
            (0, ) * 4,
            (self.batch_size, layer.nofm, layer.hofm, layer.wofm)), ),
                             regions=(nr, ),
                             parts=(part, ))

        filter_nodes = frozenset([PhyDim2(3, -3)])

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        self.assertEqual(nhops[de.OFM], 0)
    def test_lr_layer(self):
        ''' LR layers. '''
        layer = self.layers['LR']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 4), (2, 2), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 2), (4, 1),
                                         PhyDim2(8, 4))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(5, 5), (1, 1), (1, 2),
                                         PhyDim2(2, 4))

        filter_nodes = frozenset()

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        true_nhops = self._true_unit_nhops(layer, nr, part, filter_nodes,
                                           ilayout, olayout)

        self.assertListEqual(nhops, true_nhops)
    def test_origin(self):
        ''' Origin. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((1, 1), (1, 1), (1, 1), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(3, 3),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-3, -3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(3, 3), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(3, -3)])

        nhops_1 = partition.unit_nhops_to_proc_region(layer, self.batch_size,
                                                      nr, part, filter_nodes,
                                                      ilayout, olayout,
                                                      self.options['BASE'])

        nr = NodeRegion(origin=PhyDim2(6, 6),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-6, -6), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(6, 6), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(6, -6)])

        nhops_2 = partition.unit_nhops_to_proc_region(layer, self.batch_size,
                                                      nr, part, filter_nodes,
                                                      ilayout, olayout,
                                                      self.options['BASE'])

        self.assertListEqual(nhops_2, [n * 2 for n in nhops_1])
    def test_use_fwd(self):
        ''' Use access forwarding. '''
        layer = self.layers['BASE']

        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 4), (1, 2), (2, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        far_dist = 1000

        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-far_dist, 0), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(0, -far_dist), (1, 1), (1, 1),
                                         PhyDim2(1, 1))

        filter_nodes = frozenset([PhyDim2(far_dist, 0), PhyDim2(0, far_dist)])

        nhops_base = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['BASE'])
        nhops_accfwd = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['ACCFWD'])
        nhops_bufshr = partition.unit_nhops_to_proc_region(
            layer, self.batch_size, nr, part, filter_nodes, ilayout, olayout,
            self.options['BUFSHR'])

        for dce in range(de.NUM):
            self.assertEqual(nhops_accfwd[dce], nhops_bufshr[dce])

        # In the basic access scheme, FIL and IFM are independently fetched,
        # resulting in repeating remote fetch. OFM are merged locally and only
        # stored back remotely once.
        self.assertGreater(
            nhops_base[de.FIL],
            layer.total_filter_size() * far_dist * part.size(pe.BATP) *
            part.size(pe.OFMP) * 0.8)
        self.assertGreater(
            nhops_base[de.IFM],
            layer.total_ifmap_size(self.batch_size) * far_dist *
            part.size(pe.OUTP) * 0.8)

        p_layer, p_batch_size, _ = part.part_layer(layer, self.batch_size)
        # With forwarding, everyone is only remotely fetched once.
        self.assertLess(
            nhops_accfwd[de.FIL],
            p_layer.total_filter_size() * part.size(pe.INPP, pe.OUTP) *
            (far_dist + nr.dim.size()))
        self.assertLess(
            nhops_accfwd[de.IFM],
            p_layer.total_ifmap_size(p_batch_size) *
            part.size(pe.INPP, pe.OFMP, pe.BATP) * (far_dist + nr.dim.size()))
        self.assertLess(
            nhops_accfwd[de.OFM],
            p_layer.total_ofmap_size(p_batch_size) *
            part.size(pe.OUTP, pe.OFMP, pe.BATP) * (far_dist + nr.dim.size()))
    def test_small(self):
        ''' Small case with hand calculation. '''
        layer = ConvLayer(6, 8, 16, 3)
        assert self.batch_size == 8

        # i (0, 0), (2, 0): (0, 0, 0, 0) -- (4, 6, 10, 10)
        #   (0, 1), (2, 1): (0, 0, 0, 8) -- (4, 6, 10, 18)
        #   (0, 2), (2, 2): (4, 0, 0, 0) -- (8, 6, 10, 10)
        #   (0, 3), (2, 3): (4, 0, 0, 8) -- (8, 6, 10, 18)
        #   (1, 0), (3, 0): (0, 0, 8, 0) -- (4, 6, 18, 10)
        #   (1, 1), (3, 1): (0, 0, 8, 8) -- (4, 6, 18, 18)
        #   (1, 2), (3, 2): (4, 0, 8, 0) -- (8, 6, 18, 10)
        #   (1, 3), (3, 3): (4, 0, 8, 8) -- (8, 6, 18, 18)
        # o (0, 0): (0, 0, 0, 0) -- (4, 4, 8, 8)
        #   (0, 1): (0, 0, 0, 8) -- (4, 4, 8, 16)
        #   (0, 2): (4, 0, 0, 0) -- (8, 4, 8, 8)
        #   (0, 3): (4, 0, 0, 8) -- (8, 4, 8, 16)
        #   (1, 0): (0, 0, 8, 0) -- (4, 4, 16, 8)
        #   (1, 1): (0, 0, 8, 8) -- (4, 4, 16, 16)
        #   (1, 2): (4, 0, 8, 0) -- (8, 4, 16, 8)
        #   (1, 3): (4, 0, 8, 8) -- (8, 4, 16, 16)
        #   (2, 0): (0, 4, 0, 0) -- (4, 8, 8, 8)
        #   (2, 1): (0, 4, 0, 8) -- (4, 8, 8, 16)
        #   (2, 2): (4, 4, 0, 0) -- (8, 8, 8, 8)
        #   (2, 3): (4, 4, 0, 8) -- (8, 8, 8, 16)
        #   (3, 0): (0, 4, 8, 0) -- (4, 8, 16, 8)
        #   (3, 1): (0, 4, 8, 8) -- (4, 8, 16, 16)
        #   (3, 2): (4, 4, 8, 0) -- (8, 8, 16, 8)
        #   (3, 3): (4, 4, 8, 8) -- (8, 8, 16, 16)
        part = PartitionScheme(order=(pe.BATP, pe.INPP, pe.OUTP, pe.OFMP),
                               pdims=((2, 1), (2, 2), (1, 2), (1, 1)))

        nr = NodeRegion(origin=PhyDim2(0, 0),
                        dim=part.dim(),
                        type=NodeRegion.PROC)

        # (0, 0, 0, 0) -- (4, 6, 18, 9): (-2, -2)
        # (0, 0, 0, 9) -- (4, 6, 18, 18): (-2, -1)
        # (4, 0, 0, 0) -- (8, 6, 18, 9): (-1, -2)
        # (4, 0, 0, 9) -- (8, 6, 18, 18): (-1, -1)
        ilayout = self._make_data_layout(layer.nifm, layer.hifm, layer.wifm,
                                         PhyDim2(-2, -2), (2, 1), (1, 1),
                                         PhyDim2(2, 2))

        # (0, 0, 0, 0) -- (8, 4, 16, 8): (2, 2)
        # (0, 0, 0, 8) -- (8, 4, 16, 16): (2, 3)
        # (0, 4, 0, 0) -- (8, 8, 16, 8): (3, 2)
        # (0, 4, 0, 8) -- (8, 8, 16, 16): (3, 3)
        olayout = self._make_data_layout(layer.nofm, layer.hofm, layer.wofm,
                                         PhyDim2(2, 2), (1, 1), (2, 1),
                                         PhyDim2(2, 2))

        filter_nodes = frozenset([PhyDim2(0, 0)])

        # filter: (0, 0) -> all, 6 * 4 * 3 * 3

        # ifmap: (-2, -2) -> (0, 0), (2, 0): 4 * 6 * 10 * 9
        #                 -> (0, 1), (2, 1): 4 * 6 * 10 * (9 - 8)
        #                 -> (1, 0), (3, 0): 4 * 6 * (18 - 8) * 9
        #                 -> (1, 1), (3, 1): 4 * 6 * (18 - 8) * (9 - 8)
        #        (-2, -1) -> (0, 0), (2, 0): 4 * 6 * 10 * (10 - 9)
        #                 -> (0, 1), (2, 1): 4 * 6 * 10 * (18 - 9)
        #                 -> (1, 0), (3, 0): 4 * 6 * (18 - 8) * (10 - 9)
        #                 -> (1, 1), (3, 1): 4 * 6 * (18 - 8) * (18 - 9)
        #        (-1, -2) -> (0, 2), (2, 2): (8 - 4) * 6 * 10 * 9
        #                 -> (0, 3), (2, 3): (8 - 4) * 6 * 10 * (9 - 8)
        #                 -> (1, 2), (3, 2): (8 - 4) * 6 * (18 - 8) * 9
        #                 -> (1, 3), (3, 3): (8 - 4) * 6 * (18 - 8) * (9 - 8)
        #        (-1, -1) -> (0, 2), (2, 2): (8 - 4) * 6 * 10 * (10 - 9)
        #                 -> (0, 3), (2, 3): (8 - 4) * 6 * 10 * (18 - 9)
        #                 -> (1, 2), (3, 2): (8 - 4) * 6 * (18 - 8) * (10 - 9)
        #                 -> (1, 3), (3, 3): (8 - 4) * 6 * (18 - 8) * (18 - 9)

        # ofmap: (2, 2) -> (0, 0):
        #               -> (0, 2):
        #               -> (1, 0):
        #               -> (1, 2): 4 * 4 * 8 * 8
        #        (2, 3) -> (0/1, 1/3)
        #        (3, 2) -> (2/3, 0/2)
        #        (3, 3) -> (2/3, 1/3)

        nhops = partition.unit_nhops_to_proc_region(layer, self.batch_size, nr,
                                                    part, filter_nodes,
                                                    ilayout, olayout,
                                                    self.options['BASE'])

        self.assertEqual(
            nhops[de.FIL],
            6 * 4 * 3 * 3 * sum(h + w for h in range(4) for w in range(4)))
        self.assertEqual(
            nhops[de.IFM], 4 * 6 * 10 *
            ((4 + 6) * 9 + (5 + 7) * 1 + (5 + 7) * 9 + (6 + 8) * 1 +
             (3 + 5) * 1 + (4 + 6) * 9 + (4 + 6) * 1 + (5 + 7) * 9 +
             (5 + 7) * 9 + (6 + 8) * 1 + (6 + 8) * 9 + (7 + 9) * 1 +
             (4 + 6) * 1 + (5 + 7) * 9 + (5 + 7) * 1 + (6 + 8) * 9))
        self.assertEqual(
            nhops[de.OFM], 4 * 4 * 8 * 8 * ((4 + 2 + 3 + 1) + (4 + 2 + 3 + 1) +
                                            (3 + 1 + 2 + 0) + (3 + 1 + 2 + 0)))