Exemplo n.º 1
0
    def __init__(self, dictionary, embed_tokens, layers, encoder_config):
        super().__init__()
        self.dropout = encoder_config.dropout

        input_embed_dim = embed_tokens.embedding_dim
        self.padding_idx = dictionary.get_pad_index()
        self.max_source_positions = encoder_config.max_source_positions

        self.embed_scale = math.sqrt(
            input_embed_dim)  # todo: try with input_embed_dim
        self.no_token_positional_embeddings = (
            encoder_config.no_token_positional_embeddings)
        # creating this is also conditional
        self.project_in_dim = (
            Linear(input_embed_dim, encoder_config.encoder_embed_dim)
            if encoder_config.encoder_embed_dim != input_embed_dim else
            PlaceholderIdentity())
        self.embed_layer_norm = LayerNorm(encoder_config.encoder_embed_dim)

        self.combine_pos_embed = encoder_config.combine_pos_embed.value
        if encoder_config.combine_pos_embed == PostionalEmbedCombine.SUM:
            pos_embed_dim = encoder_config.encoder_embed_dim
        elif encoder_config.combine_pos_embed == PostionalEmbedCombine.CONCAT:
            pos_embed_dim = encoder_config.encoder_embed_dim - input_embed_dim
        else:
            raise NotImplementedError

        if not encoder_config.no_token_positional_embeddings:
            if encoder_config.positional_embedding_type == PostionalEmbedType.LEARNED:
                self.embed_positions = PositionalEmbedding(
                    encoder_config.max_source_positions,
                    pos_embed_dim,
                    self.padding_idx,
                )
            elif (encoder_config.positional_embedding_type
                  == PostionalEmbedType.SINUSOIDAL
                  or encoder_config.positional_embedding_type
                  == PostionalEmbedType.HYBRID):
                self.embed_positions = SinusoidalPositionalEmbedding(
                    pos_embed_dim,
                    self.padding_idx,
                    init_size=encoder_config.max_source_positions,
                    learned_embed=encoder_config.positional_embedding_type ==
                    PostionalEmbedType.HYBRID,
                )
            else:
                raise NotImplementedError(
                    "Positional embedding type not supported")
        else:
            self.embed_positions = PlaceholderIdentity()

        self.layers = nn.ModuleList(layers)

        self.normalize = encoder_config.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(encoder_config.encoder_embed_dim)
        else:
            self.layer_norm = PlaceholderIdentity()
Exemplo n.º 2
0
 def __init__(self, embedding_dim, padding_idx, init_size=124, learned_embed=False):
     super().__init__()
     self.embedding_dim = embedding_dim
     self.padding_idx = padding_idx
     if not learned_embed:
         self.sinusoidal_embedding_dim = embedding_dim
         self.learned_embed = learned_embed
         self.learned_embedding = PlaceholderIdentity()
     else:
         assert embedding_dim % 2 == 0
         self.sinusoidal_embedding_dim = embedding_dim // 2
         self.learned_embedding = nn.Embedding(
             init_size, embedding_dim // 2, padding_idx
         )
         self.learned_embed = learned_embed
     self.weights = nn.Parameter(
         get_sinusoidal_embedding(
             init_size, self.sinusoidal_embedding_dim, padding_idx
         )
     )
     self.weights.requires_grad = False
Exemplo n.º 3
0
def build_positional_embedding(
    positional_embedding_type: PostionalEmbedType,
    combine_pos_embed: PostionalEmbedCombine,
    max_target_positions: int,
    input_embed_dim: int,
    embed_dim: int,
    padding_idx: int,
    no_token_positional_embeddings: bool,
):
    if combine_pos_embed == PostionalEmbedCombine.SUM:
        pos_embed_dim = embed_dim
    elif combine_pos_embed == PostionalEmbedCombine.CONCAT:
        pos_embed_dim = embed_dim - input_embed_dim
    else:
        raise NotImplementedError
    if not no_token_positional_embeddings:
        if positional_embedding_type == PostionalEmbedType.LEARNED:
            return PositionalEmbedding(
                max_target_positions,
                pos_embed_dim,
                padding_idx,
            )
        elif (positional_embedding_type == PostionalEmbedType.SINUSOIDAL
              or positional_embedding_type == PostionalEmbedType.HYBRID):
            return SinusoidalPositionalEmbedding(
                pos_embed_dim,
                padding_idx,
                init_size=max_target_positions,
                learned_embed=positional_embedding_type ==
                PostionalEmbedType.HYBRID,
            )
        else:
            raise NotImplementedError(
                "Positional embedding type not supported")
    else:
        return PlaceholderIdentity()
