def test_mask_alphas(alpha): torch.manual_seed(42) x = torch.randn(2, 6) x[:, 3:] = -float('inf') x0 = x[:, :3] y = tsallis_bisect(x, alpha) y0 = tsallis_bisect(x0, alpha) y[:, :3] -= y0 assert torch.allclose(y, torch.zeros_like(y))
def forward(self, X): assert X.dim() == 2 p_star = tsallis_bisect(X, self.alpha, self.n_iter) p_star /= p_star.sum(dim=1).unsqueeze(dim=1) return torch.log(p_star)
def attn_map(self, Z): if self.attn_func == "softmax": return F.softmax(Z, -1) elif self.attn_func == "esoftmax": return esoftmax(Z, -1) elif self.attn_func == "sparsemax": return sparsemax(Z, -1) elif self.attn_func == "tsallis15": return tsallis15(Z, -1) elif self.attn_func == "tsallis": if self.attn_alpha == 2: # slightly faster specialized impl return sparsemax_bisect(Z, self.bisect_iter) else: return tsallis_bisect(Z, self.attn_alpha, self.bisect_iter) raise ValueError("invalid combination of arguments")
def forward(ctx, input, target, alpha=1.5, n_iter=50): """ input (FloatTensor): n x num_classes target (LongTensor): n, the indices of the target classes """ assert_equal(input.shape[0], target.shape[0]) p_star = tsallis_bisect(input, alpha, n_iter) # this is now done directly in tsallis_bisect # p_star /= p_star.sum(dim=1).unsqueeze(dim=1) loss = _omega_tsallis(p_star, alpha) p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1)) loss += torch.einsum("ij,ij->i", p_star, input) ctx.save_for_backward(p_star) # loss = torch.clamp(loss, min=0.0) # needed? return loss
def test_tsallis15(): for _ in range(10): x = 0.5 * torch.randn(10, 30000, dtype=torch.float32) p1 = tsallis15(x, 1) p2 = tsallis_bisect(x, 1.5) assert torch.sum((p1 - p2)**2) < 1e-7