def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):

        super(TransformerEncoderHEO, self).__init__()
        inter_layers = [int(i) for i in inter_layers]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.dropout = nn.Dropout(dropout)
        ######
        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else 'local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        ### self-attention based positional embeddings
        self.pos_emb = PositionalEncoding(dropout,
                                          self.embeddings.embedding_dim)
        self.pos_attn = SelfAttention(d_model, dropout)
        self.final_proj = nn.Linear(3 * d_model, d_model)
示例#2
0
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_att_version, inter_layers, inter_heads, device):
        super(TransformerInterEncoder, self).__init__()
        inter_layers = [int(i) for i in inter_layers]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)
        self.inter_att_version = inter_att_version

        if (inter_att_version == 2):
            self.transformer_layers = nn.ModuleList([
                TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
                if i in inter_layers else TransformerEncoderLayer(
                    d_model, heads, d_ff, dropout) for i in range(num_layers)
            ])
        elif (inter_att_version == 3):
            self.transformer_layers = nn.ModuleList([
                TransformerNewInterLayer(d_model, inter_heads, d_ff, dropout)
                if i in inter_layers else TransformerEncoderLayer(
                    d_model, heads, d_ff, dropout) for i in range(num_layers)
            ])
        self.transformer_types = [
            'inter' if i in inter_layers else 'local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
 def __init__(self,
              num_layers,
              d_model,
              heads,
              d_ff,
              dropout,
              embeddings,
              inter_layers,
              inter_heads,
              device,
              mem_args=None):
     super(TransformerInterEncoder, self).__init__()
     inter_layers = [int(i) for i in inter_layers]
     self.device = device
     self.d_model = d_model
     self.num_layers = num_layers
     self.embeddings = embeddings
     self.pos_emb = PositionalEncoding(
         dropout, int(self.embeddings.embedding_dim / 2))
     self.dropout = nn.Dropout(dropout)
     if mem_args is not None and mem_args.mem_enc_positions:
         mem_flags = [
             True if str(i) in mem_args.mem_enc_positions else False
             for i in range(num_layers)
         ]
     else:
         mem_flags = [False] * num_layers
     self.transformer_layers = nn.ModuleList([
         TransformerInterLayer(d_model,
                               inter_heads,
                               d_ff,
                               dropout,
                               mem_args=mem_args,
                               use_mem=mem_flags[i]) if i in inter_layers
         else TransformerEncoderLayer(d_model,
                                      heads,
                                      d_ff,
                                      dropout,
                                      mem_args=mem_args,
                                      use_mem=mem_flags[i])
         for i in range(num_layers)
     ])
     self.transformer_types = [
         'inter' if i in inter_layers else 'local'
         for i in range(num_layers)
     ]
     self.mem_types = ['PKM' if flag else 'FFN' for flag in mem_flags]
     print(list(zip(self.transformer_types, self.mem_types)))
     self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
 def __init__(self,
              num_layers,
              d_model,
              heads,
              d_ff,
              dropout,
              embeddings,
              mem_args=None):
     super(TransformerEncoder, self).__init__()
     self.d_model = d_model
     self.num_layers = num_layers
     self.embeddings = embeddings
     self.pos_emb = PositionalEncoding(dropout,
                                       self.embeddings.embedding_dim)
     if mem_args and mem_args.mem_enc_positions:
         mem_flags = [
             True if str(i) in mem_args.mem_enc_positions else False
             for i in range(num_layers)
         ]
     else:
         mem_flags = [False] * num_layers
     self.transformer_local = nn.ModuleList([
         TransformerEncoderLayer(d_model, heads, d_ff, dropout, mem_args,
                                 mem_flags[i]) for i in range(num_layers)
     ])
     self.mem_types = ['PKM' if flag else 'FFN' for flag in mem_flags]
     print(self.mem_types)
     self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):
        super(TransformerEncoderHEQ, self).__init__()

        inter_layers = [int(i) for i in inter_layers]
        para_layers = [i for i in range(num_layers)]
        for i in inter_layers:
            para_layers.remove(i)
        query_layer = para_layers[-1]
        para_layers = para_layers[:-1]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)

        ### Query Encoder
        self.transformer_query_encoder = TransformerEncoderLayer(
            d_model, heads, d_ff, dropout)
        self.query_layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.query_pos_emb = PositionalEncoding(dropout,
                                                self.embeddings.embedding_dim,
                                                buffer_name='qpe')
        ######
        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) if i in para_layers else
            TransformerQueryEncoderLayer(d_model, heads, d_ff, dropout)
            for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else
            'para_local' if i in para_layers else 'query_local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.proj = nn.Linear(
            2 * d_model, d_model,
            bias=False)  # for concating local and global layer
