Exemple #1
0
 def loss_compute(self, x: torch.tensor, y: torch.tensor, norm: int):
     x = self.model.generator(x)
     loss = self.labelsmooth(x.contiguous().view(-1, x.size(-1)),
                             y.contiguous().view(-1)) / norm
     loss.backward()
     if self.model_opt is not None:
         self.model_opt.step()
         self.model_opt.optimizer.zero_grad()
     return loss.item() * norm
Exemple #2
0
def get_token(h: torch.tensor, x: torch.tensor, token: int):
    """ Get specific token embedding (e.g. [CLS]) """
    emb_size = h.shape[-1]
    token_h = h.view(-1, emb_size)
    flat = x.contiguous().view(-1)

    # get contextualized embedding of given token
    token_h = token_h[flat == token, :]
    return token_h
Exemple #3
0
def get_token(h: torch.tensor, x: torch.tensor, token: int):
    emb_size = h.shape[-1]

    token_h = h.view(-1, emb_size)
    flat = x.contiguous().view(-1)

    token_h = token_h[flat == token, :]

    return token_h
Exemple #4
0
    def backward(ctx,
                 grad_out: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        :param ctx:
        :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
        :return: (B, C, N) gradient of the features
        """

        idx, N = ctx.for_backwards

        grad_features = ppp_ops.group_points_grad_cuda(grad_out.contiguous(),
                                                       idx, N)

        return grad_features, None
Exemple #5
0
 def valid_loss_compute(self, x: torch.tensor, y: torch.tensor, norm: int):
     x = self.model.generator(x)
     loss = self.labelsmooth(x.contiguous().view(-1, x.size(-1)),
                             y.contiguous().view(-1)) / norm
     return loss.item() * norm
Exemple #6
0
    def forward(self,
                query: torch.tensor,
                padding_mask: Optional[torch.tensor] = None,
                attention_mask: Optional[torch.tensor] = None,
                need_weights: bool = False) -> torch.tensor:
        """
        :param query: [sequence_length, batch_size, embed_dim]
        :param padding_mask: [batch_size, sequence_len]
        :param attention_mask: [batch_size, sequence_len, sequence_len]
        :param need_weights: bool
        :return: [sequence_length, batch_size, embed_dim]
        """

        sequence_len, batch_size, embed_dim = query.size()
        assert embed_dim == self.embed_dim

        query, key, value = self.in_projection(query).chunk(3, dim=-1)

        query *= self.scaling

        # [batch_size * self.num_heads, sequence_len, self.head_dim]
        query = query.contiguous().view(sequence_len,
                                        batch_size * self.num_heads,
                                        self.head_dim).transpose(0, 1)
        key = key.contiguous().view(sequence_len, batch_size * self.num_heads,
                                    self.head_dim).transpose(0, 1)
        value = value.contiguous().view(sequence_len,
                                        batch_size * self.num_heads,
                                        self.head_dim).transpose(0, 1)

        # [batch_size * self.num_heads, sequence_len, sequence_len]
        attention_scores = torch.bmm(query, key.transpose(1, 2))

        # [batch_size, self.num_heads, sequence_len, sequence_len]
        attention_scores = self.split_heads(attention_scores, batch_size,
                                            sequence_len)

        # fp16 compatibility
        parameters_type = next(self.parameters()).dtype

        if attention_mask is not None:
            assert attention_mask.size(1) == sequence_len
            assert attention_mask.size(2) == sequence_len
            # [batch_size, 1, sequence_len, sequence_len]
            attention_mask = attention_mask.unsqueeze(1)
            attention_mask = attention_mask.to(dtype=parameters_type)
            attention_scores += attention_mask

        if padding_mask is not None:
            assert padding_mask.size(0) == batch_size
            assert padding_mask.size(1) == sequence_len
            # padding_mask = [batch_size, sequence_len]
            attention_scores = attention_scores.masked_fill(
                padding_mask.unsqueeze(1).unsqueeze(2),
                float(-10000.),
            )

        # [batch_size * self.num_heads, sequence_len, sequence_len]
        attention_scores = self.join_heads(attention_scores, batch_size,
                                           sequence_len)

        if attention_scores.dtype == torch.float16:
            tensor_type = torch.float32
        else:
            tensor_type = attention_scores.dtype

        # [batch_size * self.num_heads, sequence_len, sequence_len]
        attention_scores = F.softmax(attention_scores.float(),
                                     dim=-1,
                                     dtype=tensor_type)

        attention_scores = self.dropout(attention_scores)

        # attention_scores = [batch_size * self.num_heads, sequence_len, sequence_len]
        # value = [batch_size * self.num_heads, sequence_len, self.head_dim]
        # [batch_size * self.num_heads, sequence_len, self.head_dim]
        attention_output = torch.bmm(attention_scores, value)

        # [sequence_len, batch_size, embed_dim]
        attention_output = attention_output.transpose(0, 1).contiguous().view(
            sequence_len, batch_size, embed_dim)
        attention_output = self.out_projection(attention_output)

        # for visualize attention scores
        if need_weights:
            # [batch_size, self.num_heads, sequence_len, sequence_len]
            attention_scores = self.split_heads(attention_scores, batch_size,
                                                sequence_len)
        else:
            attention_scores = None

        return attention_output, attention_scores