示例#1
0
    def forward(ctx, input, target):
        """
        input (FloatTensor): n x num_classes
        target (LongTensor): n, the indices of the target classes
        """
        assert_equal(input.shape[0], target.shape[0])

        p_star = tsallis15(input, 1)
        loss = _omega_tsallis15(p_star)

        p_star.scatter_add_(1, target.unsqueeze(1),
                            torch.full_like(p_star, -1))
        loss += torch.einsum("ij,ij->i", p_star, input)

        ctx.save_for_backward(p_star)

        # loss = torch.clamp(loss, min=0.0)  # needed?
        return loss
示例#2
0
    def attn_map(self, Z):
        if self.attn_func == "softmax":
            return F.softmax(Z, -1)

        elif self.attn_func == "esoftmax":
            return esoftmax(Z, -1)

        elif self.attn_func == "sparsemax":
            return sparsemax(Z, -1)

        elif self.attn_func == "tsallis15":
            return tsallis15(Z, -1)

        elif self.attn_func == "tsallis":
            if self.attn_alpha == 2:  # slightly faster specialized impl
                return sparsemax_bisect(Z, self.bisect_iter)
            else:
                return tsallis_bisect(Z, self.attn_alpha, self.bisect_iter)

        raise ValueError("invalid combination of arguments")
示例#3
0
def test_tsallis15():
    for _ in range(10):
        x = 0.5 * torch.randn(10, 30000, dtype=torch.float32)
        p1 = tsallis15(x, 1)
        p2 = tsallis_bisect(x, 1.5)
        assert torch.sum((p1 - p2)**2) < 1e-7