Пример #1
0
def make_trident_stage(block_class, num_blocks, **kwargs):
    """
    Create a resnet stage by creating many blocks for TridentNet.
    """
    concat_output = [False] * (num_blocks - 1) + [True]
    kwargs["concat_output_per_block"] = concat_output
    return ResNet.make_stage(block_class, num_blocks, **kwargs)
Пример #2
0
def build_bam_resnet_backbone(cfg, input_shape):
    r"""Create a ResNet with BAM instance from config

    Returns:
        ResNet+CBAM: a :class:`ResNet` instance
    """
    # need registeration of new blocks/stems?
    # need registration of new blocks/stems?
    norm = cfg.MODEL.RESNETS.NORM
    stem = BasicStem(
        in_channels=input_shape.channels,
        out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
        norm=norm,
    )

    # fmt: off
    freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
    out_features = cfg.MODEL.RESNETS.OUT_FEATURES
    depth = cfg.MODEL.RESNETS.DEPTH
    num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
    width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
    bottleneck_channels = num_groups * width_per_group
    in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
    res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
    # fmt: on
    assert res5_dilation in {
        1, 2
    }, "res5_dilation cannot be {}.".format(res5_dilation)

    num_blocks_per_stage = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[depth]

    if depth in [18, 34]:
        assert (out_channels == 64
                ), "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
        assert (res5_dilation == 1
                ), "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
        assert (num_groups == 1
                ), "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"

    stages = []

    # Avoid creating variables without gradients
    # It consumes extra memory and may cause allreduce to fail
    out_stage_idx = [{
        "res2": 2,
        "res3": 3,
        "res4": 4,
        "res5": 5
    }[f] for f in out_features]
    max_stage_idx = max(out_stage_idx)
    for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
        dilation = res5_dilation if stage_idx == 5 else 1
        first_stride = (1 if idx == 0 or
                        (stage_idx == 5 and dilation == 2) else 2)
        stage_kargs = {
            "num_blocks":
            num_blocks_per_stage[idx],
            "stride_per_block":
            [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
            "in_channels":
            in_channels,
            "out_channels":
            out_channels,
            "norm":
            norm,
        }
        # Use BasicBlock for R18 and R34.
        if depth in [18, 34]:
            stage_kargs["block_class"] = BasicBlock
        else:
            stage_kargs["bottleneck_channels"] = bottleneck_channels
            stage_kargs["stride_in_1x1"] = stride_in_1x1
            stage_kargs["dilation"] = dilation
            stage_kargs["num_groups"] = num_groups
            stage_kargs["block_class"] = BottleneckBlock
        blocks = ResNet.make_stage(**stage_kargs)
        if stage_idx in [2, 3, 4]:
            bam_kargs = {
                "block_class": BAMBlock,
                "stride_per_block": [1],
                "in_channels": in_channels,
                "out_channels": out_channels,
                "num_blocks": 1,
            }
            bam_block = ResNet.make_stage(**bam_kargs)
        in_channels = out_channels
        out_channels *= 2
        bottleneck_channels *= 2
        stages.append(blocks)
        if stage_idx in [2, 3, 4]:
            stages.append(bam_block)

    return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)
