예제 #1
0
def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn,
                 fp32_reduce_scatter):
    model = Model()
    if with_fsdp:
        if wrap_bn:
            model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False)
            model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False)
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=True)
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            model.head = wrap(model.head)
    else:
        if with_checkpoint:
            model.block2 = checkpoint_wrapper(model.block2,
                                              maintain_forward_counter=False)
    return model
def _create_model(
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    model = Model2() if with_model2 else Model()
    fsdp_config = None
    if with_sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        fsdp_config = {
            "mixed_precision": False,
            "flatten_parameters": False,
            "reshard_after_forward": False,
            "bucket_cap_mb": 0,
            "force_input_to_fp32": True,  # SyncBN needs this.
        }

    if with_fsdp and wrap_bn:
        model.block1 = auto_wrap_bn(model.block1,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        model.block2 = auto_wrap_bn(model.block2,
                                    single_rank_pg=False,
                                    fsdp_config=fsdp_config)
        if with_model2:
            model.block3 = auto_wrap_bn(model.block3,
                                        single_rank_pg=False,
                                        fsdp_config=fsdp_config)

    if with_checkpoint:
        model.block2 = checkpoint_wrapper(model.block2)
        if with_model2:
            model.block3 = checkpoint_wrapper(model.block3)

    if with_fsdp:
        with enable_wrap(
                wrapper_cls=FSDP,
                flatten_parameters=flatten,
                mixed_precision=mixed_precision,
                compute_dtype=torch.float32,
                fp32_reduce_scatter=fp32_reduce_scatter,
                bucket_cap_mb=bucket_cap_mb,
        ):
            model.block1 = wrap(model.block1)
            model.block2 = wrap(model.block2)
            if with_model2:
                model.block3 = wrap(model.block3)
            model.head = wrap(model.head)

    return model
예제 #3
0
def fsdp_wrapper(module):
    """Customer wrapper that does FSDP + checkpoint at the same time

    Currently not used. Will be used in the next commit. Included here
    to check the imports.
    """
    fsdp_config = {
        "wrapper_cls": fsdp_wrapper,
        "mixed_precision": True,
        "flatten_parameters": True,
    }
    with enable_wrap(fsdp_config):
        wrap()
    return FSDP(checkpoint_wrapper(module))
예제 #4
0
파일: regnet_2.py 프로젝트: wpc/vissl
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = auto_wrap_bn(block, single_rank_pg=False)
     with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
         block = wrap(block)
     return block
예제 #5
0
    def __init__(
        self,
        model_config,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: nn.Module,
        activation: nn.Module,
        bot_mul: float,
        group_width: int,
        params: "RegNetParams",
        stage_index: int = 0,
    ):
        super().__init__()
        self.stage_depth = 0

        fsdp_config = {
            "wrapper_cls": fsdp_wrapper,
        }
        fsdp_config.update(model_config.FSDP_CONFIG)
        for i in range(depth):
            # Make a block and move it to cuda since shard-as-we-build of FSDP needs
            # cuda to do dist.all_gather() call.
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                params.bn_epsilon,
                params.bn_momentum,
                activation,
                bot_mul,
                group_width,
                params.se_ratio,
            ).cuda()
            # Init weight before wrapping and sharding.
            init_weights(block)

            # Now, wrap it with fsdp+checkpoint, which will perform the sharding.
            block = auto_wrap_bn(block)
            with enable_wrap(**fsdp_config):
                block = wrap(block)

            self.stage_depth += block.depth
            self.add_module(f"block{stage_index}-{i}", block)
예제 #6
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(
         width_in, width_out, stride, params, bottleneck_multiplier, group_width
     )
     block = auto_wrap_bn(block, single_rank_pg=False)
     if self.use_activation_checkpointing:
         # TODO - make this configurable
         block = checkpoint_wrapper(block, offload_to_cpu=False)
     with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
         block = wrap(block)
     return block
예제 #7
0
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
    assert model_config.INPUT_TYPE in ["rgb",
                                       "bgr"], "Input type not supported"
    trunk_config = model_config.TRUNK.REGNET
    if "name" in trunk_config:
        assert (trunk_config["name"] == "anynet"
                ), "Please use AnyNetParams or specify RegNetParams dictionary"

    if "name" in trunk_config and trunk_config["name"] == "anynet":
        params = AnyNetParams(
            depths=trunk_config["depths"],
            widths=trunk_config["widths"],
            group_widths=trunk_config["group_widths"],
            bottleneck_multipliers=trunk_config["bottleneck_multipliers"],
            strides=trunk_config["strides"],
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )
    else:
        params = RegNetParams(
            depth=trunk_config["depth"],
            w_0=trunk_config["w_0"],
            w_a=trunk_config["w_a"],
            w_m=trunk_config["w_m"],
            group_width=trunk_config["group_width"],
            bottleneck_multiplier=trunk_config.get("bottleneck_multiplier",
                                                   1.0),
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )

    # Ad hoc stem
    #
    # Important: do NOT retain modules in self.stem or self.trunk_output. It may
    # seem to be harmless, but it appears that autograd will result in computing
    # grads in different order. Different ordering can cause deterministic OOM,
    # even when the peak memory otherwise is only 24GB out of 32GB.
    #
    # When debugging this, it is not enough to just dump the total module
    # params. You need to diff the module string representations.
    stem = factory.create_stem(params)

    # Instantiate all the AnyNet blocks in the trunk
    current_width, trunk_depth, blocks = params.stem_width, 0, []
    for i, (width_out, stride, depth, group_width,
            bottleneck_multiplier) in enumerate(params.get_expanded_params()):
        new_stage = AnyStage(
            factory=factory,
            width_in=current_width,
            width_out=width_out,
            stride=stride,
            depth=depth,
            group_width=group_width,
            bottleneck_multiplier=bottleneck_multiplier,
            params=params,
            stage_index=i + 1,
        )
        if isinstance(factory, RegnetFSDPBlocksFactory):
            if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING:
                logging.info("Using activation checkpointing")
                new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False)

            with enable_wrap(wrapper_cls=fsdp_wrapper,
                             **model_config.FSDP_CONFIG):
                new_stage = wrap(new_stage)

        blocks.append((
            f"block{i + 1}",
            new_stage,
        ))
        trunk_depth += blocks[-1][1].stage_depth
        current_width = width_out

    trunk_output = nn.Sequential(OrderedDict(blocks))

    ################################################################################

    # Now map the models to the structure we want to expose for SSL tasks
    # The upstream RegNet model is made of :
    # - `stem`
    # - n x blocks in trunk_output, named `block1, block2, ..`
    # We're only interested in the stem and successive blocks
    # everything else is not picked up on purpose
    feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)]
    for k, v in trunk_output.named_children():
        assert k.startswith("block"), f"Unexpected layer name {k}"
        block_index = len(feature_blocks) + 1
        feature_blocks.append((f"res{block_index}", v))
    feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
    feature_blocks.append(("flatten", Flatten(1)))
    return nn.ModuleDict(feature_blocks), trunk_depth
