def test_torch(self):
     torch_true = torch.tensor([[0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]])
     torch_pred = torch.tensor([[0.1, 0.9, 0.05, 0.05], [0.1, 0.2, 0.0, 0.7], [0.0, 0.15, 0.8, 0.05],
                                [1.0, 0.0, 0.0, 0.0]])
     mse = MeanSquaredError(inputs='x', outputs='x')
     output = mse.forward(data=[torch_pred, torch_true], state={})
     self.assertTrue(np.allclose(output.detach().numpy(), 0.014375001))
 def test_tf(self):
     tf_true = tf.constant([[0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]])
     tf_pred = tf.constant([[0.1, 0.9, 0.05, 0.05], [0.1, 0.2, 0.0, 0.7], [0.0, 0.15, 0.8, 0.05],
                            [1.0, 0.0, 0.0, 0.0]])
     mse = MeanSquaredError(inputs='x', outputs='x')
     output = mse.forward(data=[tf_pred, tf_true], state={})
     self.assertTrue(np.allclose(output.numpy(), 0.014375001))