Example #1
0
    def test_add_positional_features(self):
        # This is hard to test, so we check that we get the same result as the
        # original tensorflow implementation:
        # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py#L270
        tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
                                              [8.41470957e-01, 9.99999902e-05, 5.40302277e-01, 1.00000000e+00],
                                              [9.09297407e-01, 1.99999980e-04, -4.16146845e-01, 1.00000000e+00]])

        tensor = torch.zeros([2, 3, 4])
        result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4)
        numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result)
        numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result)

        # Check case with odd number of dimensions.
        tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
                                               1.00000000e+00, 1.00000000e+00, 0.00000000e+00],
                                              [8.41470957e-01, 9.99983307e-03, 9.99999902e-05, 5.40302277e-01,
                                               9.99949992e-01, 1.00000000e+00, 0.00000000e+00],
                                              [9.09297407e-01, 1.99986659e-02, 1.99999980e-04, -4.16146815e-01,
                                               9.99800026e-01, 1.00000000e+00, 0.00000000e+00]])

        tensor = torch.zeros([2, 3, 7])
        result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4)
        numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result)
        numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result)
Example #2
0
    def test_add_positional_features(self):
        # This is hard to test, so we check that we get the same result as the
        # original tensorflow implementation:
        # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py#L270
        tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
                                              [8.41470957e-01, 9.99999902e-05, 5.40302277e-01, 1.00000000e+00],
                                              [9.09297407e-01, 1.99999980e-04, -4.16146845e-01, 1.00000000e+00]])

        tensor = torch.zeros([2, 3, 4])
        result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4)
        numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result)
        numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result)

        # Check case with odd number of dimensions.
        tensor2tensor_result = numpy.asarray([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
                                               1.00000000e+00, 1.00000000e+00, 0.00000000e+00],
                                              [8.41470957e-01, 9.99983307e-03, 9.99999902e-05, 5.40302277e-01,
                                               9.99949992e-01, 1.00000000e+00, 0.00000000e+00],
                                              [9.09297407e-01, 1.99986659e-02, 1.99999980e-04, -4.16146815e-01,
                                               9.99800026e-01, 1.00000000e+00, 0.00000000e+00]])

        tensor = torch.zeros([2, 3, 7])
        result = util.add_positional_features(tensor, min_timescale=1.0, max_timescale=1.0e4)
        numpy.testing.assert_almost_equal(result[0].detach().cpu().numpy(), tensor2tensor_result)
        numpy.testing.assert_almost_equal(result[1].detach().cpu().numpy(), tensor2tensor_result)
Example #3
0
    def _encode(self, source_tokens: Dict[str, torch.Tensor], segments: Dict[str, torch.Tensor],
                source_entity_length: torch.Tensor, edge_mask: torch.Tensor, ) -> Dict[str, torch.Tensor]:
        """
        :param source_tokens:
        :param segments:
        :param merge_indicators:
        :return:
        """
        # shape: (batch_size, encode_length, embedding_dim)
        source_embedded_input = self._embed_source(source_tokens, source_entity_length)

        # shape: (batch_size, encode_length, embedding_dim)
        segments_embedded_input = self._segment_embedder(segments)
        encode_length = segments_embedded_input.size(1)

        assert source_embedded_input.size(1) == segments_embedded_input.size(1)

        # token_mask = (segments['tokens'] == self._token_index).unsqueeze(-1).float()
        # valid_token_embedded_input = batched_embedded_input * token_mask
        # valid_token_embedded_input = util.add_positional_features(valid_token_embedded_input)
        # valid_token_embedded_input = batched_embedded_input * (1 - token_mask) + valid_token_embedded_input * token_mask

        if self._source_embedding_dim == self._encoder_d_model:
            batched_embedded_input = segments_embedded_input + source_embedded_input
            final_embedded_input = util.add_positional_features(batched_embedded_input)
        else:
            batched_embedded_input = torch.cat([source_embedded_input, segments_embedded_input], dim=-1)
            final_embedded_input = util.add_positional_features(batched_embedded_input)

        # shape: (encode_length, batch_size, d_model)
        final_embedded_input = final_embedded_input.permute(1, 0, 2)

        # shape: (batch_size, encode_length)
        source_mask = util.get_text_field_mask(segments)
        source_key_padding_mask = (1 - source_mask.byte()).bool()

        if not self._use_gnn_encoder:
            # shape: (encode_length, batch_size, d_model)
            encoder_outputs = self._encoder(final_embedded_input, src_key_padding_mask=source_key_padding_mask)
        else:
            # GNN encoders
            encoder_outputs = self._encoder(src=final_embedded_input, edge_mask=edge_mask.permute(0, 2, 3, 1),
                                            padding_mask=source_key_padding_mask)

        source_token_mask = (segments['tokens'] == self._token_index).float()

        return {
            "source_mask": source_mask,
            "source_key_padding_mask": source_key_padding_mask,
            "source_token_mask": source_token_mask,
            "encoder_outputs": encoder_outputs,
            "source_embedded": batched_embedded_input,
            "source_raw_embedded": source_embedded_input,
        }
Example #4
0
    def forward(
            self,  # type: ignore
            abstract_text: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_abstract_text = self.text_field_embedder(abstract_text,
                                                          num_wrapping_dims=1)
        abstract_text_mask = util.get_text_field_mask(abstract_text,
                                                      num_wrapping_dims=1)

        num_instance, num_sentence, num_word, _ = embedded_abstract_text.size()
        embedded_abstract_text = embedded_abstract_text.view(
            num_instance * num_sentence, num_word, -1)
        abstract_text_mask = abstract_text_mask.view(
            num_instance * num_sentence, -1)

        if self.use_positional_encoding:
            embedded_abstract_text = util.add_positional_features(
                embedded_abstract_text)
        encoded_abstract_text = self.word_encoder(embedded_abstract_text,
                                                  abstract_text_mask)

        attended_sentences = self.word_level_attention(encoded_abstract_text,
                                                       abstract_text_mask)

        attended_sentences = attended_sentences.view(num_instance,
                                                     num_sentence, -1)
        abstract_text_mask = abstract_text_mask.view(num_instance,
                                                     num_sentence,
                                                     -1).sum(2).ge(1).long()

        if self.use_positional_encoding:
            attended_sentences = util.add_positional_features(
                attended_sentences)
        attended_sentences = self.sentence_encoder(attended_sentences,
                                                   abstract_text_mask)

        attended_abstract_text = self.sentence_level_attention(
            attended_sentences, abstract_text_mask)

        outputs = self.classifier_feedforward(attended_abstract_text)
        logits = torch.sigmoid(outputs)
        logits = logits.unsqueeze(0) if logits.dim() < 2 else logits
        output_dict = {'logits': logits}

        if label is not None:
            outputs = outputs.unsqueeze(0) if outputs.dim() < 2 else outputs
            loss = self.loss(outputs, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Example #5
0
    def forward(self, **inputs):
        x = self.aa_embedder(inputs['aa']) + self.pssm_embedder(inputs['pssm'])
        if 'ss' in inputs and self.use_ss:
            x += self.ss_embedder(inputs['ss'])
        if self.use_positional_encoding:
            add_positional_features(x)
        x = self.input_dropout(x)

        mask = get_text_field_mask(inputs['aa'])
        x = self.encoder(x, mask)

        if 'msa' in inputs and self.msa_encoder is not None:
            x += self.msa_encoder(inputs['msa'])

        x = self.decoder(self.feedforward(x))

        outputs = {
            'predictions': x,
            'protein_id': inputs['protein_id'],
            'length': inputs['length']
        }
        if 'dcalpha' in inputs:
            mask = mask.unsqueeze(-1).float()
            if self.target == 'dcalpha':
                target = inputs['dcalpha']
                mask = mask.matmul(mask.transpose(dim0=-2, dim1=-1))
                mask_triu = torch.triu(torch.ones_like(mask[0]),
                                       diagonal=1).unsqueeze(0).to(mask.device)
                mask *= mask_triu
            elif self.target == 'angles':
                target = torch.stack([inputs['psi'], inputs['phi']], dim=2)
            else:
                target = inputs['coords']

            mse = ((mask * (x - target))**2).sum() / mask.sum()

            if torch.isnan(mse):
                while True:
                    expr = input('\nInput = ')
                    if expr == 'q':
                        exit(0)
                    try:
                        print(eval(expr))
                    except Exception as e:
                        print(e)

            self.metrics['rmse']((mse**0.5).detach().item())
            outputs['loss'] = mse

        return outputs
Example #6
0
    def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor):
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs
        for i in range(len(self._attention_layers)):
            # It's necessary to use `getattr` here because the elements stored
            # in the lists are not replicated by torch.nn.parallel.replicate
            # when running on multiple GPUs. Please use `ModuleList` in new
            # code. It handles this issue transparently. We've kept `add_module`
            # (in conjunction with `getattr`) solely for backwards compatibility
            # with existing serialized models.
            attention = getattr(self, f"self_attention_{i}")
            feedforward = getattr(self, f"feedforward_{i}")
            feedforward_layer_norm = getattr(self,
                                             f"feedforward_layer_norm_{i}")
            layer_norm = getattr(self, f"layer_norm_{i}")
            cached_input = output
            # Project output of attention encoder through a feedforward
            # network and back to the input size for the next layer.
            # shape (batch_size, timesteps, input_size)
            feedforward_output = feedforward(output)
            feedforward_output = self.dropout(feedforward_output)
            if feedforward_output.size() == cached_input.size():
                # First layer might have the wrong size for highway
                # layers, so we exclude it here.
                feedforward_output = feedforward_layer_norm(
                    feedforward_output + cached_input)
            # shape (batch_size, sequence_length, hidden_dim)
            attention_output = attention(feedforward_output, mask)
            output = layer_norm(
                self.dropout(attention_output) + feedforward_output)

        return output
Example #7
0
    def forward(self,  # type: ignore
                abstract_text: Dict[str, torch.LongTensor],
                local_label: torch.LongTensor = None, 
                global_label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_abstract_text = self.text_field_embedder(abstract_text)
        abstract_text_mask = util.get_text_field_mask(abstract_text)
        if self.use_positional_encoding:
            embedded_abstract_text = util.add_positional_features(embedded_abstract_text)
        encoded_abstract_text = self.abstract_text_encoder(embedded_abstract_text, abstract_text_mask)
        
        attended_abstract_text = self.attention_encoder(encoded_abstract_text, abstract_text_mask)
        local_outputs, global_outputs = self.HMCN_recurrent(attended_abstract_text)
        logits = self.local_globel_tradeoff * global_outputs + (1 - self.local_globel_tradeoff) * local_outputs
        logits = torch.sigmoid(logits)
        logits = logits.unsqueeze(0) if logits.dim() < 2 else logits
        output_dict = {'logits': logits}

        if local_label is not None and global_label is not None:
            local_outputs = local_outputs.unsqueeze(0) if local_outputs.dim() < 2 else local_outputs
            global_outputs = global_outputs.unsqueeze(0) if global_outputs.dim() < 2 else global_outputs
            loss = self.loss(local_outputs, global_outputs, local_label.squeeze(-1), global_label.squeeze(-1))
            for metric in self.metrics.values():
                metric(logits, global_label.squeeze(-1))
            output_dict["loss"] = loss
            
        return output_dict
Example #8
0
    def forward(
            self,  # type: ignore
            abstract_text: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_abstract_text = self.text_field_embedder(abstract_text)
        abstract_text_mask = util.get_text_field_mask(abstract_text)
        encoded_abstract_text = self.abstract_text_encoder(
            embedded_abstract_text, abstract_text_mask)

        if self.use_positional_encoding:
            encoded_abstract_text = util.add_positional_features(
                encoded_abstract_text)

        attended_abstract_text = self.attention_encoder(
            encoded_abstract_text, abstract_text_mask)
        outputs = self.classifier_feedforward(attended_abstract_text)
        logits = torch.sigmoid(outputs)
        logits = logits.unsqueeze(0) if len(logits.size()) < 2 else logits
        output_dict = {'logits': logits}

        if label is not None:
            outputs = outputs.unsqueeze(0) if len(
                outputs.size()) < 2 else outputs
            loss = self.loss(outputs, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Example #9
0
    def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor):
        output = inputs
        if self._sinusoidal_positional_encoding:
            output = add_positional_features(output)
        if self._positional_embedding is not None:
            position_ids = torch.arange(inputs.size(1),
                                        dtype=torch.long,
                                        device=output.device)
            position_ids = position_ids.unsqueeze(0).expand(inputs.shape[:-1])
            output = output + self._positional_embedding(position_ids)

        # print()
        # print(sum(output[0][4]), sum(output[0][100]))

        # For some reason the torch transformer expects the shape (sequence, batch, features), not the more
        # familiar (batch, sequence, features), so we have to fix it.
        output = output.permute(1, 0, 2)
        # For some other reason, the torch transformer takes the mask backwards.
        mask = ~mask
        output = self._transformer(output, src_key_padding_mask=mask)
        output = output.permute(1, 0, 2)
        # print(sum(inputs[0][4]), sum(inputs[0][100]))
        # print(sum(output[0][4]), sum(output[0][100]))
        # print()

        return output
Example #10
0
 def forward(self, inputs, mask):  # pylint: disable=arguments-differ
     if self._use_positional_encoding:
         output = add_positional_features(inputs)
     else:
         output = inputs
     for (attention, feedforward, feedforward_layer_norm,
          layer_norm) in zip(
              self._attention_layers,
              self._feedfoward_layers,
              self._feed_forward_layer_norm_layers,
              self._layer_norm_layers,
          ):
         cached_input = output
         # Project output of attention encoder through a feedforward
         # network and back to the input size for the next layer.
         # shape (batch_size, timesteps, input_size)
         feedforward_output = feedforward(feedforward_layer_norm(output))
         feedforward_output = self.dropout(feedforward_output)
         if feedforward_output.size() == cached_input.size():
             # First layer might have the wrong size for highway
             # layers, so we exclude it here.
             feedforward_output += cached_input
         # shape (batch_size, sequence_length, hidden_dim)
         attention_output = attention(layer_norm(feedforward_output), mask)
         output = self.dropout(attention_output) + feedforward_output
     return self._output_layer_norm(output)
    def forward(self, inputs: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs
        for (attention,
             feedforward,
             feedforward_layer_norm,
             layer_norm) in zip(self._attention_layers,
                                self._feedfoward_layers,
                                self._feed_forward_layer_norm_layers,
                                self._layer_norm_layers):
            cached_input = output
            # Project output of attention encoder through a feedforward
            # network and back to the input size for the next layer.
            # shape (batch_size, timesteps, input_size)
            feedforward_output = feedforward(output)
            feedforward_output = self.dropout(feedforward_output)
            if feedforward_output.size() == cached_input.size():
                # First layer might have the wrong size for highway
                # layers, so we exclude it here.
                feedforward_output = feedforward_layer_norm(feedforward_output + cached_input)
            # shape (batch_size, sequence_length, hidden_dim)
            attention_output = attention(feedforward_output, mask)
            output = layer_norm(self.dropout(attention_output) + feedforward_output)

        return output
Example #12
0
    def forward(self, inputs: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs

        total_sublayers = len(self._conv_layers) + 2
        sublayer_count = 0

        for conv_norm_layer, conv_layer in zip(self._conv_norm_layers, self._conv_layers):
            conv_norm_out = self.dropout(conv_norm_layer(output))
            conv_out = self.dropout(conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2))
            sublayer_count += 1
            output = self.residual_with_layer_dropout(
                output, conv_out, sublayer_count, total_sublayers
            )

        attention_norm_out = self.dropout(self.attention_norm_layer(output))
        attention_out = self.dropout(self.attention_layer(attention_norm_out, mask))
        sublayer_count += 1
        output = self.residual_with_layer_dropout(
            output, attention_out, sublayer_count, total_sublayers
        )

        feedforward_norm_out = self.dropout(self.feedforward_norm_layer(output))
        feedforward_out = self.dropout(self.feedforward(feedforward_norm_out))
        sublayer_count += 1
        output = self.residual_with_layer_dropout(
            output, feedforward_out, sublayer_count, total_sublayers
        )
        return output
Example #13
0
    def forward(self,
                inputs: torch.Tensor,
                mask: Optional[torch.Tensor] = None
                ) -> torch.Tensor:
        output = add_positional_features(inputs)
        output = self.projection(output)

        for block in self.blocks:
            output = block(output, mask)

        return output
Example #14
0
    def _run_transformer_decoder(self, context_output: torch.Tensor,
                                 query_output: torch.Tensor,
                                 context_mask: torch.Tensor,
                                 query_mask: torch.Tensor,
                                 rewrite_embed: torch.Tensor,
                                 rewrite_mask: torch.Tensor):
        """
        实现Transformer解码器的decoder过程
        :param _output: [B, _len, d_model]
        :param _mask: [B, _len]
        :param rewrite_embed: [B, cur_dec_len, d_model]
        :param rewrite_mask: [B, cur_dec_len],这里只是pad的mask,上三角mask在decoder内部实现
        """
        if self._share_decoder_params:
            rewrite_embed = add_positional_features(rewrite_embed)

        previous_state = None
        encoder_outputs = {
            "context_output": context_output,
            "query_output": query_output
        }
        source_mask = {"context_mask": context_mask, "query_mask": query_mask}
        # dec_output: [B, dec_len, d_model]
        # context_attn: [B, num_heads, dec_len, context_len]
        # query_attn: [B, num_heads, dec_len, query_len]
        # x_context: [B, dec_len, d_model]
        # x_query: [B, dec_len, d_model]
        dec_output, context_attn, query_attn, x_context, x_query = self.decoder(
            previous_state, encoder_outputs, source_mask, rewrite_embed,
            rewrite_mask)
        # 如果共享解码器的参数
        if self._share_decoder_params:
            for _ in range(self.decoder_num_layers - 1):
                dec_output, context_attn, query_attn, x_context, x_query = self.decoder(
                    previous_state, encoder_outputs, source_mask, dec_output,
                    rewrite_mask)

        # sum the attention dists of different heads
        context_attn = torch.sum(context_attn, dim=1, keepdim=False)
        query_attn = torch.sum(query_attn, dim=1, keepdim=False)
        # mask softmax get the final attention dists
        context_attn = masked_softmax(context_attn, context_mask, dim=-1)
        query_attn = masked_softmax(query_attn, query_mask, dim=-1)

        # compute lambda
        # [B, dec_len, 2]
        # 注意这里和LSTM解码器的区别
        # Transformer解码器是一次解码全部输出的,所以需要包含len维度
        lamb = self._compute_lambda(dec_output,
                                    dec_context=x_context,
                                    dec_query=x_query)
        return dec_output, context_attn, query_attn, lamb
Example #15
0
    def forward(self,
                src: torch.Tensor,
                kb: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                kb_mask: Optional[torch.Tensor] = None,
                ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        src = add_positional_features(src)
        src = self.src_projection(src)
        kb = self.kb_projection(kb)

        for i in range(self.num_layers):
            src = self.src_self_blocks[i](src, src_mask)
            kb = self.kb_self_blocks[i](kb, kb_mask)

            src, kb = self.rel_blocks[i](kb, src, kb_mask, src_mask)

        if self.return_kb:
            return src, kb
        return src
Example #16
0
    def _forward(self,
                 embedding_sequence: torch.Tensor,
                 state: Dict[str, torch.Tensor],
                 apply_target_subsequent_mask: bool = True) -> torch.Tensor:

        if self.use_position_encoding:
            embedding_sequence = add_positional_features(embedding_sequence)

        output = embedding_sequence
        for layer in self.layers:
            output = layer(
                output,
                state['target_mask'] if 'target_mask' in state else None,
                apply_target_subsequent_mask=apply_target_subsequent_mask)

        if self.pre_norm:
            # We need to add an additional function of layer normalization to the top layer
            # to prevent the excessively increased value caused by the sum of unnormalized output
            # (https://arxiv.org/pdf/1906.01787.pdf)
            output = self.output_layer_norm(output)

        return output
    def forward(
        self,
        previous_state: Dict[str, torch.Tensor],
        encoder_outputs: torch.Tensor,
        source_mask: torch.Tensor,
        previous_steps_predictions: torch.Tensor,
        previous_steps_mask: Optional[torch.Tensor] = None
    ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        seq_len = previous_steps_predictions.size(-2)
        future_mask = torch.triu(
            torch.ones(seq_len,
                       seq_len,
                       device=source_mask.device,
                       dtype=torch.float)).transpose(0, 1)
        future_mask = future_mask.masked_fill(future_mask == 0,
                                              float('-inf')).masked_fill(
                                                  future_mask == 1, float(0.0))
        future_mask = Variable(future_mask)

        if self._use_positional_encoding:
            previous_steps_predictions = add_positional_features(
                previous_steps_predictions)
        previous_steps_predictions = self._dropout(previous_steps_predictions)

        source_mask = ~(source_mask.bool())
        if previous_steps_mask is not None:
            previous_steps_mask = ~(previous_steps_mask.bool())
        previous_steps_predictions = previous_steps_predictions.permute(
            1, 0, 2)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        output = self._decoder(previous_steps_predictions,
                               encoder_outputs,
                               tgt_mask=future_mask,
                               tgt_key_padding_mask=previous_steps_mask,
                               memory_key_padding_mask=source_mask)
        return {}, output.permute(1, 0, 2)
    def forward(self, embedding_sequence: torch.LongTensor) -> torch.Tensor:
        """

        Parameters
        ----------
        embedding_sequence : (batch_size, sequence_length, embedding_size)

        Returns
        -------
        (batch_size, sequence_length, embedding_size)
        """

        if self.learned:

            batch_size, sequence_length, _ = embedding_sequence.size()

            position_indices = torch.arange(sequence_length).to(
                self.embedding.weight.device)
            position_embeddings = self.embedding(position_indices)
            position_embeddings = position_embeddings.unsqueeze(0).expand(
                batch_size, sequence_length, -1)
            return embedding_sequence + position_embeddings
        else:
            return util.add_positional_features(embedding_sequence)
Example #19
0
    def forward(self,
                query_embeddings: torch.Tensor,
                document_embeddings: torch.Tensor,
                query_pad_oov_mask: torch.Tensor,
                document_pad_oov_mask: torch.Tensor,
                output_secondary_output: bool = False) -> torch.Tensor:
        # pylint: disable=arguments-differ

        query_embeddings = query_embeddings * query_pad_oov_mask.unsqueeze(-1)
        document_embeddings = document_embeddings * document_pad_oov_mask.unsqueeze(
            -1)

        query_embeddings_context = self.contextualizer(
            add_positional_features(query_embeddings).transpose(1, 0),
            src_key_padding_mask=~query_pad_oov_mask.bool()).transpose(1, 0)
        document_embeddings_context = self.contextualizer(
            add_positional_features(document_embeddings).transpose(1, 0),
            src_key_padding_mask=~document_pad_oov_mask.bool()).transpose(
                1, 0)

        query_embeddings = (self.mixer * query_embeddings +
                            (1 - self.mixer) * query_embeddings_context
                            ) * query_pad_oov_mask.unsqueeze(-1)
        document_embeddings = (self.mixer * document_embeddings +
                               (1 - self.mixer) * document_embeddings_context
                               ) * document_pad_oov_mask.unsqueeze(-1)

        #
        # prepare embedding tensors & paddings masks
        # -------------------------------------------------------

        query_by_doc_mask = torch.bmm(
            query_pad_oov_mask.unsqueeze(-1),
            document_pad_oov_mask.unsqueeze(-1).transpose(-1, -2))
        query_by_doc_mask_view = query_by_doc_mask.unsqueeze(-1)

        #
        # cosine matrix
        # -------------------------------------------------------

        # shape: (batch, query_max, doc_max)
        cosine_matrix = self.cosine_module.forward(query_embeddings,
                                                   document_embeddings)
        cosine_matrix_masked = cosine_matrix * query_by_doc_mask
        cosine_matrix_extradim = cosine_matrix_masked.unsqueeze(-1)

        #
        # gaussian kernels & soft-TF
        #
        # first run through kernel, then sum on doc dim then sum on query dim
        # -------------------------------------------------------

        raw_kernel_results = torch.exp(
            -torch.pow(cosine_matrix_extradim - self.mu, 2) /
            (2 * torch.pow(self.sigma, 2)))
        kernel_results_masked = raw_kernel_results * query_by_doc_mask_view

        #
        # mean kernels
        #
        #kernel_results_masked2 = kernel_results_masked.clone()

        doc_lengths = torch.sum(document_pad_oov_mask, 1)

        #kernel_results_masked2_mean = kernel_results_masked / doc_lengths.unsqueeze(-1)

        per_kernel_query = torch.sum(kernel_results_masked, 2)
        log_per_kernel_query = torch.log2(
            torch.clamp(per_kernel_query, min=1e-10)) * self.nn_scaler
        log_per_kernel_query_masked = log_per_kernel_query * query_pad_oov_mask.unsqueeze(
            -1)  # make sure we mask out padding values
        per_kernel = torch.sum(log_per_kernel_query_masked, 1)

        #per_kernel_query_mean = torch.sum(kernel_results_masked2_mean, 2)

        per_kernel_query_mean = per_kernel_query / (
            doc_lengths.view(-1, 1, 1) + 1
        )  # well, that +1 needs an explanation, sometimes training data is just broken ... (and nans all the things!)

        log_per_kernel_query_mean = per_kernel_query_mean * self.nn_scaler
        log_per_kernel_query_masked_mean = log_per_kernel_query_mean * query_pad_oov_mask.unsqueeze(
            -1)  # make sure we mask out padding values
        per_kernel_mean = torch.sum(log_per_kernel_query_masked_mean, 1)

        ##
        ## "Learning to rank" layer - connects kernels with learned weights
        ## -------------------------------------------------------

        dense_out = self.dense(per_kernel)
        dense_mean_out = self.dense_mean(per_kernel_mean)
        dense_comb_out = self.dense_comb(
            torch.cat([dense_out, dense_mean_out], dim=1))
        score = torch.squeeze(dense_comb_out, 1)  #torch.tanh(dense_out), 1)

        if output_secondary_output:
            query_mean_vector = query_embeddings.sum(
                dim=1) / query_pad_oov_mask.sum(dim=1).unsqueeze(-1)
            return score, {
                "score": score,
                "dense_out": dense_out,
                "dense_mean_out": dense_mean_out,
                "per_kernel": per_kernel,
                "per_kernel_mean": per_kernel_mean,
                "query_mean_vector": query_mean_vector,
                "cosine_matrix_masked": cosine_matrix_masked
            }
        else:
            return score
    def forward(self, source: torch.Tensor, target: torch.Tensor, metadata: dict, seq_labels: torch.Tensor = None,
                reg_labels: torch.Tensor = None, source_mask: Optional[torch.Tensor]=None,
                target_mask: Optional[torch.Tensor]=None, memory_mask: Optional[torch.Tensor]=None,
                src_key_padding_mask: Optional[torch.Tensor]=None, tgt_key_padding_mask: Optional[torch.Tensor]=None,
                memory_key_padding_mask: Optional[torch.Tensor]=None) -> Dict[str, torch.Tensor]:
        r"""Take in and process masked source/target sequences.
        Args:
            source: the sequence to the encoder (required).
            target: the sequence to the decoder (required).
            metadata: the metadata of the samples (required).
            seq_labels: the labels of each round (optional).
            reg_labels: the labels of the total future payoff (optional).
            source_mask: the additive mask for the src sequence (optional).
            target_mask: the additive mask for the tgt sequence (optional).
            memory_mask: the additive mask for the encoder output (optional).
            src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
            tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
            memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
        Shape:
            - source: :math:`(S, N, E)`.
            - target: :math:`(T, N, E)`.
            - source_mask: :math:`(S, S)`.
            - target_mask: :math:`(T, T)`.
            - memory_mask: :math:`(T, S)`.
            - src_key_padding_mask: :math:`(N, S)`.
            - tgt_key_padding_mask: :math:`(N, T)`.
            - memory_key_padding_mask: :math:`(N, S)`.
            Note: [src/tgt/memory]_mask should be filled with
            float('-inf') for the masked positions and float(0.0) else. These masks
            ensure that predictions for position i depend only on the unmasked positions
            j and are applied identically for each sequence in a batch.
            [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
            that should be masked with float('-inf') and False values will be unchanged.
            This mask ensures that no information will be taken from position i if
            it is masked, and has a separate mask for each sequence in a batch.
            - output: :math:`(T, N, E)`.
            Note: Due to the multi-head attention architecture in the transformer model,
            the output sequence length of a transformer is same as the input sequence
            (i.e. target) length of the decode.
            where S is the source sequence length, T is the target sequence length, N is the
            batch size, E is the feature number
        """

        if self._first_pair is not None:
            if self._first_pair == metadata[0]['pair_id']:
                self._epoch += 1
        else:
            self._first_pair = metadata[0]['pair_id']

        output = dict()
        src_key_padding_mask = get_text_field_mask({'tokens': source})
        tgt_key_padding_mask = get_text_field_mask({'tokens': target})
        # The torch transformer takes the mask backwards.
        src_key_padding_mask_byte = ~src_key_padding_mask.bool()
        tgt_key_padding_mask_byte = ~tgt_key_padding_mask.bool()
        # create mask where only the first round is not masked --> need to be the same first round for all sequances
        if self.only_raisha:
            temp_mask = torch.ones(tgt_key_padding_mask_byte.shape, dtype=torch.bool)
            temp_mask[:, 0] = False
            tgt_key_padding_mask_byte = temp_mask

        if self._sinusoidal_positional_encoding:
            source = add_positional_features(source)
            target = add_positional_features(target)

        if torch.cuda.is_available():  # change to cuda
            source = source.cuda()
            target = target.cuda()
            tgt_key_padding_mask = tgt_key_padding_mask.cuda()
            src_key_padding_mask = src_key_padding_mask.cuda()
            tgt_key_padding_mask_byte = tgt_key_padding_mask_byte.cuda()
            src_key_padding_mask_byte = src_key_padding_mask_byte.cuda()
            if seq_labels is not None:
                seq_labels = seq_labels.cuda()
            if reg_labels is not None:
                reg_labels = reg_labels.cuda()

        # The torch transformer expects the shape (sequence, batch, features), not the more
        # familiar (batch, sequence, features), so we have to fix it.
        source = source.permute(1, 0, 2)
        target = target.permute(1, 0, 2)

        if source.size(1) != target.size(1):
            raise RuntimeError("the batch number of src and tgt must be equal")

        if source.size(2) != self._input_dim or target.size(2) != self._input_dim:
            raise RuntimeError("the feature number of src and tgt must be equal to d_model")

        encoder_out = self.encoder(source, src_key_padding_mask=src_key_padding_mask_byte)
        decoder_output = self.decoder(target, encoder_out, tgt_key_padding_mask=tgt_key_padding_mask_byte,
                                      memory_key_padding_mask=src_key_padding_mask_byte)
        decoder_output = decoder_output.permute(1, 0, 2)

        if self.predict_seq:
            if self.linear_layer is not None:
                decoder_output = self.linear_layer(decoder_output)  # add linear layer before hidden2tag
            decision_logits = self.hidden2tag(decoder_output)
            output['decision_logits'] = masked_softmax(decision_logits, tgt_key_padding_mask.unsqueeze(2))
            self.seq_predictions = save_predictions_seq_models(prediction_df=self.seq_predictions,
                                                               mask=tgt_key_padding_mask,
                                                               predictions=output['decision_logits'],
                                                               gold_labels=seq_labels, metadata=metadata,
                                                               epoch=self._epoch, is_train=self.training,)

        if self.predict_avg_total_payoff:
            # (batch_size, seq_len, dimensions) * (batch_size, dimensions, 1) -> (batch_size, seq_len)
            attention_output = self.attention(self.attention_vector, decoder_output, tgt_key_padding_mask)
            # (batch_size, 1, seq_len) * (batch_size, seq_len, dimensions) -> (batch_size, dimensions)
            attention_output = torch.bmm(attention_output.unsqueeze(1), decoder_output).squeeze()
            # (batch_size, dimensions) -> (batch_size, batch_size)
            linear_out = self.linear_after_attention_layer(attention_output)
            # (batch_size, batch_size) -> (batch_size, 1)
            regression_output = self.regressor(linear_out)
            output['regression_output'] = regression_output
            self.reg_predictions = save_predictions(prediction_df=self.reg_predictions,
                                                    predictions=output['regression_output'], gold_labels=reg_labels,
                                                    metadata=metadata, epoch=self._epoch, is_train=self.training,
                                                    int_label=False)

        if seq_labels is not None or reg_labels is not None:
            temp_loss = 0
            if self.predict_seq and seq_labels is not None:
                for metric_name, metric in self.metrics_dict_seq.items():
                    metric(decision_logits, seq_labels, tgt_key_padding_mask)
                output['seq_loss'] = sequence_cross_entropy_with_logits(decision_logits, seq_labels,
                                                                        tgt_key_padding_mask)
                temp_loss += self.seq_weight_loss * output['seq_loss']
            if self.predict_avg_total_payoff and reg_labels is not None:
                for metric_name, metric in self.metrics_dict_reg.items():
                    metric(regression_output, reg_labels, tgt_key_padding_mask)
                output['reg_loss'] = self.mse_loss(regression_output, reg_labels.view(reg_labels.shape[0], -1))
                temp_loss += self.reg_weight_loss * output['reg_loss']

            output['loss'] = temp_loss

        return output
Example #21
0
    def forward(self,
                content_id: torch.LongTensor,
                bundle_id: torch.LongTensor,
                feature: torch.FloatTensor,
                user_id: torch.FloatTensor,
                mask: torch.Tensor,
                initial_state: Optional[TensorPair] = None,
                ans_prev_correctly: Optional[torch.Tensor] = None):

        batch_size, seq_len = content_id.shape
        # content_emb: (batch, seq, dim)
        content_emb = self.content_id_emb(content_id)

        if self.hparams["emb_dropout"] > 0:
            content_emb = self.emb_dropout(content_emb)

        # content_emb: (batch, seq, dim)
        feature = torch.cat([content_emb, feature], dim=-1)

        if not self.hparams.get("no_prev_ans", False):
            if ans_prev_correctly is None:
                ans_prev_correctly = torch.ones(batch_size,
                                                seq_len,
                                                1,
                                                device=self.device,
                                                dtype=torch.float)
            feature = torch.cat([ans_prev_correctly, feature], dim=-1)

        if hasattr(self, "lstm_in_proj"):
            feature = self.lstm_in_proj(feature)
            feature = F.relu(feature)

        # Apply LSTM
        sequence_lengths = self.__class__.get_lengths_from_seq_mask(mask)
        clamped_sequence_lengths = sequence_lengths.clamp(min=1)

        if self.encoder_type == "augmented_lstm":
            sorted_feature, sorted_sequence_lengths, restoration_indices, sorting_indices = \
                self.__class__.sort_batch_by_length(feature, clamped_sequence_lengths)
            packed_sequence_input = pack_padded_sequence(
                sorted_feature,
                sorted_sequence_lengths.data.tolist(),
                enforce_sorted=False,
                batch_first=True)
            # encoder_out: (batch, seq_len, num_directions * hidden_size):
            # h_t: (num_layers * num_directions, batch, hidden_size)
            #    - this dimension is valid regardless of batch_first=True!!
            packed_lstm_out, (h_t, c_t) = self.encoder(packed_sequence_input)
            # lstm_out: (batch, seq, num_directions * hidden_size)
            lstm_out, _ = pad_packed_sequence(packed_lstm_out,
                                              batch_first=True)

            lstm_out = lstm_out.index_select(0, restoration_indices)
            h_t = h_t.index_select(1, restoration_indices)
            c_t = c_t.index_select(1, restoration_indices)
            state = (h_t, c_t)
        elif self.encoder_type == "attention":
            if initial_state is None:
                mask_seq_len = seq_len
                attention_mask = self.__class__ \
                    .get_attention_mask(src_seq_len=mask_seq_len) \
                    .to(self.device)
                permuted_feature = add_positional_features(feature, max_timescale=seq_len) \
                    .permute(1, 0, 2)
                query = permuted_feature
            else:
                mask_seq_len = seq_len + 1

                # initial_state: (batch, 1, dim)
                initial_state = initial_state[0]
                feature = torch.cat([feature, initial_state],
                                    dim=1)  # previous sequence summary vector

                attention_mask = self.__class__ \
                    .get_attention_mask(src_seq_len=mask_seq_len,
                                        target_seq_len=seq_len) \
                    .to(self.device)

                # (seq, N, dim)
                permuted_feature = add_positional_features(feature, max_timescale=mask_seq_len) \
                    .permute(1, 0, 2)
                query = permuted_feature[1:]

            # att_output: (seq, batch, dim)
            att_output, att_weight = self.encoder(query=query,
                                                  key=permuted_feature,
                                                  value=permuted_feature,
                                                  attn_mask=attention_mask,
                                                  need_weights=False)
            # (batch, seq, dim)
            lstm_out = att_output.permute(1, 0, 2)

            # (batch, 1, dim)
            summary_vec, _ = lstm_out.max(dim=1, keepdim=True)
            state = [summary_vec]
        else:
            packed_sequence_input = pack_padded_sequence(
                feature,
                clamped_sequence_lengths.data.tolist(),
                enforce_sorted=False,
                batch_first=True)

            # encoder_out: (batch, seq_len, num_directions * hidden_size):
            # h_t: (num_layers * num_directions, batch, hidden_size)
            #    - this dimension is valid regardless of batch_first=True!!
            packed_lstm_out, state = self.encoder(packed_sequence_input,
                                                  initial_state)

            # lstm_out: (batch, seq, num_directions * hidden_size)
            lstm_out, _ = pad_packed_sequence(packed_lstm_out,
                                              batch_first=True)

        if self.hparams.get("layer_norm", False):
            lstm_out = self.layer_norm(lstm_out)

        if self.hparams.get("highway_connection", False):
            c = torch.sigmoid(self.highway_C(lstm_out))
            h = self.highway_H(lstm_out)

            lstm_out = (1 - c) * torch.relu(h) + c * torch.relu(feature)
        else:
            lstm_out = F.relu(lstm_out)

        if self.hparams["output_dropout"] > 0:
            lstm_out = self.output_dropout(lstm_out)

        y_pred = torch.squeeze(self.hidden2logit(lstm_out),
                               dim=-1)  # (batch, seq)

        return y_pred, state
Example #22
0
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                extend_context_ids: torch.Tensor,
                extend_query_ids: torch.Tensor,
                context_turn: torch.Tensor,
                query_turn: torch.Tensor,
                context_len: torch.Tensor,
                query_len: torch.Tensor,
                oovs_len: torch.Tensor,
                rewrite_input_ids: Optional[TextFieldTensors] = None,
                rewrite_target_ids: Optional[TextFieldTensors] = None,
                extend_rewrite_ids: Optional[torch.Tensor] = None,
                rewrite_len: Optional[torch.Tensor] = None,
                metadata: Optional[List[Dict[str, Any]]] = None):
        """前向传播的过程"""
        context_token_ids = context_ids[self._index_name]["tokens"]
        query_token_ids = query_ids[self._index_name]["tokens"]

        # get the extended context and query ids
        extend_context_ids = context_token_ids + extend_context_ids.to(
            dtype=torch.long)
        extend_query_ids = query_token_ids + extend_query_ids.to(
            dtype=torch.long)

        # ---------------- 编码器计算输出 ------------
        # 计算context和query的embedding
        context_embed = self._get_embeddings(context_ids, context_turn)
        query_embed = self._get_embeddings(query_ids, query_turn)
        # 计算context和query的长度
        max_context_len = context_embed.size(1)
        max_query_len = query_embed.size(1)
        # 计算mask
        context_mask = get_mask_from_sequence_lengths(
            context_len, max_length=max_context_len)
        query_mask = get_mask_from_sequence_lengths(query_len,
                                                    max_length=max_query_len)
        # 计算编码器输出
        dialogue_embed = torch.cat([context_embed, query_embed], dim=1)
        dialogue_mask = torch.cat([context_mask, query_mask], dim=1)
        # 如果共享编码器参数,需要提交添加位置编码
        if self._share_encoder_params:
            dialogue_embed = add_positional_features(dialogue_embed)
        # 编码器输出 [B, dialogue_len, encoder_output_dim]
        dialogue_output = self.encoder(dialogue_embed, dialogue_mask)
        if self._share_encoder_params:
            for _ in range(self.encoder_num_layers - 1):
                dialogue_output = self.encoder(dialogue_output, dialogue_mask)
        # 计算编码结果
        # [B, context_len, *]和[B, query_len, *]
        context_output, query_output, dec_init_state = self._run_encoder(
            dialogue_output, context_mask, query_mask)
        output_dict = {"metadata": metadata}
        if self.training:
            rewrite_input_token_ids = rewrite_input_ids[
                self._index_name]["tokens"]
            # 计算rewrite的长度
            max_rewrite_len = rewrite_input_token_ids.size(1)
            rewrite_input_mask = get_mask_from_sequence_lengths(
                rewrite_len, max_length=max_rewrite_len)
            rewrite_target_ids = rewrite_target_ids[self._index_name]["tokens"]
            # 计算rewrite的目标序列索引
            rewrite_target_ids = rewrite_target_ids + extend_rewrite_ids.to(
                dtype=torch.long)

            # 计算embedding输出,[B, rewrite_len, embedding_size]
            rewrite_embed = self._get_embeddings(rewrite_input_ids)
            # 前向传播计算loss
            new_output_dict = self._forward_step(
                context_output, query_output, context_mask, query_mask,
                rewrite_embed, rewrite_target_ids, rewrite_len,
                rewrite_input_mask, extend_context_ids, extend_query_ids,
                oovs_len, dec_init_state)
            output_dict.update(new_output_dict)
        else:
            batch_hyps = self._run_inference(context_output,
                                             query_output,
                                             context_mask,
                                             query_mask,
                                             extend_context_ids,
                                             extend_query_ids,
                                             oovs_len,
                                             dec_init_state=dec_init_state)
            # get the result of each instance
            output_dict['hypothesis'] = batch_hyps
            output_dict = self.get_rewrite_string(output_dict)
            output_dict["loss"] = torch.tensor(0)

        return output_dict
Example #23
0
    def _eval_decode(
            self, state: Dict[str, torch.Tensor],
            segments: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        encoder_outputs = state["encoder_outputs"]
        source_key_padding_mask = state["source_key_padding_mask"]
        source_embedded = state["source_raw_embedded"]

        source_token_mask = state["source_token_mask"]
        memory_key_padding_mask = (1 - source_token_mask).bool()
        # memory_key_padding_mask = source_key_padding_mask

        batch_size = source_key_padding_mask.size(0)
        encode_length = source_key_padding_mask.size(1)
        log_probs_after_end = encoder_outputs.new_full(
            (batch_size, self._num_classes + encode_length),
            fill_value=float("-inf"))
        log_probs_after_end[:, self._end_index] = 0.

        start_predictions = state["source_mask"].new_full(
            (batch_size, 1), fill_value=self._start_index)
        partial_generate_predictions = start_predictions
        partial_copy_predictions = state["source_mask"].new_zeros(
            (batch_size, 1))
        basic_index = torch.arange(batch_size).to(
            source_embedded.device).unsqueeze(1).long()
        generate_mask = state["source_mask"].new_ones((batch_size, 1)).float()
        # shape: (batch_size)
        last_prediction = start_predictions.squeeze(1)
        for _ in range(self._max_decoding_step):
            # shape: (batch_size, partial_len, d_model)
            partial_source_embedded_input = source_embedded[
                basic_index, partial_copy_predictions]
            partial_target_embedded_input = self._target_embedder(
                partial_generate_predictions)
            partial_embedded_input = partial_target_embedded_input * generate_mask.unsqueeze(-1) \
                                     + partial_source_embedded_input * (1 - generate_mask).unsqueeze(-1)

            partial_embedded_input = util.add_positional_features(
                partial_embedded_input)
            partial_len = partial_embedded_input.size(1)

            partial_embedded_input = partial_embedded_input.permute(1, 0, 2)

            mask = (torch.triu(state["source_mask"].new_ones(
                partial_len, partial_len)) == 1).transpose(0, 1)
            mask = mask.float().masked_fill(mask == 0,
                                            float('-inf')).masked_fill(
                                                mask == 1, float(0.0))

            if not self._decode_use_relative_position:
                # shape: (partial_len, batch_size, d_model)
                outputs = self._decoder(
                    partial_embedded_input,
                    memory=encoder_outputs,
                    tgt_mask=mask,
                    memory_key_padding_mask=memory_key_padding_mask)
            else:
                # gnn decoder
                edge_mask = get_decode_edge_mask(
                    partial_embedded_input,
                    max_decode_clip_range=self._max_decode_clip_range)
                tgt_padding_mask = torch.tril(edge_mask.new_ones(
                    [partial_len, partial_len]),
                                              diagonal=0)
                tgt_padding_mask = (1 - tgt_padding_mask.unsqueeze(0).repeat(
                    batch_size, 1, 1)).float()
                # shape: (partial_len, batch_size, d_model)
                outputs = self._decoder(
                    partial_embedded_input,
                    edge_mask=edge_mask.permute(0, 2, 3, 1),
                    memory=encoder_outputs,
                    tgt_padding_mask=tgt_padding_mask,
                    memory_key_padding_mask=memory_key_padding_mask)

            outputs = outputs.permute(1, 0, 2)

            # shape: (batch_size, d_model)
            curr_outputs = outputs[:, -1, :]

            # shape: (batch_size, num_classes)
            generate_scores = self.get_generate_scores(curr_outputs)
            # shape: (batch_size, encode_length)
            copy_scores = self.get_copy_scores(
                state, curr_outputs.unsqueeze(1)).squeeze(1)

            # Gate
            # shape: (batch_size, 1)
            generate_gate = F.sigmoid(self.gate_linear(curr_outputs))
            copy_gate = 1 - generate_gate
            scores = torch.cat(
                (generate_scores * generate_gate, copy_scores * copy_gate),
                dim=-1)
            # scores = torch.cat((generate_scores, copy_scores), dim=-1)

            # shape: (batch_size, encode_length)
            entity_mask = 1 - (
                (segments['tokens'] == self._token_index) |
                (segments['tokens'] == self._non_func_symbol_index) |
                (segments['tokens'] == self._segment_pad_index)).float()
            # shape: (batch_size, num_classes + encode_length)
            score_mask = torch.cat((entity_mask.new_ones(
                (batch_size, self._num_classes)), entity_mask),
                                   dim=-1)

            # shape: (batch_size, num_classes + encode_length)
            normalized_scores = util.masked_softmax(scores,
                                                    mask=score_mask,
                                                    dim=-1)

            last_prediction_expanded = last_prediction.unsqueeze(-1).expand(
                batch_size, self._num_classes + encode_length)

            # shape: (batch_size, num_classes + encode_length)
            cleaned_logits = torch.where(
                last_prediction_expanded == self._end_index,
                log_probs_after_end, normalized_scores)

            # shape: (batch_size)
            _, predicted = torch.max(input=cleaned_logits,
                                     dim=1,
                                     keepdim=False)

            copy_mask = (predicted >= self._num_classes).long()
            generate_predicted = predicted * (1 - copy_mask)
            copy_predicted = (predicted - self._num_classes) * copy_mask
            partial_copy_predictions = torch.cat(
                (partial_copy_predictions, copy_predicted.unsqueeze(1)), dim=1)
            partial_generate_predictions = torch.cat(
                (partial_generate_predictions,
                 generate_predicted.unsqueeze(1)),
                dim=1)
            generate_mask = torch.cat(
                (generate_mask, (1 - copy_mask).unsqueeze(1).float()), dim=1)

            last_prediction = predicted

            if (last_prediction == self._end_index).sum() == batch_size:
                break

        predictions = partial_generate_predictions * generate_mask.long() + \
                      (1 - generate_mask).long() * (partial_copy_predictions + self._num_classes)

        # shape: (batch_size, partial_len)
        output_dict = {"predictions": predictions}
        return output_dict
Example #24
0
    def _train_decode(
            self, state: Dict[str, torch.Tensor],
            target_tokens: [str, torch.Tensor],
            generate_targets: torch.Tensor) -> Dict[str, torch.Tensor]:
        encoder_outputs = state["encoder_outputs"]
        source_key_padding_mask = state["source_key_padding_mask"]

        # shape: (batch_size, encode_length, d_model)
        source_embedded = state["source_raw_embedded"]
        batch_size, _, _ = source_embedded.size()
        basic_index = torch.arange(batch_size).to(
            source_embedded.device).long()
        generate_targets = generate_targets.long()
        retrieved_target_embedded_input = source_embedded[
            basic_index.unsqueeze(1), generate_targets][:, :-1, :]
        target_embedded_input = self._target_embedder(
            target_tokens['tokens'])[:, :-1, :]

        # shape: (batch_size, max_decode_length)
        # where 1 indicates that the target token is generated rather than copied
        generate_mask = (generate_targets == 0).float()
        target_embedded_input = target_embedded_input * generate_mask[:, :-1].unsqueeze(-1) \
                                + retrieved_target_embedded_input * (1 - generate_mask)[:, :-1].unsqueeze(-1)
        target_embedded_input = util.add_positional_features(
            target_embedded_input)

        # shape: (max_target_sequence_length - 1, batch_size, d_model)
        target_embedded_input = target_embedded_input.permute(1, 0, 2)

        # shape: (batch_size, max_target_sequence_length - 1)
        """
        key_padding_mask should be a ByteTensor where True values are positions
            that should be masked with float('-inf') and False values will be unchanged.
        """
        target_mask = util.get_text_field_mask(target_tokens)[:, 1:]
        target_key_padding_mask = (1 - target_mask.byte()).bool()

        assert target_key_padding_mask.size(1) == target_embedded_input.size(0) and \
               target_embedded_input.size(1) == target_key_padding_mask.size(0)

        max_target_seq_length = target_key_padding_mask.size(1)
        target_additive_mask = (torch.triu(
            target_mask.new_ones(max_target_seq_length,
                                 max_target_seq_length)) == 1).transpose(0, 1)
        target_additive_mask = target_additive_mask.float().masked_fill(
            target_additive_mask == 0, float('-inf'))
        target_additive_mask = target_additive_mask.masked_fill(
            target_additive_mask == 1, float(0.0))

        assert target_embedded_input.size(1) == encoder_outputs.size(1)

        source_token_mask = state["source_token_mask"]
        memory_key_padding_mask = (1 - source_token_mask).bool()
        # memory_key_padding_mask = source_key_padding_mask
        if not self._decode_use_relative_position:
            # shape: (max_target_sequence_length, batch_size, d_model)
            decoder_outputs = self._decoder(
                target_embedded_input,
                memory=encoder_outputs,
                tgt_mask=target_additive_mask,
                tgt_key_padding_mask=None,
                memory_key_padding_mask=memory_key_padding_mask)
        else:
            # gnn decoder
            edge_mask = get_decode_edge_mask(
                target_embedded_input,
                max_decode_clip_range=self._max_decode_clip_range)
            batch_size = edge_mask.size(0)
            tgt_padding_mask = torch.tril(edge_mask.new_ones(
                [max_target_seq_length, max_target_seq_length]),
                                          diagonal=0)
            tgt_padding_mask = (1 - (tgt_padding_mask.unsqueeze(0).repeat(
                batch_size, 1, 1))).float()
            decoder_outputs = self._decoder(
                target_embedded_input,
                edge_mask=edge_mask.permute(0, 2, 3, 1),
                memory=encoder_outputs,
                tgt_padding_mask=tgt_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask)
        # shape: (batch_size, max_target_sequence_length, d_model)
        decoder_outputs = decoder_outputs.permute(1, 0, 2)

        state.update({
            "decoder_outputs": decoder_outputs,
            "target_key_padding_mask": target_key_padding_mask,
            "target_mask": target_mask,
            "generate_mask": generate_mask
        })
        return state
Example #25
0
    def forward(
        self,
        inputs: torch.Tensor,
        semantic_views_q: torch.Tensor,
        semantic_views_scope_mask: torch.Tensor,
        mask: torch.Tensor = None,
        return_output_metadata: bool = False
    ) -> torch.Tensor:  # pylint: disable=arguments-differ

        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs

        total_sublayers = len(self._conv_layers) + 2
        sublayer_count = 0

        for conv_norm_layer, conv_layer in zip(self._conv_norm_layers,
                                               self._conv_layers):
            conv_norm_out = self.dropout(conv_norm_layer(output))
            conv_out = self.dropout(
                conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2))
            sublayer_count += 1
            output = self.residual_with_layer_dropout(output, conv_out,
                                                      sublayer_count,
                                                      total_sublayers)

        attention_norm_out = self.dropout(self.attention_norm_layer(output))

        bs, seq_len, _ = inputs.shape
        # pad semantic views
        if semantic_views_q.shape[1] < self.num_attention_heads:
            #logging.info("semantic_views_q.dtype:{0}".format(semantic_views_q.dtype))
            # #
            # logging.info("semantic_views_q_pad.dtype:{0}".format(semantic_views_q_pad.dtype))
            semantic_views_q = torch.cat([
                semantic_views_q,
                long_fill(
                    (bs, self.num_attention_heads - semantic_views_q.shape[1],
                     seq_len), 0, semantic_views_q.is_cuda)
            ],
                                         dim=1)
            #logging.info("semantic_views_q.dtype after concat:{0}".format(semantic_views_q.dtype))

            #logging.info("semantic_views_scope_mask.dtype:{0}".format(semantic_views_scope_mask.dtype))
            semantic_views_scope_mask = torch.cat([
                semantic_views_scope_mask,
                long_fill((bs, self.num_attention_heads -
                           semantic_views_scope_mask.shape[1], seq_len), 0,
                          semantic_views_scope_mask.is_cuda)
            ],
                                                  dim=1)

        if self._replace_zero_semantic_labels_with_per_head_labels:
            heads_dim_id = 1
            views_shape = list(semantic_views_q.shape)
            views_shape_single_head = views_shape[:]
            views_shape_single_head[heads_dim_id] = 1

            attention_head_mask = torch.cat([
                long_fill(views_shape_single_head, x, semantic_views_q.is_cuda)
                for x in range(
                    self.num_semantic_labels -
                    self.num_attention_heads, self.num_semantic_labels)
            ],
                                            dim=heads_dim_id).contiguous()

            # mask the values
            semantic_views_q = semantic_views_q + attention_head_mask * (
                semantic_views_q < 1).long()
            if len(semantic_views_scope_mask.shape) == len(
                    semantic_views_q.shape):
                semantic_views_scope_mask = semantic_views_scope_mask + attention_head_mask * (
                    semantic_views_scope_mask < 1).long()

        attention_output_meta = None
        if is_output_meta_supported(self.attention_layer):
            attention_out, attention_output_meta = self.attention_layer(
                attention_norm_out,
                semantic_views_q,
                semantic_views_scope_mask,
                mask,
                return_output_metadata=return_output_metadata)
            attention_out = self.dropout(attention_out)
        else:
            attention_out = self.dropout(
                self.attention_layer(attention_norm_out, semantic_views_q,
                                     semantic_views_scope_mask, mask))

        sublayer_count += 1
        output = self.residual_with_layer_dropout(output, attention_out,
                                                  sublayer_count,
                                                  total_sublayers)

        feedforward_norm_out = self.dropout(
            self.feedforward_norm_layer(output))
        feedforward_out = self.dropout(self.feedforward(feedforward_norm_out))
        sublayer_count += 1
        output = self.residual_with_layer_dropout(output, feedforward_out,
                                                  sublayer_count,
                                                  total_sublayers)

        return output, attention_output_meta
    def forward(
        self,
        inputs: torch.Tensor,
        semantic_views_q: torch.Tensor,
        semantic_views_k: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:  # pylint: disable=arguments-differ
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs

        total_sublayers = len(self._conv_layers) + 2
        sublayer_count = 0

        for conv_norm_layer, conv_layer in zip(self._conv_norm_layers,
                                               self._conv_layers):
            conv_norm_out = self.dropout(conv_norm_layer(output))
            conv_out = self.dropout(
                conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2))
            sublayer_count += 1
            output = self.residual_with_layer_dropout(output, conv_out,
                                                      sublayer_count,
                                                      total_sublayers)

        attention_norm_out = self.dropout(self.attention_norm_layer(output))

        if self._replace_zero_semantic_labels_with_per_head_labels:
            heads_dim_id = 1
            views_shape = list(semantic_views_q.shape)
            views_shape_single_head = views_shape[:]
            views_shape_single_head[heads_dim_id] = 1

            attention_head_mask = torch.cat([
                torch.full(views_shape_single_head, x).long() for x in range(
                    self.num_semantic_labels -
                    self.num_attention_heads, self.num_semantic_labels)
            ],
                                            dim=heads_dim_id).contiguous()
            if semantic_views_q.is_cuda:
                attention_head_mask = attention_head_mask.cuda()

            # semantic_views_q_zeros = (semantic_views_q < 1).long()
            #
            # logging.info("type(semantic_views_q):{0}".format(semantic_views_q.type()))
            # logging.info("type(attention_head_mask):{0}".format(attention_head_mask.type()))
            # logging.info("type(semantic_views_q_zeros):{0}".format(semantic_views_q_zeros.type()))

            # mask the values
            semantic_views_q = semantic_views_q + attention_head_mask * (
                semantic_views_q < 1).long()
            semantic_views_k = semantic_views_k + attention_head_mask * (
                semantic_views_k < 1).long()

        attention_out = self.dropout(
            self.attention_layer(attention_norm_out, semantic_views_q,
                                 semantic_views_k, mask))
        sublayer_count += 1
        output = self.residual_with_layer_dropout(output, attention_out,
                                                  sublayer_count,
                                                  total_sublayers)

        feedforward_norm_out = self.dropout(
            self.feedforward_norm_layer(output))
        feedforward_out = self.dropout(self.feedforward(feedforward_norm_out))
        sublayer_count += 1
        output = self.residual_with_layer_dropout(output, feedforward_out,
                                                  sublayer_count,
                                                  total_sublayers)
        return output
    def forward(self,
                inputs: torch.Tensor,
                source_memory_bank: torch.Tensor,
                source_mask: torch.Tensor,
                target_mask: torch.Tensor,
                is_train: bool = True) -> Dict: 

        batch_size, source_seq_length, _ = source_memory_bank.size()
        __, target_seq_length, __ = inputs.size()

        source_padding_mask = None
        target_padding_mask  = None
        if source_mask is not None:
            source_padding_mask = ~source_mask.bool()
        if target_mask is not None:
            target_padding_mask = ~target_mask.bool() 

        # project to correct dimensionality 
        outputs = self.input_proj_layer(inputs)
        # add pos encoding feats 
        outputs = add_positional_features(outputs) 

        # swap to pytorch's batch-second convention 
        outputs = outputs.permute(1, 0, 2)
        source_memory_bank = source_memory_bank.permute(1, 0, 2)

        # get a mask 
        ar_mask = self.make_autoregressive_mask(outputs.shape[0]).to(source_memory_bank.device)

        for i in range(len(self.layers)):
            outputs , __, __ = self.layers[i](outputs, 
                                    source_memory_bank, 
                                    tgt_mask=ar_mask,
                                    #memory_mask=None,
                                    tgt_key_padding_mask=target_padding_mask,
                                    memory_key_padding_mask=source_padding_mask
                                    )

        # do final norm here
        if self.prenorm:
            outputs = self.final_norm(outputs) 

        # switch back from pytorch's absolutely moronic batch-second convention
        outputs = outputs.permute(1, 0, 2)
        source_memory_bank = source_memory_bank.permute(1, 0, 2) 

        if not self.use_coverage:
            source_attention_output = self.source_attn_layer(outputs, 
                                                             source_memory_bank,
                                                             source_mask,
                                                             None)
            attentional_tensors = self.dropout(source_attention_output['attentional']) 

            source_attention_weights = source_attention_output['attention_weights']
            coverage_history = None
        else:
            # need to do step by step because running sum of coverage
            source_attention_weights = []
            attentional_tensors = []

            # init to zeros 
            coverage = inputs.new_zeros(size=(batch_size, 1, source_seq_length))
            coverage_history = []

            for timestep in range(outputs.shape[1]):
                output = outputs[:,timestep,:].unsqueeze(1)
                source_attention_output = self.source_attn_layer(output,
                                                                 source_memory_bank,
                                                                 source_mask,
                                                                 coverage)

                attentional_tensors.append(source_attention_output['attentional'])
                source_attention_weights.append(source_attention_output['attention_weights'])
                coverage = source_attention_output['coverage'] 
                coverage_history.append(coverage) 

            # [batch_size, tgt_seq_len, hidden_dim]
            attentional_tensors = self.dropout(torch.cat(attentional_tensors, dim=1))
            # [batch_size, tgt_seq_len, src_seq_len]
            source_attention_weights = torch.cat(source_attention_weights, dim=1) 
            coverage_history = torch.cat(coverage_history, dim=1)
                
        if is_train:
            tgt_attn_list = []
            for timestep in range(attentional_tensors.shape[1]):
                bsz, seq_len, __ = attentional_tensors.shape 
                attn_mask = torch.ones((bsz, seq_len))    
                attn_mask[:,timestep:] = 0
                attn_mask = attn_mask.to(attentional_tensors.device)
                
                target_attention_output = self.target_attn_layer(attentional_tensors[:,timestep,:].unsqueeze(1),
                                                                 attentional_tensors,
                                                                 mask = attn_mask)
                if timestep == 0:
                    # zero out weights at 0, effectively banning target copy since there is nothing to copy 
                    tgt_attn_list.append(torch.zeros_like(target_attention_output["attention_weights"][:,-1,:].unsqueeze(1)))
                else:
                    tgt_attn_list.append(target_attention_output["attention_weights"])

            target_attention_weights = torch.cat(tgt_attn_list, dim=1) 
        else:
            target_attention_output = self.target_attn_layer(attentional_tensors,
                                                             attentional_tensors,
                                                             mask = target_padding_mask)
                                   
            target_attention_weights = target_attention_output['attention_weights']

        return dict(
                outputs=outputs,
                output=outputs[:,-1,:].unsqueeze(1),
                attentional_tensors=attentional_tensors,
                attentional_tensor=attentional_tensors[:,-1,:].unsqueeze(1),
                target_attention_weights = target_attention_weights,
                source_attention_weights = source_attention_weights,
                coverage_history = coverage_history,
                ) 
    def forward(self,
                context: Dict[str, torch.LongTensor],
                length: torch.LongTensor = None,
                repeat: torch.FloatTensor = None,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        #expected_dim = (self.final_classifier_feedforward.get_input_dim() + 1) / 2
        expected_dim = (self.final_classifier_feedforward.get_input_dim() +
                        2) / 2
        dia_len = context['tokens'].size()[1]
        if expected_dim - dia_len > 0:
            padding = torch.zeros([
                context['tokens'].size()[0], (expected_dim - dia_len),
                context['tokens'].size()[2]
            ]).long()
            context['tokens'] = torch.cat([context['tokens'], padding], dim=1)

        # context: batch_size * dials_len * sentences_len
        # embedded_context: batch_size * dials_len * sentences_len * emb_dim
        embedded_context = self.text_field_embedder(context)
        # utterances_mask: batch_size * dials_len * sentences_len
        utterances_mask = get_text_field_mask(context, 1).float()
        # encoded_utterances: batch_size * dials_len * emb_dim
        encoded_utterances = self.utterances_encoder(embedded_context,
                                                     utterances_mask)
        encoded_utterances = add_positional_features(encoded_utterances)

        #projected_utterances = self.attend_feedforward(encoded_utterances)
        # similarity_matrix: batch_size * dials_len * output_dim
        mask = get_text_field_mask(context).float()
        similarity_matrix = self.matrix_attention(encoded_utterances, mask)

        # attended_context: batch * (dials_len - 1) * emb_dim
        attended_context = similarity_matrix[:, :-1, :]
        # attended_response: batch * (dials_len - 1) * emb_dim
        attended_response = similarity_matrix[:, 1:, :]

        # embedded_context: batch_size * (dials_len - 1) * emb_dim
        embedded_context = encoded_utterances[:, :-1, :]
        # embedded_response: batch_size * (dials_len - 1) * emb_dim
        embedded_response = encoded_utterances[:, 1:, :]

        # context_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim)
        context_compare_input = torch.cat(
            [embedded_context, attended_response], dim=-1)
        # response_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim)
        response_compare_input = torch.cat(
            [embedded_response, attended_context], dim=-1)

        # compared_context: batch * (dials_len - 1) * emb_dim
        compared_context = self.compare_feedforward(context_compare_input)
        compared_context = compared_context * mask[:, :-1].unsqueeze(-1)

        # compared_response: batch * (dials_len - 1) * emb_dim
        compared_response = self.compare_feedforward(response_compare_input)
        compared_response = compared_response * mask[:, 1:].unsqueeze(-1)

        # aggregate_input: batch * (dials_len - 1) * (compare_context_dim + compared_response_dim)
        aggregate_input = torch.cat([compared_context, compared_response],
                                    dim=-1)

        # class_logits & class_probs:  batch * (dials_len - 1) * 2
        class_logits = self.classifier_feedforward(aggregate_input)
        class_probs = F.softmax(class_logits,
                                dim=-1).reshape(class_logits.size()[0], -1)
        length_tensor = torch.FloatTensor(length).reshape(-1, 1)
        repeat_tensor = torch.FloatTensor(repeat).reshape(-1, 1)
        #class_probs = torch.cat([class_probs, length_tensor, repeat_tensor], dim=1)
        #class_probs = torch.cat([class_probs, repeat_tensor], dim=1)

        full_logits = self.final_classifier_feedforward(class_probs)
        full_probs = F.softmax(full_logits, dim=-1)
        output_dict = {
            "class_logits": full_logits,
            "class_probabilities": full_probs
        }

        if label is not None:
            loss = self.loss(full_logits, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(full_logits, label.squeeze(-1))
            output_dict['loss'] = loss

        return output_dict