Пример #1
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        encoder_normalize_before: bool = False,
        use_bert_layer_norm: bool = False,
        use_gelu: bool = True,
        apply_bert_init: bool = False,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init

        self.embed_tokens = nn.Embedding(self.vocab_size, self.embedding_dim,
                                         self.padding_idx)

        self.segment_embeddings = (nn.Embedding(
            self.num_segments, self.embedding_dim, self.padding_idx)
                                   if self.num_segments > 0 else None)

        self.embed_positions = (PositionalEmbedding(
            self.max_seq_len,
            self.embedding_dim,
            self.padding_idx,
        ) if self.use_position_embeddings else None)

        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                encoder_normalize_before=encoder_normalize_before,
                use_bert_layer_norm=use_bert_layer_norm,
                use_gelu=use_gelu,
            ) for _ in range(num_encoder_layers)
        ])

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)
 def build_transformer_sentence_encoder_layer(
     self,
     embedding_dim,
     ffn_embedding_dim,
     num_attention_heads,
     dropout,
     attention_dropout,
     activation_dropout,
     activation_fn,
     export,
     q_noise,
     qn_block_size,
 ):
     return TransformerSentenceEncoderLayer(
         embedding_dim=embedding_dim,
         ffn_embedding_dim=ffn_embedding_dim,
         num_attention_heads=num_attention_heads,
         dropout=dropout,
         attention_dropout=attention_dropout,
         activation_dropout=activation_dropout,
         activation_fn=activation_fn,
         export=export,
         q_noise=q_noise,
         qn_block_size=qn_block_size,
     )
Пример #3
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 256,
        # num_segments: int = 2,
        use_position_embeddings: bool = True,
        offset_positions_by_padding: bool = True,
        encoder_normalize_before: bool = False,
        apply_bert_init: bool = True,
        activation_fn: str = "relu",
        learned_pos_embedding: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
        freeze_embeddings: bool = False,
        n_trans_layers_to_freeze: int = 0,
        export: bool = False,
        use_residual: bool = True,
        use_norm: bool = True,
        use_pretrain: bool = False,
        pretrain_vectors=None,
        pretrain_dim: int = 200,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        # self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding
        self.use_pretrain = use_pretrain
        self.pretrain_dim = pretrain_dim

        assert self.embedding_dim % num_attention_heads == 0

        # word embedding:
        if self.use_pretrain:
            self.embed_tokens = nn.Embedding(self.vocab_size,
                                             self.pretrain_dim,
                                             self.padding_idx)
        else:
            self.embed_tokens = nn.Embedding(self.vocab_size,
                                             self.embedding_dim,
                                             self.padding_idx)

        if self.use_pretrain and self.pretrain_dim % num_attention_heads != 0:
            self.pretrain_emb_transfer = nn.Linear(self.pretrain_dim,
                                                   self.embedding_dim)
            print("Use the pre-trained embedding transfer layer")
        else:
            self.pretrain_emb_transfer = None

        self.embed_scale = embed_scale

        # position embedding
        self.embed_positions = (PositionalEmbedding(
            self.max_seq_len,
            self.embedding_dim,
            padding_idx=(
                self.padding_idx if offset_positions_by_padding else None),
            learned=self.learned_pos_embedding,
        ) if self.use_position_embeddings else None)
        print(f'position embedding: {self.embed_positions}')

        # Transformer Layers:
        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                export=export,
                use_residual=use_residual,
                use_norm=use_norm,
            ) for _ in range(num_encoder_layers)
        ])

        # Layer Norm:
        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)
            # self.embed_tokens.weight.data.normal_(mean=0.0, std=0.02)
            # if use_position_embeddings:
            #     self.embed_positions.weight.data.normal_(mean=0.0, std=0.02)

        if self.use_pretrain:
            # self.embed_tokens.from_pretrained(pretrain_vectors, freeze=freeze_embeddings)
            self.embed_tokens.weight.data.copy_(pretrain_vectors)
            self.embed_tokens.weight.requires_grad = not freeze_embeddings
            print(f'Use the pre-train embedding(freeze={freeze_embeddings}):')
            # print(pretrain_vectors)

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings and use_pretrain:
            freeze_module_params(self.embed_tokens)