示例#6
0
 def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
     super(TransformerEncoder, self).__init__()
     self.d_model = d_model
     self.num_layers = num_layers
     self.embeddings = embeddings
     self.pos_emb = PositionalEncoding(dropout,
                                       self.embeddings.embedding_dim)
     self.transformer_local = nn.ModuleList([
         TransformerEncoderLayer(d_model, heads, d_ff, dropout)
         for _ in range(num_layers)
     ])
     self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
示例#7
0
class TransformerInterEncoder(nn.Module):
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):
        super(TransformerInterEncoder, self).__init__()
        inter_layers = [int(i) for i in inter_layers]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)

        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else 'local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, src):
        """ See :obj:`EncoderBase.forward()`"""
        batch_size, n_blocks, n_tokens = src.size()
        # src = src.view(batch_size * n_blocks, n_tokens)
        emb = self.embeddings(src)
        padding_idx = self.embeddings.padding_idx
        #mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens)
        mask_local = ~src.data.eq(padding_idx).view(batch_size * n_blocks,
                                                    n_tokens)
        mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens),
                               -1) > 0

        local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1)
        emb = emb * math.sqrt(self.embeddings.embedding_dim)
        emb = emb + combined_pos_emb
        emb = self.pos_emb.dropout(emb)

        word_vec = emb.view(batch_size * n_blocks, n_tokens, -1)

        for i in range(self.num_layers):
            if (self.transformer_types[i] == 'local'):
                word_vec = self.transformer_layers[i](
                    word_vec, word_vec,
                    ~mask_local)  # all_sents * max_tokens * dim
            elif (self.transformer_types[i] == 'inter'):
                word_vec = self.transformer_layers[i](
                    word_vec, ~mask_local, ~mask_block, batch_size,
                    n_blocks)  # all_sents * max_tokens * dim

        word_vec = self.layer_norm(word_vec)
        mask_hier = mask_local[:, :, None].float()
        src_features = word_vec * mask_hier
        src_features = src_features.view(batch_size, n_blocks * n_tokens, -1)
        src_features = src_features.transpose(
            0, 1).contiguous()  # src_len, batch_size, hidden_dim
        mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1)
        mask_hier = mask_hier.transpose(0, 1).contiguous()

        unpadded = [
            torch.masked_select(src_features[:, i],
                                mask_hier[:, i].byte()).view(
                                    [-1, src_features.size(-1)])
            for i in range(src_features.size(1))
        ]
        max_l = max([p.size(0) for p in unpadded])
        mask_hier = sequence_mask(torch.tensor([p.size(0) for p in unpadded]),
                                  max_l)  # .to(self.device)
        mask_hier = ~mask_hier[:, None, :]

        # .to(self.device)
        unpadded = torch.stack([
            torch.cat(
                [p, torch.zeros(max_l - p.size(0), src_features.size(-1))])
            for p in unpadded
        ], 1)
        return unpadded, mask_hier