예제 #8
0
 def wrap_module(self, module: torch.nn.Module) -> torch.nn.Module:
     with enable_wrap(wrapper_cls=_FSDP, **self._fsdp_kwargs):
         wrapped_module = wrap(module)
     return wrapped_module
예제 #9
0
def main(local_rank, *args):
    torch.backends.cudnn.benchmark = True
    init_method = "tcp://%s:%s" % ("0.0.0.0", "9999")
    torch.distributed.init_process_group(backend="nccl",
                                         rank=local_rank,
                                         world_size=8,
                                         init_method=init_method)
    print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" %
          (get_time_string(), local_rank))
    device = torch.device(
        f'cuda:{local_rank}')  # Unique only on individual node.
    torch.cuda.set_device(device)
    torch.cuda.set_device(device)
    fsdp_params = dict(mixed_precision=True,
                       flatten_parameters=True,
                       bucket_cap_mb=25,
                       reshard_after_forward=False,
                       fp32_reduce_scatter=False,
                       cpu_offload=False,
                       move_grads_to_cpu=False,
                       process_group=torch.distributed.group.WORLD)
    with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params):
        nn_model = nn.Sequential(
            nn.Linear(200, 200),
            wrap(
                checkpoint_wrapper(nn.Sequential(
                    nn.Linear(200, 200), nn.Linear(200, 200),
                    wrap(
                        checkpoint_wrapper(nn.Linear(200, 200),
                                           offload_to_cpu=True)),
                    checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                    nn.Linear(200, 200)),
                                   offload_to_cpu=True)),
            checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
            nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()

        model = FullyShardedDDP(nn_model, **fsdp_params)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=1e-4,
                                  eps=1e-7,
                                  weight_decay=1e-2,
                                  betas=(0.9, 0.99))
    optimizer.zero_grad(set_to_none=True)

    for i in range(1000):
        optimizer.zero_grad(set_to_none=True)
        fake_inputs = torch.randn(32, 200, device=device)
        fake_labels = torch.randn(32, 64, device=device)
        outputs = model(fake_inputs)
        loss = ((outputs - fake_labels)**2).mean()
        loss.backward()
        model.clip_grad_norm_(1.0)
        optimizer.step()
        if i % 100 == 0:
            print("Loss = %s, rank = %s" % (loss.item(), local_rank))

    state_dict = model.state_dict()
    nn_model = nn.Sequential(
        nn.Linear(200, 200),
        nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200),
                      nn.Linear(200, 200),
                      checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                      nn.Linear(200, 200)),
        checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
        nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()
    nn_model.load_state_dict(state_dict)
    print("[Train]: Time = %s, Trainable Params = %s" %
          (get_time_string(), numel(nn_model) / 1_000_000))