Пример #4
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        layerdrop : float = 0.0,
        max_seq_len: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        offset_positions_by_padding: bool = True,
        encoder_normalize_before: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = "relu",
        learned_pos_embedding: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
        freeze_embeddings: bool = False,
        n_trans_layers_to_freeze: int = 0,
        export: bool = False,
        traceable: bool = False,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.layerdrop = layerdrop
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding
        self.traceable = traceable

        self.embed_tokens = nn.Embedding(
            self.vocab_size, self.embedding_dim, self.padding_idx
        )
        self.embed_scale = embed_scale

        self.segment_embeddings = (
            nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None)
            if self.num_segments > 0
            else None
        )

        self.embed_positions = (
            PositionalEmbedding(
                self.max_seq_len,
                self.embedding_dim,
                padding_idx=(self.padding_idx if offset_positions_by_padding else None),
                learned=self.learned_pos_embedding,
            )
            if self.use_position_embeddings
            else None
        )

        self.layers = nn.ModuleList(
            [
                TransformerSentenceEncoderLayer(
                    embedding_dim=self.embedding_dim,
                    ffn_embedding_dim=ffn_embedding_dim,
                    num_attention_heads=num_attention_heads,
                    dropout=self.dropout,
                    attention_dropout=attention_dropout,
                    activation_dropout=activation_dropout,
                    activation_fn=activation_fn,
                    add_bias_kv=add_bias_kv,
                    add_zero_attn=add_zero_attn,
                    export=export,
                )
                for _ in range(num_encoder_layers)
            ]
        )

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings:
            freeze_module_params(self.embed_tokens)
            freeze_module_params(self.segment_embeddings)
            freeze_module_params(self.embed_positions)
            freeze_module_params(self.emb_layer_norm)

        for layer in range(n_trans_layers_to_freeze):
            freeze_module_params(self.layers[layer])
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        encoder_normalize_before: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = 'relu',
        learned_pos_embedding: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding

        self.embed_tokens = nn.Embedding(
            self.vocab_size,
            self.embedding_dim,
            self.padding_idx,
        )
        self.embed_scale = embed_scale

        self.segment_embeddings = (nn.Embedding(
            self.num_segments, self.embedding_dim, padding_idx=None)
                                   if self.num_segments > 0 else None)

        self.embed_positions = (PositionalEmbedding(
            self.max_seq_len,
            self.embedding_dim,
            self.padding_idx,
            learned=self.learned_pos_embedding,
        ) if self.use_position_embeddings else None)

        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
            ) for _ in range(num_encoder_layers)
        ])

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim)
        else:
            self.emb_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)