class TransformerInterExtracter(nn.Module):
    def __init__(self,
                 num_layers,
                 d_model,
                 heads,
                 d_ff,
                 dropout,
                 embeddings,
                 inter_layers,
                 inter_heads,
                 device,
                 mem_args=None):
        super(TransformerInterExtracter, self).__init__()
        inter_layers = [int(i) for i in inter_layers]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)
        if mem_args is not None and mem_args.mem_enc_positions:
            mem_flags = [
                True if str(i) in mem_args.mem_enc_positions else False
                for i in range(num_layers)
            ]
        else:
            mem_flags = [False] * num_layers
        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model,
                                  inter_heads,
                                  d_ff,
                                  dropout,
                                  mem_args=mem_args,
                                  use_mem=mem_flags[i]) if i in inter_layers
            else TransformerEncoderLayer(d_model,
                                         heads,
                                         d_ff,
                                         dropout,
                                         mem_args=mem_args,
                                         use_mem=mem_flags[i])
            for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else 'local'
            for i in range(num_layers)
        ]
        self.mem_types = ['PKM' if flag else 'FFN' for flag in mem_flags]
        print(list(zip(self.transformer_types, self.mem_types)))
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, src):
        """ See :obj:`EncoderBase.forward()`"""
        batch_size, n_blocks, n_tokens = src.size()
        # src = src.view(batch_size * n_blocks, n_tokens)
        emb = self.embeddings(src)
        padding_idx = self.embeddings.padding_idx
        mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks,
                                                       n_tokens)
        mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens),
                               -1) > 0

        local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1)
        emb = emb * math.sqrt(self.embeddings.embedding_dim)
        emb = emb + combined_pos_emb
        emb = self.pos_emb.dropout(emb)

        word_vec = emb.view(batch_size * n_blocks, n_tokens, -1)

        for i in range(self.num_layers):
            #print('about to process layer:', i, self.transformer_types[i])
            if (self.transformer_types[i] == 'local'):
                word_vec = self.transformer_layers[i](
                    word_vec, word_vec,
                    1 - mask_local)  # all_sents * max_tokens * dim
            elif (self.transformer_types[i] == 'inter'):
                word_vec = self.transformer_layers[i](
                    word_vec, 1 - mask_local, 1 - mask_block, batch_size,
                    n_blocks)  # all_sents * max_tokens * dim

        word_vec = self.layer_norm(word_vec)
        mask_hier = mask_local[:, :, None].float()
        src_features = word_vec * mask_hier
        src_features = src_features.view(batch_size, n_blocks * n_tokens, -1)
        #src_features = src_features.transpose(0, 1).contiguous()  # src_len, batch_size, hidden_dim
        mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1)
        #mask_hier = mask_hier.transpose(0, 1).contiguous()
        return src_features, mask_hier
class TransformerEncoderHERO(nn.Module):
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):
        super(TransformerEncoderHERO, self).__init__()

        inter_layers = [int(i) for i in inter_layers]
        para_layers = [i for i in range(num_layers)]
        for i in inter_layers:
            para_layers.remove(i)
        query_layer = para_layers[-1]
        para_layers = para_layers[:-1]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(dropout,
                                          int(self.embeddings.embedding_dim))
        self.dropout = nn.Dropout(dropout)

        ### Query Encoder
        self.transformer_query_encoder = TransformerEncoderLayer(
            d_model, heads, d_ff, dropout)
        self.query_layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.query_pos_emb = PositionalEncoding(dropout,
                                                self.embeddings.embedding_dim,
                                                buffer_name='qpe')
        ######
        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) if i in para_layers else
            TransformerQueryEncoderLayer(d_model, heads, d_ff, dropout)
            for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else
            'para_local' if i in para_layers else 'query_local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        ### self-attention based positional embeddings
        self.pos_emb = PositionalEncoding(dropout,
                                          self.embeddings.embedding_dim)
        self.pos_attn = SelfAttention(d_model, dropout)
        self.final_proj = nn.Linear(3 * d_model, d_model)

    def forward(self, src, query):
        batch_size, n_blocks, n_tokens = src.size()
        emb = self.embeddings(src)
        padding_idx = self.embeddings.padding_idx
        mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks,
                                                       n_tokens)
        mask_query = 1 - query.data.eq(padding_idx)
        mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens),
                               -1) > 0

        ### operations for query encoding

        _, qn_tokens = query.size()
        qpos_emb = self.query_pos_emb.qpe[:, :qn_tokens].expand(
            batch_size, qn_tokens, self.embeddings.embedding_dim)
        qemb = self.embeddings(query) * math.sqrt(
            self.embeddings.embedding_dim) + qpos_emb
        qemb = self.query_pos_emb.dropout(qemb)
        query_vec = self.transformer_query_encoder(qemb, qemb, 1 - mask_query)

        #####
        local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(
            batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim))
        combined_pos_emb = local_pos_emb
        emb = emb * math.sqrt(self.embeddings.embedding_dim)
        emb = emb + combined_pos_emb
        emb = self.pos_emb.dropout(emb)

        word_vec = emb.view(batch_size * n_blocks, n_tokens, -1)

        for i in range(self.num_layers):
            if (self.transformer_types[i] == 'para_local'):
                word_vec = self.transformer_layers[i](
                    word_vec, word_vec,
                    1 - mask_local)  # all_sents * max_tokens * dim
            elif self.transformer_layers[i] == 'query_local':
                word_vec = self.transformer_layers[i](word_vec, query_vec,
                                                      1 - mask_local,
                                                      1 - mask_query)

            elif (self.transformer_types[i] == 'inter'):
                if 'local' in self.transformer_types[i - 1]:
                    local_vec = word_vec
                word_vec = self.transformer_layers[i](
                    word_vec, 1 - mask_local, 1 - mask_block, batch_size,
                    n_blocks)  # all_sents * max_tokens * dim

        global_vec = self.layer_norm(word_vec)
        local_vec = self.layer_norm(local_vec)
        mask_hier = mask_local[:, :, None].float()

        global_src_features = global_vec * mask_hier
        global_src_features = global_src_features.view(batch_size,
                                                       n_blocks * n_tokens, -1)
        global_src_features = global_src_features.transpose(0, 1).contiguous()

        local_src_features = local_vec * mask_hier
        local_src_features = local_src_features.view(batch_size,
                                                     n_blocks * n_tokens, -1)
        local_src_features = local_src_features.transpose(0, 1).contiguous()

        mask = mask_local
        mask = mask.view(batch_size, n_blocks * n_tokens)
        mask = mask.unsqueeze(1)

        ### self attention for positional embedding
        d_model = self.d_model
        pos_features = global_src_features.transpose(0, 1).contiguous()
        _, attn = self.pos_attn(pos_features, mask.squeeze(1))

        attn = attn.view(batch_size, n_blocks, n_tokens)
        para_attn = attn.sum(-1)  # batch_size x n_blocks

        pe = torch.zeros(batch_size, n_blocks, d_model).cuda()
        multiplier_term = torch.exp(
            (torch.arange(0, d_model, 2, dtype=torch.float) *
             -(math.log(10000.0) / d_model))).cuda()

        pe[:, :, 0::2] = torch.sin(para_attn.unsqueeze(-1) * multiplier_term)
        pe[:, :, 1::2] = torch.cos(para_attn.unsqueeze(-1) * multiplier_term)

        pe = pe.unsqueeze(-2).expand(batch_size, n_blocks, n_tokens,
                                     -1).contiguous()
        pe = pe.view(batch_size, n_blocks * n_tokens,
                     -1).contiguous().transpose(0, 1)

        feats = torch.cat((global_src_features, local_src_features, pe), -1)

        feats = self.final_proj(feats)

        return feats, mask
