コード例 #1
0
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn,
                 fp32_reduce_scatter):
    model = Model()
    if with_fsdp:
        if wrap_bn:
            model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
            model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=True)
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            model.head = wrap(model.head)
    else:
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=False)
    return model
コード例 #2
0
def _create_model(
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    model = Model2() if with_model2 else Model()
    fsdp_config = None
    if with_sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        fsdp_config = {
            "mixed_precision": False,
            "flatten_parameters": False,
            "reshard_after_forward": False,
            "bucket_cap_mb": 0,
            "force_input_to_fp32": True,  # SyncBN needs this.
        }

    if with_fsdp and wrap_bn:
        model.block1 = auto_wrap_bn(model.block1,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        model.block2 = auto_wrap_bn(model.block2,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        if with_model2:
            model.block3 = auto_wrap_bn(model.block3,
                                        single_rank_pg=False,
                                        fsdp_config=fsdp_config)

    if with_checkpoint:
        model.block2 = checkpoint_wrapper(model.block2)
        if with_model2:
            model.block3 = checkpoint_wrapper(model.block3)

    if with_fsdp:
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
                bucket_cap_mb=bucket_cap_mb,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            if with_model2:
                model.block3 = wrap(model.block3)
            model.head = wrap(model.head)

    return model
コード例 #3
0
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config):
    model = Model(model_hidden_dim)
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
        model.stem = to_fsdp(model.stem, fsdp_config)
        model.blocks = to_fsdp(model.blocks, fsdp_config)
        model.head = to_fsdp(model.head, fsdp_config)
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model
コード例 #4
0
def create_model(with_fsdp, with_checkpoint):
    model = Model()
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
        model.stem = to_fsdp(model.stem)
        model.blocks = to_fsdp(model.blocks)
        model.head = to_fsdp(model.head)
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model
コード例 #5
0
ファイル: vision_transformer.py プロジェクト: zlapp/vissl
 def _build_block(self, dpr: float, norm_layer) -> nn.Module:
     block = Block(
         dim=self.embed_dim,
         num_heads=self.num_heads,
         mlp_ratio=self.mlp_ratio,
         qkv_bias=self.qkv_bias,
         qk_scale=self.qk_scale,
         drop=self.drop_rate,
         attn_drop=self.attn_drop_rate,
         drop_path=dpr,
         norm_layer=norm_layer,
     )
     if self.trunk_config.CHECKPOINT_MLP:
         block.mlp = checkpoint_wrapper(block.mlp)
     if self.trunk_config.CHECKPOINT_BLOCK:
         block = checkpoint_wrapper(block)
     return block
コード例 #6
0
ファイル: regnet_fsdp.py プロジェクト: captainfffsama/vissl
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(
         width_in, width_out, stride, params, bottleneck_multiplier, group_width
     )
     block = auto_wrap_bn(block, single_rank_pg=False)
     if self.use_activation_checkpointing:
         # TODO - make this configurable
         block = checkpoint_wrapper(block, offload_to_cpu=False)
     with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
         block = wrap(block)
     return block
コード例 #7
0
ファイル: regnet_fsdp.py プロジェクト: QuentinDuval/vissl
    def create_any_stage(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        group_width: int,
        bottleneck_multiplier: float,
        params: Union[RegNetParams, AnyNetParams],
        stage_index: int = 0,
        checkpoints: List[int] = 0,
    ):
        assert sorted(
            checkpoints) == checkpoints, "Checkpoint indices should be sorted"

        with_checkpointing = len(checkpoints) > 0
        block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints

        any_stage = AnyStage()
        prev_depth = 0
        for block_group_index, next_depth in enumerate(block_delimiters):
            block_group = nn.Sequential()
            for i in range(prev_depth, next_depth):
                block = self.create_block(
                    width_in=width_in if i == 0 else width_out,
                    width_out=width_out,
                    stride=stride if i == 0 else 1,
                    params=params,
                    group_width=group_width,
                    bottleneck_multiplier=bottleneck_multiplier,
                )
                any_stage.stage_depth += block.depth
                block_group.add_module(f"block{stage_index}-{i}", block)
            prev_depth = next_depth
            if with_checkpointing:
                block_group = checkpoint_wrapper(block_group)
            block_group = fsdp_wrapper(block_group, **self.fsdp_config)
            any_stage.add_module(f"block{stage_index}-part{block_group_index}",
                                 block_group)
        return any_stage
コード例 #8
0
ファイル: regnet_fsdp.py プロジェクト: iseessel/vissl
    def create_any_stage(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        group_width: int,
        bottleneck_multiplier: float,
        params: Union[RegNetParams, AnyNetParams],
        group_delimiters: List[int],
        group_checkpoint: List[bool],
        stage_index: int = 0,
    ):
        assert len(group_delimiters) == len(group_checkpoint)

        any_stage = AnyStage()
        prev_depth = 0
        for group_index, next_depth in enumerate(group_delimiters):
            block_group = nn.Sequential()
            for i in range(prev_depth, next_depth):
                block = self.create_block(
                    width_in=width_in if i == 0 else width_out,
                    width_out=width_out,
                    stride=stride if i == 0 else 1,
                    params=params,
                    group_width=group_width,
                    bottleneck_multiplier=bottleneck_multiplier,
                )
                any_stage.stage_depth += block.depth
                block_group.add_module(f"block{stage_index}-{i}", block)
            prev_depth = next_depth
            if group_checkpoint[group_index]:
                block_group = checkpoint_wrapper(block_group)
            block_group = fsdp_wrapper(block_group, **self.fsdp_config)
            any_stage.add_module(f"block{stage_index}-part{group_index}",
                                 block_group)
        return any_stage
コード例 #9
0
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
    assert model_config.INPUT_TYPE in ["rgb",
                                       "bgr"], "Input type not supported"
    trunk_config = model_config.TRUNK.REGNET
    if "name" in trunk_config:
        assert (trunk_config["name"] == "anynet"
                ), "Please use AnyNetParams or specify RegNetParams dictionary"

    if "name" in trunk_config and trunk_config["name"] == "anynet":
        params = AnyNetParams(
            depths=trunk_config["depths"],
            widths=trunk_config["widths"],
            group_widths=trunk_config["group_widths"],
            bottleneck_multipliers=trunk_config["bottleneck_multipliers"],
            strides=trunk_config["strides"],
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )
    else:
        params = RegNetParams(
            depth=trunk_config["depth"],
            w_0=trunk_config["w_0"],
            w_a=trunk_config["w_a"],
            w_m=trunk_config["w_m"],
            group_width=trunk_config["group_width"],
            bottleneck_multiplier=trunk_config.get("bottleneck_multiplier",
                                                   1.0),
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )

    # Ad hoc stem
    #
    # Important: do NOT retain modules in self.stem or self.trunk_output. It may
    # seem to be harmless, but it appears that autograd will result in computing
    # grads in different order. Different ordering can cause deterministic OOM,
    # even when the peak memory otherwise is only 24GB out of 32GB.
    #
    # When debugging this, it is not enough to just dump the total module
    # params. You need to diff the module string representations.
    stem = factory.create_stem(params)

    # Instantiate all the AnyNet blocks in the trunk
    current_width, trunk_depth, blocks = params.stem_width, 0, []
    for i, (width_out, stride, depth, group_width,
            bottleneck_multiplier) in enumerate(params.get_expanded_params()):
        new_stage = AnyStage(
            factory=factory,
            width_in=current_width,
            width_out=width_out,
            stride=stride,
            depth=depth,
            group_width=group_width,
            bottleneck_multiplier=bottleneck_multiplier,
            params=params,
            stage_index=i + 1,
        )

        if isinstance(factory, RegnetFSDPBlocksFactory):
            if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING:
                logging.info("Using activation checkpointing")
                new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False)
            new_stage = fsdp_wrapper(module=new_stage,
                                     **model_config.FSDP_CONFIG)

        blocks.append((f"block{i + 1}", new_stage))
        trunk_depth += blocks[-1][1].stage_depth
        current_width = width_out

    trunk_output = nn.Sequential(OrderedDict(blocks))

    ################################################################################

    # Now map the models to the structure we want to expose for SSL tasks
    # The upstream RegNet model is made of :
    # - `stem`
    # - n x blocks in trunk_output, named `block1, block2, ..`
    # We're only interested in the stem and successive blocks
    # everything else is not picked up on purpose
    feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)]
    for k, v in trunk_output.named_children():
        assert k.startswith("block"), f"Unexpected layer name {k}"
        block_index = len(feature_blocks) + 1
        feature_blocks.append((f"res{block_index}", v))
    feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
    feature_blocks.append(("flatten", Flatten(1)))
    return nn.ModuleDict(feature_blocks), trunk_depth