def __init__(self, config: FastFormerConfig, block_index, is_last_layer_of_block, is_encoder_layer): super().__init__() self.config = config groups, layers = config.ffn_groups, config.ffn_layers d_model, d_inner = config.block_channel_size[ block_index], config.block_channel_size[ block_index] * config.ffn_width d_next = config.block_channel_size[block_index + 1] if ( block_index + 1) < len(config.block_channel_size) else d_model self.n_blocks = config.block_sizes[block_index] - 1 self.need_dim_match = d_model != d_next and is_encoder_layer and is_last_layer_of_block self.diff = d_next - d_model self.d_model = d_model self.activation_function = checkpoint_wrapper( ACT2FN[config.hidden_act](), offload_to_cpu=False) self.layer_norm = nn.LayerNorm(d_model, config.layer_norm_eps) if self.need_dim_match: self.layer_norm = nn.LayerNorm(d_next, config.layer_norm_eps) self.dim_match_stride = int(np.ceil(d_model / self.diff)) if groups > 1: assert d_model % groups == 0 self.lin = nn.Linear(d_model, d_model) self.ffn = ConvFFN(config, d_model, d_inner, groups, layers) else: self.lin = nn.Identity() self.ffn = BertFFN(config, d_model, d_inner, layers)
def __init__(self, config: FastFormerConfig, d_model, d_inner, groups, layers=0, d_out=None): super().__init__() d_out = d_model if d_out is None else d_out cin, cout = d_model, d_out act = config.hidden_act self.conv1d_in = Conv1d(in_channels=cin, out_channels=d_inner, kernel_size=1, groups=groups, bias=True) self.conv1d_in.post_permute = False self.activation_dropout = Dropout(config.hidden_dropout) self.layers = nn.ModuleList() if layers > 0 else None for _ in range(layers): cnn = Conv1d(in_channels=d_inner, out_channels=d_inner, kernel_size=1, groups=groups) cnn.pre_permute = False cnn.post_permute = False self.layers.append(cnn) self.conv1d_out = Conv1d(in_channels=d_inner, out_channels=cout, kernel_size=1, groups=groups, bias=False) self.conv1d_out.pre_permute = False self.act = checkpoint_wrapper(ACT2FN[act](), offload_to_cpu=False)
def wrap(module: nn.Module, cls: Callable = FullyShardedDataParallel, activation_checkpoint: bool = False, **wrap_overrides: Any) -> nn.Module: """ Annotate that a module should be wrapped. Annotated modules will only be wrapped if inside of an :func:`enable_wrap` context manager. An important use case is annotating large layers that should be sharded (in-place) during initialization, to avoid running out of system memory. Usage:: with enable_wrap(**params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) Args: module (nn.Module): module to wrap (if in :func:`enable_wrap` context) cls (Callable): class wrapper to wrap the model with if in context (default: :class:`FullyShardedDataParallel`) activation_checkpoint (bool): use activation checkpointing wrapper (default: False) **wrap_overrides: configuration overrides that will take priority over the values provided by the :func:`enable_wrap` context """ if ConfigAutoWrap.in_autowrap_context: wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides} if activation_checkpoint: module = checkpoint_wrapper(module) return cls(module, **wrap_overrides) return module
def fsdp_wrapper(module): """Customer wrapper that does FSDP + checkpoint at the same time Currently not used. Will be used in the next commit. Included here to check the imports. """ fsdp_config = { "wrapper_cls": fsdp_wrapper, "mixed_precision": True, "flatten_parameters": True, } with enable_wrap(fsdp_config): wrap() return FSDP(checkpoint_wrapper(module))
def __init__(self, config: FastFormerConfig, d_model, d_inner, layers=0, d_out=None): super().__init__() self.linear_1 = nn.Linear(d_model, d_inner, bias=True) self.activation_function = checkpoint_wrapper( ACT2FN[config.hidden_act](), offload_to_cpu=False) self.activation_dropout = Dropout(config.hidden_dropout) d_out = d_model if d_out is None else d_out self.linear_2 = nn.Linear(d_inner, d_out, bias=False) self.layers = nn.ModuleList() if layers > 0 else None for _ in range(layers): self.layers.append(nn.Linear(d_inner, d_inner, bias=False))
def __init__(self, config: FastFormerConfig, hidden_size, heads, head_size, kernel_size=9, stride=1): super().__init__() self.config = config self.cls_tokens = config.num_highway_cls_tokens + 1 self.heads = heads self.kernel_size = kernel_size self.all_head_size = heads * head_size self.hidden_size = hidden_size self.stride = stride act = config.hidden_act self.act = checkpoint_wrapper(ACT2FN[act](), offload_to_cpu=False) assert hidden_size % heads == 0 self.head_size = head_size self.conv_attn_kernel = nn.Conv1d(in_channels=hidden_size, out_channels=self.heads * self.kernel_size, kernel_size=kernel_size, groups=heads, bias=False, stride=stride, padding=(kernel_size - 1) // 2) # self.conv_attn_kernel = nn.Linear(self.all_head_size, self.heads * self.kernel_size) # Multi-head? # if config.no_v_head: # self.conv_attn_point = nn.Identity() # else: # self.conv_attn_point = nn.Linear(hidden_size, hidden_size, bias=False) self.use_cuda_conv = config.use_cuda_conv if not self.use_cuda_conv or self.stride != 1: self.unfold1d = nn.Unfold(kernel_size=[kernel_size, 1], padding=[(kernel_size - 1) // 2, 0], stride=[stride, 1]) else: self.padding_l = (self.kernel_size - 1) // 2
def main(local_rank, *args): torch.backends.cudnn.benchmark = True init_method = "tcp://%s:%s" % ("0.0.0.0", "9999") torch.distributed.init_process_group(backend="nccl", rank=local_rank, world_size=8, init_method=init_method) print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" % (get_time_string(), local_rank)) device = torch.device( f'cuda:{local_rank}') # Unique only on individual node. torch.cuda.set_device(device) torch.cuda.set_device(device) fsdp_params = dict(mixed_precision=True, flatten_parameters=True, bucket_cap_mb=25, reshard_after_forward=False, fp32_reduce_scatter=False, cpu_offload=False, move_grads_to_cpu=False, process_group=torch.distributed.group.WORLD) with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params): nn_model = nn.Sequential( nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Sequential( nn.Linear(200, 200), nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Linear(200, 200), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() model = FullyShardedDDP(nn_model, **fsdp_params) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-7, weight_decay=1e-2, betas=(0.9, 0.99)) optimizer.zero_grad(set_to_none=True) for i in range(1000): optimizer.zero_grad(set_to_none=True) fake_inputs = torch.randn(32, 200, device=device) fake_labels = torch.randn(32, 64, device=device) outputs = model(fake_inputs) loss = ((outputs - fake_labels)**2).mean() loss.backward() model.clip_grad_norm_(1.0) optimizer.step() if i % 100 == 0: print("Loss = %s, rank = %s" % (loss.item(), local_rank)) state_dict = model.state_dict() nn_model = nn.Sequential( nn.Linear(200, 200), nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200), nn.Linear(200, 200), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() nn_model.load_state_dict(state_dict) print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), numel(nn_model) / 1_000_000))