Ejemplo n.º 1
0
 def __init__(self, config: FastFormerConfig, block_index,
              is_last_layer_of_block, is_encoder_layer):
     super().__init__()
     self.config = config
     groups, layers = config.ffn_groups, config.ffn_layers
     d_model, d_inner = config.block_channel_size[
         block_index], config.block_channel_size[
             block_index] * config.ffn_width
     d_next = config.block_channel_size[block_index + 1] if (
         block_index + 1) < len(config.block_channel_size) else d_model
     self.n_blocks = config.block_sizes[block_index] - 1
     self.need_dim_match = d_model != d_next and is_encoder_layer and is_last_layer_of_block
     self.diff = d_next - d_model
     self.d_model = d_model
     self.activation_function = checkpoint_wrapper(
         ACT2FN[config.hidden_act](), offload_to_cpu=False)
     self.layer_norm = nn.LayerNorm(d_model, config.layer_norm_eps)
     if self.need_dim_match:
         self.layer_norm = nn.LayerNorm(d_next, config.layer_norm_eps)
         self.dim_match_stride = int(np.ceil(d_model / self.diff))
     if groups > 1:
         assert d_model % groups == 0
         self.lin = nn.Linear(d_model, d_model)
         self.ffn = ConvFFN(config, d_model, d_inner, groups, layers)
     else:
         self.lin = nn.Identity()
         self.ffn = BertFFN(config, d_model, d_inner, layers)
Ejemplo n.º 2
0
 def __init__(self,
              config: FastFormerConfig,
              d_model,
              d_inner,
              groups,
              layers=0,
              d_out=None):
     super().__init__()
     d_out = d_model if d_out is None else d_out
     cin, cout = d_model, d_out
     act = config.hidden_act
     self.conv1d_in = Conv1d(in_channels=cin,
                             out_channels=d_inner,
                             kernel_size=1,
                             groups=groups,
                             bias=True)
     self.conv1d_in.post_permute = False
     self.activation_dropout = Dropout(config.hidden_dropout)
     self.layers = nn.ModuleList() if layers > 0 else None
     for _ in range(layers):
         cnn = Conv1d(in_channels=d_inner,
                      out_channels=d_inner,
                      kernel_size=1,
                      groups=groups)
         cnn.pre_permute = False
         cnn.post_permute = False
         self.layers.append(cnn)
     self.conv1d_out = Conv1d(in_channels=d_inner,
                              out_channels=cout,
                              kernel_size=1,
                              groups=groups,
                              bias=False)
     self.conv1d_out.pre_permute = False
     self.act = checkpoint_wrapper(ACT2FN[act](), offload_to_cpu=False)
Ejemplo n.º 3
0
def wrap(module: nn.Module,
         cls: Callable = FullyShardedDataParallel,
         activation_checkpoint: bool = False,
         **wrap_overrides: Any) -> nn.Module:
    """
    Annotate that a module should be wrapped. Annotated modules will only be
    wrapped if inside of an :func:`enable_wrap` context manager. An important
    use case is annotating large layers that should be sharded (in-place) during
    initialization, to avoid running out of system memory.

    Usage::

        with enable_wrap(**params):
            # Wraps layer in FSDP by default if within context
            self.l1 = wrap(torch.nn.Linear(5, 5))

    Args:
        module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
        cls (Callable): class wrapper to wrap the model with if in context
            (default: :class:`FullyShardedDataParallel`)
        activation_checkpoint (bool): use activation checkpointing wrapper
            (default: False)
        **wrap_overrides: configuration overrides that will take priority over
            the values provided by the :func:`enable_wrap` context
    """
    if ConfigAutoWrap.in_autowrap_context:
        wrap_overrides = {**ConfigAutoWrap.kwargs, **wrap_overrides}
        if activation_checkpoint:
            module = checkpoint_wrapper(module)
        return cls(module, **wrap_overrides)
    return module
Ejemplo n.º 4
0
def fsdp_wrapper(module):
    """Customer wrapper that does FSDP + checkpoint at the same time

    Currently not used. Will be used in the next commit. Included here
    to check the imports.
    """
    fsdp_config = {
        "wrapper_cls": fsdp_wrapper,
        "mixed_precision": True,
        "flatten_parameters": True,
    }
    with enable_wrap(fsdp_config):
        wrap()
    return FSDP(checkpoint_wrapper(module))
