def __init__(self, attention_mask_func, init_method, output_layer_init_method, layer_number): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm # Layernorm on the input data. self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.attention = ParallelSelfAttention(attention_mask_func, init_method, output_layer_init_method, layer_number) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the input data. self.post_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method)
def __init__(self, attention_mask_func, init_method, output_layer_init_method): super(ParallelTransformer, self).__init__() args = get_args() # Store activation checkpoiting flag. self.checkpoint_activations = args.checkpoint_activations self.checkpoint_num_layers = args.checkpoint_num_layers # Number of layers: self.num_layers = args.num_layers self.num_unique_layers = args.num_unique_layers if self.num_unique_layers is None: self.num_unique_layers = self.num_layers assert self.num_layers % self.num_unique_layers == 0, \ 'number of layers should be divisible by number of unique layers' self.param_sharing_style = args.param_sharing_style # Transformer layers. sparsity = args.sparsity def build_layer(layer_number): if sparsity == 'none': sparse = False elif sparsity == 'all': sparse = True elif sparsity == 'interspersed': sparse = not layer_number % 2 == 0 return ParallelTransformerLayer(attention_mask_func, init_method, output_layer_init_method, layer_number, sparse=sparse) self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_unique_layers)]) # Print layer ordering. if self.num_layers != self.num_unique_layers: if torch.distributed.get_rank() == 0: print('> will be using the following layer ordering:') for i in range(self.num_layers): print(' layer id: {:3d} --> unique layer id: ' '{:3d}'.format(i, self._get_layer_index(i)), flush=True) # Final layer norm before output. self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) if deepspeed.checkpointing.is_configured(): global get_cuda_rng_tracker, checkpoint get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint