def test_add_positional_features(self): # This is hard to test, so we check that we get the same result as the # original tensorflow implementation: # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py#L270 tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00], [8.41470957e-01, 9.99999902e-05, 5.40302277e-01, 1.00000000e+00], [9.09297407e-01, 1.99999980e-04, -4.16146845e-01, 1.00000000e+00]]) tensor = torch.zeros([2, 3, 4]) result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4) numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result) numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result) # Check case with odd number of dimensions. tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 0.00000000e+00], [8.41470957e-01, 9.99983307e-03, 9.99999902e-05, 5.40302277e-01, 9.99949992e-01, 1.00000000e+00, 0.00000000e+00], [9.09297407e-01, 1.99986659e-02, 1.99999980e-04, -4.16146815e-01, 9.99800026e-01, 1.00000000e+00, 0.00000000e+00]]) tensor = torch.zeros([2, 3, 7]) result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4) numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result) numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result)
def _encode(self, source_tokens: Dict[str, torch.Tensor], segments: Dict[str, torch.Tensor], source_entity_length: torch.Tensor, edge_mask: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ :param source_tokens: :param segments: :param merge_indicators: :return: """ # shape: (batch_size, encode_length, embedding_dim) source_embedded_input = self._embed_source(source_tokens, source_entity_length) # shape: (batch_size, encode_length, embedding_dim) segments_embedded_input = self._segment_embedder(segments) encode_length = segments_embedded_input.size(1) assert source_embedded_input.size(1) == segments_embedded_input.size(1) # token_mask = (segments['tokens'] == self._token_index).unsqueeze(-1).float() # valid_token_embedded_input = batched_embedded_input * token_mask # valid_token_embedded_input = util.add_positional_features(valid_token_embedded_input) # valid_token_embedded_input = batched_embedded_input * (1 - token_mask) + valid_token_embedded_input * token_mask if self._source_embedding_dim == self._encoder_d_model: batched_embedded_input = segments_embedded_input + source_embedded_input final_embedded_input = util.add_positional_features(batched_embedded_input) else: batched_embedded_input = torch.cat([source_embedded_input, segments_embedded_input], dim=-1) final_embedded_input = util.add_positional_features(batched_embedded_input) # shape: (encode_length, batch_size, d_model) final_embedded_input = final_embedded_input.permute(1, 0, 2) # shape: (batch_size, encode_length) source_mask = util.get_text_field_mask(segments) source_key_padding_mask = (1 - source_mask.byte()).bool() if not self._use_gnn_encoder: # shape: (encode_length, batch_size, d_model) encoder_outputs = self._encoder(final_embedded_input, src_key_padding_mask=source_key_padding_mask) else: # GNN encoders encoder_outputs = self._encoder(src=final_embedded_input, edge_mask=edge_mask.permute(0, 2, 3, 1), padding_mask=source_key_padding_mask) source_token_mask = (segments['tokens'] == self._token_index).float() return { "source_mask": source_mask, "source_key_padding_mask": source_key_padding_mask, "source_token_mask": source_token_mask, "encoder_outputs": encoder_outputs, "source_embedded": batched_embedded_input, "source_raw_embedded": source_embedded_input, }
def forward( self, # type: ignore abstract_text: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_abstract_text = self.text_field_embedder(abstract_text, num_wrapping_dims=1) abstract_text_mask = util.get_text_field_mask(abstract_text, num_wrapping_dims=1) num_instance, num_sentence, num_word, _ = embedded_abstract_text.size() embedded_abstract_text = embedded_abstract_text.view( num_instance * num_sentence, num_word, -1) abstract_text_mask = abstract_text_mask.view( num_instance * num_sentence, -1) if self.use_positional_encoding: embedded_abstract_text = util.add_positional_features( embedded_abstract_text) encoded_abstract_text = self.word_encoder(embedded_abstract_text, abstract_text_mask) attended_sentences = self.word_level_attention(encoded_abstract_text, abstract_text_mask) attended_sentences = attended_sentences.view(num_instance, num_sentence, -1) abstract_text_mask = abstract_text_mask.view(num_instance, num_sentence, -1).sum(2).ge(1).long() if self.use_positional_encoding: attended_sentences = util.add_positional_features( attended_sentences) attended_sentences = self.sentence_encoder(attended_sentences, abstract_text_mask) attended_abstract_text = self.sentence_level_attention( attended_sentences, abstract_text_mask) outputs = self.classifier_feedforward(attended_abstract_text) logits = torch.sigmoid(outputs) logits = logits.unsqueeze(0) if logits.dim() < 2 else logits output_dict = {'logits': logits} if label is not None: outputs = outputs.unsqueeze(0) if outputs.dim() < 2 else outputs loss = self.loss(outputs, label.squeeze(-1)) for metric in self.metrics.values(): metric(logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward(self, **inputs): x = self.aa_embedder(inputs['aa']) + self.pssm_embedder(inputs['pssm']) if 'ss' in inputs and self.use_ss: x += self.ss_embedder(inputs['ss']) if self.use_positional_encoding: add_positional_features(x) x = self.input_dropout(x) mask = get_text_field_mask(inputs['aa']) x = self.encoder(x, mask) if 'msa' in inputs and self.msa_encoder is not None: x += self.msa_encoder(inputs['msa']) x = self.decoder(self.feedforward(x)) outputs = { 'predictions': x, 'protein_id': inputs['protein_id'], 'length': inputs['length'] } if 'dcalpha' in inputs: mask = mask.unsqueeze(-1).float() if self.target == 'dcalpha': target = inputs['dcalpha'] mask = mask.matmul(mask.transpose(dim0=-2, dim1=-1)) mask_triu = torch.triu(torch.ones_like(mask[0]), diagonal=1).unsqueeze(0).to(mask.device) mask *= mask_triu elif self.target == 'angles': target = torch.stack([inputs['psi'], inputs['phi']], dim=2) else: target = inputs['coords'] mse = ((mask * (x - target))**2).sum() / mask.sum() if torch.isnan(mse): while True: expr = input('\nInput = ') if expr == 'q': exit(0) try: print(eval(expr)) except Exception as e: print(e) self.metrics['rmse']((mse**0.5).detach().item()) outputs['loss'] = mse return outputs
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor): if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs for i in range(len(self._attention_layers)): # It's necessary to use `getattr` here because the elements stored # in the lists are not replicated by torch.nn.parallel.replicate # when running on multiple GPUs. Please use `ModuleList` in new # code. It handles this issue transparently. We've kept `add_module` # (in conjunction with `getattr`) solely for backwards compatibility # with existing serialized models. attention = getattr(self, f"self_attention_{i}") feedforward = getattr(self, f"feedforward_{i}") feedforward_layer_norm = getattr(self, f"feedforward_layer_norm_{i}") layer_norm = getattr(self, f"layer_norm_{i}") cached_input = output # Project output of attention encoder through a feedforward # network and back to the input size for the next layer. # shape (batch_size, timesteps, input_size) feedforward_output = feedforward(output) feedforward_output = self.dropout(feedforward_output) if feedforward_output.size() == cached_input.size(): # First layer might have the wrong size for highway # layers, so we exclude it here. feedforward_output = feedforward_layer_norm( feedforward_output + cached_input) # shape (batch_size, sequence_length, hidden_dim) attention_output = attention(feedforward_output, mask) output = layer_norm( self.dropout(attention_output) + feedforward_output) return output
def forward(self, # type: ignore abstract_text: Dict[str, torch.LongTensor], local_label: torch.LongTensor = None, global_label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_abstract_text = self.text_field_embedder(abstract_text) abstract_text_mask = util.get_text_field_mask(abstract_text) if self.use_positional_encoding: embedded_abstract_text = util.add_positional_features(embedded_abstract_text) encoded_abstract_text = self.abstract_text_encoder(embedded_abstract_text, abstract_text_mask) attended_abstract_text = self.attention_encoder(encoded_abstract_text, abstract_text_mask) local_outputs, global_outputs = self.HMCN_recurrent(attended_abstract_text) logits = self.local_globel_tradeoff * global_outputs + (1 - self.local_globel_tradeoff) * local_outputs logits = torch.sigmoid(logits) logits = logits.unsqueeze(0) if logits.dim() < 2 else logits output_dict = {'logits': logits} if local_label is not None and global_label is not None: local_outputs = local_outputs.unsqueeze(0) if local_outputs.dim() < 2 else local_outputs global_outputs = global_outputs.unsqueeze(0) if global_outputs.dim() < 2 else global_outputs loss = self.loss(local_outputs, global_outputs, local_label.squeeze(-1), global_label.squeeze(-1)) for metric in self.metrics.values(): metric(logits, global_label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore abstract_text: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ embedded_abstract_text = self.text_field_embedder(abstract_text) abstract_text_mask = util.get_text_field_mask(abstract_text) encoded_abstract_text = self.abstract_text_encoder( embedded_abstract_text, abstract_text_mask) if self.use_positional_encoding: encoded_abstract_text = util.add_positional_features( encoded_abstract_text) attended_abstract_text = self.attention_encoder( encoded_abstract_text, abstract_text_mask) outputs = self.classifier_feedforward(attended_abstract_text) logits = torch.sigmoid(outputs) logits = logits.unsqueeze(0) if len(logits.size()) < 2 else logits output_dict = {'logits': logits} if label is not None: outputs = outputs.unsqueeze(0) if len( outputs.size()) < 2 else outputs loss = self.loss(outputs, label.squeeze(-1)) for metric in self.metrics.values(): metric(logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor): output = inputs if self._sinusoidal_positional_encoding: output = add_positional_features(output) if self._positional_embedding is not None: position_ids = torch.arange(inputs.size(1), dtype=torch.long, device=output.device) position_ids = position_ids.unsqueeze(0).expand(inputs.shape[:-1]) output = output + self._positional_embedding(position_ids) # print() # print(sum(output[0][4]), sum(output[0][100])) # For some reason the torch transformer expects the shape (sequence, batch, features), not the more # familiar (batch, sequence, features), so we have to fix it. output = output.permute(1, 0, 2) # For some other reason, the torch transformer takes the mask backwards. mask = ~mask output = self._transformer(output, src_key_padding_mask=mask) output = output.permute(1, 0, 2) # print(sum(inputs[0][4]), sum(inputs[0][100])) # print(sum(output[0][4]), sum(output[0][100])) # print() return output
def forward(self, inputs, mask): # pylint: disable=arguments-differ if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs for (attention, feedforward, feedforward_layer_norm, layer_norm) in zip( self._attention_layers, self._feedfoward_layers, self._feed_forward_layer_norm_layers, self._layer_norm_layers, ): cached_input = output # Project output of attention encoder through a feedforward # network and back to the input size for the next layer. # shape (batch_size, timesteps, input_size) feedforward_output = feedforward(feedforward_layer_norm(output)) feedforward_output = self.dropout(feedforward_output) if feedforward_output.size() == cached_input.size(): # First layer might have the wrong size for highway # layers, so we exclude it here. feedforward_output += cached_input # shape (batch_size, sequence_length, hidden_dim) attention_output = attention(layer_norm(feedforward_output), mask) output = self.dropout(attention_output) + feedforward_output return self._output_layer_norm(output)
def forward(self, inputs: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs for (attention, feedforward, feedforward_layer_norm, layer_norm) in zip(self._attention_layers, self._feedfoward_layers, self._feed_forward_layer_norm_layers, self._layer_norm_layers): cached_input = output # Project output of attention encoder through a feedforward # network and back to the input size for the next layer. # shape (batch_size, timesteps, input_size) feedforward_output = feedforward(output) feedforward_output = self.dropout(feedforward_output) if feedforward_output.size() == cached_input.size(): # First layer might have the wrong size for highway # layers, so we exclude it here. feedforward_output = feedforward_layer_norm(feedforward_output + cached_input) # shape (batch_size, sequence_length, hidden_dim) attention_output = attention(feedforward_output, mask) output = layer_norm(self.dropout(attention_output) + feedforward_output) return output
def forward(self, inputs: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs total_sublayers = len(self._conv_layers) + 2 sublayer_count = 0 for conv_norm_layer, conv_layer in zip(self._conv_norm_layers, self._conv_layers): conv_norm_out = self.dropout(conv_norm_layer(output)) conv_out = self.dropout(conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2)) sublayer_count += 1 output = self.residual_with_layer_dropout( output, conv_out, sublayer_count, total_sublayers ) attention_norm_out = self.dropout(self.attention_norm_layer(output)) attention_out = self.dropout(self.attention_layer(attention_norm_out, mask)) sublayer_count += 1 output = self.residual_with_layer_dropout( output, attention_out, sublayer_count, total_sublayers ) feedforward_norm_out = self.dropout(self.feedforward_norm_layer(output)) feedforward_out = self.dropout(self.feedforward(feedforward_norm_out)) sublayer_count += 1 output = self.residual_with_layer_dropout( output, feedforward_out, sublayer_count, total_sublayers ) return output
def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: output = add_positional_features(inputs) output = self.projection(output) for block in self.blocks: output = block(output, mask) return output
def _run_transformer_decoder(self, context_output: torch.Tensor, query_output: torch.Tensor, context_mask: torch.Tensor, query_mask: torch.Tensor, rewrite_embed: torch.Tensor, rewrite_mask: torch.Tensor): """ 实现Transformer解码器的decoder过程 :param _output: [B, _len, d_model] :param _mask: [B, _len] :param rewrite_embed: [B, cur_dec_len, d_model] :param rewrite_mask: [B, cur_dec_len],这里只是pad的mask,上三角mask在decoder内部实现 """ if self._share_decoder_params: rewrite_embed = add_positional_features(rewrite_embed) previous_state = None encoder_outputs = { "context_output": context_output, "query_output": query_output } source_mask = {"context_mask": context_mask, "query_mask": query_mask} # dec_output: [B, dec_len, d_model] # context_attn: [B, num_heads, dec_len, context_len] # query_attn: [B, num_heads, dec_len, query_len] # x_context: [B, dec_len, d_model] # x_query: [B, dec_len, d_model] dec_output, context_attn, query_attn, x_context, x_query = self.decoder( previous_state, encoder_outputs, source_mask, rewrite_embed, rewrite_mask) # 如果共享解码器的参数 if self._share_decoder_params: for _ in range(self.decoder_num_layers - 1): dec_output, context_attn, query_attn, x_context, x_query = self.decoder( previous_state, encoder_outputs, source_mask, dec_output, rewrite_mask) # sum the attention dists of different heads context_attn = torch.sum(context_attn, dim=1, keepdim=False) query_attn = torch.sum(query_attn, dim=1, keepdim=False) # mask softmax get the final attention dists context_attn = masked_softmax(context_attn, context_mask, dim=-1) query_attn = masked_softmax(query_attn, query_mask, dim=-1) # compute lambda # [B, dec_len, 2] # 注意这里和LSTM解码器的区别 # Transformer解码器是一次解码全部输出的,所以需要包含len维度 lamb = self._compute_lambda(dec_output, dec_context=x_context, dec_query=x_query) return dec_output, context_attn, query_attn, lamb
def forward(self, src: torch.Tensor, kb: torch.Tensor, src_mask: Optional[torch.Tensor] = None, kb_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: src = add_positional_features(src) src = self.src_projection(src) kb = self.kb_projection(kb) for i in range(self.num_layers): src = self.src_self_blocks[i](src, src_mask) kb = self.kb_self_blocks[i](kb, kb_mask) src, kb = self.rel_blocks[i](kb, src, kb_mask, src_mask) if self.return_kb: return src, kb return src
def _forward(self, embedding_sequence: torch.Tensor, state: Dict[str, torch.Tensor], apply_target_subsequent_mask: bool = True) -> torch.Tensor: if self.use_position_encoding: embedding_sequence = add_positional_features(embedding_sequence) output = embedding_sequence for layer in self.layers: output = layer( output, state['target_mask'] if 'target_mask' in state else None, apply_target_subsequent_mask=apply_target_subsequent_mask) if self.pre_norm: # We need to add an additional function of layer normalization to the top layer # to prevent the excessively increased value caused by the sum of unnormalized output # (https://arxiv.org/pdf/1906.01787.pdf) output = self.output_layer_norm(output) return output
def forward( self, previous_state: Dict[str, torch.Tensor], encoder_outputs: torch.Tensor, source_mask: torch.Tensor, previous_steps_predictions: torch.Tensor, previous_steps_mask: Optional[torch.Tensor] = None ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: seq_len = previous_steps_predictions.size(-2) future_mask = torch.triu( torch.ones(seq_len, seq_len, device=source_mask.device, dtype=torch.float)).transpose(0, 1) future_mask = future_mask.masked_fill(future_mask == 0, float('-inf')).masked_fill( future_mask == 1, float(0.0)) future_mask = Variable(future_mask) if self._use_positional_encoding: previous_steps_predictions = add_positional_features( previous_steps_predictions) previous_steps_predictions = self._dropout(previous_steps_predictions) source_mask = ~(source_mask.bool()) if previous_steps_mask is not None: previous_steps_mask = ~(previous_steps_mask.bool()) previous_steps_predictions = previous_steps_predictions.permute( 1, 0, 2) encoder_outputs = encoder_outputs.permute(1, 0, 2) output = self._decoder(previous_steps_predictions, encoder_outputs, tgt_mask=future_mask, tgt_key_padding_mask=previous_steps_mask, memory_key_padding_mask=source_mask) return {}, output.permute(1, 0, 2)
def forward(self, embedding_sequence: torch.LongTensor) -> torch.Tensor: """ Parameters ---------- embedding_sequence : (batch_size, sequence_length, embedding_size) Returns ------- (batch_size, sequence_length, embedding_size) """ if self.learned: batch_size, sequence_length, _ = embedding_sequence.size() position_indices = torch.arange(sequence_length).to( self.embedding.weight.device) position_embeddings = self.embedding(position_indices) position_embeddings = position_embeddings.unsqueeze(0).expand( batch_size, sequence_length, -1) return embedding_sequence + position_embeddings else: return util.add_positional_features(embedding_sequence)
def forward(self, query_embeddings: torch.Tensor, document_embeddings: torch.Tensor, query_pad_oov_mask: torch.Tensor, document_pad_oov_mask: torch.Tensor, output_secondary_output: bool = False) -> torch.Tensor: # pylint: disable=arguments-differ query_embeddings = query_embeddings * query_pad_oov_mask.unsqueeze(-1) document_embeddings = document_embeddings * document_pad_oov_mask.unsqueeze( -1) query_embeddings_context = self.contextualizer( add_positional_features(query_embeddings).transpose(1, 0), src_key_padding_mask=~query_pad_oov_mask.bool()).transpose(1, 0) document_embeddings_context = self.contextualizer( add_positional_features(document_embeddings).transpose(1, 0), src_key_padding_mask=~document_pad_oov_mask.bool()).transpose( 1, 0) query_embeddings = (self.mixer * query_embeddings + (1 - self.mixer) * query_embeddings_context ) * query_pad_oov_mask.unsqueeze(-1) document_embeddings = (self.mixer * document_embeddings + (1 - self.mixer) * document_embeddings_context ) * document_pad_oov_mask.unsqueeze(-1) # # prepare embedding tensors & paddings masks # ------------------------------------------------------- query_by_doc_mask = torch.bmm( query_pad_oov_mask.unsqueeze(-1), document_pad_oov_mask.unsqueeze(-1).transpose(-1, -2)) query_by_doc_mask_view = query_by_doc_mask.unsqueeze(-1) # # cosine matrix # ------------------------------------------------------- # shape: (batch, query_max, doc_max) cosine_matrix = self.cosine_module.forward(query_embeddings, document_embeddings) cosine_matrix_masked = cosine_matrix * query_by_doc_mask cosine_matrix_extradim = cosine_matrix_masked.unsqueeze(-1) # # gaussian kernels & soft-TF # # first run through kernel, then sum on doc dim then sum on query dim # ------------------------------------------------------- raw_kernel_results = torch.exp( -torch.pow(cosine_matrix_extradim - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) kernel_results_masked = raw_kernel_results * query_by_doc_mask_view # # mean kernels # #kernel_results_masked2 = kernel_results_masked.clone() doc_lengths = torch.sum(document_pad_oov_mask, 1) #kernel_results_masked2_mean = kernel_results_masked / doc_lengths.unsqueeze(-1) per_kernel_query = torch.sum(kernel_results_masked, 2) log_per_kernel_query = torch.log2( torch.clamp(per_kernel_query, min=1e-10)) * self.nn_scaler log_per_kernel_query_masked = log_per_kernel_query * query_pad_oov_mask.unsqueeze( -1) # make sure we mask out padding values per_kernel = torch.sum(log_per_kernel_query_masked, 1) #per_kernel_query_mean = torch.sum(kernel_results_masked2_mean, 2) per_kernel_query_mean = per_kernel_query / ( doc_lengths.view(-1, 1, 1) + 1 ) # well, that +1 needs an explanation, sometimes training data is just broken ... (and nans all the things!) log_per_kernel_query_mean = per_kernel_query_mean * self.nn_scaler log_per_kernel_query_masked_mean = log_per_kernel_query_mean * query_pad_oov_mask.unsqueeze( -1) # make sure we mask out padding values per_kernel_mean = torch.sum(log_per_kernel_query_masked_mean, 1) ## ## "Learning to rank" layer - connects kernels with learned weights ## ------------------------------------------------------- dense_out = self.dense(per_kernel) dense_mean_out = self.dense_mean(per_kernel_mean) dense_comb_out = self.dense_comb( torch.cat([dense_out, dense_mean_out], dim=1)) score = torch.squeeze(dense_comb_out, 1) #torch.tanh(dense_out), 1) if output_secondary_output: query_mean_vector = query_embeddings.sum( dim=1) / query_pad_oov_mask.sum(dim=1).unsqueeze(-1) return score, { "score": score, "dense_out": dense_out, "dense_mean_out": dense_mean_out, "per_kernel": per_kernel, "per_kernel_mean": per_kernel_mean, "query_mean_vector": query_mean_vector, "cosine_matrix_masked": cosine_matrix_masked } else: return score
def forward(self, source: torch.Tensor, target: torch.Tensor, metadata: dict, seq_labels: torch.Tensor = None, reg_labels: torch.Tensor = None, source_mask: Optional[torch.Tensor]=None, target_mask: Optional[torch.Tensor]=None, memory_mask: Optional[torch.Tensor]=None, src_key_padding_mask: Optional[torch.Tensor]=None, tgt_key_padding_mask: Optional[torch.Tensor]=None, memory_key_padding_mask: Optional[torch.Tensor]=None) -> Dict[str, torch.Tensor]: r"""Take in and process masked source/target sequences. Args: source: the sequence to the encoder (required). target: the sequence to the decoder (required). metadata: the metadata of the samples (required). seq_labels: the labels of each round (optional). reg_labels: the labels of the total future payoff (optional). source_mask: the additive mask for the src sequence (optional). target_mask: the additive mask for the tgt sequence (optional). memory_mask: the additive mask for the encoder output (optional). src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). Shape: - source: :math:`(S, N, E)`. - target: :math:`(T, N, E)`. - source_mask: :math:`(S, S)`. - target_mask: :math:`(T, T)`. - memory_mask: :math:`(T, S)`. - src_key_padding_mask: :math:`(N, S)`. - tgt_key_padding_mask: :math:`(N, T)`. - memory_key_padding_mask: :math:`(N, S)`. Note: [src/tgt/memory]_mask should be filled with float('-inf') for the masked positions and float(0.0) else. These masks ensure that predictions for position i depend only on the unmasked positions j and are applied identically for each sequence in a batch. [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions that should be masked with float('-inf') and False values will be unchanged. This mask ensures that no information will be taken from position i if it is masked, and has a separate mask for each sequence in a batch. - output: :math:`(T, N, E)`. Note: Due to the multi-head attention architecture in the transformer model, the output sequence length of a transformer is same as the input sequence (i.e. target) length of the decode. where S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number """ if self._first_pair is not None: if self._first_pair == metadata[0]['pair_id']: self._epoch += 1 else: self._first_pair = metadata[0]['pair_id'] output = dict() src_key_padding_mask = get_text_field_mask({'tokens': source}) tgt_key_padding_mask = get_text_field_mask({'tokens': target}) # The torch transformer takes the mask backwards. src_key_padding_mask_byte = ~src_key_padding_mask.bool() tgt_key_padding_mask_byte = ~tgt_key_padding_mask.bool() # create mask where only the first round is not masked --> need to be the same first round for all sequances if self.only_raisha: temp_mask = torch.ones(tgt_key_padding_mask_byte.shape, dtype=torch.bool) temp_mask[:, 0] = False tgt_key_padding_mask_byte = temp_mask if self._sinusoidal_positional_encoding: source = add_positional_features(source) target = add_positional_features(target) if torch.cuda.is_available(): # change to cuda source = source.cuda() target = target.cuda() tgt_key_padding_mask = tgt_key_padding_mask.cuda() src_key_padding_mask = src_key_padding_mask.cuda() tgt_key_padding_mask_byte = tgt_key_padding_mask_byte.cuda() src_key_padding_mask_byte = src_key_padding_mask_byte.cuda() if seq_labels is not None: seq_labels = seq_labels.cuda() if reg_labels is not None: reg_labels = reg_labels.cuda() # The torch transformer expects the shape (sequence, batch, features), not the more # familiar (batch, sequence, features), so we have to fix it. source = source.permute(1, 0, 2) target = target.permute(1, 0, 2) if source.size(1) != target.size(1): raise RuntimeError("the batch number of src and tgt must be equal") if source.size(2) != self._input_dim or target.size(2) != self._input_dim: raise RuntimeError("the feature number of src and tgt must be equal to d_model") encoder_out = self.encoder(source, src_key_padding_mask=src_key_padding_mask_byte) decoder_output = self.decoder(target, encoder_out, tgt_key_padding_mask=tgt_key_padding_mask_byte, memory_key_padding_mask=src_key_padding_mask_byte) decoder_output = decoder_output.permute(1, 0, 2) if self.predict_seq: if self.linear_layer is not None: decoder_output = self.linear_layer(decoder_output) # add linear layer before hidden2tag decision_logits = self.hidden2tag(decoder_output) output['decision_logits'] = masked_softmax(decision_logits, tgt_key_padding_mask.unsqueeze(2)) self.seq_predictions = save_predictions_seq_models(prediction_df=self.seq_predictions, mask=tgt_key_padding_mask, predictions=output['decision_logits'], gold_labels=seq_labels, metadata=metadata, epoch=self._epoch, is_train=self.training,) if self.predict_avg_total_payoff: # (batch_size, seq_len, dimensions) * (batch_size, dimensions, 1) -> (batch_size, seq_len) attention_output = self.attention(self.attention_vector, decoder_output, tgt_key_padding_mask) # (batch_size, 1, seq_len) * (batch_size, seq_len, dimensions) -> (batch_size, dimensions) attention_output = torch.bmm(attention_output.unsqueeze(1), decoder_output).squeeze() # (batch_size, dimensions) -> (batch_size, batch_size) linear_out = self.linear_after_attention_layer(attention_output) # (batch_size, batch_size) -> (batch_size, 1) regression_output = self.regressor(linear_out) output['regression_output'] = regression_output self.reg_predictions = save_predictions(prediction_df=self.reg_predictions, predictions=output['regression_output'], gold_labels=reg_labels, metadata=metadata, epoch=self._epoch, is_train=self.training, int_label=False) if seq_labels is not None or reg_labels is not None: temp_loss = 0 if self.predict_seq and seq_labels is not None: for metric_name, metric in self.metrics_dict_seq.items(): metric(decision_logits, seq_labels, tgt_key_padding_mask) output['seq_loss'] = sequence_cross_entropy_with_logits(decision_logits, seq_labels, tgt_key_padding_mask) temp_loss += self.seq_weight_loss * output['seq_loss'] if self.predict_avg_total_payoff and reg_labels is not None: for metric_name, metric in self.metrics_dict_reg.items(): metric(regression_output, reg_labels, tgt_key_padding_mask) output['reg_loss'] = self.mse_loss(regression_output, reg_labels.view(reg_labels.shape[0], -1)) temp_loss += self.reg_weight_loss * output['reg_loss'] output['loss'] = temp_loss return output
def forward(self, content_id: torch.LongTensor, bundle_id: torch.LongTensor, feature: torch.FloatTensor, user_id: torch.FloatTensor, mask: torch.Tensor, initial_state: Optional[TensorPair] = None, ans_prev_correctly: Optional[torch.Tensor] = None): batch_size, seq_len = content_id.shape # content_emb: (batch, seq, dim) content_emb = self.content_id_emb(content_id) if self.hparams["emb_dropout"] > 0: content_emb = self.emb_dropout(content_emb) # content_emb: (batch, seq, dim) feature = torch.cat([content_emb, feature], dim=-1) if not self.hparams.get("no_prev_ans", False): if ans_prev_correctly is None: ans_prev_correctly = torch.ones(batch_size, seq_len, 1, device=self.device, dtype=torch.float) feature = torch.cat([ans_prev_correctly, feature], dim=-1) if hasattr(self, "lstm_in_proj"): feature = self.lstm_in_proj(feature) feature = F.relu(feature) # Apply LSTM sequence_lengths = self.__class__.get_lengths_from_seq_mask(mask) clamped_sequence_lengths = sequence_lengths.clamp(min=1) if self.encoder_type == "augmented_lstm": sorted_feature, sorted_sequence_lengths, restoration_indices, sorting_indices = \ self.__class__.sort_batch_by_length(feature, clamped_sequence_lengths) packed_sequence_input = pack_padded_sequence( sorted_feature, sorted_sequence_lengths.data.tolist(), enforce_sorted=False, batch_first=True) # encoder_out: (batch, seq_len, num_directions * hidden_size): # h_t: (num_layers * num_directions, batch, hidden_size) # - this dimension is valid regardless of batch_first=True!! packed_lstm_out, (h_t, c_t) = self.encoder(packed_sequence_input) # lstm_out: (batch, seq, num_directions * hidden_size) lstm_out, _ = pad_packed_sequence(packed_lstm_out, batch_first=True) lstm_out = lstm_out.index_select(0, restoration_indices) h_t = h_t.index_select(1, restoration_indices) c_t = c_t.index_select(1, restoration_indices) state = (h_t, c_t) elif self.encoder_type == "attention": if initial_state is None: mask_seq_len = seq_len attention_mask = self.__class__ \ .get_attention_mask(src_seq_len=mask_seq_len) \ .to(self.device) permuted_feature = add_positional_features(feature, max_timescale=seq_len) \ .permute(1, 0, 2) query = permuted_feature else: mask_seq_len = seq_len + 1 # initial_state: (batch, 1, dim) initial_state = initial_state[0] feature = torch.cat([feature, initial_state], dim=1) # previous sequence summary vector attention_mask = self.__class__ \ .get_attention_mask(src_seq_len=mask_seq_len, target_seq_len=seq_len) \ .to(self.device) # (seq, N, dim) permuted_feature = add_positional_features(feature, max_timescale=mask_seq_len) \ .permute(1, 0, 2) query = permuted_feature[1:] # att_output: (seq, batch, dim) att_output, att_weight = self.encoder(query=query, key=permuted_feature, value=permuted_feature, attn_mask=attention_mask, need_weights=False) # (batch, seq, dim) lstm_out = att_output.permute(1, 0, 2) # (batch, 1, dim) summary_vec, _ = lstm_out.max(dim=1, keepdim=True) state = [summary_vec] else: packed_sequence_input = pack_padded_sequence( feature, clamped_sequence_lengths.data.tolist(), enforce_sorted=False, batch_first=True) # encoder_out: (batch, seq_len, num_directions * hidden_size): # h_t: (num_layers * num_directions, batch, hidden_size) # - this dimension is valid regardless of batch_first=True!! packed_lstm_out, state = self.encoder(packed_sequence_input, initial_state) # lstm_out: (batch, seq, num_directions * hidden_size) lstm_out, _ = pad_packed_sequence(packed_lstm_out, batch_first=True) if self.hparams.get("layer_norm", False): lstm_out = self.layer_norm(lstm_out) if self.hparams.get("highway_connection", False): c = torch.sigmoid(self.highway_C(lstm_out)) h = self.highway_H(lstm_out) lstm_out = (1 - c) * torch.relu(h) + c * torch.relu(feature) else: lstm_out = F.relu(lstm_out) if self.hparams["output_dropout"] > 0: lstm_out = self.output_dropout(lstm_out) y_pred = torch.squeeze(self.hidden2logit(lstm_out), dim=-1) # (batch, seq) return y_pred, state
def forward(self, context_ids: TextFieldTensors, query_ids: TextFieldTensors, extend_context_ids: torch.Tensor, extend_query_ids: torch.Tensor, context_turn: torch.Tensor, query_turn: torch.Tensor, context_len: torch.Tensor, query_len: torch.Tensor, oovs_len: torch.Tensor, rewrite_input_ids: Optional[TextFieldTensors] = None, rewrite_target_ids: Optional[TextFieldTensors] = None, extend_rewrite_ids: Optional[torch.Tensor] = None, rewrite_len: Optional[torch.Tensor] = None, metadata: Optional[List[Dict[str, Any]]] = None): """前向传播的过程""" context_token_ids = context_ids[self._index_name]["tokens"] query_token_ids = query_ids[self._index_name]["tokens"] # get the extended context and query ids extend_context_ids = context_token_ids + extend_context_ids.to( dtype=torch.long) extend_query_ids = query_token_ids + extend_query_ids.to( dtype=torch.long) # ---------------- 编码器计算输出 ------------ # 计算context和query的embedding context_embed = self._get_embeddings(context_ids, context_turn) query_embed = self._get_embeddings(query_ids, query_turn) # 计算context和query的长度 max_context_len = context_embed.size(1) max_query_len = query_embed.size(1) # 计算mask context_mask = get_mask_from_sequence_lengths( context_len, max_length=max_context_len) query_mask = get_mask_from_sequence_lengths(query_len, max_length=max_query_len) # 计算编码器输出 dialogue_embed = torch.cat([context_embed, query_embed], dim=1) dialogue_mask = torch.cat([context_mask, query_mask], dim=1) # 如果共享编码器参数,需要提交添加位置编码 if self._share_encoder_params: dialogue_embed = add_positional_features(dialogue_embed) # 编码器输出 [B, dialogue_len, encoder_output_dim] dialogue_output = self.encoder(dialogue_embed, dialogue_mask) if self._share_encoder_params: for _ in range(self.encoder_num_layers - 1): dialogue_output = self.encoder(dialogue_output, dialogue_mask) # 计算编码结果 # [B, context_len, *]和[B, query_len, *] context_output, query_output, dec_init_state = self._run_encoder( dialogue_output, context_mask, query_mask) output_dict = {"metadata": metadata} if self.training: rewrite_input_token_ids = rewrite_input_ids[ self._index_name]["tokens"] # 计算rewrite的长度 max_rewrite_len = rewrite_input_token_ids.size(1) rewrite_input_mask = get_mask_from_sequence_lengths( rewrite_len, max_length=max_rewrite_len) rewrite_target_ids = rewrite_target_ids[self._index_name]["tokens"] # 计算rewrite的目标序列索引 rewrite_target_ids = rewrite_target_ids + extend_rewrite_ids.to( dtype=torch.long) # 计算embedding输出,[B, rewrite_len, embedding_size] rewrite_embed = self._get_embeddings(rewrite_input_ids) # 前向传播计算loss new_output_dict = self._forward_step( context_output, query_output, context_mask, query_mask, rewrite_embed, rewrite_target_ids, rewrite_len, rewrite_input_mask, extend_context_ids, extend_query_ids, oovs_len, dec_init_state) output_dict.update(new_output_dict) else: batch_hyps = self._run_inference(context_output, query_output, context_mask, query_mask, extend_context_ids, extend_query_ids, oovs_len, dec_init_state=dec_init_state) # get the result of each instance output_dict['hypothesis'] = batch_hyps output_dict = self.get_rewrite_string(output_dict) output_dict["loss"] = torch.tensor(0) return output_dict
def _eval_decode( self, state: Dict[str, torch.Tensor], segments: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: encoder_outputs = state["encoder_outputs"] source_key_padding_mask = state["source_key_padding_mask"] source_embedded = state["source_raw_embedded"] source_token_mask = state["source_token_mask"] memory_key_padding_mask = (1 - source_token_mask).bool() # memory_key_padding_mask = source_key_padding_mask batch_size = source_key_padding_mask.size(0) encode_length = source_key_padding_mask.size(1) log_probs_after_end = encoder_outputs.new_full( (batch_size, self._num_classes + encode_length), fill_value=float("-inf")) log_probs_after_end[:, self._end_index] = 0. start_predictions = state["source_mask"].new_full( (batch_size, 1), fill_value=self._start_index) partial_generate_predictions = start_predictions partial_copy_predictions = state["source_mask"].new_zeros( (batch_size, 1)) basic_index = torch.arange(batch_size).to( source_embedded.device).unsqueeze(1).long() generate_mask = state["source_mask"].new_ones((batch_size, 1)).float() # shape: (batch_size) last_prediction = start_predictions.squeeze(1) for _ in range(self._max_decoding_step): # shape: (batch_size, partial_len, d_model) partial_source_embedded_input = source_embedded[ basic_index, partial_copy_predictions] partial_target_embedded_input = self._target_embedder( partial_generate_predictions) partial_embedded_input = partial_target_embedded_input * generate_mask.unsqueeze(-1) \ + partial_source_embedded_input * (1 - generate_mask).unsqueeze(-1) partial_embedded_input = util.add_positional_features( partial_embedded_input) partial_len = partial_embedded_input.size(1) partial_embedded_input = partial_embedded_input.permute(1, 0, 2) mask = (torch.triu(state["source_mask"].new_ones( partial_len, partial_len)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) if not self._decode_use_relative_position: # shape: (partial_len, batch_size, d_model) outputs = self._decoder( partial_embedded_input, memory=encoder_outputs, tgt_mask=mask, memory_key_padding_mask=memory_key_padding_mask) else: # gnn decoder edge_mask = get_decode_edge_mask( partial_embedded_input, max_decode_clip_range=self._max_decode_clip_range) tgt_padding_mask = torch.tril(edge_mask.new_ones( [partial_len, partial_len]), diagonal=0) tgt_padding_mask = (1 - tgt_padding_mask.unsqueeze(0).repeat( batch_size, 1, 1)).float() # shape: (partial_len, batch_size, d_model) outputs = self._decoder( partial_embedded_input, edge_mask=edge_mask.permute(0, 2, 3, 1), memory=encoder_outputs, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask) outputs = outputs.permute(1, 0, 2) # shape: (batch_size, d_model) curr_outputs = outputs[:, -1, :] # shape: (batch_size, num_classes) generate_scores = self.get_generate_scores(curr_outputs) # shape: (batch_size, encode_length) copy_scores = self.get_copy_scores( state, curr_outputs.unsqueeze(1)).squeeze(1) # Gate # shape: (batch_size, 1) generate_gate = F.sigmoid(self.gate_linear(curr_outputs)) copy_gate = 1 - generate_gate scores = torch.cat( (generate_scores * generate_gate, copy_scores * copy_gate), dim=-1) # scores = torch.cat((generate_scores, copy_scores), dim=-1) # shape: (batch_size, encode_length) entity_mask = 1 - ( (segments['tokens'] == self._token_index) | (segments['tokens'] == self._non_func_symbol_index) | (segments['tokens'] == self._segment_pad_index)).float() # shape: (batch_size, num_classes + encode_length) score_mask = torch.cat((entity_mask.new_ones( (batch_size, self._num_classes)), entity_mask), dim=-1) # shape: (batch_size, num_classes + encode_length) normalized_scores = util.masked_softmax(scores, mask=score_mask, dim=-1) last_prediction_expanded = last_prediction.unsqueeze(-1).expand( batch_size, self._num_classes + encode_length) # shape: (batch_size, num_classes + encode_length) cleaned_logits = torch.where( last_prediction_expanded == self._end_index, log_probs_after_end, normalized_scores) # shape: (batch_size) _, predicted = torch.max(input=cleaned_logits, dim=1, keepdim=False) copy_mask = (predicted >= self._num_classes).long() generate_predicted = predicted * (1 - copy_mask) copy_predicted = (predicted - self._num_classes) * copy_mask partial_copy_predictions = torch.cat( (partial_copy_predictions, copy_predicted.unsqueeze(1)), dim=1) partial_generate_predictions = torch.cat( (partial_generate_predictions, generate_predicted.unsqueeze(1)), dim=1) generate_mask = torch.cat( (generate_mask, (1 - copy_mask).unsqueeze(1).float()), dim=1) last_prediction = predicted if (last_prediction == self._end_index).sum() == batch_size: break predictions = partial_generate_predictions * generate_mask.long() + \ (1 - generate_mask).long() * (partial_copy_predictions + self._num_classes) # shape: (batch_size, partial_len) output_dict = {"predictions": predictions} return output_dict
def _train_decode( self, state: Dict[str, torch.Tensor], target_tokens: [str, torch.Tensor], generate_targets: torch.Tensor) -> Dict[str, torch.Tensor]: encoder_outputs = state["encoder_outputs"] source_key_padding_mask = state["source_key_padding_mask"] # shape: (batch_size, encode_length, d_model) source_embedded = state["source_raw_embedded"] batch_size, _, _ = source_embedded.size() basic_index = torch.arange(batch_size).to( source_embedded.device).long() generate_targets = generate_targets.long() retrieved_target_embedded_input = source_embedded[ basic_index.unsqueeze(1), generate_targets][:, :-1, :] target_embedded_input = self._target_embedder( target_tokens['tokens'])[:, :-1, :] # shape: (batch_size, max_decode_length) # where 1 indicates that the target token is generated rather than copied generate_mask = (generate_targets == 0).float() target_embedded_input = target_embedded_input * generate_mask[:, :-1].unsqueeze(-1) \ + retrieved_target_embedded_input * (1 - generate_mask)[:, :-1].unsqueeze(-1) target_embedded_input = util.add_positional_features( target_embedded_input) # shape: (max_target_sequence_length - 1, batch_size, d_model) target_embedded_input = target_embedded_input.permute(1, 0, 2) # shape: (batch_size, max_target_sequence_length - 1) """ key_padding_mask should be a ByteTensor where True values are positions that should be masked with float('-inf') and False values will be unchanged. """ target_mask = util.get_text_field_mask(target_tokens)[:, 1:] target_key_padding_mask = (1 - target_mask.byte()).bool() assert target_key_padding_mask.size(1) == target_embedded_input.size(0) and \ target_embedded_input.size(1) == target_key_padding_mask.size(0) max_target_seq_length = target_key_padding_mask.size(1) target_additive_mask = (torch.triu( target_mask.new_ones(max_target_seq_length, max_target_seq_length)) == 1).transpose(0, 1) target_additive_mask = target_additive_mask.float().masked_fill( target_additive_mask == 0, float('-inf')) target_additive_mask = target_additive_mask.masked_fill( target_additive_mask == 1, float(0.0)) assert target_embedded_input.size(1) == encoder_outputs.size(1) source_token_mask = state["source_token_mask"] memory_key_padding_mask = (1 - source_token_mask).bool() # memory_key_padding_mask = source_key_padding_mask if not self._decode_use_relative_position: # shape: (max_target_sequence_length, batch_size, d_model) decoder_outputs = self._decoder( target_embedded_input, memory=encoder_outputs, tgt_mask=target_additive_mask, tgt_key_padding_mask=None, memory_key_padding_mask=memory_key_padding_mask) else: # gnn decoder edge_mask = get_decode_edge_mask( target_embedded_input, max_decode_clip_range=self._max_decode_clip_range) batch_size = edge_mask.size(0) tgt_padding_mask = torch.tril(edge_mask.new_ones( [max_target_seq_length, max_target_seq_length]), diagonal=0) tgt_padding_mask = (1 - (tgt_padding_mask.unsqueeze(0).repeat( batch_size, 1, 1))).float() decoder_outputs = self._decoder( target_embedded_input, edge_mask=edge_mask.permute(0, 2, 3, 1), memory=encoder_outputs, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask) # shape: (batch_size, max_target_sequence_length, d_model) decoder_outputs = decoder_outputs.permute(1, 0, 2) state.update({ "decoder_outputs": decoder_outputs, "target_key_padding_mask": target_key_padding_mask, "target_mask": target_mask, "generate_mask": generate_mask }) return state
def forward( self, inputs: torch.Tensor, semantic_views_q: torch.Tensor, semantic_views_scope_mask: torch.Tensor, mask: torch.Tensor = None, return_output_metadata: bool = False ) -> torch.Tensor: # pylint: disable=arguments-differ if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs total_sublayers = len(self._conv_layers) + 2 sublayer_count = 0 for conv_norm_layer, conv_layer in zip(self._conv_norm_layers, self._conv_layers): conv_norm_out = self.dropout(conv_norm_layer(output)) conv_out = self.dropout( conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, conv_out, sublayer_count, total_sublayers) attention_norm_out = self.dropout(self.attention_norm_layer(output)) bs, seq_len, _ = inputs.shape # pad semantic views if semantic_views_q.shape[1] < self.num_attention_heads: #logging.info("semantic_views_q.dtype:{0}".format(semantic_views_q.dtype)) # # # logging.info("semantic_views_q_pad.dtype:{0}".format(semantic_views_q_pad.dtype)) semantic_views_q = torch.cat([ semantic_views_q, long_fill( (bs, self.num_attention_heads - semantic_views_q.shape[1], seq_len), 0, semantic_views_q.is_cuda) ], dim=1) #logging.info("semantic_views_q.dtype after concat:{0}".format(semantic_views_q.dtype)) #logging.info("semantic_views_scope_mask.dtype:{0}".format(semantic_views_scope_mask.dtype)) semantic_views_scope_mask = torch.cat([ semantic_views_scope_mask, long_fill((bs, self.num_attention_heads - semantic_views_scope_mask.shape[1], seq_len), 0, semantic_views_scope_mask.is_cuda) ], dim=1) if self._replace_zero_semantic_labels_with_per_head_labels: heads_dim_id = 1 views_shape = list(semantic_views_q.shape) views_shape_single_head = views_shape[:] views_shape_single_head[heads_dim_id] = 1 attention_head_mask = torch.cat([ long_fill(views_shape_single_head, x, semantic_views_q.is_cuda) for x in range( self.num_semantic_labels - self.num_attention_heads, self.num_semantic_labels) ], dim=heads_dim_id).contiguous() # mask the values semantic_views_q = semantic_views_q + attention_head_mask * ( semantic_views_q < 1).long() if len(semantic_views_scope_mask.shape) == len( semantic_views_q.shape): semantic_views_scope_mask = semantic_views_scope_mask + attention_head_mask * ( semantic_views_scope_mask < 1).long() attention_output_meta = None if is_output_meta_supported(self.attention_layer): attention_out, attention_output_meta = self.attention_layer( attention_norm_out, semantic_views_q, semantic_views_scope_mask, mask, return_output_metadata=return_output_metadata) attention_out = self.dropout(attention_out) else: attention_out = self.dropout( self.attention_layer(attention_norm_out, semantic_views_q, semantic_views_scope_mask, mask)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, attention_out, sublayer_count, total_sublayers) feedforward_norm_out = self.dropout( self.feedforward_norm_layer(output)) feedforward_out = self.dropout(self.feedforward(feedforward_norm_out)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, feedforward_out, sublayer_count, total_sublayers) return output, attention_output_meta
def forward( self, inputs: torch.Tensor, semantic_views_q: torch.Tensor, semantic_views_k: torch.Tensor, mask: torch.Tensor = None ) -> torch.Tensor: # pylint: disable=arguments-differ if self._use_positional_encoding: output = add_positional_features(inputs) else: output = inputs total_sublayers = len(self._conv_layers) + 2 sublayer_count = 0 for conv_norm_layer, conv_layer in zip(self._conv_norm_layers, self._conv_layers): conv_norm_out = self.dropout(conv_norm_layer(output)) conv_out = self.dropout( conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, conv_out, sublayer_count, total_sublayers) attention_norm_out = self.dropout(self.attention_norm_layer(output)) if self._replace_zero_semantic_labels_with_per_head_labels: heads_dim_id = 1 views_shape = list(semantic_views_q.shape) views_shape_single_head = views_shape[:] views_shape_single_head[heads_dim_id] = 1 attention_head_mask = torch.cat([ torch.full(views_shape_single_head, x).long() for x in range( self.num_semantic_labels - self.num_attention_heads, self.num_semantic_labels) ], dim=heads_dim_id).contiguous() if semantic_views_q.is_cuda: attention_head_mask = attention_head_mask.cuda() # semantic_views_q_zeros = (semantic_views_q < 1).long() # # logging.info("type(semantic_views_q):{0}".format(semantic_views_q.type())) # logging.info("type(attention_head_mask):{0}".format(attention_head_mask.type())) # logging.info("type(semantic_views_q_zeros):{0}".format(semantic_views_q_zeros.type())) # mask the values semantic_views_q = semantic_views_q + attention_head_mask * ( semantic_views_q < 1).long() semantic_views_k = semantic_views_k + attention_head_mask * ( semantic_views_k < 1).long() attention_out = self.dropout( self.attention_layer(attention_norm_out, semantic_views_q, semantic_views_k, mask)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, attention_out, sublayer_count, total_sublayers) feedforward_norm_out = self.dropout( self.feedforward_norm_layer(output)) feedforward_out = self.dropout(self.feedforward(feedforward_norm_out)) sublayer_count += 1 output = self.residual_with_layer_dropout(output, feedforward_out, sublayer_count, total_sublayers) return output
def forward(self, inputs: torch.Tensor, source_memory_bank: torch.Tensor, source_mask: torch.Tensor, target_mask: torch.Tensor, is_train: bool = True) -> Dict: batch_size, source_seq_length, _ = source_memory_bank.size() __, target_seq_length, __ = inputs.size() source_padding_mask = None target_padding_mask = None if source_mask is not None: source_padding_mask = ~source_mask.bool() if target_mask is not None: target_padding_mask = ~target_mask.bool() # project to correct dimensionality outputs = self.input_proj_layer(inputs) # add pos encoding feats outputs = add_positional_features(outputs) # swap to pytorch's batch-second convention outputs = outputs.permute(1, 0, 2) source_memory_bank = source_memory_bank.permute(1, 0, 2) # get a mask ar_mask = self.make_autoregressive_mask(outputs.shape[0]).to(source_memory_bank.device) for i in range(len(self.layers)): outputs , __, __ = self.layers[i](outputs, source_memory_bank, tgt_mask=ar_mask, #memory_mask=None, tgt_key_padding_mask=target_padding_mask, memory_key_padding_mask=source_padding_mask ) # do final norm here if self.prenorm: outputs = self.final_norm(outputs) # switch back from pytorch's absolutely moronic batch-second convention outputs = outputs.permute(1, 0, 2) source_memory_bank = source_memory_bank.permute(1, 0, 2) if not self.use_coverage: source_attention_output = self.source_attn_layer(outputs, source_memory_bank, source_mask, None) attentional_tensors = self.dropout(source_attention_output['attentional']) source_attention_weights = source_attention_output['attention_weights'] coverage_history = None else: # need to do step by step because running sum of coverage source_attention_weights = [] attentional_tensors = [] # init to zeros coverage = inputs.new_zeros(size=(batch_size, 1, source_seq_length)) coverage_history = [] for timestep in range(outputs.shape[1]): output = outputs[:,timestep,:].unsqueeze(1) source_attention_output = self.source_attn_layer(output, source_memory_bank, source_mask, coverage) attentional_tensors.append(source_attention_output['attentional']) source_attention_weights.append(source_attention_output['attention_weights']) coverage = source_attention_output['coverage'] coverage_history.append(coverage) # [batch_size, tgt_seq_len, hidden_dim] attentional_tensors = self.dropout(torch.cat(attentional_tensors, dim=1)) # [batch_size, tgt_seq_len, src_seq_len] source_attention_weights = torch.cat(source_attention_weights, dim=1) coverage_history = torch.cat(coverage_history, dim=1) if is_train: tgt_attn_list = [] for timestep in range(attentional_tensors.shape[1]): bsz, seq_len, __ = attentional_tensors.shape attn_mask = torch.ones((bsz, seq_len)) attn_mask[:,timestep:] = 0 attn_mask = attn_mask.to(attentional_tensors.device) target_attention_output = self.target_attn_layer(attentional_tensors[:,timestep,:].unsqueeze(1), attentional_tensors, mask = attn_mask) if timestep == 0: # zero out weights at 0, effectively banning target copy since there is nothing to copy tgt_attn_list.append(torch.zeros_like(target_attention_output["attention_weights"][:,-1,:].unsqueeze(1))) else: tgt_attn_list.append(target_attention_output["attention_weights"]) target_attention_weights = torch.cat(tgt_attn_list, dim=1) else: target_attention_output = self.target_attn_layer(attentional_tensors, attentional_tensors, mask = target_padding_mask) target_attention_weights = target_attention_output['attention_weights'] return dict( outputs=outputs, output=outputs[:,-1,:].unsqueeze(1), attentional_tensors=attentional_tensors, attentional_tensor=attentional_tensors[:,-1,:].unsqueeze(1), target_attention_weights = target_attention_weights, source_attention_weights = source_attention_weights, coverage_history = coverage_history, )
def forward(self, context: Dict[str, torch.LongTensor], length: torch.LongTensor = None, repeat: torch.FloatTensor = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: #expected_dim = (self.final_classifier_feedforward.get_input_dim() + 1) / 2 expected_dim = (self.final_classifier_feedforward.get_input_dim() + 2) / 2 dia_len = context['tokens'].size()[1] if expected_dim - dia_len > 0: padding = torch.zeros([ context['tokens'].size()[0], (expected_dim - dia_len), context['tokens'].size()[2] ]).long() context['tokens'] = torch.cat([context['tokens'], padding], dim=1) # context: batch_size * dials_len * sentences_len # embedded_context: batch_size * dials_len * sentences_len * emb_dim embedded_context = self.text_field_embedder(context) # utterances_mask: batch_size * dials_len * sentences_len utterances_mask = get_text_field_mask(context, 1).float() # encoded_utterances: batch_size * dials_len * emb_dim encoded_utterances = self.utterances_encoder(embedded_context, utterances_mask) encoded_utterances = add_positional_features(encoded_utterances) #projected_utterances = self.attend_feedforward(encoded_utterances) # similarity_matrix: batch_size * dials_len * output_dim mask = get_text_field_mask(context).float() similarity_matrix = self.matrix_attention(encoded_utterances, mask) # attended_context: batch * (dials_len - 1) * emb_dim attended_context = similarity_matrix[:, :-1, :] # attended_response: batch * (dials_len - 1) * emb_dim attended_response = similarity_matrix[:, 1:, :] # embedded_context: batch_size * (dials_len - 1) * emb_dim embedded_context = encoded_utterances[:, :-1, :] # embedded_response: batch_size * (dials_len - 1) * emb_dim embedded_response = encoded_utterances[:, 1:, :] # context_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim) context_compare_input = torch.cat( [embedded_context, attended_response], dim=-1) # response_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim) response_compare_input = torch.cat( [embedded_response, attended_context], dim=-1) # compared_context: batch * (dials_len - 1) * emb_dim compared_context = self.compare_feedforward(context_compare_input) compared_context = compared_context * mask[:, :-1].unsqueeze(-1) # compared_response: batch * (dials_len - 1) * emb_dim compared_response = self.compare_feedforward(response_compare_input) compared_response = compared_response * mask[:, 1:].unsqueeze(-1) # aggregate_input: batch * (dials_len - 1) * (compare_context_dim + compared_response_dim) aggregate_input = torch.cat([compared_context, compared_response], dim=-1) # class_logits & class_probs: batch * (dials_len - 1) * 2 class_logits = self.classifier_feedforward(aggregate_input) class_probs = F.softmax(class_logits, dim=-1).reshape(class_logits.size()[0], -1) length_tensor = torch.FloatTensor(length).reshape(-1, 1) repeat_tensor = torch.FloatTensor(repeat).reshape(-1, 1) #class_probs = torch.cat([class_probs, length_tensor, repeat_tensor], dim=1) #class_probs = torch.cat([class_probs, repeat_tensor], dim=1) full_logits = self.final_classifier_feedforward(class_probs) full_probs = F.softmax(full_logits, dim=-1) output_dict = { "class_logits": full_logits, "class_probabilities": full_probs } if label is not None: loss = self.loss(full_logits, label.squeeze(-1)) for metric in self.metrics.values(): metric(full_logits, label.squeeze(-1)) output_dict['loss'] = loss return output_dict