def test_energy_mse(self): self.assertAlmostEqual( float(losses.mean_square_error( target=jnp.array([[0.2, 0.6]]), predict=jnp.array([[0.4, 0.7]]))), # ((0.4 - 0.2) ** 2 + (0.7 - 0.6) ** 2) / 2 = 0.025 0.025)
def test_density_mse(self): self.assertAlmostEqual( float(losses.mean_square_error( target=jnp.array([[0.2, 0.2, 0.2, 0.2], [0.6, 0.6, 0.6, 0.6]]), predict=jnp.array([[0.4, 0.5, 0.2, 0.3], [0.6, 0.6, 0.6, 0.6]]))), # (( # (0.4 - 0.2) ** 2 + (0.5 - 0.2) ** 2 # + (0.2 - 0.2) ** 2 + (0.3 - 0.2) ** 2 # ) / 4 + 0) / 2 = 0.0175 0.0175)