class TransformerEncoderHEQ(nn.Module):
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):
        super(TransformerEncoderHEQ, self).__init__()

        inter_layers = [int(i) for i in inter_layers]
        para_layers = [i for i in range(num_layers)]
        for i in inter_layers:
            para_layers.remove(i)
        query_layer = para_layers[-1]
        para_layers = para_layers[:-1]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)

        ### Query Encoder
        self.transformer_query_encoder = TransformerEncoderLayer(
            d_model, heads, d_ff, dropout)
        self.query_layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.query_pos_emb = PositionalEncoding(dropout,
                                                self.embeddings.embedding_dim,
                                                buffer_name='qpe')
        ######
        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) if i in para_layers else
            TransformerQueryEncoderLayer(d_model, heads, d_ff, dropout)
            for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else
            'para_local' if i in para_layers else 'query_local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.proj = nn.Linear(
            2 * d_model, d_model,
            bias=False)  # for concating local and global layer

    def forward(self, src, query):
        batch_size, n_blocks, n_tokens = src.size()
        emb = self.embeddings(src)
        padding_idx = self.embeddings.padding_idx
        mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks,
                                                       n_tokens)
        mask_query = 1 - query.data.eq(padding_idx)
        mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens),
                               -1) > 0

        ### operations for query encoding

        _, qn_tokens = query.size()
        qpos_emb = self.query_pos_emb.qpe[:, :qn_tokens].expand(
            batch_size, qn_tokens, self.embeddings.embedding_dim)
        qemb = self.embeddings(query) * math.sqrt(
            self.embeddings.embedding_dim) + qpos_emb
        qemb = self.query_pos_emb.dropout(qemb)
        query_vec = self.transformer_query_encoder(qemb, qemb, 1 - mask_query)

        #####
        local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1)
        emb = emb * math.sqrt(self.embeddings.embedding_dim)
        emb = emb + combined_pos_emb
        emb = self.pos_emb.dropout(emb)

        word_vec = emb.view(batch_size * n_blocks, n_tokens, -1)

        for i in range(self.num_layers):
            if (self.transformer_types[i] == 'para_local'):
                word_vec = self.transformer_layers[i](
                    word_vec, word_vec,
                    1 - mask_local)  # all_sents * max_tokens * dim
            elif self.transformer_layers[i] == 'query_local':
                word_vec = self.transformer_layers[i](word_vec, query_vec,
                                                      1 - mask_local,
                                                      1 - mask_query)

            elif (self.transformer_types[i] == 'inter'):
                if 'local' in self.transformer_types[i - 1]:
                    local_vec = word_vec
                word_vec = self.transformer_layers[i](
                    word_vec, 1 - mask_local, 1 - mask_block, batch_size,
                    n_blocks)  # all_sents * max_tokens * dim

        global_vec = self.layer_norm(word_vec)
        local_vec = self.layer_norm(local_vec)
        mask_hier = mask_local[:, :, None].float()
        global_src_features = global_vec * mask_hier
        global_src_features = global_src_features.view(
            -1, global_src_features.size(-1))

        local_src_features = local_vec * mask_hier
        local_src_features = local_src_features.view(
            -1, local_src_features.size(-1))

        # cocat and project
        src_features = torch.cat((global_src_features, local_src_features), 1)
        src_features = self.proj(src_features)
        src_features = src_features.view(batch_size, n_blocks * n_tokens, -1)
        src_features = src_features.transpose(0, 1).contiguous()

        mask = mask_local
        mask = mask.view(batch_size, n_blocks * n_tokens)
        mask = mask.unsqueeze(1)
        return src_features, mask