Пример #6
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        layerdrop : float = 0.0,
        max_seq_len: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        offset_positions_by_padding: bool = True,
        encoder_normalize_before: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = "relu",
        learned_pos_embedding: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
        freeze_embeddings: bool = False,
        n_trans_layers_to_freeze: int = 0,
        export: bool = False,
        traceable: bool = False,
        emb_weights_path: str = None,
        count_bins: int = 100,
        use_counts: bool = True,
        input_format: str = 'embs'
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.layerdrop = layerdrop
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding
        self.traceable = traceable
        self.count_bins = count_bins
        self.use_counts = use_counts
        self.input_format = input_format

        self.embed_tokens = nn.Embedding(
            self.vocab_size, self.embedding_dim, self.padding_idx
        )
        self.embed_scale = embed_scale

        self.segment_embeddings = (
            nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None)
            if self.num_segments > 0
            else None
        )

        self.embed_positions = (
            PositionalEmbedding(
                self.max_seq_len,
                self.embedding_dim,
                padding_idx=(self.padding_idx if offset_positions_by_padding else None),
                learned=self.learned_pos_embedding,
            )
            if self.use_position_embeddings
            else None
        )

        self.embed_counts = (
            nn.Embedding(self.count_bins+1, self.embedding_dim, padding_idx=None)
            if self.use_counts
            else None
        )

        self.layers = nn.ModuleList(
            [
                TransformerSentenceEncoderLayer(
                    embedding_dim=self.embedding_dim,
                    ffn_embedding_dim=ffn_embedding_dim,
                    num_attention_heads=num_attention_heads,
                    dropout=self.dropout,
                    attention_dropout=attention_dropout,
                    activation_dropout=activation_dropout,
                    activation_fn=activation_fn,
                    add_bias_kv=add_bias_kv,
                    add_zero_attn=add_zero_attn,
                    export=export,
                )
                for _ in range(num_encoder_layers)
            ]
        )

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

        # if emb_weights_path:
        #     print('loading pretrained token embs.')
        #     with open(emb_weights_path, 'rb') as f:
        #         emb_weights = pickle.load(f)
        #     # QUICK HACK (4 special symbols in beginning and mask at the end)
        #     mycopy = copy.copy(self.embed_tokens.weight.data.detach().numpy())
        #     mycopy[4:-1] = emb_weights
        #     self.embed_tokens.weight.data.copy_(torch.from_numpy(mycopy*0))
        #      # set to zero for masked elements
        #     if self.embed_tokens.padding_idx is not None:
        #         self.embed_tokens.weight.data[self.embed_tokens.padding_idx].zero_()

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings:
            print('freezing token embs.')
            freeze_module_params(self.embed_tokens)
            # freeze_module_params(self.segment_embeddings)
            # freeze_module_params(self.embed_positions)
            # freeze_module_params(self.emb_layer_norm)

        for layer in range(n_trans_layers_to_freeze):
            freeze_module_params(self.layers[layer])
    def __init__(self,
                 padding_idx_dict: dict,
                 vocab_size_dict: dict,
                 num_encoder_layers: int = 6,
                 embedding_dim: int = 768,
                 ffn_embedding_dim: int = 3072,
                 num_attention_heads: int = 8,
                 dropout: float = 0.1,
                 attention_dropout: float = 0.1,
                 activation_dropout: float = 0.1,
                 layerdrop: float = 0.0,
                 max_seq_len: int = 256,
                 num_segments: int = 2,
                 encoder_normalize_before: bool = False,
                 apply_bert_init: bool = False,
                 activation_fn: str = "relu",
                 learned_pos_embedding: bool = True,
                 add_bias_kv: bool = False,
                 add_zero_attn: bool = False,
                 embed_scale: float = None,
                 freeze_embeddings: bool = False,
                 n_trans_layers_to_freeze: int = 0,
                 export: bool = False,
                 traceable: bool = False,
                 input_combine: str = 'sum') -> None:

        super().__init__()
        self.padding_idx_dict = padding_idx_dict
        self.vocab_size_dict = vocab_size_dict
        self.dropout = dropout
        self.layerdrop = layerdrop
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding
        self.traceable = traceable
        self.fields = configs.fields

        self.token_emb_dict = nn.ModuleDict({
            field:
            nn.Embedding(self.vocab_size_dict[field], self.embedding_dim,
                         self.padding_idx_dict[field])
            for field in configs.fields[:configs.byte_start_pos]
        })
        self.byte_emb = nn.Embedding(
            self.vocab_size_dict[configs.fields[configs.byte_start_pos]],
            self.embedding_dim,
            self.padding_idx_dict[configs.fields[configs.byte_start_pos]])
        self.input_combine = input_combine

        if input_combine == 'MLP':
            # concatenate all byte embeddings
            self.input_combine_layer = nn.ReLU(nn.Linear(
                self.embedding_dim * len(self.fields[configs.byte_start_pos:]),
                self.embedding_dim),
                                               inplace=True)

        elif input_combine == 'birnn':
            self.input_combine_layer = nn.LSTM(self.embedding_dim,
                                               self.embedding_dim,
                                               1,
                                               batch_first=True,
                                               bidirectional=True)
            self.lstm_fc = nn.Linear(self.embedding_dim * 2,
                                     self.embedding_dim)  # 2 for bidirection

        self.embed_scale = embed_scale

        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                add_bias_kv=add_bias_kv,
                add_zero_attn=add_zero_attn,
                export=export,
            ) for _ in range(num_encoder_layers)
        ])

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings:
            freeze_module_params(self.embed_tokens)
            freeze_module_params(self.emb_layer_norm)

        for layer in range(n_trans_layers_to_freeze):
            freeze_module_params(self.layers[layer])
