def make_trident_stage(block_class, num_blocks, **kwargs): """ Create a resnet stage by creating many blocks for TridentNet. """ concat_output = [False] * (num_blocks - 1) + [True] kwargs["concat_output_per_block"] = concat_output return ResNet.make_stage(block_class, num_blocks, **kwargs)
def build_bam_resnet_backbone(cfg, input_shape): r"""Create a ResNet with BAM instance from config Returns: ResNet+CBAM: a :class:`ResNet` instance """ # need registeration of new blocks/stems? # need registration of new blocks/stems? norm = cfg.MODEL.RESNETS.NORM stem = BasicStem( in_channels=input_shape.channels, out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, norm=norm, ) # fmt: off freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT out_features = cfg.MODEL.RESNETS.OUT_FEATURES depth = cfg.MODEL.RESNETS.DEPTH num_groups = cfg.MODEL.RESNETS.NUM_GROUPS width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP bottleneck_channels = num_groups * width_per_group in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION # fmt: on assert res5_dilation in { 1, 2 }, "res5_dilation cannot be {}.".format(res5_dilation) num_blocks_per_stage = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth] if depth in [18, 34]: assert (out_channels == 64 ), "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" assert (res5_dilation == 1 ), "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" assert (num_groups == 1 ), "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" stages = [] # Avoid creating variables without gradients # It consumes extra memory and may cause allreduce to fail out_stage_idx = [{ "res2": 2, "res3": 3, "res4": 4, "res5": 5 }[f] for f in out_features] max_stage_idx = max(out_stage_idx) for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): dilation = res5_dilation if stage_idx == 5 else 1 first_stride = (1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2) stage_kargs = { "num_blocks": num_blocks_per_stage[idx], "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), "in_channels": in_channels, "out_channels": out_channels, "norm": norm, } # Use BasicBlock for R18 and R34. if depth in [18, 34]: stage_kargs["block_class"] = BasicBlock else: stage_kargs["bottleneck_channels"] = bottleneck_channels stage_kargs["stride_in_1x1"] = stride_in_1x1 stage_kargs["dilation"] = dilation stage_kargs["num_groups"] = num_groups stage_kargs["block_class"] = BottleneckBlock blocks = ResNet.make_stage(**stage_kargs) if stage_idx in [2, 3, 4]: bam_kargs = { "block_class": BAMBlock, "stride_per_block": [1], "in_channels": in_channels, "out_channels": out_channels, "num_blocks": 1, } bam_block = ResNet.make_stage(**bam_kargs) in_channels = out_channels out_channels *= 2 bottleneck_channels *= 2 stages.append(blocks) if stage_idx in [2, 3, 4]: stages.append(bam_block) return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)
def build_trident_resnet_backbone(cfg, input_shape): """ Create a ResNet instance from config for TridentNet. Returns: ResNet: a :class:`ResNet` instance. """ # need registration of new blocks/stems? norm = cfg.MODEL.RESNETS.NORM stem = BasicStem( in_channels=input_shape.channels, out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, norm=norm, ) freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT if freeze_at >= 1: for p in stem.parameters(): p.requires_grad = False stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) # fmt: off out_features = cfg.MODEL.RESNETS.OUT_FEATURES depth = cfg.MODEL.RESNETS.DEPTH num_groups = cfg.MODEL.RESNETS.NUM_GROUPS width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP bottleneck_channels = num_groups * width_per_group in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX # fmt: on assert res5_dilation in { 1, 2 }, "res5_dilation cannot be {}.".format(res5_dilation) num_blocks_per_stage = { 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3] }[depth] stages = [] res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5} out_stage_idx = [res_stage_idx[f] for f in out_features] trident_stage_idx = res_stage_idx[trident_stage] max_stage_idx = max(out_stage_idx) for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): dilation = res5_dilation if stage_idx == 5 else 1 first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 stage_kargs = { "num_blocks": num_blocks_per_stage[idx], "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), "in_channels": in_channels, "bottleneck_channels": bottleneck_channels, "out_channels": out_channels, "num_groups": num_groups, "norm": norm, "stride_in_1x1": stride_in_1x1, "dilation": dilation, } if stage_idx == trident_stage_idx: assert not deform_on_per_stage[ idx], "Not support deformable conv in Trident blocks yet." stage_kargs["block_class"] = TridentBottleneckBlock stage_kargs["num_branch"] = num_branch stage_kargs["dilations"] = branch_dilations stage_kargs["test_branch_idx"] = test_branch_idx stage_kargs.pop("dilation") elif deform_on_per_stage[idx]: stage_kargs["block_class"] = DeformBottleneckBlock stage_kargs["deform_modulated"] = deform_modulated stage_kargs["deform_num_groups"] = deform_num_groups else: stage_kargs["block_class"] = BottleneckBlock blocks = (make_trident_stage(**stage_kargs) if stage_idx == trident_stage_idx else ResNet.make_stage(**stage_kargs)) in_channels = out_channels out_channels *= 2 bottleneck_channels *= 2 if freeze_at >= stage_idx: for block in blocks: block.freeze() stages.append(blocks) return ResNet(stem, stages, out_features=out_features)
def build_conl_resnet_backbone(cfg, input_shape): """ Create a ResNet instance from config. Returns: ResNet: a :class:`ResNet` instance. """ # need registration of new blocks/stems? norm = cfg.MODEL.RESNETS.NORM stem = BasicStem( in_channels=input_shape.channels, out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, norm=norm, ) # fmt: off freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT out_features = cfg.MODEL.RESNETS.OUT_FEATURES depth = cfg.MODEL.RESNETS.DEPTH num_groups = cfg.MODEL.RESNETS.NUM_GROUPS width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP bottleneck_channels = num_groups * width_per_group in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS # fmt: on conl_stages = cfg.MODEL.CONL.STAGES conl_blocks = cfg.MODEL.CONL.BLOCKS plugin_config = { 'ratio': cfg.MODEL.CONL.RATIO, 'downsample': cfg.MODEL.CONL.DOWNSAMPLE, 'use_gn': cfg.MODEL.CONL.USE_GN, 'lr_mult': cfg.MODEL.CONL.LR_MULT if cfg.MODEL.CONL.LR_MULT > 0 else None, 'use_out': cfg.MODEL.CONL.USE_OUT, 'out_bn': cfg.MODEL.CONL.OUT_BN, 'whiten_type': cfg.MODEL.CONL.WHITEN_TYPE, 'temp': cfg.MODEL.CONL.TEMP, 'with_gc': cfg.MODEL.CONL.WITH_GC, 'with_2fc': cfg.MODEL.CONL.WITH_2FC, 'double_conv': cfg.MODEL.CONL.DOUBLE_CONV, 'with_state': cfg.MODEL.CONL.WITH_STATE, 'Ncls': cfg.MODEL.CONL.NCLS, } assert res5_dilation in { 1, 2 }, "res5_dilation cannot be {}.".format(res5_dilation) num_blocks_per_stage = { 18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth] if depth in [18, 34]: assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" assert not any( deform_on_per_stage ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" stages = [] # Avoid creating variables without gradients # It consumes extra memory and may cause allreduce to fail out_stage_idx = [{ "res2": 2, "res3": 3, "res4": 4, "res5": 5 }[f] for f in out_features if f != "stem"] plugin_stage_idx = [{ "res2": 2, "res3": 3, "res4": 4, "res5": 5 }[f] for f in conl_stages if f != "stem"] max_stage_idx = max(out_stage_idx) for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): dilation = res5_dilation if stage_idx == 5 else 1 first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 num_blocks = num_blocks_per_stage[idx] stage_kargs = { "num_blocks": num_blocks, "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), "in_channels": in_channels, "out_channels": out_channels, "norm": norm, } if stage_idx in plugin_stage_idx: plugin_idxs = conl_blocks.pop(0) plugin_per_block = [] for i in range(num_blocks): plugin = None if i in plugin_idxs or (i - num_blocks) in plugin_idxs: plugin = partial(NonLocal2dGc, **plugin_config) plugin_per_block.append(plugin) stage_kargs.update({'plugin_per_block': plugin_per_block}) # Use BasicBlock for R18 and R34. if depth in [18, 34]: stage_kargs["block_class"] = BasicBlock else: stage_kargs["bottleneck_channels"] = bottleneck_channels stage_kargs["stride_in_1x1"] = stride_in_1x1 stage_kargs["dilation"] = dilation stage_kargs["num_groups"] = num_groups if deform_on_per_stage[idx]: stage_kargs["block_class"] = DeformBottleneckBlock stage_kargs["deform_modulated"] = deform_modulated stage_kargs["deform_num_groups"] = deform_num_groups else: stage_kargs["block_class"] = BottleneckBlock blocks = ResNet.make_stage(**stage_kargs) in_channels = out_channels out_channels *= 2 bottleneck_channels *= 2 stages.append(blocks) return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)