class TransformerEncoderHE(nn.Module):
    def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings,
                 inter_layers, inter_heads, device):
        super(TransformerEncoderHE, self).__init__()
        inter_layers = [int(i) for i in inter_layers]
        self.device = device
        self.d_model = d_model
        self.num_layers = num_layers
        self.embeddings = embeddings
        self.pos_emb = PositionalEncoding(
            dropout, int(self.embeddings.embedding_dim / 2))
        self.dropout = nn.Dropout(dropout)

        self.transformer_layers = nn.ModuleList([
            TransformerInterLayer(d_model, inter_heads, d_ff, dropout)
            if i in inter_layers else TransformerEncoderLayer(
                d_model, heads, d_ff, dropout) for i in range(num_layers)
        ])
        self.transformer_types = [
            'inter' if i in inter_layers else 'local'
            for i in range(num_layers)
        ]
        print(self.transformer_types)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.proj = nn.Linear(2 * d_model, d_model, bias=False)

    def forward(self, src):
        """ See :obj:`EncoderBase.forward()`"""
        batch_size, n_blocks, n_tokens = src.size()
        # src = src.view(batch_size * n_blocks, n_tokens)
        emb = self.embeddings(src)
        padding_idx = self.embeddings.padding_idx
        mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks,
                                                       n_tokens)
        mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens),
                               -1) > 0

        local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand(
            batch_size, n_blocks, n_tokens,
            int(self.embeddings.embedding_dim / 2))
        combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1)
        emb = emb * math.sqrt(self.embeddings.embedding_dim)
        emb = emb + combined_pos_emb
        emb = self.pos_emb.dropout(emb)

        word_vec = emb.view(batch_size * n_blocks, n_tokens, -1)

        for i in range(self.num_layers):
            if (self.transformer_types[i] == 'local'):
                word_vec = self.transformer_layers[i](
                    word_vec, word_vec,
                    1 - mask_local)  # all_sents * max_tokens * dim
            elif (self.transformer_types[i] == 'inter'):
                if self.transformer_types[i - 1] == 'local':
                    local_vec = word_vec
                word_vec = self.transformer_layers[i](
                    word_vec, 1 - mask_local, 1 - mask_block, batch_size,
                    n_blocks)  # all_sents * max_tokens * dim

        global_vec = self.layer_norm(word_vec)
        local_vec = self.layer_norm(local_vec)
        mask_hier = mask_local[:, :, None].float()
        global_src_features = global_vec * mask_hier
        global_src_features = global_src_features.view(
            -1, global_src_features.size(-1))

        local_src_features = local_vec * mask_hier
        local_src_features = local_src_features.view(
            -1, local_src_features.size(-1))

        # cocat and project
        src_features = torch.cat((global_src_features, local_src_features), 1)
        src_features = self.proj(src_features)
        src_features = src_features.view(batch_size, n_blocks * n_tokens, -1)
        src_features = src_features.transpose(0, 1).contiguous()

        mask = mask_local
        mask = mask.view(batch_size, n_blocks * n_tokens)
        mask = mask.unsqueeze(1)

        return src_features, mask