def test_count_strides(self):
        e6 = {"expansion": 6}
        arch_def = {
            "blocks1": [
                # [op, c, s, n, ...]
                # stage 0
                [("conv_k3", 32, 2, 1)],
                # stage 1
                [("ir_k3", 64, 2, 2, e6), ("ir_k5", 96, 1, 1, e6)],
            ],
            "blocks2": [
                # [op, c, s, n, ...]
                # stage 0
                [("conv_k3", 32, 2, 1)],
                # stage 1
                [("ir_k3", 64, -2, 2, e6), ("ir_k5", 96, -2, 1, e6)],
            ],
        }

        unified_arch = mbuilder.unify_arch_def(arch_def,
                                               ["blocks1", "blocks2"])

        gt_strides_blocks1 = [2, 2, 1, 1]
        gt_strides_blocks2 = [2, 0.5, 1, 0.5]
        count_strides1 = mbuilder.count_stride_each_block(
            unified_arch["blocks1"])
        count_strides2 = mbuilder.count_stride_each_block(
            unified_arch["blocks2"])
        self.assertEqual(gt_strides_blocks1, count_strides1)
        self.assertEqual(gt_strides_blocks2, count_strides2)

        all_strides1 = mbuilder.count_strides(unified_arch["blocks1"])
        all_strides2 = mbuilder.count_strides(unified_arch["blocks2"])
        self.assertEqual(all_strides1, 4)
        self.assertEqual(all_strides2, 0.5)
Exemple #2
0
def _get_stride_per_stage(blocks):
    """
    Count the accummulated stride per stage given a list of blocks. The mbuilder
    provides API for counting per-block accumulated stride, this function leverages
    it to count per-stage accumulated stride.

    Input: a list of blocks from the unified arch_def. Note that the stage_idx
        must be contiguous (not necessarily starting from 0), and can be
        non-ascending (not tested).
    Output: a list of accumulated stride per stage, starting from lowest stage_idx.
    """
    stride_per_block = mbuilder.count_stride_each_block(blocks)

    assert len(stride_per_block) == len(blocks)
    stage_idx_set = {s["stage_idx"] for s in blocks}
    # assume stage idx are contiguous, eg. 1, 2, 3, ...
    assert max(stage_idx_set) - min(stage_idx_set) + 1 == len(stage_idx_set)
    start_stage_id = min(stage_idx_set)
    ids_per_stage = [
        [i for i, s in enumerate(blocks) if s["stage_idx"] == stage_idx]
        for stage_idx in range(start_stage_id, start_stage_id + len(stage_idx_set))
    ]  # eg. [[0], [1, 2], [3, 4, 5, 6], ...]
    block_stride_per_stage = [
        [stride_per_block[i] for i in ids] for ids in ids_per_stage
    ]  # eg. [[1], [2, 1], [2, 1, 1, 1], ...]
    stride_per_stage = [
        list(itertools.accumulate(s, lambda x, y: x * y))[-1]
        for s in block_stride_per_stage
    ]  # eg. [1, 2, 2, ...]
    accum_stride_per_stage = list(
        itertools.accumulate(stride_per_stage, lambda x, y: x * y)
    )  # eg. [first*1, first*2, first*4, ...]

    assert accum_stride_per_stage[-1] == mbuilder.count_strides(blocks)
    return accum_stride_per_stage