예제 #1
0
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn,
                 fp32_reduce_scatter):
    model = Model()
    if with_fsdp:
        if wrap_bn:
            model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
            model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=True)
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            model.head = wrap(model.head)
    else:
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=False)
    return model
def _create_model(
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    model = Model2() if with_model2 else Model()
    fsdp_config = None
    if with_sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        fsdp_config = {
            "mixed_precision": False,
            "flatten_parameters": False,
            "reshard_after_forward": False,
            "bucket_cap_mb": 0,
            "force_input_to_fp32": True,  # SyncBN needs this.
        }

    if with_fsdp and wrap_bn:
        model.block1 = auto_wrap_bn(model.block1,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        model.block2 = auto_wrap_bn(model.block2,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        if with_model2:
            model.block3 = auto_wrap_bn(model.block3,
                                        single_rank_pg=False,
                                        fsdp_config=fsdp_config)

    if with_checkpoint:
        model.block2 = checkpoint_wrapper(model.block2)
        if with_model2:
            model.block3 = checkpoint_wrapper(model.block3)

    if with_fsdp:
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
                bucket_cap_mb=bucket_cap_mb,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            if with_model2:
                model.block3 = wrap(model.block3)
            model.head = wrap(model.head)

    return model
예제 #3
0
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config):
    model = Model(model_hidden_dim)
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
        model.stem = to_fsdp(model.stem, fsdp_config)
        model.blocks = to_fsdp(model.blocks, fsdp_config)
        model.head = to_fsdp(model.head, fsdp_config)
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model
예제 #4
0
def create_model(with_fsdp, with_checkpoint):
    model = Model()
    if with_fsdp:
        model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
        model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
        model.stem = to_fsdp(model.stem)
        model.blocks = to_fsdp(model.blocks)
        model.head = to_fsdp(model.head)
    else:
        if with_checkpoint:
            model.blocks = checkpoint_wrapper(model.blocks)
    return model
예제 #5
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            # TODO (Min): for now, we just test pytorch sync_bn here.
            #             this will grow into regnet; testing apex sync_bn, etc.
            self.conv = Conv2d(2, 2, (1, 1))
            self.bn = BatchNorm2d(2)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return x

    # TODO (Min): check DDP equivalency.

    model = Model()
    # Note, different rank may wrap in different order due to different random
    # seeds. But results should be the same.
    if random.randint(0, 1) == 0:
        print("auto_wrap_bn, then convert_sync_batchnorm")
        model = auto_wrap_bn(model)
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    else:
        print("convert_sync_batchnorm, then auto_wrap_bn")
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model = auto_wrap_bn(model)
    model = FSDP(model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(2, 2, 2, 2).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
예제 #6
0
 def __init__(
     self,
     model_config: AttrDict,
     in_channels: int,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
 ):
     super().__init__()
     mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
     mlp = auto_wrap_bn(mlp, single_rank_pg=False)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #7
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = auto_wrap_bn(block, single_rank_pg=False)
     block = fsdp_wrapper(module=block, **self.fsdp_config)
     return block
예제 #8
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: "RegNetParams",
     bot_mul: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bot_mul, group_width)
     block = auto_wrap_bn(block, single_rank_pg=False)
     with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
         block = wrap(block)
     return block
예제 #9
0
 def __init__(
     self,
     model_config: AttrDict,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
     use_dropout: bool = False,
     use_bias: bool = True,
     skip_last_layer_relu_bn: bool = True,
 ):
     super().__init__()
     mlp = MLP(model_config, dims, use_bn, use_relu, use_dropout, use_bias,
               skip_last_layer_relu_bn)
     mlp = auto_wrap_bn(mlp, single_rank_pg=False)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #10
0
    def __init__(
        self,
        model_config,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: nn.Module,
        activation: nn.Module,
        bot_mul: float,
        group_width: int,
        params: "RegNetParams",
        stage_index: int = 0,
    ):
        super().__init__()
        self.stage_depth = 0

        fsdp_config = {
            "wrapper_cls": fsdp_wrapper,
        }
        fsdp_config.update(model_config.FSDP_CONFIG)
        for i in range(depth):
            # Make a block and move it to cuda since shard-as-we-build of FSDP needs
            # cuda to do dist.all_gather() call.
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                params.bn_epsilon,
                params.bn_momentum,
                activation,
                bot_mul,
                group_width,
                params.se_ratio,
            ).cuda()
            # Init weight before wrapping and sharding.
            init_weights(block)

            # Now, wrap it with fsdp+checkpoint, which will perform the sharding.
            block = auto_wrap_bn(block)
            with enable_wrap(**fsdp_config):
                block = wrap(block)

            self.stage_depth += block.depth
            self.add_module(f"block{stage_index}-{i}", block)
예제 #11
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(
         width_in, width_out, stride, params, bottleneck_multiplier, group_width
     )
     block = auto_wrap_bn(block, single_rank_pg=False)
     if self.use_activation_checkpointing:
         # TODO - make this configurable
         block = checkpoint_wrapper(block, offload_to_cpu=False)
     with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
         block = wrap(block)
     return block
예제 #12
0
 def create_stem(self, params: Union[RegNetParams, AnyNetParams]):
     stem = super().create_stem(params)
     stem = auto_wrap_bn(stem, single_rank_pg=False)
     return stem
예제 #13
0
    def __init__(self, model_config: AttrDict, model_name: str):
        super().__init__()
        self.model_config = model_config

        assert model_config.INPUT_TYPE in ["rgb",
                                           "bgr"], "Input type not supported"
        trunk_config = model_config.TRUNK.TRUNK_PARAMS.REGNET

        assert "name" not in trunk_config, "Please specify the RegNet Params dictionary"

        ################################################################################

        params = RegNetParams(
            depth=trunk_config["depth"],
            w_0=trunk_config["w_0"],
            w_a=trunk_config["w_a"],
            w_m=trunk_config["w_m"],
            group_w=trunk_config["group_width"],
            stem_type=trunk_config.get("stem_type", "simple_stem_in").upper(),
            stem_width=trunk_config.get("stem_width", 32),
            block_type=trunk_config.get("block_type",
                                        "res_bottleneck_block").upper(),
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )

        # We need all workers (on all nodes) to have the same weights.
        # Unlike DDP, FSDP does not sync weights using rank 0 on start.
        # Therefore, we init stem and trunk_output below within the seed context.
        #
        # TODO (Min): we can make this seed coming from the config or env.
        stem = None
        trunk_output = None
        with set_torch_seed(0):
            # Ad hoc stem
            #
            # Important: do NOT retain modules in self.stem or self.trunk_output. It may
            # seem to be harmless, but it appears that autograd will result in computing
            # grads in different order. Different ordering can cause deterministic OOM,
            # even when the peak memory otherwise is only 24GB out of 32GB.
            #
            # When debugging this, it is not enough to just dump the total module
            # params. You need to diff the module string representations.
            stem = {
                StemType.RES_STEM_CIFAR: ResStemCifar,
                StemType.RES_STEM_IN: ResStemIN,
                StemType.SIMPLE_STEM_IN: SimpleStemIN,
            }[params.stem_type](
                3,
                params.stem_width,
                params.bn_epsilon,
                params.bn_momentum,
                False,  # params.relu_in_place
            )
            init_weights(stem)
            stem = auto_wrap_bn(stem)

            # Instantiate all the AnyNet blocks in the trunk
            block_fun = {
                BlockType.VANILLA_BLOCK: VanillaBlock,
                BlockType.RES_BASIC_BLOCK: ResBasicBlock,
                BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock,
            }[params.block_type]

            current_width = params.stem_width

            self.trunk_depth = 0

            blocks = []

            for i, (width_out, stride, depth, bot_mul,
                    group_width) in enumerate(params.get_expanded_params()):
                blocks.append((
                    f"block{i+1}",
                    AnyStage(
                        model_config,
                        current_width,
                        width_out,
                        stride,
                        depth,
                        block_fun,
                        bot_mul,
                        group_width,
                        params,
                        stage_index=i + 1,
                    ),
                ))

                self.trunk_depth += blocks[-1][1].stage_depth

                current_width = width_out

            trunk_output = nn.Sequential(OrderedDict(blocks))

        ################################################################################

        # Now map the models to the structure we want to expose for SSL tasks
        # The upstream RegNet model is made of :
        # - `stem`
        # - n x blocks in trunk_output, named `block1, block2, ..`

        # We're only interested in the stem and successive blocks
        # everything else is not picked up on purpose
        feature_blocks: List[Tuple[str, nn.Module]] = []

        # - get the stem
        feature_blocks.append(("conv1", stem))

        # - get all the feature blocks
        for k, v in trunk_output.named_children():
            assert k.startswith("block"), f"Unexpected layer name {k}"
            block_index = len(feature_blocks) + 1
            feature_blocks.append((f"res{block_index}", v))

        # - finally, add avgpool and flatten.
        feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
        feature_blocks.append(("flatten", Flatten(1)))

        self._feature_blocks = nn.ModuleDict(feature_blocks)
예제 #14
0
def _distributed_worker(
    rank,
    world_size,
    fsdp_config,
    fsdp_wrap_bn,
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
    sync_bn,
    conv_bias,
    linear_bias,
):
    torch.backends.cudnn.deterministic = True

    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True

    model = Model(conv_bias, linear_bias)
    model.load_state_dict(state_before)
    model = model.cuda()

    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
    if ddp:
        if sync_bn == "pytorch":
            model = pytorch_bn_converter(model)
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
    else:
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
        else:
            print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
        model = FSDP(model, **fsdp_config).cuda()
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
    optim = SGD(model.parameters(), lr=0.1)
    loss_func = CrossEntropyLoss()

    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
        with context:
            out = model(in_data)
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()

    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
        # If sync_bn is used, all ranks should have the same state, so we can compare with
        # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
        if sync_bn != "none" or rank == 0:
            assert objects_are_equal(state_after,
                                     fsdp_state,
                                     raise_exception=True)

    teardown()