def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, self_attention=False, encoder_decoder_attention=False, ): super().__init__() if not has_megatron_submodule: raise ImportError('\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron') self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.model_parallel_size = get_model_parallel_world_size() self.num_heads_partition = num_heads // self.model_parallel_size assert (self.num_heads_partition * self.model_parallel_size == num_heads ), "Number of heads must be divisble by model parallel size" self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__) self.head_dim = embed_dim // num_heads assert (self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ( "Self-attention requires query, key and value to be of the same size" ) self.k_proj = ColumnParallelLinear(self.kdim, embed_dim, bias=bias, gather_output=False) self.v_proj = ColumnParallelLinear(self.vdim, embed_dim, bias=bias, gather_output=False) self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, gather_output=False) self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True) self.tpu = False
def __init__( self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout ): super().__init__() self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes)
def __init__(self, embed_dim, output_dim, activation_fn, weight=None): super().__init__() self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.layer_norm = LayerNorm(embed_dim) if weight is None: weight = nn.Linear(embed_dim, output_dim, bias=False).weight self.weight = weight self.bias = nn.Parameter(torch.zeros(output_dim))
def build_fc1(self, input_dim, output_dim, **unused): return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): if q_noise > 0: raise NotImplementedError return ColumnParallelLinear(input_dim, output_dim, gather_output=False)