Exemplo n.º 4
0
    def __init__(self, config: Config, embed_dim: int,
                 padding_idx: Tensor) -> None:
        super().__init__(config)
        self.padding_idx = padding_idx
        self.pooling_type = config.pooling_type
        self.dropout = nn.Dropout(config.encoder_config.dropout)
        input_embed_dim = embed_dim
        self.embed_scale = math.sqrt(
            input_embed_dim)  # todo: try with input_embed_dim
        self.max_source_positions = config.encoder_config.max_source_positions
        self.no_token_positional_embeddings = (
            config.encoder_config.no_token_positional_embeddings)

        # creating this is also conditional
        self.project_in_dim = (
            Linear(input_embed_dim, config.encoder_config.encoder_embed_dim)
            if config.encoder_config.encoder_embed_dim != input_embed_dim else
            PlaceholderIdentity())

        layers = []
        # Overwrite the config.layer_config.encoder_embed_dim so that it will always match with config.encoder_config.encoder_embed_dim
        config.layer_config.encoder_embed_dim = config.encoder_config.encoder_embed_dim
        for size in config.encoder_kernel_size_list:
            layers.append(create_module(config.layer_config, kernel_size=size))

        self.layers = nn.ModuleList(layers)
        self.embed_layer_norm = LayerNorm(
            config.encoder_config.encoder_embed_dim)
        self.combine_pos_embed = config.encoder_config.combine_pos_embed.value

        if config.encoder_config.combine_pos_embed == PostionalEmbedCombine.SUM:
            pos_embed_dim = config.encoder_config.encoder_embed_dim
        elif config.encoder_config.combine_pos_embed == PostionalEmbedCombine.CONCAT:
            pos_embed_dim = config.encoder_config.encoder_embed_dim - input_embed_dim
        else:
            raise NotImplementedError

        if not config.encoder_config.no_token_positional_embeddings:
            if (config.encoder_config.positional_embedding_type ==
                    PostionalEmbedType.LEARNED):
                self.embed_positions = PositionalEmbedding(
                    config.encoder_config.max_source_positions,
                    pos_embed_dim,
                    self.padding_idx,
                )
            elif (config.encoder_config.positional_embedding_type
                  == PostionalEmbedType.SINUSOIDAL
                  or config.encoder_config.positional_embedding_type
                  == PostionalEmbedType.HYBRID):
                self.embed_positions = SinusoidalPositionalEmbedding(
                    pos_embed_dim,
                    self.padding_idx,
                    init_size=config.encoder_config.max_source_positions,
                    learned_embed=config.encoder_config.
                    positional_embedding_type == PostionalEmbedType.HYBRID,
                )
            else:
                raise NotImplementedError(
                    "Positional embedding type not supported")
        else:
            self.embed_positions = PlaceholderIdentity()

        self.normalize = config.encoder_config.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(
                config.encoder_config.encoder_embed_dim)
        else:
            self.layer_norm = PlaceholderIdentity()

        log_class_usage(__class__)