Пример #8
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 256,
        encoder_normalize_before: bool = False,
        embedding_normalize: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = "relu",
        embed_scale: float = None,
        rel_pos: bool = False,
        rel_pos_bins: int = 32,
        max_rel_pos: int = 128,
        export: bool = False,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.apply_bert_init = apply_bert_init
        self.embed_tokens = nn.Embedding(self.vocab_size, self.embedding_dim,
                                         self.padding_idx)
        self.embed_scale = embed_scale

        self.attn_scale_factor = 2
        self.num_attention_heads = num_attention_heads
        self.pos = nn.Embedding(self.max_seq_len + 1, self.embedding_dim)
        self.pos_q_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.pos_k_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.pos_scaling = float(self.embedding_dim / num_attention_heads *
                                 self.attn_scale_factor)**-0.5
        self.pos_ln = LayerNorm(self.embedding_dim, export=export)
        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                attn_scale_factor=self.attn_scale_factor,
                export=export,
                encoder_normalize_before=encoder_normalize_before,
            ) for _ in range(num_encoder_layers)
        ])

        if embedding_normalize:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        if encoder_normalize_before:
            self.emb_out_layer_norm = LayerNorm(self.embedding_dim,
                                                export=export)
        else:
            self.emb_out_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

        self.rel_pos = rel_pos
        if self.rel_pos:
            assert rel_pos_bins % 2 == 0
            self.rel_pos_bins = rel_pos_bins
            self.max_rel_pos = max_rel_pos
            self.relative_attention_bias = nn.Embedding(
                self.rel_pos_bins + 1, self.num_attention_heads)
            seq_len = self.max_seq_len
            context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
            memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
            relative_position = memory_position - context_position
            self.rp_bucket = relative_position_bucket(
                relative_position,
                num_buckets=self.rel_pos_bins,
                max_distance=self.max_rel_pos)
            # others to [CLS]
            self.rp_bucket[:, 0] = self.rel_pos_bins
            # [CLS] to others, Note: self.rel_pos_bins // 2 is not used in relative_position_bucket
            self.rp_bucket[0, :] = self.rel_pos_bins // 2
