def forward(ctx, input, target): """ input (FloatTensor): n x num_classes target (LongTensor): n, the indices of the target classes """ assert_equal(input.shape[0], target.shape[0]) p_star = tsallis15(input, 1) loss = _omega_tsallis15(p_star) p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1)) loss += torch.einsum("ij,ij->i", p_star, input) ctx.save_for_backward(p_star) # loss = torch.clamp(loss, min=0.0) # needed? return loss
def attn_map(self, Z): if self.attn_func == "softmax": return F.softmax(Z, -1) elif self.attn_func == "esoftmax": return esoftmax(Z, -1) elif self.attn_func == "sparsemax": return sparsemax(Z, -1) elif self.attn_func == "tsallis15": return tsallis15(Z, -1) elif self.attn_func == "tsallis": if self.attn_alpha == 2: # slightly faster specialized impl return sparsemax_bisect(Z, self.bisect_iter) else: return tsallis_bisect(Z, self.attn_alpha, self.bisect_iter) raise ValueError("invalid combination of arguments")
def test_tsallis15(): for _ in range(10): x = 0.5 * torch.randn(10, 30000, dtype=torch.float32) p1 = tsallis15(x, 1) p2 = tsallis_bisect(x, 1.5) assert torch.sum((p1 - p2)**2) < 1e-7