def build_encoder_attention(self): self.preprocess_enc_attn = self.get_preprocessing_module() self.enc_attention = MultiHeadAttention( self.model_dim, self.num_heads, self.attention_dropout, masked_layers=self.masked_layers, batch_first=self.batch_first) self.postprocess_enc_attn = self.get_postprocessing_module()
def __init__(self, decoder, embedding, dropout, linear, *, copy_decoder=False, batch_first=False, extra_attention=False, masked_layers=False, attention_dropout=0.1, language_embedding=None): super().__init__() self.decoder = decoder self.embedded_dropout = EmbeddingDropout(embedding, dropout) self.linear = linear self.copy_decoder = copy_decoder self.batch_first = batch_first self.extra_attention = extra_attention if self.copy_decoder: model_dim = linear.weight.size(1) self.gate_layer = XavierLinear(model_dim, 1) if extra_attention: self.attention = MultiHeadAttention(model_dim, 1, attention_dropout, batch_first, masked_layers) self._register_load_state_dict_pre_hook( self._load_nmt_model_compatibility) if language_embedding is not None: self.language_embedding = language_embedding model_dim = self.embedded_dropout.embedding.weight.size(1) emb_dim = language_embedding.weight.size(1) self.merge_layer = XavierLinear(model_dim + emb_dim, model_dim) else: self.language_embedding = None
class TransformerDecoderLayer(IncrementalModule): """ Wraps multi-head self-attention, encoder-decoder attention and position-wise feed forward into one layer of decoder Layers: (1) Layer norm Multi-head self-attention Dropout Residual with (1) (2) Layer norm Multi-head query-context attention Dropout Residual with (2) (3) Layer norm Feed-forward Dropout Residual with (3) Feed-Forward: Configurable between linear -> ReLU -> linear and Maxout Args: model_dim: dimension of model num_heads: number of heads feed_forward_dim: dimension of feed forward feed_forward_dropout: dropout probability in the feed forward attention_dropout: dropout probability in attention residual_dropout: dropout probability for the residual layers weight_norm: whether to use weight normalization on the feed forward layers masked_layers: whether to use masking for layer norm and feed forward. Useful for sparse masks gated_residuals: whether to use gated residuals batch_first: whether input (and output) should be batch dimension first or sequence length dimension first feed_forward_type: Which type of feed forward to use. Currently supports 'linear_relu_linear' and 'maxout' ignore_context: If True, do not use the context input at all encoder_to_share: Instance of TransformerEncoderLayer to share parameters with Input Shapes: inputs: len_query x batch_size x model_dim or batch_size x len_query x model_dim context: len_context x batch_size x model_dim or batch_size x len_context x model_dim input_mask: batch_size x len_query or len_query x batch_size context_mask: batch_size x len_context or len_context x batch_size self_attention_mask: batch_size x len_query x len_query or broadcastable, regardless of batch_first Output Shapes: out: len_query x batch_size x model_dim or len_query x batch_size x model_dim """ _version = 2 def __init__(self, *, model_dim=512, num_heads=8, feed_forward_dim=2048, feed_forward_dropout=0.1, attention_dropout=0.1, residual_dropout=0.1, weight_norm=False, masked_layers=False, gated_residuals=False, batch_first=False, feed_forward_type='linear_relu_linear', ignore_context=False, encoder_to_share=None): super().__init__() self.model_dim = model_dim self.num_heads = num_heads self.feed_forward_dim = feed_forward_dim self.feed_forward_dropout = feed_forward_dropout self.attention_dropout = attention_dropout self.residual_dropout = residual_dropout self.weight_norm = weight_norm self.masked_layers = masked_layers self.gated_residuals = gated_residuals self.batch_first = batch_first self.feed_forward_type = feed_forward_type self.ignore_context = ignore_context if encoder_to_share is None: self.build_self_attention() self.build_feed_forward() else: # share the self-attention layers between encoder and decoder self.share_feed_forward(encoder_to_share) self.share_self_attention(encoder_to_share) if not ignore_context: self.build_encoder_attention() self._register_load_state_dict_pre_hook(self._update_names) def get_preprocessing_module(self): return PrePostProcessing(self.model_dim, 'n', masking=self.masked_layers) def get_postprocessing_module(self): return PrePostProcessing(self.model_dim, 'da', self.residual_dropout, gated_residuals=self.gated_residuals) # noinspection PyAttributeOutsideInit def build_self_attention(self): self.preprocess_self_attn = self.get_preprocessing_module() self.self_attention = MultiHeadAttention( self.model_dim, self.num_heads, self.attention_dropout, masked_layers=self.masked_layers, batch_first=self.batch_first) self.postprocess_self_attn = self.get_postprocessing_module() # noinspection PyAttributeOutsideInit def share_self_attention(self, encoder): self.preprocess_self_attn = encoder.preprocess_attn self.postprocess_self_attn = encoder.postprocess_attn self.self_attention = encoder.attention def self_attention_layer(self, inputs, input_mask=None, self_attention_bias=None): query = self.preprocess_self_attn(inputs, mask=input_mask) self_attention_out, _ = self.self_attention(query, query, query, self_attention_bias, input_mask) self_attention_out = self.postprocess_self_attn( self_attention_out, inputs) return self_attention_out def self_attention_step(self, inputs, incremental_state, input_mask=None, self_attention_bias=None): query = self.preprocess_self_attn(inputs, mask=input_mask) self_attention_out, _ = self.self_attention.step( query, query, query, incremental_state, self_attention_bias, input_mask) self_attention_out = self.postprocess_self_attn( self_attention_out, inputs) return self_attention_out # noinspection PyAttributeOutsideInit def build_encoder_attention(self): self.preprocess_enc_attn = self.get_preprocessing_module() self.enc_attention = MultiHeadAttention( self.model_dim, self.num_heads, self.attention_dropout, masked_layers=self.masked_layers, batch_first=self.batch_first) self.postprocess_enc_attn = self.get_postprocessing_module() def encoder_attention_layer(self, inputs, encoder_outputs, input_mask=None, context_mask=None, encoder_attention_bias=None): query = self.preprocess_enc_attn(inputs, mask=input_mask) enc_attention_out, attention_weights = self.enc_attention( query, encoder_outputs, encoder_outputs, encoder_attention_bias, input_mask, context_mask) enc_attention_out = self.postprocess_enc_attn(enc_attention_out, inputs) return enc_attention_out, attention_weights def encoder_attention_step(self, inputs, encoder_outputs, incremental_state, input_mask=None, context_mask=None, encoder_attention_bias=None): query = self.preprocess_enc_attn(inputs, mask=input_mask) enc_attention_out, attention_weights = self.enc_attention.step( query, encoder_outputs, encoder_outputs, incremental_state, encoder_attention_bias, input_mask, context_mask, static_kv=True) enc_attention_out = self.postprocess_enc_attn(enc_attention_out, inputs) return enc_attention_out, attention_weights # noinspection PyAttributeOutsideInit def build_feed_forward(self): self.preprocess_ffn = self.get_preprocessing_module() self.feed_forward = MaskedFunction( get_feed_forward(self.feed_forward_type, self.model_dim, self.feed_forward_dim, self.feed_forward_dropout, self.weight_norm)) self.postprocess_ffn = self.get_postprocessing_module() # noinspection PyAttributeOutsideInit def share_feed_forward(self, encoder): self.preprocess_ffn = encoder.preprocess_ffn self.postprocess_ffn = encoder.postprocess_ffn self.feed_forward = encoder.feed_forward def feed_forward_layer(self, inputs, input_mask=None): out = self.preprocess_ffn(inputs, mask=input_mask) out = self.feed_forward( out, mask=input_mask if self.masked_layers else None) out = self.postprocess_ffn(out, inputs) return out def feed_forward_step(self, inputs, input_mask): return self.feed_forward_layer(inputs, input_mask) def forward(self, inputs, context, input_mask=None, context_mask=None, self_attention_bias=None, encoder_attention_bias=None): self_attention_out = self.self_attention_layer(inputs, input_mask, self_attention_bias) if not self.ignore_context: context_attention_out, attention_weights = self.encoder_attention_layer( self_attention_out, context, input_mask, context_mask, encoder_attention_bias) else: context_attention_out = self_attention_out attention_weights = None out = self.feed_forward_layer(context_attention_out, input_mask) return out, attention_weights def _step(self, inputs, encoder_outputs, incremental_state, input_mask=None, context_mask=None, self_attention_bias=None, encoder_attention_bias=None): self_attention_out = self.self_attention_step(inputs, incremental_state, input_mask, self_attention_bias) if not self.ignore_context: enc_attention_out, attention_weights = self.encoder_attention_step( self_attention_out, encoder_outputs, incremental_state, input_mask, context_mask, encoder_attention_bias) else: enc_attention_out = self_attention_out attention_weights = None out = self.feed_forward_step(enc_attention_out, input_mask) return out, attention_weights def _update_names(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', 1) if version == 1 and prefix + 'version' not in state_dict: for key in self.preprocess_self_attn.state_dict().keys(): state_dict[prefix + 'preprocess_self_attn.' + key] = state_dict.pop(prefix + 'preprocess_attn.' + key) for key in self.preprocess_enc_attn.state_dict().keys(): state_dict[prefix + 'preprocess_enc_attn.' + key] = state_dict.pop(prefix + 'preprocess_src_attn.' + key) for key in self.self_attention.state_dict().keys(): state_dict[prefix + 'self_attention.' + key] = state_dict.pop(prefix + 'attention_tgt.' + key) for key in self.enc_attention.state_dict().keys(): state_dict[prefix + 'enc_attention.' + key] = state_dict.pop(prefix + 'attention_src.' + key) elif version == 1: del state_dict[prefix + 'version']
class NMTDecoder(IncrementalDecoder): """Wraps a Decoder and adds embedding and projection""" def __init__(self, decoder, embedding, dropout, linear, *, copy_decoder=False, batch_first=False, extra_attention=False, masked_layers=False, attention_dropout=0.1, language_embedding=None): super().__init__() self.decoder = decoder self.embedded_dropout = EmbeddingDropout(embedding, dropout) self.linear = linear self.copy_decoder = copy_decoder self.batch_first = batch_first self.extra_attention = extra_attention if self.copy_decoder: model_dim = linear.weight.size(1) self.gate_layer = XavierLinear(model_dim, 1) if extra_attention: self.attention = MultiHeadAttention(model_dim, 1, attention_dropout, batch_first, masked_layers) self._register_load_state_dict_pre_hook( self._load_nmt_model_compatibility) if language_embedding is not None: self.language_embedding = language_embedding model_dim = self.embedded_dropout.embedding.weight.size(1) emb_dim = language_embedding.weight.size(1) self.merge_layer = XavierLinear(model_dim + emb_dim, model_dim) else: self.language_embedding = None def forward(self, decoder_inputs, encoder_outputs, decoder_mask=None, encoder_mask=None): if self.language_embedding is not None: indices, language_id = decoder_inputs emb = torch.cat((self.embedded_dropout(indices), self.language_embedding(language_id)), dim=-1) emb = self.merge_layer(emb) else: emb = self.embedded_dropout(decoder_inputs) out, attention_weights = self.decoder(emb, encoder_outputs, decoder_mask, encoder_mask) if self.copy_decoder: if self.extra_attention: source_attention_bias = self.get_encoder_attention_bias( encoder_outputs, self.batch_first, encoder_mask) _, attention_weights = self.attention(out, encoder_outputs, encoder_outputs, source_attention_bias, decoder_mask, encoder_mask) gates = torch.sigmoid(self.gate_layer(out)).squeeze(-1) if self.training and decoder_mask is not None: # Optimize the projection by calculating only those position where # the input was not padding nonpad_indices = torch.nonzero(decoder_mask.view(-1)).squeeze(1) out = out.view(-1, out.size(-1)) out = out.index_select(0, nonpad_indices) # For multihead attention, the batch size dimension will be bigger. That means the results # of this operation are garbage if attention_weights is not None: attention_weights = attention_weights.view( -1, attention_weights.size(-1)) attention_weights = attention_weights.index_select( 0, nonpad_indices) if self.copy_decoder: gates = gates.masked_select(decoder_mask) if self.copy_decoder: attention_weights = {'attn': attention_weights, 'gates': gates} return self.linear(out), attention_weights def _step(self, decoder_inputs, encoder_outputs, incremental_state, decoder_mask=None, encoder_mask=None): emb = self.embedded_dropout(decoder_inputs) out, attention_weights = self.decoder.step(emb, encoder_outputs, incremental_state, decoder_mask, encoder_mask) if self.copy_decoder: if self.extra_attention: source_attention_bias = self.get_encoder_attention_bias( encoder_outputs, self.batch_first, encoder_mask) _, attention_weights = self.attention(out, encoder_outputs, encoder_outputs, source_attention_bias, decoder_mask, encoder_mask) gates = torch.sigmoid(self.gate_layer(out)).squeeze(-1) attention_weights = {'attn': attention_weights, 'gates': gates} return self.linear(out), attention_weights def get_normalized_probs(self, decoder_outputs, attention_weights, encoder_inputs=None, encoder_mask=None, decoder_mask=None, log_probs=False): decoder_probs = self.decoder.get_normalized_probs( decoder_outputs, attention_weights, encoder_inputs, encoder_mask, decoder_mask, log_probs) if not self.copy_decoder: return decoder_probs attention_weights, gates = attention_weights[ 'attn'], attention_weights['gates'] gates = gates.unsqueeze(-1) optimized = decoder_outputs.dim() == 2 if not self.batch_first: encoder_inputs = encoder_inputs.transpose(0, 1).unsqueeze( 0) # (1, batch, src_len) if optimized: # (batch, tgt_len, src_len) | (tgt_len, batch, src_len) new_size = list(decoder_mask.size()) + [encoder_inputs.size(-1)] nonpad_indices = torch.nonzero(decoder_mask.view(-1)).squeeze(1) encoder_inputs = encoder_inputs.expand(new_size).contiguous() \ .view(-1, encoder_inputs.size(-1)) \ .index_select(0, nonpad_indices) # encoder_inputs is now (decoder_outputs.size(0), src_len) else: encoder_inputs = encoder_inputs.expand_as(attention_weights) assert encoder_inputs.size() == attention_weights.size() encoder_probs = decoder_probs.new_full(decoder_probs.size(), 1e-20) encoder_probs.scatter_add_(1 if optimized else 2, encoder_inputs, attention_weights) if log_probs: encoder_probs.log_() encoder_probs.add_(torch.log(gates)) decoder_probs.add_(torch.log(1 - gates)) # Very important to have it this way around, otherwise we will add -inf + inf = NaN res = decoder_probs + torch.log1p( torch.exp(encoder_probs - decoder_probs)) return res else: return gates * encoder_probs + (1 - gates) * decoder_probs def reorder_incremental_state(self, incremental_state, new_order): self.decoder.reorder_incremental_state(incremental_state, new_order) if self.extra_attention: self.attention.reorder_incremental_state(incremental_state, new_order) def _load_nmt_model_compatibility(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if prefix + 'gate_layer.weight' in state_dict: return logger.info('Augmenting NMTModel with a copy decoder') items = self.gate_layer.state_dict(prefix=prefix + 'gate_layer.').items() if self.extra_attention: items = itertools.chain( items, self.attention.state_dict(prefix=prefix + 'attention.').items()) for key, value in items: assert key not in state_dict state_dict[key] = value
in_a_felix = in_tensor_A_Felix in_b_felix = in_tensor_A_Felix if sa else in_tensor_B_Felix if nm: bias_felix = None elif sa: mask_felix = (tgt_mask_felix.unsqueeze(1) + future_mask_felix).gt_(1) bias_felix = in_b_felix.new_full(mask_felix.size(), float('-inf')).masked_fill(mask_felix, 0) else: bias_felix = in_a_felix.new_full(src_mask_felix.size(), float('-inf')).masked_fill( src_mask_felix, 0).unsqueeze(1) felix_attention = MultiHeadAttention(512, 8, 0.0, bf, False) felix_attention.query_projection.function.weight = quan_attention.fc_query.function.linear.weight felix_attention.key_projection.function.weight = quan_attention.fc_key.function.linear.weight felix_attention.value_projection.function.weight = quan_attention.fc_value.function.linear.weight felix_attention.out_projection.function.weight = quan_attention.fc_concat.function.linear.weight out_tensor_felix, _ = felix_attention(in_a_felix, in_b_felix, in_b_felix, bias_felix, tgt_mask_felix, src_mask_felix) if bf: out_tensor_felix = out_tensor_felix.transpose(0, 1).contiguous() out_tensor_felix.sum().backward() grads_felix = felix_attention.query_projection.function.weight.grad.clone( ).detach().cpu() grads_felix2 = felix_attention.out_projection.function.weight.grad.clone( ).detach().cpu()