Exemplo n.º 5
0
class LightConvEncoder(PyTextSeq2SeqModule, NAREncoderUtility):
    class Config(ModuleConfig):
        encoder_config: ConvEncoderConfig = ConvEncoderConfig()
        layer_config: LightConvEncoderLayer.Config = LightConvEncoderLayer.Config(
        )
        encoder_kernel_size_list: List[int] = [3, 7, 15]
        compression_dim: Optional[int] = 128

    @classmethod
    def from_config(cls, config, src_dict, src_embedding):
        kernel_size_list = config.encoder_kernel_size_list
        layers = []
        # Overwrite the config.layer_config.encoder_embed_dim so that it will always match with config.encoder_config.encoder_embed_dim
        config.layer_config.encoder_embed_dim = config.encoder_config.encoder_embed_dim
        for size in kernel_size_list:
            assert (config.encoder_config.encoder_embed_dim ==
                    config.layer_config.encoder_embed_dim)
            layers.append(create_module(config.layer_config, kernel_size=size))
        return cls(src_dict, src_embedding, layers, config.encoder_config)

    def __init__(self, dictionary, embed_tokens, layers, encoder_config):
        super().__init__()
        self.dropout = encoder_config.dropout

        input_embed_dim = embed_tokens.embedding_dim
        self.padding_idx = dictionary.get_pad_index()
        self.max_source_positions = encoder_config.max_source_positions

        self.embed_scale = math.sqrt(
            input_embed_dim)  # todo: try with input_embed_dim
        self.no_token_positional_embeddings = (
            encoder_config.no_token_positional_embeddings)
        # creating this is also conditional
        self.project_in_dim = (
            Linear(input_embed_dim, encoder_config.encoder_embed_dim)
            if encoder_config.encoder_embed_dim != input_embed_dim else
            PlaceholderIdentity())
        self.embed_layer_norm = LayerNorm(encoder_config.encoder_embed_dim)

        self.combine_pos_embed = encoder_config.combine_pos_embed.value
        if encoder_config.combine_pos_embed == PostionalEmbedCombine.SUM:
            pos_embed_dim = encoder_config.encoder_embed_dim
        elif encoder_config.combine_pos_embed == PostionalEmbedCombine.CONCAT:
            pos_embed_dim = encoder_config.encoder_embed_dim - input_embed_dim
        else:
            raise NotImplementedError

        if not encoder_config.no_token_positional_embeddings:
            if encoder_config.positional_embedding_type == PostionalEmbedType.LEARNED:
                self.embed_positions = PositionalEmbedding(
                    encoder_config.max_source_positions,
                    pos_embed_dim,
                    self.padding_idx,
                )
            elif (encoder_config.positional_embedding_type
                  == PostionalEmbedType.SINUSOIDAL
                  or encoder_config.positional_embedding_type
                  == PostionalEmbedType.HYBRID):
                self.embed_positions = SinusoidalPositionalEmbedding(
                    pos_embed_dim,
                    self.padding_idx,
                    init_size=encoder_config.max_source_positions,
                    learned_embed=encoder_config.positional_embedding_type ==
                    PostionalEmbedType.HYBRID,
                )
            else:
                raise NotImplementedError(
                    "Positional embedding type not supported")
        else:
            self.embed_positions = PlaceholderIdentity()

        self.layers = nn.ModuleList(layers)

        self.normalize = encoder_config.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(encoder_config.encoder_embed_dim)
        else:
            self.layer_norm = PlaceholderIdentity()

    def forward(self, src_tokens: Tensor, src_embeddings: Tensor,
                src_lengths: Tensor) -> Dict[str, Tensor]:
        output_dict: Dict[str, Tensor] = {}

        # embed tokens and positions
        x = self.embed_scale * src_embeddings
        if not self.no_token_positional_embeddings:
            x = self.pos_embed(x, src_tokens)
        else:
            x = self.project_in_dim(x)

        x = self.embed_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        output_dict["encoder_layer_0"] = x.clone()

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # Compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)  # B x T
        if (not encoder_padding_mask.any()
            ):  # Setting to None helps us avoid some masking operations later.
            # Different name is used to avoid some torchscript type checks
            encoder_mask = None
        else:
            encoder_mask = encoder_padding_mask

        # Encoder layers
        for idx, layer in enumerate(self.layers):
            x = layer(x, encoder_mask)
            output_dict["encoder_layer_" + str(idx + 1)] = x.transpose(
                0, 1).clone()

        if self.normalize:
            x = self.layer_norm(x)

        output_dict["src_tokens"] = src_tokens  # B x T
        if src_lengths is not None:
            output_dict["src_lengths"] = src_lengths
        output_dict["encoder_out"] = x  # T x B x C
        if encoder_mask is not None:
            output_dict["encoder_mask"] = encoder_mask  # B x T
        return output_dict

    def reorder_encoder_out(self, encoder_out: Dict[str, Tensor],
                            new_order: Tensor):
        encoder = encoder_out["encoder_out"]
        encoder = encoder.index_select(1, new_order)
        output_dict = {"encoder_out": encoder}

        output_dict["src_tokens"] = encoder_out["src_tokens"].index_select(
            0, new_order)
        padding_mask = encoder_out.get("encoder_mask", None)
        if padding_mask is not None:
            padding_mask = padding_mask.index_select(0, new_order)
            output_dict["encoder_mask"] = padding_mask
        return output_dict

    def pos_embed(self, x, src_tokens):
        if self.combine_pos_embed == PostionalEmbedCombine.SUM.value:
            x = self.project_in_dim(x)
            return self._vanilla_transformer(x, src_tokens)
        elif self.combine_pos_embed == PostionalEmbedCombine.CONCAT.value:
            return self._concat_pos_embed(x, src_tokens)
        else:
            raise NotImplementedError("Method not supported")

    def _vanilla_transformer(self, x, src_tokens):
        x += self.embed_positions(src_tokens)
        return x

    def _concat_pos_embed(self, x, src_tokens):
        pos_embed = self.embed_positions(src_tokens)
        return torch.cat([x, pos_embed], dim=2)

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.no_token_positional_embeddings:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def tile_encoder_out(self, tile_size: int,
                         encoder_out: Dict[str, Tensor]) -> Dict[str, Tensor]:
        tiled_out = torch.jit.annotate(Dict[str, Tensor], {})

        x = encoder_out["encoder_out"]
        new_x = x.repeat(1, tile_size, 1)
        tiled_out["encoder_out"] = new_x

        if "encoder_mask" in encoder_out:
            new_encoder_mask = encoder_out["encoder_mask"].repeat(tile_size, 1)
            tiled_out["encoder_mask"] = new_encoder_mask
        if "src_tokens" in encoder_out:
            tiled_out["src_tokens"] = encoder_out["src_tokens"].repeat(
                tile_size, 1)
        if "src_lengths" in encoder_out:
            tiled_out["src_lengths"] = encoder_out["src_lengths"].repeat(
                tile_size, 1)

        return tiled_out

    def extra_repr(self):
        s = "dropout={}, embed_scale={}, normalize={}".format(
            self.dropout, self.embed_scale, self.normalize)
        return s