Ejemplo n.º 5
0
 def __init__(self,
              config: FastFormerConfig,
              d_model,
              d_inner,
              layers=0,
              d_out=None):
     super().__init__()
     self.linear_1 = nn.Linear(d_model, d_inner, bias=True)
     self.activation_function = checkpoint_wrapper(
         ACT2FN[config.hidden_act](), offload_to_cpu=False)
     self.activation_dropout = Dropout(config.hidden_dropout)
     d_out = d_model if d_out is None else d_out
     self.linear_2 = nn.Linear(d_inner, d_out, bias=False)
     self.layers = nn.ModuleList() if layers > 0 else None
     for _ in range(layers):
         self.layers.append(nn.Linear(d_inner, d_inner, bias=False))
Ejemplo n.º 6
0
 def __init__(self,
              config: FastFormerConfig,
              hidden_size,
              heads,
              head_size,
              kernel_size=9,
              stride=1):
     super().__init__()
     self.config = config
     self.cls_tokens = config.num_highway_cls_tokens + 1
     self.heads = heads
     self.kernel_size = kernel_size
     self.all_head_size = heads * head_size
     self.hidden_size = hidden_size
     self.stride = stride
     act = config.hidden_act
     self.act = checkpoint_wrapper(ACT2FN[act](), offload_to_cpu=False)
     assert hidden_size % heads == 0
     self.head_size = head_size
     self.conv_attn_kernel = nn.Conv1d(in_channels=hidden_size,
                                       out_channels=self.heads *
                                       self.kernel_size,
                                       kernel_size=kernel_size,
                                       groups=heads,
                                       bias=False,
                                       stride=stride,
                                       padding=(kernel_size - 1) // 2)
     # self.conv_attn_kernel = nn.Linear(self.all_head_size, self.heads * self.kernel_size)  # Multi-head?
     # if config.no_v_head:
     #     self.conv_attn_point = nn.Identity()
     # else:
     #     self.conv_attn_point = nn.Linear(hidden_size, hidden_size, bias=False)
     self.use_cuda_conv = config.use_cuda_conv
     if not self.use_cuda_conv or self.stride != 1:
         self.unfold1d = nn.Unfold(kernel_size=[kernel_size, 1],
                                   padding=[(kernel_size - 1) // 2, 0],
                                   stride=[stride, 1])
     else:
         self.padding_l = (self.kernel_size - 1) // 2
Ejemplo n.º 7
0
def main(local_rank, *args):
    torch.backends.cudnn.benchmark = True
    init_method = "tcp://%s:%s" % ("0.0.0.0", "9999")
    torch.distributed.init_process_group(backend="nccl",
                                         rank=local_rank,
                                         world_size=8,
                                         init_method=init_method)
    print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" %
          (get_time_string(), local_rank))
    device = torch.device(
        f'cuda:{local_rank}')  # Unique only on individual node.
    torch.cuda.set_device(device)
    torch.cuda.set_device(device)
    fsdp_params = dict(mixed_precision=True,
                       flatten_parameters=True,
                       bucket_cap_mb=25,
                       reshard_after_forward=False,
                       fp32_reduce_scatter=False,
                       cpu_offload=False,
                       move_grads_to_cpu=False,
                       process_group=torch.distributed.group.WORLD)
    with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params):
        nn_model = nn.Sequential(
            nn.Linear(200, 200),
            wrap(
                checkpoint_wrapper(nn.Sequential(
                    nn.Linear(200, 200), nn.Linear(200, 200),
                    wrap(
                        checkpoint_wrapper(nn.Linear(200, 200),
                                           offload_to_cpu=True)),
                    checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                    nn.Linear(200, 200)),
                                   offload_to_cpu=True)),
            checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
            nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()

        model = FullyShardedDDP(nn_model, **fsdp_params)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=1e-4,
                                  eps=1e-7,
                                  weight_decay=1e-2,
                                  betas=(0.9, 0.99))
    optimizer.zero_grad(set_to_none=True)

    for i in range(1000):
        optimizer.zero_grad(set_to_none=True)
        fake_inputs = torch.randn(32, 200, device=device)
        fake_labels = torch.randn(32, 64, device=device)
        outputs = model(fake_inputs)
        loss = ((outputs - fake_labels)**2).mean()
        loss.backward()
        model.clip_grad_norm_(1.0)
        optimizer.step()
        if i % 100 == 0:
            print("Loss = %s, rank = %s" % (loss.item(), local_rank))

    state_dict = model.state_dict()
    nn_model = nn.Sequential(
        nn.Linear(200, 200),
        nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200),
                      nn.Linear(200, 200),
                      checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                      nn.Linear(200, 200)),
        checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
        nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()
    nn_model.load_state_dict(state_dict)
    print("[Train]: Time = %s, Trainable Params = %s" %
          (get_time_string(), numel(nn_model) / 1_000_000))