def forward(self, embedding, target): origin_logits = self.fc(embedding) one_hot_target = F.one_hot(target, self.num_class).astype("bool") large_margined_logit = F.cos(F.acos(origin_logits) + self.margin) small_margined_logit = origin_logits margined_logit = F.where(origin_logits >= 0, large_margined_logit, small_margined_logit) logits = F.where(one_hot_target, margined_logit, origin_logits) logits = logits * self.scale loss = F.loss.cross_entropy(logits, target) accuracy = F.topk_accuracy(origin_logits, target, topk=1) return loss, accuracy
def test_grad_manager_group(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) gm = GradManager().attach([x]) gm2 = GradManager().attach([x]) with gm | gm2: y = F.cos(x) gm.backward(y) gm2.backward(y) np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5) x.grad = None
def test_grad_manager_visibility_by_order(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) gm = GradManager().attach([x]) gm2 = GradManager().attach([x]) with gm2: with gm: y = F.cos(x) gm2.backward(y) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) gm.backward(x.grad) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
def forward(self, x): self.inp = x x = mge.Tensor(x.numpy()) y = F.cos(x) return y
def backward(self, dy): dx = F.cos(self.inp) * dy return dx