Exemplo n.º 6
0
class LightConvDecoderBase(PyTextIncrementalDecoderComponent):
    class Config(ModuleConfig):
        decoder_config: ConvDecoderConfig = ConvDecoderConfig()
        layer_config: LightConvDecoderLayer.Config = LightConvDecoderLayer.Config(
        )
        decoder_kernel_size_list: List[int] = [3, 7, 15]

    @classmethod
    def from_config(cls, config, tgt_dict, tgt_embedding):
        kernel_size_list = config.decoder_kernel_size_list
        layers = []
        for size in kernel_size_list:
            assert (config.decoder_config.decoder_embed_dim ==
                    config.layer_config.decoder_embed_dim)
            layers.append(create_module(config.layer_config, kernel_size=size))
        return cls(tgt_dict, tgt_embedding, layers, config.decoder_config)

    def __init__(self, target_dict, embed_tokens, layers, decoder_config):
        super().__init__()
        self.dropout = decoder_config.dropout

        input_embed_dim = embed_tokens.embedding_dim
        embed_dim = decoder_config.decoder_embed_dim
        output_embed_dim = decoder_config.decoder_output_dim

        padding_idx = target_dict.get_pad_index()
        self.max_target_positions = decoder_config.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(
            embed_dim)  # todo: try with input_embed_dim
        self.padding_idx = padding_idx

        self.no_token_positional_embeddings = (
            decoder_config.no_token_positional_embeddings)
        # creating this is also conditional
        self.project_in_dim = (Linear(input_embed_dim, embed_dim)
                               if embed_dim != input_embed_dim else
                               PlaceholderIdentity())
        self.embed_layer_norm = LayerNorm(embed_dim)
        self.combine_pos_embed = decoder_config.combine_pos_embed.value
        if decoder_config.combine_pos_embed == PostionalEmbedCombine.SUM:
            pos_embed_dim = embed_dim
        elif decoder_config.combine_pos_embed == PostionalEmbedCombine.CONCAT:
            pos_embed_dim = embed_dim - input_embed_dim
        else:
            raise NotImplementedError
        if not decoder_config.no_token_positional_embeddings:
            if decoder_config.positional_embedding_type == PostionalEmbedType.LEARNED:
                self.embed_positions = PositionalEmbedding(
                    decoder_config.max_target_positions,
                    pos_embed_dim,
                    padding_idx,
                )
            elif (decoder_config.positional_embedding_type
                  == PostionalEmbedType.SINUSOIDAL
                  or decoder_config.positional_embedding_type
                  == PostionalEmbedType.HYBRID):
                self.embed_positions = SinusoidalPositionalEmbedding(
                    pos_embed_dim,
                    padding_idx,
                    init_size=decoder_config.max_target_positions,
                    learned_embed=decoder_config.positional_embedding_type ==
                    PostionalEmbedType.HYBRID,
                )
            else:
                raise NotImplementedError(
                    "Positional embedding type not supported")
        else:
            self.embed_positions = PlaceholderIdentity()

        self.layers = nn.ModuleList(layers)

        self.project_out_dim = (Linear(embed_dim, output_embed_dim, bias=False)
                                if embed_dim != output_embed_dim else
                                PlaceholderIdentity())

        self.normalize = decoder_config.decoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = PlaceholderIdentity()

    def forward_unprojected(
        self,
        prev_output_tokens: Tensor,
        encoder_out: Dict[str, Tensor],
        incremental_state: Optional[Dict[str, Tensor]] = None,
        timestep: Optional[int] = None,
    ) -> Tuple[Tensor, Dict[str, Tensor]]:
        output_dict: Dict[str, Tensor] = {}
        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens([[prev_output_tokens]])

        if not self.no_token_positional_embeddings:
            # TODO : Verify incremental generation for AR mode
            x = self.pos_embed(x, prev_output_tokens)
        else:
            x = self.project_in_dim(x)

        x = self.embed_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        output_dict["decoder_layer_0"] = x.clone()

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        last_layer_attn: Optional[Tensor] = None

        decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
        target_lengths = (~decoder_padding_mask).sum(dim=1)

        if not decoder_padding_mask.any():
            decoder_mask = None
        else:
            decoder_mask = decoder_padding_mask

        encoder = encoder_out["encoder_out"]
        encoder_mask: Optional[Tensor] = None
        if "encoder_mask" in encoder_out:
            encoder_mask = encoder_out["encoder_mask"]

        # decoder layers
        for idx, layer in enumerate(self.layers):
            encoder = encoder_out["encoder_out"]
            encoder_mask: Optional[Tensor] = None
            if "encoder_mask" in encoder_out:
                encoder_mask = encoder_out["encoder_mask"]
            x, last_layer_attn = layer(x, encoder, encoder_mask, decoder_mask,
                                       incremental_state)
            output_dict["decoder_layer_" + str(idx + 1)] = x.transpose(
                0, 1).clone()

        if self.normalize:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        x = self.project_out_dim(x)

        if last_layer_attn is not None:
            output_dict["attn_scores"] = last_layer_attn
        output_dict["target_lengths"] = target_lengths
        output_dict["decoder_mask"] = decoder_padding_mask

        for key in encoder_out.keys():
            output_dict[key] = encoder_out[key]

        return x, output_dict

    def pos_embed(self, x, src_tokens):
        # TODO : Positional embeddings needs to be tested in AR mode
        if self.combine_pos_embed == PostionalEmbedCombine.SUM.value:
            x = self.project_in_dim(x)
            return self._vanilla_transformer(x, src_tokens)
        elif self.combine_pos_embed == PostionalEmbedCombine.CONCAT.value:
            return self._concat_pos_embed(x, src_tokens)
        else:
            raise NotImplementedError("Method not supported")

    def _vanilla_transformer(self, x, src_tokens):
        x += self.embed_positions(src_tokens)
        return x

    def _concat_pos_embed(self, x, src_tokens):
        pos_embed = self.embed_positions(src_tokens)
        return torch.cat([x, pos_embed], dim=2)

    def max_positions(self):
        """Maximum output length supported by the decoder."""
        if self.no_token_positional_embeddings:
            return self.max_target_positions
        return min(self.max_target_positions,
                   self.embed_positions.max_positions())

    def reorder_incremental_state(self, incremental_state: Dict[str, Tensor],
                                  new_order: Tensor):
        for layer in self.layers:
            layer.reorder_incremental_state(incremental_state, new_order)

    def get_probs(
        self, decoder_out: Tuple[Tensor, Dict[str, Tensor]]
    ) -> Tuple[Tensor, Tensor, Tensor]:
        return self.projection_layer.get_probs(decoder_out)
