예제 #1
0
    def init_distributed_data_parallel_model(self):
        """
        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None
        for module in self.base_model.heads.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        self.base_model = fsdp_wrapper(self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
        assert is_valid_fsdp_model(
            self.distributed_model), "FSDP is not setup correctly"
예제 #2
0
def RegNetFSDP(model_config: AttrDict, model_name: str):
    """
    Wrap the entire trunk since we need to load checkpoint before
    train_fsdp_task.py wrapping happens.
    """
    module = _RegNetFSDP(model_config, model_name).cuda()
    return fsdp_wrapper(module, **model_config.FSDP_CONFIG)
예제 #3
0
def FSDPLinearEvalMLP(
    model_config: AttrDict,
    in_channels: int,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
):
    mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #4
0
def SwavPrototypesHeadFSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool,
    num_clusters: int,
    use_bias: bool = True,
    return_embeddings: bool = True,
    skip_last_bn: bool = True,
    normalize_feats: bool = True,
):
    """
    SwAV head specific FSDP wrapping: we keep the full precision for the prototypes

    This is important for convergence:
    Since we "normalize" this layer in the update hook, we need to keep its
    weights in full precision. It is output is going into the loss and used
    for clustering, so we need to have that in full precision as well.
    """

    head = SwAVPrototypesHead(
        model_config=model_config,
        dims=dims,
        use_bn=use_bn,
        num_clusters=num_clusters,
        use_bias=use_bias,
        return_embeddings=return_embeddings,
        skip_last_bn=skip_last_bn,
        normalize_feats=normalize_feats,
    )
    head = fsdp_auto_wrap_bn(head)

    prototypes_fp32_fsdp_config = model_config.FSDP_CONFIG.copy()
    prototypes_fp32_fsdp_config["flatten_parameters"] = False
    prototypes_fp32_fsdp_config["mixed_precision"] = False
    prototypes_fp32_fsdp_config["fp32_reduce_scatter"] = False
    prototypes_fp32_fsdp_config["compute_dtype"] = torch.float32
    for j in range(head.nmb_heads):
        module = getattr(head, "prototypes" + str(j))
        module = fsdp_wrapper(module, **prototypes_fp32_fsdp_config)
        setattr(head, "prototypes" + str(j), module)

    return fsdp_wrapper(head, **model_config.FSDP_CONFIG)
예제 #5
0
 def __init__(
     self,
     model_config: AttrDict,
     in_channels: int,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
 ):
     super().__init__()
     mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu)
     mlp = fsdp_auto_wrap_bn(mlp)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #6
0
    def _pretraining_worker(
        gpu_id: int,
        with_fsdp: bool,
        with_activation_checkpointing: bool,
        with_larc: bool,
        sync_file: str,
        result_file: str,
    ):
        init_distributed_on_file(world_size=2,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True

        # Create the inputs
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()
        target = torch.tensor(0.0).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_pretraining_config(
            with_fsdp, with_activation_checkpointing, with_larc=with_larc)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = build_optimizer(config["OPTIMIZER"])
        optimizer.set_param_groups(model.parameters())

        # Run a few iterations and collect the losses
        losses = []
        num_iterations = 5
        for iteration in range(num_iterations):
            out = model(batch)
            loss = criterion(out[0], target)
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step(where=float(iteration / num_iterations))

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
예제 #7
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = auto_wrap_bn(block, single_rank_pg=False)
     block = fsdp_wrapper(module=block, **self.fsdp_config)
     return block
예제 #8
0
 def __init__(
     self,
     model_config: AttrDict,
     dims: List[int],
     use_bn: bool = False,
     use_relu: bool = False,
     use_dropout: bool = False,
     use_bias: bool = True,
     skip_last_layer_relu_bn: bool = True,
 ):
     super().__init__()
     mlp = MLP(model_config, dims, use_bn, use_relu, use_dropout, use_bias,
               skip_last_layer_relu_bn)
     mlp = auto_wrap_bn(mlp, single_rank_pg=False)
     self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #9
0
 def create_block(
     self,
     width_in: int,
     width_out: int,
     stride: int,
     params: Union[RegNetParams, AnyNetParams],
     bottleneck_multiplier: float,
     group_width: int = 1,
 ):
     block = super().create_block(width_in, width_out, stride, params,
                                  bottleneck_multiplier, group_width)
     block = fsdp_auto_wrap_bn(block)
     if self.fsdp_config.AUTO_WRAP_THRESHOLD > 0:
         block = auto_wrap_big_layers(block, self.fsdp_config)
     block = fsdp_wrapper(module=block, **self.fsdp_config)
     return block
예제 #10
0
    def _distributed_worker(gpu_id: int, with_fsdp: bool, sync_file: str,
                            result_file: str):
        torch.cuda.set_device(gpu_id)
        dist.init_process_group(backend="nccl",
                                init_method="file://" + sync_file,
                                world_size=2,
                                rank=gpu_id)

        # Create the inputs
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_config(with_fsdp)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = optim.SGD(model.parameters(), lr=1e-2)

        # Run a few iterations and collect the losses
        losses = []
        for iteration in range(5):
            out = model(batch)
            loss = criterion(out[0], torch.tensor(0.0).cuda())
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step()

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
예제 #11
0
def MLP_FSDP(
    model_config: AttrDict,
    dims: List[int],
    use_bn: bool = False,
    use_relu: bool = False,
    use_dropout: bool = False,
    use_bias: bool = True,
    skip_last_layer_relu_bn: bool = True,
):
    mlp = MLP(
        model_config,
        dims,
        use_bn,
        use_relu,
        use_dropout,
        use_bias,
        skip_last_layer_relu_bn,
    )
    mlp = fsdp_auto_wrap_bn(mlp)
    return fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
예제 #12
0
    def create_any_stage(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        group_width: int,
        bottleneck_multiplier: float,
        params: Union[RegNetParams, AnyNetParams],
        stage_index: int = 0,
        checkpoints: List[int] = 0,
    ):
        assert sorted(
            checkpoints) == checkpoints, "Checkpoint indices should be sorted"

        with_checkpointing = len(checkpoints) > 0
        block_delimiters = [depth] if len(checkpoints) == 0 else checkpoints

        any_stage = AnyStage()
        prev_depth = 0
        for block_group_index, next_depth in enumerate(block_delimiters):
            block_group = nn.Sequential()
            for i in range(prev_depth, next_depth):
                block = self.create_block(
                    width_in=width_in if i == 0 else width_out,
                    width_out=width_out,
                    stride=stride if i == 0 else 1,
                    params=params,
                    group_width=group_width,
                    bottleneck_multiplier=bottleneck_multiplier,
                )
                any_stage.stage_depth += block.depth
                block_group.add_module(f"block{stage_index}-{i}", block)
            prev_depth = next_depth
            if with_checkpointing:
                block_group = checkpoint_wrapper(block_group)
            block_group = fsdp_wrapper(block_group, **self.fsdp_config)
            any_stage.add_module(f"block{stage_index}-part{block_group_index}",
                                 block_group)
        return any_stage
예제 #13
0
    def create_any_stage(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        group_width: int,
        bottleneck_multiplier: float,
        params: Union[RegNetParams, AnyNetParams],
        group_delimiters: List[int],
        group_checkpoint: List[bool],
        stage_index: int = 0,
    ):
        assert len(group_delimiters) == len(group_checkpoint)

        any_stage = AnyStage()
        prev_depth = 0
        for group_index, next_depth in enumerate(group_delimiters):
            block_group = nn.Sequential()
            for i in range(prev_depth, next_depth):
                block = self.create_block(
                    width_in=width_in if i == 0 else width_out,
                    width_out=width_out,
                    stride=stride if i == 0 else 1,
                    params=params,
                    group_width=group_width,
                    bottleneck_multiplier=bottleneck_multiplier,
                )
                any_stage.stage_depth += block.depth
                block_group.add_module(f"block{stage_index}-{i}", block)
            prev_depth = next_depth
            if group_checkpoint[group_index]:
                block_group = checkpoint_wrapper(block_group)
            block_group = fsdp_wrapper(block_group, **self.fsdp_config)
            any_stage.add_module(f"block{stage_index}-part{group_index}",
                                 block_group)
        return any_stage
예제 #14
0
def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
    assert model_config.INPUT_TYPE in ["rgb",
                                       "bgr"], "Input type not supported"
    trunk_config = model_config.TRUNK.REGNET
    if "name" in trunk_config:
        assert (trunk_config["name"] == "anynet"
                ), "Please use AnyNetParams or specify RegNetParams dictionary"

    if "name" in trunk_config and trunk_config["name"] == "anynet":
        params = AnyNetParams(
            depths=trunk_config["depths"],
            widths=trunk_config["widths"],
            group_widths=trunk_config["group_widths"],
            bottleneck_multipliers=trunk_config["bottleneck_multipliers"],
            strides=trunk_config["strides"],
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )
    else:
        params = RegNetParams(
            depth=trunk_config["depth"],
            w_0=trunk_config["w_0"],
            w_a=trunk_config["w_a"],
            w_m=trunk_config["w_m"],
            group_width=trunk_config["group_width"],
            bottleneck_multiplier=trunk_config.get("bottleneck_multiplier",
                                                   1.0),
            stem_type=StemType[trunk_config.get("stem_type",
                                                "simple_stem_in").upper()],
            stem_width=trunk_config.get("stem_width", 32),
            block_type=BlockType[trunk_config.get(
                "block_type", "res_bottleneck_block").upper()],
            activation=ActivationType[trunk_config.get("activation",
                                                       "relu").upper()],
            use_se=trunk_config.get("use_se", True),
            se_ratio=trunk_config.get("se_ratio", 0.25),
            bn_epsilon=trunk_config.get("bn_epsilon", 1e-05),
            bn_momentum=trunk_config.get("bn_momentum", 0.1),
        )

    # Ad hoc stem
    #
    # Important: do NOT retain modules in self.stem or self.trunk_output. It may
    # seem to be harmless, but it appears that autograd will result in computing
    # grads in different order. Different ordering can cause deterministic OOM,
    # even when the peak memory otherwise is only 24GB out of 32GB.
    #
    # When debugging this, it is not enough to just dump the total module
    # params. You need to diff the module string representations.
    stem = factory.create_stem(params)

    # Instantiate all the AnyNet blocks in the trunk
    current_width, trunk_depth, blocks = params.stem_width, 0, []
    for i, (width_out, stride, depth, group_width,
            bottleneck_multiplier) in enumerate(params.get_expanded_params()):
        new_stage = AnyStage(
            factory=factory,
            width_in=current_width,
            width_out=width_out,
            stride=stride,
            depth=depth,
            group_width=group_width,
            bottleneck_multiplier=bottleneck_multiplier,
            params=params,
            stage_index=i + 1,
        )

        if isinstance(factory, RegnetFSDPBlocksFactory):
            if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING:
                logging.info("Using activation checkpointing")
                new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False)
            new_stage = fsdp_wrapper(module=new_stage,
                                     **model_config.FSDP_CONFIG)

        blocks.append((f"block{i + 1}", new_stage))
        trunk_depth += blocks[-1][1].stage_depth
        current_width = width_out

    trunk_output = nn.Sequential(OrderedDict(blocks))

    ################################################################################

    # Now map the models to the structure we want to expose for SSL tasks
    # The upstream RegNet model is made of :
    # - `stem`
    # - n x blocks in trunk_output, named `block1, block2, ..`
    # We're only interested in the stem and successive blocks
    # everything else is not picked up on purpose
    feature_blocks: List[Tuple[str, nn.Module]] = [("conv1", stem)]
    for k, v in trunk_output.named_children():
        assert k.startswith("block"), f"Unexpected layer name {k}"
        block_index = len(feature_blocks) + 1
        feature_blocks.append((f"res{block_index}", v))
    feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1))))
    feature_blocks.append(("flatten", Flatten(1)))
    return nn.ModuleDict(feature_blocks), trunk_depth
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(
            with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target)
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {
                        "trunk": model.trunk.local_state_dict()
                    },
                    "meta": {
                        "trunk": model.trunk.local_metadata_dict()
                    },
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(content,
                                                  shard_rank=gpu_id,
                                                  world_size=world_size)
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu"))
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu"))
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu"))
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"