def test_count_strides(self): e6 = {"expansion": 6} arch_def = { "blocks1": [ # [op, c, s, n, ...] # stage 0 [("conv_k3", 32, 2, 1)], # stage 1 [("ir_k3", 64, 2, 2, e6), ("ir_k5", 96, 1, 1, e6)], ], "blocks2": [ # [op, c, s, n, ...] # stage 0 [("conv_k3", 32, 2, 1)], # stage 1 [("ir_k3", 64, -2, 2, e6), ("ir_k5", 96, -2, 1, e6)], ], } unified_arch = mbuilder.unify_arch_def(arch_def, ["blocks1", "blocks2"]) gt_strides_blocks1 = [2, 2, 1, 1] gt_strides_blocks2 = [2, 0.5, 1, 0.5] count_strides1 = mbuilder.count_stride_each_block( unified_arch["blocks1"]) count_strides2 = mbuilder.count_stride_each_block( unified_arch["blocks2"]) self.assertEqual(gt_strides_blocks1, count_strides1) self.assertEqual(gt_strides_blocks2, count_strides2) all_strides1 = mbuilder.count_strides(unified_arch["blocks1"]) all_strides2 = mbuilder.count_strides(unified_arch["blocks2"]) self.assertEqual(all_strides1, 4) self.assertEqual(all_strides2, 0.5)
def _get_stride_per_stage(blocks): """ Count the accummulated stride per stage given a list of blocks. The mbuilder provides API for counting per-block accumulated stride, this function leverages it to count per-stage accumulated stride. Input: a list of blocks from the unified arch_def. Note that the stage_idx must be contiguous (not necessarily starting from 0), and can be non-ascending (not tested). Output: a list of accumulated stride per stage, starting from lowest stage_idx. """ stride_per_block = mbuilder.count_stride_each_block(blocks) assert len(stride_per_block) == len(blocks) stage_idx_set = {s["stage_idx"] for s in blocks} # assume stage idx are contiguous, eg. 1, 2, 3, ... assert max(stage_idx_set) - min(stage_idx_set) + 1 == len(stage_idx_set) start_stage_id = min(stage_idx_set) ids_per_stage = [ [i for i, s in enumerate(blocks) if s["stage_idx"] == stage_idx] for stage_idx in range(start_stage_id, start_stage_id + len(stage_idx_set)) ] # eg. [[0], [1, 2], [3, 4, 5, 6], ...] block_stride_per_stage = [ [stride_per_block[i] for i in ids] for ids in ids_per_stage ] # eg. [[1], [2, 1], [2, 1, 1, 1], ...] stride_per_stage = [ list(itertools.accumulate(s, lambda x, y: x * y))[-1] for s in block_stride_per_stage ] # eg. [1, 2, 2, ...] accum_stride_per_stage = list( itertools.accumulate(stride_per_stage, lambda x, y: x * y) ) # eg. [first*1, first*2, first*4, ...] assert accum_stride_per_stage[-1] == mbuilder.count_strides(blocks) return accum_stride_per_stage