Пример #9
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 24,
        embedding_dim: int = 1024,
        ffn_embedding_dim: int = 4096,
        num_attention_heads: int = 16,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.0,
        layerdrop: float = 0.0,
        max_seq_len: int = 512,
        num_segments: int = 0,
        use_position_embeddings: bool = True,
        offset_positions_by_padding: bool = True,
        encoder_normalize_before: bool = True,
        apply_bert_init: bool = True,
        activation_fn: str = "gelu",
        learned_pos_embedding: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
        freeze_embeddings: bool = False,
        n_trans_layers_to_freeze: int = 0,
        export: bool = False,
        traceable: bool = False,
        q_noise: float = 0.0,
        qn_block_size: int = 8,
    ):

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.layerdrop = layerdrop
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding
        self.traceable = traceable
        self.num_encoder_layers = num_encoder_layers
        self.num_attention_heads = num_attention_heads

        self.embed_tokens = nn.Embedding(self.vocab_size, self.embedding_dim,
                                         self.padding_idx)

        self.embed_scale = embed_scale

        if q_noise > 0:
            self.quant_noise = apply_quant_noise_(
                nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
                q_noise,
                qn_block_size,
            )
        else:
            self.quant_noise = None

        self.segment_embeddings = (nn.Embedding(
            self.num_segments, self.embedding_dim, padding_idx=None)
                                   if self.num_segments > 0 else None)

        self.embed_positions = (
            PositionalEmbedding(
                self.max_seq_len,
                self.embedding_dim,
                padding_idx=(
                    self.padding_idx if offset_positions_by_padding else None),
                #padding_idx=None,
                learned=self.learned_pos_embedding,
            ) if self.use_position_embeddings else None)

        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                # add_bias_kv=add_bias_kv,
                # add_zero_attn=add_zero_attn,
                q_noise=q_noise,
                qn_block_size=qn_block_size,
                export=export,
            ) for _ in range(self.num_encoder_layers)
        ])
        #self.roberta = torch.hub.load('pytorch/fairseq', load_model)
        # self.roberta = RobertaModel.from_pretrained('model/roberta.base/',checkpoint_file='model.pt')
        # self.roberta=RobertaModel()
        # print(self.roberta.encode('Hello world!'))

        #self.score = nn.Linear(embedding_dim*2, 1, bias=True)

        self.score2 = nn.Sequential(
            nn.Linear(embedding_dim * 2, 200, bias=True), nn.Tanh())

        self.score3 = nn.Linear(200, 1, bias=True)

        if encoder_normalize_before:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        if self.apply_bert_init:
            self.apply(init_bert_params)

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings:
            freeze_module_params(self.embed_tokens)
            freeze_module_params(self.segment_embeddings)
            freeze_module_params(self.embed_positions)
            freeze_module_params(self.emb_layer_norm)

        for layer in range(n_trans_layers_to_freeze):
            freeze_module_params(self.layers[layer])
    def __init__(self, args, task):
        super(BertRanker, self).__init__(args, task)

        init_model = getattr(args, "pretrained_model", "")
        self.joint_layers = nn.ModuleList()
        if os.path.isfile(init_model):
            print(f"initialize weight from {init_model}")

            from fairseq import hub_utils

            x = hub_utils.from_pretrained(
                os.path.dirname(init_model),
                checkpoint_file=os.path.basename(init_model),
            )

            in_state_dict = x["models"][0].state_dict()
            init_args = x["args"].model

            num_positional_emb = init_args.max_positions + task.dictionary.pad(
            ) + 1

            # follow the setup in roberta
            self.model = TransformerSentenceEncoder(
                padding_idx=task.dictionary.pad(),
                vocab_size=len(task.dictionary),
                num_encoder_layers=getattr(args, "encoder_layers",
                                           init_args.encoder_layers),
                embedding_dim=init_args.encoder_embed_dim,
                ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
                num_attention_heads=init_args.encoder_attention_heads,
                dropout=init_args.dropout,
                attention_dropout=init_args.attention_dropout,
                activation_dropout=init_args.activation_dropout,
                num_segments=2,  # add language embeddings
                max_seq_len=num_positional_emb,
                offset_positions_by_padding=False,
                encoder_normalize_before=True,
                apply_bert_init=True,
                activation_fn=init_args.activation_fn,
                freeze_embeddings=args.freeze_embeddings,
                n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
            )

            # still need to learn segment embeddings as we added a second language embedding
            if args.freeze_embeddings:
                for p in self.model.segment_embeddings.parameters():
                    p.requires_grad = False

            update_init_roberta_model_state(in_state_dict)
            print("loading weights from the pretrained model")
            self.model.load_state_dict(
                in_state_dict,
                strict=False)  # ignore mismatch in language embeddings

            ffn_embedding_dim = init_args.encoder_ffn_embed_dim
            num_attention_heads = init_args.encoder_attention_heads
            dropout = init_args.dropout
            attention_dropout = init_args.attention_dropout
            activation_dropout = init_args.activation_dropout
            activation_fn = init_args.activation_fn

            classifier_embed_dim = getattr(args, "embed_dim",
                                           init_args.encoder_embed_dim)
            if classifier_embed_dim != init_args.encoder_embed_dim:
                self.transform_layer = nn.Linear(init_args.encoder_embed_dim,
                                                 classifier_embed_dim)
        else:
            self.model = TransformerSentenceEncoder(
                padding_idx=task.dictionary.pad(),
                vocab_size=len(task.dictionary),
                num_encoder_layers=args.encoder_layers,
                embedding_dim=args.embed_dim,
                ffn_embedding_dim=args.ffn_embed_dim,
                num_attention_heads=args.attention_heads,
                dropout=args.dropout,
                attention_dropout=args.attention_dropout,
                activation_dropout=args.activation_dropout,
                max_seq_len=task.max_positions()
                if task.max_positions() else args.tokens_per_sample,
                num_segments=2,
                offset_positions_by_padding=False,
                encoder_normalize_before=args.encoder_normalize_before,
                apply_bert_init=args.apply_bert_init,
                activation_fn=args.activation_fn,
            )

            classifier_embed_dim = args.embed_dim
            ffn_embedding_dim = args.ffn_embed_dim
            num_attention_heads = args.attention_heads
            dropout = args.dropout
            attention_dropout = args.attention_dropout
            activation_dropout = args.activation_dropout
            activation_fn = args.activation_fn

        self.joint_classification = args.joint_classification
        if args.joint_classification == "sent":
            if args.joint_normalize_before:
                self.joint_layer_norm = LayerNorm(classifier_embed_dim)
            else:
                self.joint_layer_norm = None

            self.joint_layers = nn.ModuleList([
                TransformerSentenceEncoderLayer(
                    embedding_dim=classifier_embed_dim,
                    ffn_embedding_dim=ffn_embedding_dim,
                    num_attention_heads=num_attention_heads,
                    dropout=dropout,
                    attention_dropout=attention_dropout,
                    activation_dropout=activation_dropout,
                    activation_fn=activation_fn,
                ) for _ in range(args.num_joint_layers)
            ])

        self.classifier = RobertaClassificationHead(
            classifier_embed_dim,
            classifier_embed_dim,
            1,  # num_classes
            "tanh",
            args.classifier_dropout,
        )
