コード例 #1
0
ファイル: metrics_test.py プロジェクト: rouniuyizu/trax
 def test_weighted_mean_semantics(self):
     inputs = np.array([1, 2, 3], dtype=np.float32)
     weights1 = np.array([1, 1, 1], dtype=np.float32)
     layer = metrics._WeightedMean()
     full_signature = (signature(inputs), signature(weights1))
     layer.init(full_signature)
     mean1 = layer((inputs, weights1))
     np.testing.assert_allclose(mean1, 2.0)
     weights2 = np.array([0, 0, 1], dtype=np.float32)
     mean2 = layer((inputs, weights2))
     np.testing.assert_allclose(mean2, 3.0)
     weights3 = np.array([1, 0, 0], dtype=np.float32)
     mean3 = layer((inputs, weights3))
     np.testing.assert_allclose(mean3, 1.0)
コード例 #2
0
    def test_weighted_mean_semantics(self):
        layer = metrics._WeightedMean()
        sample_input = np.ones((3, ))
        sample_weights = np.ones((3, ))
        layer.init(shapes.signature([sample_input, sample_weights]))

        x = np.array([1., 2., 3.])
        weights = np.array([1., 1., 1.])
        mean = layer((x, weights))
        np.testing.assert_allclose(mean, 2.)

        weights = np.array([0., 0., 1.])
        mean = layer((x, weights))
        np.testing.assert_allclose(mean, 3.)

        weights = np.array([1., 0., 0.])
        mean = layer((x, weights))
        np.testing.assert_allclose(mean, 1.)
コード例 #3
0
ファイル: metrics_test.py プロジェクト: rouniuyizu/trax
 def test_weighted_mean_shape(self):
     input_signature = (ShapeDtype(
         (29, 4, 4, 20)), ShapeDtype((29, 4, 4, 20)))
     result_shape = base.check_shape_agreement(metrics._WeightedMean(),
                                               input_signature)
     self.assertEqual(result_shape, ())
コード例 #4
0
 def test_weighted_mean_shape(self):
     layer = metrics._WeightedMean()
     xs = [np.ones((9, 4, 4, 20)), np.ones((9, 4, 4, 20))]
     y = layer(xs)
     self.assertEqual(y.shape, ())