コード例 #1
0
ファイル: metrics_test.py プロジェクト: rouniuyizu/trax
 def test_weighted_sequence_mean_semantics(self):
     inputs = np.array([[1, 1, 1], [1, 1, 0]], dtype=np.float32)
     weights1 = np.array([1, 1, 1], dtype=np.float32)
     layer = metrics._WeightedSequenceMean()
     full_signature = (signature(inputs), signature(weights1))
     layer.init(full_signature)
     mean1 = layer((inputs, weights1))
     np.testing.assert_allclose(mean1, 0.5)
     weights2 = np.array([1, 1, 0], dtype=np.float32)
     mean2 = layer((inputs, weights2))
     np.testing.assert_allclose(mean2, 1.0)
コード例 #2
0
    def test_weighted_sequence_mean_semantics(self):
        layer = metrics._WeightedSequenceMean()
        sample_input = np.ones((2, 3))
        sample_weights = np.ones((3, ))
        full_signature = shapes.signature([sample_input, sample_weights])
        layer.init(full_signature)

        x = np.array([[1., 1., 1.], [1., 1., 0.]])
        weights = np.array([1., 1., 1.])
        mean = layer((x, weights))
        np.testing.assert_allclose(mean, 0.5)

        weights = np.array([1., 1., 0.])
        mean = layer((x, weights))
        np.testing.assert_allclose(mean, 1.)