Пример #11
0
    def __init__(
        self,
        padding_idx: int,
        vocab_size: int,
        num_encoder_layers: int = 6,
        num_encoder_layers_cross: int = 6,
        embedding_dim: int = 768,
        embedding_dim_text: int = 768,
        embedding_dim_audio: int = 768,
        embedding_dim_video: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len_text: int = 256,
        max_seq_len_audio: int = 256,
        max_seq_len_video: int = 256,
        num_segments: int = 2,
        use_position_embeddings: bool = True,
        is_start_AV_embeddings: bool = True,
        offset_positions_by_padding: bool = True,
        encoder_normalize_before: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = "relu",
        learned_pos_embedding: bool = True,
        is_self_attention: bool = True,
        add_bias_kv: bool = False,
        add_zero_attn: bool = False,
        embed_scale: float = None,
        freeze_embeddings: bool = False,
        n_trans_layers_to_freeze: int = 0,
        export: bool = False,
        is_only_text: bool = False,
        is_only_audio: bool = False,
        is_only_video: bool = False,
        is_all_in: bool = False,
    ) -> None:

        super().__init__()
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.max_seq_len_t = max_seq_len_text  #text
        self.max_seq_len_a = max_seq_len_audio  #audio
        self.max_seq_len_v = max_seq_len_video  #video
        self.embedding_dim = embedding_dim
        self.embedding_dim_t = embedding_dim_text
        self.embedding_dim_a = embedding_dim_audio
        self.embedding_dim_v = embedding_dim_video
        self.num_segments = num_segments
        self.use_position_embeddings = use_position_embeddings
        self.is_start_AV_embeddings = is_start_AV_embeddings
        self.apply_bert_init = apply_bert_init
        self.learned_pos_embedding = learned_pos_embedding

        self.only_t = is_only_text
        self.only_a = is_only_audio
        self.only_v = is_only_video
        self.all_in = is_all_in

        self.embed_scale = embed_scale

        if self.only_v or self.all_in:

            self.SE_embeddings_v = (  #for start and end video #only start so 1
                nn.Embedding(1, self.embedding_dim_v, padding_idx=None)
                if self.is_start_AV_embeddings else None)

            self.padding_idx_v = 1  #1
            #Vid2vec max of 5 and dimentions of  256
            self.embed_positions_v = (  #We need 2 postional embeddings matrix for each modality (A,V)
                PositionalEmbeddingMul(
                    self.max_seq_len_v,
                    self.embedding_dim_v,
                    padding_idx=(self.padding_idx_v
                                 if offset_positions_by_padding else None),
                    learned=self.learned_pos_embedding,
                ) if self.use_position_embeddings else None)

            self.layers_v = nn.ModuleList(  #Text to Audio (The query vector comes from the Text and Key-Value from the Audio)
                [
                    TransformerSentenceEncoderLayer(
                        embedding_dim=self.embedding_dim_v,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers)
                ] if is_self_attention else None)

        if self.only_a or self.all_in:

            self.SE_embeddings_a = (  #for start and end Audio   #only start so 1
                nn.Embedding(1, self.embedding_dim_a, padding_idx=None)
                if self.is_start_AV_embeddings else None)

            self.padding_idx_a = 1  #1  #take one when you use padding mask with the forward function

            #Max positions 310 and dimentions of 512
            self.embed_positions_a = (  #We need three postional embeddings matrix for each modality
                PositionalEmbeddingMul(
                    self.max_seq_len_a,
                    self.embedding_dim_a,
                    padding_idx=(self.padding_idx_a
                                 if offset_positions_by_padding else None),
                    learned=self.learned_pos_embedding,
                ) if self.use_position_embeddings else None)

            self.layers_a = nn.ModuleList(  #Text to Audio (The query vector comes from the Text and Key-Value from the Audio)
                [
                    TransformerSentenceEncoderLayer(
                        embedding_dim=self.embedding_dim_a,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers)
                ] if is_self_attention else None)

        if (self.all_in) or (self.only_a and self.only_t):

            self.layers_ta = nn.ModuleList(  #Text to Audio (The query vector comes from the Text and Key-Value from the Audio)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_t,  #self.embedding_dim,
                        qdim=self.embedding_dim_t,
                        kdim=self.embedding_dim_a,
                        vdim=self.embedding_dim_a,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

            self.layers_at = nn.ModuleList(  #Audio to Text  (The query vector comes from the Audio and Key-Value from the Text)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_a,  #self.embedding_dim,
                        qdim=self.embedding_dim_a,
                        kdim=self.embedding_dim_t,
                        vdim=self.embedding_dim_t,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

        if (self.all_in) or (self.only_a and self.only_v):

            self.layers_av = nn.ModuleList(  #Audio to Video  (The query vector comes from the Audio and Key-Value from the Video)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_a,  #self.embedding_dim,
                        qdim=self.embedding_dim_a,
                        kdim=self.embedding_dim_v,
                        vdim=self.embedding_dim_v,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

            self.layers_va = nn.ModuleList(  #Video to Audio (The query vector comes from the Video and Key-Value from the Audio)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_v,  #self.embedding_dim,
                        qdim=self.embedding_dim_v,
                        kdim=self.embedding_dim_a,
                        vdim=self.embedding_dim_a,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

        if (self.all_in) or (self.only_t and self.only_v):
            self.layers_vt = nn.ModuleList(  #Video to Text  (The query vector comes from the Video and Key-Value from the Text)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_v,  #self.embedding_dim,
                        qdim=self.embedding_dim_v,
                        kdim=self.embedding_dim_t,
                        vdim=self.embedding_dim_t,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

            self.layers_tv = nn.ModuleList(  #Text to Video  (The query vector comes from the Text and Key-Value from the video)
                [
                    TransformerMultiEncoderLayer(
                        embedding_dim=self.
                        embedding_dim_t,  #self.embedding_dim,
                        qdim=self.embedding_dim_t,
                        kdim=self.embedding_dim_v,
                        vdim=self.embedding_dim_v,
                        self_attention=False,
                        encoder_decoder_attention=True,
                        ffn_embedding_dim=ffn_embedding_dim,
                        num_attention_heads=num_attention_heads,
                        dropout=self.dropout,
                        attention_dropout=attention_dropout,
                        activation_dropout=activation_dropout,
                        activation_fn=activation_fn,
                        add_bias_kv=add_bias_kv,
                        add_zero_attn=add_zero_attn,
                        export=export,
                    ) for _ in range(num_encoder_layers_cross)
                ])

        if encoder_normalize_before:

            if self.only_a or self.all_in:
                self.emb_layer_norm_a = LayerNorm(self.embedding_dim_a,
                                                  export=export)

            else:
                self.emb_layer_norm_a = None

            if self.only_v or self.all_in:
                self.emb_layer_norm_v = LayerNorm(self.embedding_dim_v,
                                                  export=export)

            else:
                self.emb_layer_norm_v = None

        else:
            self.emb_layer_norm_a = None
            self.emb_layer_norm_v = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)

        def freeze_module_params(m):
            if m is not None:
                for p in m.parameters():
                    p.requires_grad = False

        if freeze_embeddings:

            #freeze_module_params(self.embed_tokens)
            freeze_module_params(self.segment_embeddings)
            freeze_module_params(self.embed_positions)
            freeze_module_params(self.emb_layer_norm)

        for layer in range(n_trans_layers_to_freeze
                           ):  #Can freeze first few layers with this way
            freeze_module_params(self.layers[layer])
Пример #12
0
    def __init__(
        self,
        inst_cls_idx: int,
        state_cls_idx: int,
        inst_padding_idx: int,
        state_padding_idx: int,
        inst_vocab_size: int,
        state_vocab_size: int,
        num_encoder_layers: int = 6,
        embedding_dim: int = 768,
        ffn_embedding_dim: int = 3072,
        num_attention_heads: int = 8,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.1,
        max_seq_len: int = 255,
        encoder_normalize_before: bool = False,
        embedding_normalize: bool = False,
        apply_bert_init: bool = False,
        activation_fn: str = "relu",
        embed_scale: float = None,
        smallbert_num_encoder_layers: int = 1,
        smallbert_num_attention_heads: int = 8,
        smallbert_insts_max_seq_len: int = 32,
        smallbert_states_max_seq_len: int = 16,
        smallbert_insts_per_input: int = 4,
        smallbert_states_per_input: int = 4,
        #        rel_pos: bool = False,
        #        rel_pos_bins: int = 32,
        #        max_rel_pos: int = 128,
        export: bool = False,
    ) -> None:

        super().__init__()
        self.inst_cls_idx = inst_cls_idx
        self.state_cls_idx = state_cls_idx
        self.inst_padding_idx = inst_padding_idx
        self.state_padding_idx = state_padding_idx
        self.inst_vocab_size = inst_vocab_size
        self.state_vocab_size = state_vocab_size
        self.dropout = dropout
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.apply_bert_init = apply_bert_init
        self.smallbert_insts_max_seq_len = smallbert_insts_max_seq_len
        self.smallbert_states_max_seq_len = smallbert_states_max_seq_len
        self.smallbert_num_attention_heads = smallbert_num_attention_heads
        self.smallbert_num_encoder_layers = smallbert_num_encoder_layers
        self.smallbert_insts_per_input = smallbert_insts_per_input
        self.smallbert_states_per_input = smallbert_states_per_input
        self.embed_scale = embed_scale

        self.attn_scale_factor = 4
        self.num_attention_heads = num_attention_heads

        self.inst_bert = TransformerSentenceEncoder(
            padding_idx=inst_padding_idx,
            vocab_size=inst_vocab_size,
            num_encoder_layers=smallbert_num_encoder_layers,
            embedding_dim=embedding_dim,
            ffn_embedding_dim=ffn_embedding_dim,
            num_attention_heads=smallbert_num_attention_heads,
            dropout=dropout,
            attention_dropout=attention_dropout,
            activation_dropout=activation_dropout,
            max_seq_len=self.smallbert_insts_max_seq_len *
            self.smallbert_insts_per_input,
            encoder_normalize_before=encoder_normalize_before,
            embedding_normalize=embedding_normalize,
            apply_bert_init=apply_bert_init,
            activation_fn=activation_fn,
            embed_scale=embed_scale,
            export=export,
        )

        self.state_bert = TransformerSentenceEncoder(
            padding_idx=state_padding_idx,
            vocab_size=state_vocab_size,
            num_encoder_layers=smallbert_num_encoder_layers,
            embedding_dim=embedding_dim,
            ffn_embedding_dim=ffn_embedding_dim,
            num_attention_heads=smallbert_num_attention_heads,
            dropout=dropout,
            attention_dropout=attention_dropout,
            activation_dropout=activation_dropout,
            max_seq_len=self.smallbert_states_max_seq_len *
            self.smallbert_states_per_input,
            encoder_normalize_before=encoder_normalize_before,
            embedding_normalize=embedding_normalize,
            apply_bert_init=apply_bert_init,
            activation_fn=activation_fn,
            embed_scale=embed_scale,
            export=export,
        )

        self.cpos = nn.Embedding(self.max_seq_len + 2, self.embedding_dim)
        self.cpos_q_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.cpos_k_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.cpos_scaling = float(self.embedding_dim / num_attention_heads *
                                  self.attn_scale_factor)**-0.5
        self.cpos_ln = LayerNorm(self.embedding_dim, export=export)

        self.tpos = nn.Embedding(self.max_seq_len + 3, self.embedding_dim)
        self.tpos_q_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.tpos_k_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.tpos_scaling = float(self.embedding_dim / num_attention_heads *
                                  self.attn_scale_factor)**-0.5
        self.tpos_ln = LayerNorm(self.embedding_dim, export=export)

        self.fpos = nn.Embedding(self.max_seq_len + 3, self.embedding_dim)
        self.fpos_q_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.fpos_k_linear = nn.Linear(self.embedding_dim, self.embedding_dim)
        self.fpos_scaling = float(self.embedding_dim / num_attention_heads *
                                  self.attn_scale_factor)**-0.5
        self.fpos_ln = LayerNorm(self.embedding_dim, export=export)

        self.layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                attn_scale_factor=self.attn_scale_factor,
                export=export,
                encoder_normalize_before=encoder_normalize_before,
            ) for _ in range(num_encoder_layers)
        ])

        self.inst_bert_layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=smallbert_num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                attn_scale_factor=self.attn_scale_factor,
                export=export,
                encoder_normalize_before=encoder_normalize_before,
            ) for _ in range(smallbert_num_encoder_layers)
        ])

        self.state_bert_layers = nn.ModuleList([
            TransformerSentenceEncoderLayer(
                embedding_dim=self.embedding_dim,
                ffn_embedding_dim=ffn_embedding_dim,
                num_attention_heads=smallbert_num_attention_heads,
                dropout=self.dropout,
                attention_dropout=attention_dropout,
                activation_dropout=activation_dropout,
                activation_fn=activation_fn,
                attn_scale_factor=self.attn_scale_factor,
                export=export,
                encoder_normalize_before=encoder_normalize_before,
            ) for _ in range(smallbert_num_encoder_layers)
        ])

        self.inst_layer_norm = LayerNorm(self.embedding_dim, export=export)
        self.state_layer_norm = LayerNorm(self.embedding_dim, export=export)

        if embedding_normalize:
            self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export)
        else:
            self.emb_layer_norm = None

        if encoder_normalize_before:
            self.emb_out_layer_norm = LayerNorm(self.embedding_dim,
                                                export=export)
        else:
            self.emb_out_layer_norm = None

        # Apply initialization of model params after building the model
        if self.apply_bert_init:
            self.apply(init_bert_params)