def create_model(with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter): model = Model() if with_fsdp: if wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=True) with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) model.head = wrap(model.head) else: if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2, maintain_forward_counter=False) return model
def _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): model = Model2() if with_model2 else Model() fsdp_config = None if with_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) fsdp_config = { "mixed_precision": False, "flatten_parameters": False, "reshard_after_forward": False, "bucket_cap_mb": 0, "force_input_to_fp32": True, # SyncBN needs this. } if with_fsdp and wrap_bn: model.block1 = auto_wrap_bn(model.block1, single_rank_pg=False, fsdp_config=fsdp_config) model.block2 = auto_wrap_bn(model.block2, single_rank_pg=False, fsdp_config=fsdp_config) if with_model2: model.block3 = auto_wrap_bn(model.block3, single_rank_pg=False, fsdp_config=fsdp_config) if with_checkpoint: model.block2 = checkpoint_wrapper(model.block2) if with_model2: model.block3 = checkpoint_wrapper(model.block3) if with_fsdp: with enable_wrap( wrapper_cls=FSDP, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ): model.block1 = wrap(model.block1) model.block2 = wrap(model.block2) if with_model2: model.block3 = wrap(model.block3) model.head = wrap(model.head) return model
def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config): model = Model(model_hidden_dim) if with_fsdp: model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) model.stem = to_fsdp(model.stem, fsdp_config) model.blocks = to_fsdp(model.blocks, fsdp_config) model.head = to_fsdp(model.head, fsdp_config) else: if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) return model
def create_model(with_fsdp, with_checkpoint): model = Model() if with_fsdp: model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) model.stem = to_fsdp(model.stem) model.blocks = to_fsdp(model.blocks) model.head = to_fsdp(model.head) else: if with_checkpoint: model.blocks = checkpoint_wrapper(model.blocks) return model
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() # TODO (Min): for now, we just test pytorch sync_bn here. # this will grow into regnet; testing apex sync_bn, etc. self.conv = Conv2d(2, 2, (1, 1)) self.bn = BatchNorm2d(2) def forward(self, x): x = self.conv(x) x = self.bn(x) return x # TODO (Min): check DDP equivalency. model = Model() # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print("auto_wrap_bn, then convert_sync_batchnorm") model = auto_wrap_bn(model) model = SyncBatchNorm.convert_sync_batchnorm(model) else: print("convert_sync_batchnorm, then auto_wrap_bn") model = SyncBatchNorm.convert_sync_batchnorm(model) model = auto_wrap_bn(model) model = FSDP(model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(2, 2, 2, 2).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def __init__( self, model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): super().__init__() mlp = LinearEvalMLP(model_config, in_channels, dims, use_bn, use_relu) mlp = auto_wrap_bn(mlp, single_rank_pg=False) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bottleneck_multiplier, group_width) block = auto_wrap_bn(block, single_rank_pg=False) block = fsdp_wrapper(module=block, **self.fsdp_config) return block
def create_block( self, width_in: int, width_out: int, stride: int, params: "RegNetParams", bot_mul: float, group_width: int = 1, ): block = super().create_block(width_in, width_out, stride, params, bot_mul, group_width) block = auto_wrap_bn(block, single_rank_pg=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config): block = wrap(block) return block
def __init__( self, model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, skip_last_layer_relu_bn: bool = True, ): super().__init__() mlp = MLP(model_config, dims, use_bn, use_relu, use_dropout, use_bias, skip_last_layer_relu_bn) mlp = auto_wrap_bn(mlp, single_rank_pg=False) self.mlp = fsdp_wrapper(mlp, **model_config.FSDP_CONFIG)
def __init__( self, model_config, width_in: int, width_out: int, stride: int, depth: int, block_constructor: nn.Module, activation: nn.Module, bot_mul: float, group_width: int, params: "RegNetParams", stage_index: int = 0, ): super().__init__() self.stage_depth = 0 fsdp_config = { "wrapper_cls": fsdp_wrapper, } fsdp_config.update(model_config.FSDP_CONFIG) for i in range(depth): # Make a block and move it to cuda since shard-as-we-build of FSDP needs # cuda to do dist.all_gather() call. block = block_constructor( width_in if i == 0 else width_out, width_out, stride if i == 0 else 1, params.bn_epsilon, params.bn_momentum, activation, bot_mul, group_width, params.se_ratio, ).cuda() # Init weight before wrapping and sharding. init_weights(block) # Now, wrap it with fsdp+checkpoint, which will perform the sharding. block = auto_wrap_bn(block) with enable_wrap(**fsdp_config): block = wrap(block) self.stage_depth += block.depth self.add_module(f"block{stage_index}-{i}", block)
def create_block( self, width_in: int, width_out: int, stride: int, params: Union[RegNetParams, AnyNetParams], bottleneck_multiplier: float, group_width: int = 1, ): block = super().create_block( width_in, width_out, stride, params, bottleneck_multiplier, group_width ) block = auto_wrap_bn(block, single_rank_pg=False) if self.use_activation_checkpointing: # TODO - make this configurable block = checkpoint_wrapper(block, offload_to_cpu=False) with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config): block = wrap(block) return block
def create_stem(self, params: Union[RegNetParams, AnyNetParams]): stem = super().create_stem(params) stem = auto_wrap_bn(stem, single_rank_pg=False) return stem
def __init__(self, model_config: AttrDict, model_name: str): super().__init__() self.model_config = model_config assert model_config.INPUT_TYPE in ["rgb", "bgr"], "Input type not supported" trunk_config = model_config.TRUNK.TRUNK_PARAMS.REGNET assert "name" not in trunk_config, "Please specify the RegNet Params dictionary" ################################################################################ params = RegNetParams( depth=trunk_config["depth"], w_0=trunk_config["w_0"], w_a=trunk_config["w_a"], w_m=trunk_config["w_m"], group_w=trunk_config["group_width"], stem_type=trunk_config.get("stem_type", "simple_stem_in").upper(), stem_width=trunk_config.get("stem_width", 32), block_type=trunk_config.get("block_type", "res_bottleneck_block").upper(), use_se=trunk_config.get("use_se", True), se_ratio=trunk_config.get("se_ratio", 0.25), bn_epsilon=trunk_config.get("bn_epsilon", 1e-05), bn_momentum=trunk_config.get("bn_momentum", 0.1), ) # We need all workers (on all nodes) to have the same weights. # Unlike DDP, FSDP does not sync weights using rank 0 on start. # Therefore, we init stem and trunk_output below within the seed context. # # TODO (Min): we can make this seed coming from the config or env. stem = None trunk_output = None with set_torch_seed(0): # Ad hoc stem # # Important: do NOT retain modules in self.stem or self.trunk_output. It may # seem to be harmless, but it appears that autograd will result in computing # grads in different order. Different ordering can cause deterministic OOM, # even when the peak memory otherwise is only 24GB out of 32GB. # # When debugging this, it is not enough to just dump the total module # params. You need to diff the module string representations. stem = { StemType.RES_STEM_CIFAR: ResStemCifar, StemType.RES_STEM_IN: ResStemIN, StemType.SIMPLE_STEM_IN: SimpleStemIN, }[params.stem_type]( 3, params.stem_width, params.bn_epsilon, params.bn_momentum, False, # params.relu_in_place ) init_weights(stem) stem = auto_wrap_bn(stem) # Instantiate all the AnyNet blocks in the trunk block_fun = { BlockType.VANILLA_BLOCK: VanillaBlock, BlockType.RES_BASIC_BLOCK: ResBasicBlock, BlockType.RES_BOTTLENECK_BLOCK: ResBottleneckBlock, }[params.block_type] current_width = params.stem_width self.trunk_depth = 0 blocks = [] for i, (width_out, stride, depth, bot_mul, group_width) in enumerate(params.get_expanded_params()): blocks.append(( f"block{i+1}", AnyStage( model_config, current_width, width_out, stride, depth, block_fun, bot_mul, group_width, params, stage_index=i + 1, ), )) self.trunk_depth += blocks[-1][1].stage_depth current_width = width_out trunk_output = nn.Sequential(OrderedDict(blocks)) ################################################################################ # Now map the models to the structure we want to expose for SSL tasks # The upstream RegNet model is made of : # - `stem` # - n x blocks in trunk_output, named `block1, block2, ..` # We're only interested in the stem and successive blocks # everything else is not picked up on purpose feature_blocks: List[Tuple[str, nn.Module]] = [] # - get the stem feature_blocks.append(("conv1", stem)) # - get all the feature blocks for k, v in trunk_output.named_children(): assert k.startswith("block"), f"Unexpected layer name {k}" block_index = len(feature_blocks) + 1 feature_blocks.append((f"res{block_index}", v)) # - finally, add avgpool and flatten. feature_blocks.append(("avgpool", nn.AdaptiveAvgPool2d((1, 1)))) feature_blocks.append(("flatten", Flatten(1))) self._feature_blocks = nn.ModuleDict(feature_blocks)
def _distributed_worker( rank, world_size, fsdp_config, fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, state_before, inputs, rank_0_output, state_after, sync_bn, conv_bias, linear_bias, ): torch.backends.cudnn.deterministic = True result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" ddp = True if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) if fsdp_config["mixed_precision"]: # To match DDP in AMP -O1, we need fp32 reduce scatter. fsdp_config["fp32_reduce_scatter"] = True model = Model(conv_bias, linear_bias) model.load_state_dict(state_before) model = model.cuda() class DummyScaler: def scale(self, loss): return loss def step(self, optim): optim.step() def update(self): pass scaler = DummyScaler() if ddp: if sync_bn == "pytorch": model = pytorch_bn_converter(model) model = DDP(model, device_ids=[rank], broadcast_buffers=True) if ddp_mixed_precision: scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}") if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) if sync_bn == "pytorch": model = pytorch_bn_converter(model) else: print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}") if sync_bn == "pytorch": model = pytorch_bn_converter(model) if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() if fsdp_config["mixed_precision"]: scaler = ShardedGradScaler() # Print the model for verification. if rank == 0: print(model) optim = SGD(model.parameters(), lr=0.1) loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() context = contextlib.suppress() if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) if not ddp and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) fake_label = torch.zeros(1, dtype=torch.long).cuda() loss = loss_func(out.unsqueeze(0), fake_label) scaler.scale(loss).backward() scaler.step(optim) scaler.update() optim.zero_grad() if ddp: # Save the rank 0 state_dict to the output file. if rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) else: model.assert_state(TrainingState.IDLE) # Ensure final state equals to the state_after. fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() # Change False to True to enable this when you want to debug the mismatch. if False and rank == 0: def dump(d): for k, v in d.items(): print(k, v) dump(state_after) dump(fsdp_state) # If sync_bn is used, all ranks should have the same state, so we can compare with # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0. if sync_bn != "none" or rank == 0: assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown()