示例#1
0
    def test_sum_tensors(self):

        inputs = [torch.tensor(1), torch.tensor(2)]
        self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3))

        inputs = [torch.tensor(1), None, torch.tensor(2)]
        self.assertEqual(utils.sum_tensors(inputs), torch.tensor(3))

        inputs = [torch.tensor(1), None, None]
        self.assertEqual(utils.sum_tensors(inputs), torch.tensor(1))

        inputs = [None, None, None]
        self.assertEqual(utils.sum_tensors(inputs), None)
示例#2
0
    def _forward(self,
                 word_embed: torch.Tensor,
                 segment_ids: Optional[torch.LongTensor] = None,
                 input_mask: Optional[torch.Tensor] = None,
                 memory: Optional[List[torch.Tensor]] = None,
                 permute_mask: Optional[torch.Tensor] = None,
                 target_mapping: Optional[torch.Tensor] = None,
                 bi_data: bool = False,
                 clamp_len: Optional[int] = None,
                 cache_len: int = 0,
                 same_length: bool = False,
                 attn_type: str = 'bi',
                 two_stream: bool = False) \
            -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
        r"""Compute XLNet representations for the input. This layer exists
        because :class:`XLNetDecoder` compute embeddings in the decoder helper.
        `word_embed` has shape `[batch_size, max_time, word_embed_dim]`.
        Please refer to :meth:`forward` for the detailed information of other
        arguments.
        """
        # seq_len == max_time
        # word_embed: [seq_len, batch_size, word_embed_dim]
        word_embed = word_embed.permute(1, 0, 2)
        # segment_ids: [seq_len, batch_size]
        if segment_ids is not None:
            segment_ids = segment_ids.permute(1, 0)
        # input_mask: [seq_len, batch_size]
        if input_mask is not None:
            input_mask = input_mask.permute(1, 0)
        # memory: A list of length num_layers
        # each tensor of shape [mem_len, batch_size, hidden_dim]
        if memory is not None:
            memory = [m.permute(1, 0, 2) for m in memory]
        # permute_mask: [seq_len, seq_len, batch_size]
        if permute_mask is not None:
            permute_mask = permute_mask.permute(1, 2, 0)
        # target_mapping: [num_targets, seq_len, batch_size]
        if target_mapping is not None:
            target_mapping = target_mapping.permute(1, 2, 0)

        seq_len, batch_size = word_embed.size()[:2]
        mem_len = memory[0].size(0) if memory is not None else 0
        tot_len = seq_len + mem_len
        reuse_len = self._hparams.reuse_len

        # Construct masks.
        masks: List[Optional[torch.Tensor]] = []

        # Causal attention mask.
        if attn_type == 'uni':
            causal_mask = self._create_causal_attn_mask(
                seq_len, mem_len, same_length)
            # attn_mask: (seq_len, tot_len, 1, 1)
            causal_mask = causal_mask.unsqueeze(2).unsqueeze(3)
            masks.append(causal_mask)
        elif attn_type == 'bi':
            pass
        else:
            raise ValueError(f"Unsupported attention type: {attn_type}")

        # Data mask: input mask & permutation mask.
        if input_mask is not None:
            input_mask = input_mask.expand(seq_len, -1, -1)
        data_mask = sum_tensors([input_mask, permute_mask])
        if data_mask is not None:
            # All positions in memory can be attended to.
            memory_mask = data_mask.new_zeros(seq_len, mem_len, batch_size)
            # data_mask: (seq_len, tot_len, batch_size, 1)
            data_mask = torch.cat([memory_mask, data_mask], dim=1).unsqueeze(3)
            masks.append(data_mask)

        # Exclude the main diagonal (target tokens) from the mask.
        attn_mask = sum_tensors(masks)
        if attn_mask is None:
            final_mask = None
        else:
            attn_mask = (attn_mask > 0)
            final_mask = -torch.eye(seq_len, device=attn_mask.device)
            final_mask = torch.cat(
                [final_mask.new_zeros(seq_len, mem_len), final_mask], dim=-1)
            final_mask = final_mask.unsqueeze(2).unsqueeze(3)
            # final_mask: (seq_len, tot_len, batch_size, 1)
            final_mask = ((attn_mask.float() + final_mask) > 0)

        # Construct segment embedding.
        if segment_ids is not None:
            concat_segment_ids = torch.cat(
                [segment_ids.new_zeros(mem_len, batch_size), segment_ids])
            segment_matrix = (segment_ids.unsqueeze(1) !=
                              concat_segment_ids.unsqueeze(0)).long()
            segment_matrix = F.one_hot(segment_matrix, num_classes=2).float()
        else:
            segment_matrix = None

        pos_embed = self.pos_embed(batch_size, seq_len, tot_len, clamp_len,
                                   attn_type, bi_data)
        pos_embed = self.dropout(pos_embed)

        states_h = self.dropout(word_embed)
        states_g = None
        if two_stream:
            if target_mapping is not None:
                word_embed_q = self.mask_emb.expand(target_mapping.size(0),
                                                    batch_size, -1)
            else:
                word_embed_q = word_embed
            states_g = self.dropout(word_embed_q)
        new_memory = []

        for idx in range(self._hparams.num_layers):
            cur_memory = memory[idx] if memory is not None else None
            if cache_len > 0:
                new_memory.append(
                    self._cache_mem(states_h, cur_memory, cache_len,
                                    reuse_len))
            attn_layer: RelativeMultiheadAttention
            attn_layer = self.attn_layers[idx]  # type: ignore
            states_h, states_g = attn_layer(states_h=states_h,
                                            states_g=states_g,
                                            pos_embed=pos_embed,
                                            segment_mat=segment_matrix,
                                            attn_mask_h=final_mask,
                                            attn_mask_g=attn_mask,
                                            target_mapping=target_mapping,
                                            memory=cur_memory)
            states_h = self.ff_layers[idx](states_h)
            if states_g is not None:
                states_g = self.ff_layers[idx](states_g)

        output = self.dropout(states_h if states_g is None else states_g)

        # Now output: [seq_len, batch_size, hidden_dim]
        # new_memory: None or A list of length num_layers,
        # each tensor of shape [cache_len, batch_size, hidden_dim]
        output = output.permute(1, 0, 2)
        if new_memory is not None:
            new_memory = [m.permute(1, 0, 2) for m in new_memory]

        if cache_len == 0:
            return output, None

        return output, new_memory