Пример #3
0
def build_trident_resnet_backbone(cfg, input_shape):
    """
    Create a ResNet instance from config for TridentNet.

    Returns:
        ResNet: a :class:`ResNet` instance.
    """
    # need registration of new blocks/stems?
    norm = cfg.MODEL.RESNETS.NORM
    stem = BasicStem(
        in_channels=input_shape.channels,
        out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
        norm=norm,
    )
    freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT

    if freeze_at >= 1:
        for p in stem.parameters():
            p.requires_grad = False
        stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem)

    # fmt: off
    out_features = cfg.MODEL.RESNETS.OUT_FEATURES
    depth = cfg.MODEL.RESNETS.DEPTH
    num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
    width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
    bottleneck_channels = num_groups * width_per_group
    in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
    res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
    deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
    deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
    deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
    num_branch = cfg.MODEL.TRIDENT.NUM_BRANCH
    branch_dilations = cfg.MODEL.TRIDENT.BRANCH_DILATIONS
    trident_stage = cfg.MODEL.TRIDENT.TRIDENT_STAGE
    test_branch_idx = cfg.MODEL.TRIDENT.TEST_BRANCH_IDX
    # fmt: on
    assert res5_dilation in {
        1, 2
    }, "res5_dilation cannot be {}.".format(res5_dilation)

    num_blocks_per_stage = {
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3]
    }[depth]

    stages = []

    res_stage_idx = {"res2": 2, "res3": 3, "res4": 4, "res5": 5}
    out_stage_idx = [res_stage_idx[f] for f in out_features]
    trident_stage_idx = res_stage_idx[trident_stage]
    max_stage_idx = max(out_stage_idx)
    for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
        dilation = res5_dilation if stage_idx == 5 else 1
        first_stride = 1 if idx == 0 or (stage_idx == 5
                                         and dilation == 2) else 2
        stage_kargs = {
            "num_blocks":
            num_blocks_per_stage[idx],
            "stride_per_block":
            [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
            "in_channels":
            in_channels,
            "bottleneck_channels":
            bottleneck_channels,
            "out_channels":
            out_channels,
            "num_groups":
            num_groups,
            "norm":
            norm,
            "stride_in_1x1":
            stride_in_1x1,
            "dilation":
            dilation,
        }
        if stage_idx == trident_stage_idx:
            assert not deform_on_per_stage[
                idx], "Not support deformable conv in Trident blocks yet."
            stage_kargs["block_class"] = TridentBottleneckBlock
            stage_kargs["num_branch"] = num_branch
            stage_kargs["dilations"] = branch_dilations
            stage_kargs["test_branch_idx"] = test_branch_idx
            stage_kargs.pop("dilation")
        elif deform_on_per_stage[idx]:
            stage_kargs["block_class"] = DeformBottleneckBlock
            stage_kargs["deform_modulated"] = deform_modulated
            stage_kargs["deform_num_groups"] = deform_num_groups
        else:
            stage_kargs["block_class"] = BottleneckBlock
        blocks = (make_trident_stage(**stage_kargs) if stage_idx
                  == trident_stage_idx else ResNet.make_stage(**stage_kargs))
        in_channels = out_channels
        out_channels *= 2
        bottleneck_channels *= 2

        if freeze_at >= stage_idx:
            for block in blocks:
                block.freeze()
        stages.append(blocks)
    return ResNet(stem, stages, out_features=out_features)
Пример #4
0
def build_conl_resnet_backbone(cfg, input_shape):
    """
    Create a ResNet instance from config.

    Returns:
        ResNet: a :class:`ResNet` instance.
    """
    # need registration of new blocks/stems?
    norm = cfg.MODEL.RESNETS.NORM
    stem = BasicStem(
        in_channels=input_shape.channels,
        out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
        norm=norm,
    )

    # fmt: off
    freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
    out_features = cfg.MODEL.RESNETS.OUT_FEATURES
    depth = cfg.MODEL.RESNETS.DEPTH
    num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
    width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
    bottleneck_channels = num_groups * width_per_group
    in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
    res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
    deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
    deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
    deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
    # fmt: on

    conl_stages = cfg.MODEL.CONL.STAGES
    conl_blocks = cfg.MODEL.CONL.BLOCKS

    plugin_config = {
        'ratio': cfg.MODEL.CONL.RATIO,
        'downsample': cfg.MODEL.CONL.DOWNSAMPLE,
        'use_gn': cfg.MODEL.CONL.USE_GN,
        'lr_mult':
        cfg.MODEL.CONL.LR_MULT if cfg.MODEL.CONL.LR_MULT > 0 else None,
        'use_out': cfg.MODEL.CONL.USE_OUT,
        'out_bn': cfg.MODEL.CONL.OUT_BN,
        'whiten_type': cfg.MODEL.CONL.WHITEN_TYPE,
        'temp': cfg.MODEL.CONL.TEMP,
        'with_gc': cfg.MODEL.CONL.WITH_GC,
        'with_2fc': cfg.MODEL.CONL.WITH_2FC,
        'double_conv': cfg.MODEL.CONL.DOUBLE_CONV,
        'with_state': cfg.MODEL.CONL.WITH_STATE,
        'Ncls': cfg.MODEL.CONL.NCLS,
    }
    assert res5_dilation in {
        1, 2
    }, "res5_dilation cannot be {}.".format(res5_dilation)

    num_blocks_per_stage = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[depth]

    if depth in [18, 34]:
        assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
        assert not any(
            deform_on_per_stage
        ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
        assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
        assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"

    stages = []

    # Avoid creating variables without gradients
    # It consumes extra memory and may cause allreduce to fail
    out_stage_idx = [{
        "res2": 2,
        "res3": 3,
        "res4": 4,
        "res5": 5
    }[f] for f in out_features if f != "stem"]
    plugin_stage_idx = [{
        "res2": 2,
        "res3": 3,
        "res4": 4,
        "res5": 5
    }[f] for f in conl_stages if f != "stem"]
    max_stage_idx = max(out_stage_idx)
    for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)):
        dilation = res5_dilation if stage_idx == 5 else 1
        first_stride = 1 if idx == 0 or (stage_idx == 5
                                         and dilation == 2) else 2
        num_blocks = num_blocks_per_stage[idx]

        stage_kargs = {
            "num_blocks":
            num_blocks,
            "stride_per_block":
            [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
            "in_channels":
            in_channels,
            "out_channels":
            out_channels,
            "norm":
            norm,
        }
        if stage_idx in plugin_stage_idx:
            plugin_idxs = conl_blocks.pop(0)
            plugin_per_block = []
            for i in range(num_blocks):
                plugin = None
                if i in plugin_idxs or (i - num_blocks) in plugin_idxs:
                    plugin = partial(NonLocal2dGc, **plugin_config)
                plugin_per_block.append(plugin)
            stage_kargs.update({'plugin_per_block': plugin_per_block})

        # Use BasicBlock for R18 and R34.
        if depth in [18, 34]:
            stage_kargs["block_class"] = BasicBlock
        else:
            stage_kargs["bottleneck_channels"] = bottleneck_channels
            stage_kargs["stride_in_1x1"] = stride_in_1x1
            stage_kargs["dilation"] = dilation
            stage_kargs["num_groups"] = num_groups
            if deform_on_per_stage[idx]:
                stage_kargs["block_class"] = DeformBottleneckBlock
                stage_kargs["deform_modulated"] = deform_modulated
                stage_kargs["deform_num_groups"] = deform_num_groups
            else:
                stage_kargs["block_class"] = BottleneckBlock
        blocks = ResNet.make_stage(**stage_kargs)
        in_channels = out_channels
        out_channels *= 2
        bottleneck_channels *= 2
        stages.append(blocks)
    return ResNet(stem, stages, out_features=out_features).freeze(freeze_at)