def __init__(self, attention_heads, d_model, linear_units, slf_attn_dropout_rate, ffn_dropout_rate, residual_dropout_rate, normalize_before=False, concat_after=False, activation='relu'): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention(attention_heads, d_model, slf_attn_dropout_rate) self.feed_forward = PositionwiseFeedForward(d_model, linear_units, ffn_dropout_rate, activation) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = nn.Dropout(residual_dropout_rate) self.dropout2 = nn.Dropout(residual_dropout_rate) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(d_model * 2, d_model)
def __init__(self, output_size, d_model=256, attention_heads=4, linear_units=2048, num_blocks=6, pos_dropout_rate=0.0, slf_attn_dropout_rate=0.0, src_attn_dropout_rate=0.0, ffn_dropout_rate=0.0, residual_dropout_rate=0.1, activation='relu', normalize_before=True, concat_after=False, share_embedding=False, weight_sharing=False): super(TransformerDecoder, self).__init__() self.normalize_before = normalize_before self.weight_sharing = weight_sharing self.num_blocks = num_blocks self.embedding = torch.nn.Embedding(output_size, d_model) self.pos_encoding = PositionalEncoding(d_model, pos_dropout_rate) if weight_sharing: num_blocks = 1 self.blocks = nn.ModuleList([ TransformerDecoderLayer(attention_heads, d_model, linear_units, slf_attn_dropout_rate, src_attn_dropout_rate, ffn_dropout_rate, residual_dropout_rate, normalize_before=normalize_before, concat_after=concat_after, activation=activation) for _ in range(num_blocks) ]) if self.normalize_before: self.after_norm = LayerNorm(d_model) self.output_layer = nn.Linear(d_model, output_size) if share_embedding: assert self.embedding.weight.size( ) == self.output_layer.weight.size() self.output_layer.weight = self.embedding.weight