def __init__(self, d_model, d_inner_hid, n_head, dim_per_head, dropout=0.1): super(DecoderBlock, self).__init__() self.slf_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout, dim_per_head=dim_per_head) self.ctx_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout, dim_per_head=dim_per_head) self.pos_ffn = PositionwiseFeedForward(size=d_model, hidden_size=d_inner_hid) self.layer_norm_1 = nn.LayerNorm(d_model) self.layer_norm_2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout)
class DecoderBlock(nn.Module): ''' Compose with three layers ''' def __init__(self, d_model, d_inner_hid, n_head, dropout=0.1): super(DecoderBlock, self).__init__() self.slf_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout) self.ctx_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(size=d_model, hidden_size=d_inner_hid) self.layer_norm_1 = nn.LayerNorm(d_model) self.layer_norm_2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def compute_cache(self, enc_output): return self.ctx_attn.compute_cache(enc_output, enc_output) def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None, enc_attn_cache=None, self_attn_cache=None): # Args Checks input_batch, input_len, _ = dec_input.size() contxt_batch, contxt_len, _ = enc_output.size() input_norm = self.layer_norm_1(dec_input) all_input = input_norm query, _, self_attn_cache = self.slf_attn( all_input, all_input, input_norm, mask=slf_attn_mask, self_attn_cache=self_attn_cache) query = self.dropout(query) + dec_input query_norm = self.layer_norm_2(query) mid, attn, enc_attn_cache = self.ctx_attn( enc_output, enc_output, query_norm, mask=dec_enc_attn_mask, enc_attn_cache=enc_attn_cache) output = self.pos_ffn(self.dropout(mid) + query) return output, attn, self_attn_cache, enc_attn_cache
def __init__(self, d_model, d_inner_hid, n_head, dim_per_head, dropout=0.1, dim_capsule=100, num_capsules=0, null_capsule=False): super(DecoderBlock, self).__init__() self.slf_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout, dim_per_head=dim_per_head) # self.ctx_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout, # dim_per_head=dim_per_head) self.pos_ffn = PositionwiseFeedForward(size=d_model, hidden_size=d_inner_hid) self.layer_norm_1 = nn.LayerNorm(d_model) self.layer_norm_2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) # contextual capsule layer self.apply_capsule = True # self.pre_capsule_layer_norm = nn.LayerNorm(d_model) assert dim_capsule % num_capsules == 0 self.dim_per_cap = dim_capsule // num_capsules dim_per_part = dim_capsule // 3 total_num_capsules = num_capsules self.null_caps = null_capsule if null_capsule: INFO("Using Null Capsules to attract irrelevant routing.") total_num_capsules += num_capsules // 3 self.capsule_layer = ContextualCapsuleLayer( num_out_caps=total_num_capsules, num_in_caps=None, dim_in_caps=d_model, dim_out_caps=self.dim_per_cap, dim_context=d_model, num_iterations=3, share_route_weights_for_in_caps=True) self.out_and_cap_ffn = MultiInputPositionwiseFeedForward( size=d_model, hidden_size=d_inner_hid, dropout=dropout, inp_sizes=[dim_per_part, dim_per_part, dim_per_part])
def __init__(self, d_model, n_head, feature_size=1024, hidden_size=512, dropout=0.0, **kwargs ): super(QE_ATTENTION, self).__init__() self.ctx_attn = MultiHeadedAttention(head_count=n_head, model_dim=d_model, dropout=dropout, dim_per_head=None) # Use PAD self.gru = RNN(type="gru", batch_first=True, input_size=feature_size, hidden_size=hidden_size, bidirectional=True) self.lstm = RNN(type="lstm", batch_first=True, input_size=feature_size, hidden_size=hidden_size, bidirectional=True) self.w = nn.Linear(2 * hidden_size, 1) my_init.default_init(self.w.weight) self.dropout = nn.Dropout(dropout) self.sigmoid = nn.Sigmoid()