Esempio n. 1
0
    def __init__(
        self,
        embed_dim,
        num_heads,
        kdim=None,
        vdim=None,
        dropout=0.0,
        bias=True,
        self_attention=False,
        encoder_decoder_attention=False,
    ):
        super().__init__()
        if not has_megatron_submodule:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.model_parallel_size = get_model_parallel_world_size()

        self.num_heads_partition = num_heads // self.model_parallel_size
        assert (self.num_heads_partition *
                self.model_parallel_size == num_heads
                ), "Number of heads must be divisble by model parallel size"

        self.dropout_module = FairseqDropout(
            dropout, module_name=self.__class__.__name__)
        self.head_dim = embed_dim // num_heads
        assert (self.head_dim * num_heads == self.embed_dim
                ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5

        self.self_attention = self_attention
        self.encoder_decoder_attention = encoder_decoder_attention

        assert not self.self_attention or self.qkv_same_dim, (
            "Self-attention requires query, key and value to be of the same size"
        )

        self.k_proj = ColumnParallelLinear(self.kdim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.v_proj = ColumnParallelLinear(self.vdim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.q_proj = ColumnParallelLinear(embed_dim,
                                           embed_dim,
                                           bias=bias,
                                           gather_output=False)
        self.out_proj = RowParallelLinear(embed_dim,
                                          embed_dim,
                                          bias=bias,
                                          input_is_parallel=True)

        self.tpu = False
Esempio n. 2
0
 def build_fc2(self, input_dim, output_dim, **unused):
     return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
Esempio n. 3
0
 def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
     if q_noise > 0:
         raise NotImplementedError
     return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)