Ejemplo n.º 1
0
    def __init__(
        self,
        embed_dim,
        num_heads,
        kdim=None,
        vdim=None,
        dropout=0.0,
        bias=True,
        self_attention=False,
        encoder_decoder_attention=False,
    ):
        super().__init__()
        if not has_megatron_submodule:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.model_parallel_size = get_model_parallel_world_size()

        self.num_heads_partition = num_heads // self.model_parallel_size
        assert (self.num_heads_partition *
                self.model_parallel_size == num_heads
                ), "Number of heads must be divisble by model parallel size"

        self.dropout_module = FairseqDropout(
            dropout, module_name=self.__class__.__name__)
        self.head_dim = embed_dim // num_heads
        assert (self.head_dim * num_heads == self.embed_dim
                ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        self.self_attention = self_attention
        self.encoder_decoder_attention = encoder_decoder_attention

        assert not self.self_attention or self.qkv_same_dim, (
            "Self-attention requires query, key and value to be of the same size"
        )

        self.k_proj = ColumnParallelLinear(self.kdim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.v_proj = ColumnParallelLinear(self.vdim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.q_proj = ColumnParallelLinear(embed_dim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.out_proj = RowParallelLinear(embed_dim,
                                          embed_dim,
                                          bias=bias,
                                          input_is_parallel=True)

        self.tpu = False
Ejemplo n.º 2
0
 def __init__(
     self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
 ):
     super().__init__()
     self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
     self.activation_fn = utils.get_activation_fn(activation_fn)
     self.dropout = nn.Dropout(p=pooler_dropout)
     self.out_proj = nn.Linear(inner_dim, num_classes)
Ejemplo n.º 3
0
    def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
        super().__init__()
        self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
        self.activation_fn = utils.get_activation_fn(activation_fn)
        self.layer_norm = LayerNorm(embed_dim)

        if weight is None:
            weight = nn.Linear(embed_dim, output_dim, bias=False).weight
        self.weight = weight
        self.bias = nn.Parameter(torch.zeros(output_dim))
Ejemplo n.º 4
0
 def build_fc1(self, input_dim, output_dim, **unused):
     return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
Ejemplo n.º 5
0
 def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
     if q_noise > 0:
         raise NotImplementedError
     return ColumnParallelLinear(input_dim, output_dim, gather_output=False)