class TransformerInterEncoder(nn.Module): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, inter_layers, inter_heads, device): super(TransformerInterEncoder, self).__init__() inter_layers = [int(i) for i in inter_layers] self.device = device self.d_model = d_model self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding( dropout, int(self.embeddings.embedding_dim / 2)) self.dropout = nn.Dropout(dropout) self.transformer_layers = nn.ModuleList([ TransformerInterLayer(d_model, inter_heads, d_ff, dropout) if i in inter_layers else TransformerEncoderLayer( d_model, heads, d_ff, dropout) for i in range(num_layers) ]) self.transformer_types = [ 'inter' if i in inter_layers else 'local' for i in range(num_layers) ] print(self.transformer_types) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, src): """ See :obj:`EncoderBase.forward()`""" batch_size, n_blocks, n_tokens = src.size() # src = src.view(batch_size * n_blocks, n_tokens) emb = self.embeddings(src) padding_idx = self.embeddings.padding_idx #mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_local = ~src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1) emb = emb * math.sqrt(self.embeddings.embedding_dim) emb = emb + combined_pos_emb emb = self.pos_emb.dropout(emb) word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) for i in range(self.num_layers): if (self.transformer_types[i] == 'local'): word_vec = self.transformer_layers[i]( word_vec, word_vec, ~mask_local) # all_sents * max_tokens * dim elif (self.transformer_types[i] == 'inter'): word_vec = self.transformer_layers[i]( word_vec, ~mask_local, ~mask_block, batch_size, n_blocks) # all_sents * max_tokens * dim word_vec = self.layer_norm(word_vec) mask_hier = mask_local[:, :, None].float() src_features = word_vec * mask_hier src_features = src_features.view(batch_size, n_blocks * n_tokens, -1) src_features = src_features.transpose( 0, 1).contiguous() # src_len, batch_size, hidden_dim mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1) mask_hier = mask_hier.transpose(0, 1).contiguous() unpadded = [ torch.masked_select(src_features[:, i], mask_hier[:, i].byte()).view( [-1, src_features.size(-1)]) for i in range(src_features.size(1)) ] max_l = max([p.size(0) for p in unpadded]) mask_hier = sequence_mask(torch.tensor([p.size(0) for p in unpadded]), max_l) # .to(self.device) mask_hier = ~mask_hier[:, None, :] # .to(self.device) unpadded = torch.stack([ torch.cat( [p, torch.zeros(max_l - p.size(0), src_features.size(-1))]) for p in unpadded ], 1) return unpadded, mask_hier
class TransformerInterExtracter(nn.Module): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, inter_layers, inter_heads, device, mem_args=None): super(TransformerInterExtracter, self).__init__() inter_layers = [int(i) for i in inter_layers] self.device = device self.d_model = d_model self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding( dropout, int(self.embeddings.embedding_dim / 2)) self.dropout = nn.Dropout(dropout) if mem_args is not None and mem_args.mem_enc_positions: mem_flags = [ True if str(i) in mem_args.mem_enc_positions else False for i in range(num_layers) ] else: mem_flags = [False] * num_layers self.transformer_layers = nn.ModuleList([ TransformerInterLayer(d_model, inter_heads, d_ff, dropout, mem_args=mem_args, use_mem=mem_flags[i]) if i in inter_layers else TransformerEncoderLayer(d_model, heads, d_ff, dropout, mem_args=mem_args, use_mem=mem_flags[i]) for i in range(num_layers) ]) self.transformer_types = [ 'inter' if i in inter_layers else 'local' for i in range(num_layers) ] self.mem_types = ['PKM' if flag else 'FFN' for flag in mem_flags] print(list(zip(self.transformer_types, self.mem_types))) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) def forward(self, src): """ See :obj:`EncoderBase.forward()`""" batch_size, n_blocks, n_tokens = src.size() # src = src.view(batch_size * n_blocks, n_tokens) emb = self.embeddings(src) padding_idx = self.embeddings.padding_idx mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1) emb = emb * math.sqrt(self.embeddings.embedding_dim) emb = emb + combined_pos_emb emb = self.pos_emb.dropout(emb) word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) for i in range(self.num_layers): #print('about to process layer:', i, self.transformer_types[i]) if (self.transformer_types[i] == 'local'): word_vec = self.transformer_layers[i]( word_vec, word_vec, 1 - mask_local) # all_sents * max_tokens * dim elif (self.transformer_types[i] == 'inter'): word_vec = self.transformer_layers[i]( word_vec, 1 - mask_local, 1 - mask_block, batch_size, n_blocks) # all_sents * max_tokens * dim word_vec = self.layer_norm(word_vec) mask_hier = mask_local[:, :, None].float() src_features = word_vec * mask_hier src_features = src_features.view(batch_size, n_blocks * n_tokens, -1) #src_features = src_features.transpose(0, 1).contiguous() # src_len, batch_size, hidden_dim mask_hier = mask_hier.view(batch_size, n_blocks * n_tokens, -1) #mask_hier = mask_hier.transpose(0, 1).contiguous() return src_features, mask_hier
class TransformerEncoderHEQ(nn.Module): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, inter_layers, inter_heads, device): super(TransformerEncoderHEQ, self).__init__() inter_layers = [int(i) for i in inter_layers] para_layers = [i for i in range(num_layers)] for i in inter_layers: para_layers.remove(i) query_layer = para_layers[-1] para_layers = para_layers[:-1] self.device = device self.d_model = d_model self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding( dropout, int(self.embeddings.embedding_dim / 2)) self.dropout = nn.Dropout(dropout) ### Query Encoder self.transformer_query_encoder = TransformerEncoderLayer( d_model, heads, d_ff, dropout) self.query_layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.query_pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim, buffer_name='qpe') ###### self.transformer_layers = nn.ModuleList([ TransformerInterLayer(d_model, inter_heads, d_ff, dropout) if i in inter_layers else TransformerEncoderLayer( d_model, heads, d_ff, dropout) if i in para_layers else TransformerQueryEncoderLayer(d_model, heads, d_ff, dropout) for i in range(num_layers) ]) self.transformer_types = [ 'inter' if i in inter_layers else 'para_local' if i in para_layers else 'query_local' for i in range(num_layers) ] print(self.transformer_types) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.proj = nn.Linear( 2 * d_model, d_model, bias=False) # for concating local and global layer def forward(self, src, query): batch_size, n_blocks, n_tokens = src.size() emb = self.embeddings(src) padding_idx = self.embeddings.padding_idx mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_query = 1 - query.data.eq(padding_idx) mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 ### operations for query encoding _, qn_tokens = query.size() qpos_emb = self.query_pos_emb.qpe[:, :qn_tokens].expand( batch_size, qn_tokens, self.embeddings.embedding_dim) qemb = self.embeddings(query) * math.sqrt( self.embeddings.embedding_dim) + qpos_emb qemb = self.query_pos_emb.dropout(qemb) query_vec = self.transformer_query_encoder(qemb, qemb, 1 - mask_query) ##### local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1) emb = emb * math.sqrt(self.embeddings.embedding_dim) emb = emb + combined_pos_emb emb = self.pos_emb.dropout(emb) word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) for i in range(self.num_layers): if (self.transformer_types[i] == 'para_local'): word_vec = self.transformer_layers[i]( word_vec, word_vec, 1 - mask_local) # all_sents * max_tokens * dim elif self.transformer_layers[i] == 'query_local': word_vec = self.transformer_layers[i](word_vec, query_vec, 1 - mask_local, 1 - mask_query) elif (self.transformer_types[i] == 'inter'): if 'local' in self.transformer_types[i - 1]: local_vec = word_vec word_vec = self.transformer_layers[i]( word_vec, 1 - mask_local, 1 - mask_block, batch_size, n_blocks) # all_sents * max_tokens * dim global_vec = self.layer_norm(word_vec) local_vec = self.layer_norm(local_vec) mask_hier = mask_local[:, :, None].float() global_src_features = global_vec * mask_hier global_src_features = global_src_features.view( -1, global_src_features.size(-1)) local_src_features = local_vec * mask_hier local_src_features = local_src_features.view( -1, local_src_features.size(-1)) # cocat and project src_features = torch.cat((global_src_features, local_src_features), 1) src_features = self.proj(src_features) src_features = src_features.view(batch_size, n_blocks * n_tokens, -1) src_features = src_features.transpose(0, 1).contiguous() mask = mask_local mask = mask.view(batch_size, n_blocks * n_tokens) mask = mask.unsqueeze(1) return src_features, mask
class TransformerEncoderHERO(nn.Module): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, inter_layers, inter_heads, device): super(TransformerEncoderHERO, self).__init__() inter_layers = [int(i) for i in inter_layers] para_layers = [i for i in range(num_layers)] for i in inter_layers: para_layers.remove(i) query_layer = para_layers[-1] para_layers = para_layers[:-1] self.device = device self.d_model = d_model self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding(dropout, int(self.embeddings.embedding_dim)) self.dropout = nn.Dropout(dropout) ### Query Encoder self.transformer_query_encoder = TransformerEncoderLayer( d_model, heads, d_ff, dropout) self.query_layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.query_pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim, buffer_name='qpe') ###### self.transformer_layers = nn.ModuleList([ TransformerInterLayer(d_model, inter_heads, d_ff, dropout) if i in inter_layers else TransformerEncoderLayer( d_model, heads, d_ff, dropout) if i in para_layers else TransformerQueryEncoderLayer(d_model, heads, d_ff, dropout) for i in range(num_layers) ]) self.transformer_types = [ 'inter' if i in inter_layers else 'para_local' if i in para_layers else 'query_local' for i in range(num_layers) ] print(self.transformer_types) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) ### self-attention based positional embeddings self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) self.pos_attn = SelfAttention(d_model, dropout) self.final_proj = nn.Linear(3 * d_model, d_model) def forward(self, src, query): batch_size, n_blocks, n_tokens = src.size() emb = self.embeddings(src) padding_idx = self.embeddings.padding_idx mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_query = 1 - query.data.eq(padding_idx) mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 ### operations for query encoding _, qn_tokens = query.size() qpos_emb = self.query_pos_emb.qpe[:, :qn_tokens].expand( batch_size, qn_tokens, self.embeddings.embedding_dim) qemb = self.embeddings(query) * math.sqrt( self.embeddings.embedding_dim) + qpos_emb qemb = self.query_pos_emb.dropout(qemb) query_vec = self.transformer_query_encoder(qemb, qemb, 1 - mask_query) ##### local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim)) combined_pos_emb = local_pos_emb emb = emb * math.sqrt(self.embeddings.embedding_dim) emb = emb + combined_pos_emb emb = self.pos_emb.dropout(emb) word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) for i in range(self.num_layers): if (self.transformer_types[i] == 'para_local'): word_vec = self.transformer_layers[i]( word_vec, word_vec, 1 - mask_local) # all_sents * max_tokens * dim elif self.transformer_layers[i] == 'query_local': word_vec = self.transformer_layers[i](word_vec, query_vec, 1 - mask_local, 1 - mask_query) elif (self.transformer_types[i] == 'inter'): if 'local' in self.transformer_types[i - 1]: local_vec = word_vec word_vec = self.transformer_layers[i]( word_vec, 1 - mask_local, 1 - mask_block, batch_size, n_blocks) # all_sents * max_tokens * dim global_vec = self.layer_norm(word_vec) local_vec = self.layer_norm(local_vec) mask_hier = mask_local[:, :, None].float() global_src_features = global_vec * mask_hier global_src_features = global_src_features.view(batch_size, n_blocks * n_tokens, -1) global_src_features = global_src_features.transpose(0, 1).contiguous() local_src_features = local_vec * mask_hier local_src_features = local_src_features.view(batch_size, n_blocks * n_tokens, -1) local_src_features = local_src_features.transpose(0, 1).contiguous() mask = mask_local mask = mask.view(batch_size, n_blocks * n_tokens) mask = mask.unsqueeze(1) ### self attention for positional embedding d_model = self.d_model pos_features = global_src_features.transpose(0, 1).contiguous() _, attn = self.pos_attn(pos_features, mask.squeeze(1)) attn = attn.view(batch_size, n_blocks, n_tokens) para_attn = attn.sum(-1) # batch_size x n_blocks pe = torch.zeros(batch_size, n_blocks, d_model).cuda() multiplier_term = torch.exp( (torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))).cuda() pe[:, :, 0::2] = torch.sin(para_attn.unsqueeze(-1) * multiplier_term) pe[:, :, 1::2] = torch.cos(para_attn.unsqueeze(-1) * multiplier_term) pe = pe.unsqueeze(-2).expand(batch_size, n_blocks, n_tokens, -1).contiguous() pe = pe.view(batch_size, n_blocks * n_tokens, -1).contiguous().transpose(0, 1) feats = torch.cat((global_src_features, local_src_features, pe), -1) feats = self.final_proj(feats) return feats, mask
class TransformerEncoderHE(nn.Module): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, inter_layers, inter_heads, device): super(TransformerEncoderHE, self).__init__() inter_layers = [int(i) for i in inter_layers] self.device = device self.d_model = d_model self.num_layers = num_layers self.embeddings = embeddings self.pos_emb = PositionalEncoding( dropout, int(self.embeddings.embedding_dim / 2)) self.dropout = nn.Dropout(dropout) self.transformer_layers = nn.ModuleList([ TransformerInterLayer(d_model, inter_heads, d_ff, dropout) if i in inter_layers else TransformerEncoderLayer( d_model, heads, d_ff, dropout) for i in range(num_layers) ]) self.transformer_types = [ 'inter' if i in inter_layers else 'local' for i in range(num_layers) ] print(self.transformer_types) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.proj = nn.Linear(2 * d_model, d_model, bias=False) def forward(self, src): """ See :obj:`EncoderBase.forward()`""" batch_size, n_blocks, n_tokens = src.size() # src = src.view(batch_size * n_blocks, n_tokens) emb = self.embeddings(src) padding_idx = self.embeddings.padding_idx mask_local = 1 - src.data.eq(padding_idx).view(batch_size * n_blocks, n_tokens) mask_block = torch.sum(mask_local.view(batch_size, n_blocks, n_tokens), -1) > 0 local_pos_emb = self.pos_emb.pe[:, :n_tokens].unsqueeze(1).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) inter_pos_emb = self.pos_emb.pe[:, :n_blocks].unsqueeze(2).expand( batch_size, n_blocks, n_tokens, int(self.embeddings.embedding_dim / 2)) combined_pos_emb = torch.cat([local_pos_emb, inter_pos_emb], -1) emb = emb * math.sqrt(self.embeddings.embedding_dim) emb = emb + combined_pos_emb emb = self.pos_emb.dropout(emb) word_vec = emb.view(batch_size * n_blocks, n_tokens, -1) for i in range(self.num_layers): if (self.transformer_types[i] == 'local'): word_vec = self.transformer_layers[i]( word_vec, word_vec, 1 - mask_local) # all_sents * max_tokens * dim elif (self.transformer_types[i] == 'inter'): if self.transformer_types[i - 1] == 'local': local_vec = word_vec word_vec = self.transformer_layers[i]( word_vec, 1 - mask_local, 1 - mask_block, batch_size, n_blocks) # all_sents * max_tokens * dim global_vec = self.layer_norm(word_vec) local_vec = self.layer_norm(local_vec) mask_hier = mask_local[:, :, None].float() global_src_features = global_vec * mask_hier global_src_features = global_src_features.view( -1, global_src_features.size(-1)) local_src_features = local_vec * mask_hier local_src_features = local_src_features.view( -1, local_src_features.size(-1)) # cocat and project src_features = torch.cat((global_src_features, local_src_features), 1) src_features = self.proj(src_features) src_features = src_features.view(batch_size, n_blocks * n_tokens, -1) src_features = src_features.transpose(0, 1).contiguous() mask = mask_local mask = mask.view(batch_size, n_blocks * n_tokens) mask = mask.unsqueeze(1) return src_features, mask