Exemplo n.º 7
0
class LightConvRepresentation(RepresentationBase):
    """CNN based representation of a document."""
    class Config(RepresentationBase.Config):
        encoder_config: ConvEncoderConfig = ConvEncoderConfig()
        layer_config: LightConvEncoderLayer.Config = LightConvEncoderLayer.Config(
        )
        encoder_kernel_size_list: List[int] = [3, 7, 15]
        pooling_type: str = "mean"

    def __init__(self, config: Config, embed_dim: int,
                 padding_idx: Tensor) -> None:
        super().__init__(config)
        self.padding_idx = padding_idx
        self.pooling_type = config.pooling_type
        self.dropout = nn.Dropout(config.encoder_config.dropout)
        input_embed_dim = embed_dim
        self.embed_scale = math.sqrt(
            input_embed_dim)  # todo: try with input_embed_dim
        self.max_source_positions = config.encoder_config.max_source_positions
        self.no_token_positional_embeddings = (
            config.encoder_config.no_token_positional_embeddings)

        # creating this is also conditional
        self.project_in_dim = (
            Linear(input_embed_dim, config.encoder_config.encoder_embed_dim)
            if config.encoder_config.encoder_embed_dim != input_embed_dim else
            PlaceholderIdentity())

        layers = []
        # Overwrite the config.layer_config.encoder_embed_dim so that it will always match with config.encoder_config.encoder_embed_dim
        config.layer_config.encoder_embed_dim = config.encoder_config.encoder_embed_dim
        for size in config.encoder_kernel_size_list:
            layers.append(create_module(config.layer_config, kernel_size=size))

        self.layers = nn.ModuleList(layers)
        self.embed_layer_norm = LayerNorm(
            config.encoder_config.encoder_embed_dim)
        self.combine_pos_embed = config.encoder_config.combine_pos_embed.value

        if config.encoder_config.combine_pos_embed == PostionalEmbedCombine.SUM:
            pos_embed_dim = config.encoder_config.encoder_embed_dim
        elif config.encoder_config.combine_pos_embed == PostionalEmbedCombine.CONCAT:
            pos_embed_dim = config.encoder_config.encoder_embed_dim - input_embed_dim
        else:
            raise NotImplementedError

        if not config.encoder_config.no_token_positional_embeddings:
            if (config.encoder_config.positional_embedding_type ==
                    PostionalEmbedType.LEARNED):
                self.embed_positions = PositionalEmbedding(
                    config.encoder_config.max_source_positions,
                    pos_embed_dim,
                    self.padding_idx,
                )
            elif (config.encoder_config.positional_embedding_type
                  == PostionalEmbedType.SINUSOIDAL
                  or config.encoder_config.positional_embedding_type
                  == PostionalEmbedType.HYBRID):
                self.embed_positions = SinusoidalPositionalEmbedding(
                    pos_embed_dim,
                    self.padding_idx,
                    init_size=config.encoder_config.max_source_positions,
                    learned_embed=config.encoder_config.
                    positional_embedding_type == PostionalEmbedType.HYBRID,
                )
            else:
                raise NotImplementedError(
                    "Positional embedding type not supported")
        else:
            self.embed_positions = PlaceholderIdentity()

        self.normalize = config.encoder_config.encoder_normalize_before
        if self.normalize:
            self.layer_norm = LayerNorm(
                config.encoder_config.encoder_embed_dim)
        else:
            self.layer_norm = PlaceholderIdentity()

        log_class_usage(__class__)

    def forward(self, embedded_tokens: torch.Tensor, src_tokens: Tensor,
                src_lengths: Tensor) -> torch.Tensor:

        x = self.embed_scale * embedded_tokens
        if not self.no_token_positional_embeddings:
            x = self.pos_embed(x, src_tokens)
        else:
            x = self.project_in_dim(x)

        x = self.embed_layer_norm(x)
        x = self.dropout(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # Compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)  # B x T
        if (not encoder_padding_mask.any()
            ):  # Setting to None helps us avoid some masking operations later.
            # Different name is used to avoid some torchscript type checks
            encoder_mask = None
        else:
            encoder_mask = encoder_padding_mask

        # Encoder layers
        for _, layer in enumerate(self.layers):
            x = layer(x, encoder_mask)

        if self.normalize:
            x = self.layer_norm(x)

        x = pool(self.pooling_type, x.transpose(0, 1), encoder_padding_mask)
        return x

    def reorder_encoder_out(self, encoder_out: Dict[str, Tensor],
                            new_order: Tensor):
        encoder = encoder_out["encoder_out"]
        encoder = encoder.index_select(1, new_order)
        output_dict = {"encoder_out": encoder}

        output_dict["src_tokens"] = encoder_out["src_tokens"].index_select(
            0, new_order)
        padding_mask = encoder_out.get("encoder_mask", None)
        if padding_mask is not None:
            padding_mask = padding_mask.index_select(0, new_order)
            output_dict["encoder_mask"] = padding_mask
        return output_dict

    def pos_embed(self, x, src_tokens):
        if self.combine_pos_embed == PostionalEmbedCombine.SUM.value:
            x = self.project_in_dim(x)
            return self._vanilla_transformer(x, src_tokens)
        elif self.combine_pos_embed == PostionalEmbedCombine.CONCAT.value:
            return self._concat_pos_embed(x, src_tokens)
        else:
            raise NotImplementedError("Method not supported")

    def _vanilla_transformer(self, x, src_tokens):
        x += self.embed_positions(src_tokens)
        return x

    def _concat_pos_embed(self, x, src_tokens):
        pos_embed = self.embed_positions(src_tokens)
        return torch.cat([x, pos_embed], dim=2)

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        if self.no_token_positional_embeddings:
            return self.max_source_positions
        return min(self.max_source_positions,
                   self.embed_positions.max_positions())

    def tile_encoder_out(self, tile_size: int,
                         encoder_out: Dict[str, Tensor]) -> Dict[str, Tensor]:
        tiled_out = torch.jit.annotate(Dict[str, Tensor], {})

        x = encoder_out["encoder_out"]
        new_x = x.repeat(1, tile_size, 1)
        tiled_out["encoder_out"] = new_x

        if "encoder_mask" in encoder_out:
            new_encoder_mask = encoder_out["encoder_mask"].repeat(tile_size, 1)
            tiled_out["encoder_mask"] = new_encoder_mask
        if "src_tokens" in encoder_out:
            tiled_out["src_tokens"] = encoder_out["src_tokens"].repeat(
                tile_size, 1)
        if "src_lengths" in encoder_out:
            tiled_out["src_lengths"] = encoder_out["src_lengths"].repeat(
                tile_size, 1)

        return tiled_out

    def extra_repr(self):
        s = "dropout={}, embed_scale={}, normalize={}".format(
            self.dropout, self.embed_scale, self.normalize)
        return s