コード例 #1
0
ファイル: attention_test.py プロジェクト: rouniuyizu/trax
 def test_shift_right(self):
     # Test shifts right on axis=1
     layer = attention.ShiftRight()
     input_np = np.arange(2 * 3 * 3).reshape(2, 3, 3)
     output_np = layer(input_np)
     self.assertEqual(input_np.shape, output_np.shape)
     self.assertAllEqual(
         np.array([[[0, 0, 0], [0, 1, 2], [3, 4, 5]],
                   [[0, 0, 0], [9, 10, 11], [12, 13, 14]]]), output_np)
コード例 #2
0
ファイル: attention_test.py プロジェクト: rouniuyizu/trax
    def test_shift_right_float(self):
        layer = attention.ShiftRight()
        input_np = np.arange(2 * 3 * 3).reshape(2, 3, 3).astype(np.float32)
        # Test on a float array.
        input_np = input_np.astype(np.float32)
        input_np /= 2.0
        self.assertEqual(input_np.dtype, np.float32)

        output_np = layer(input_np)
        self.assertEqual(input_np.shape, output_np.shape)
        self.assertEqual(output_np.dtype, np.float32)

        self.assertAllEqual(
            np.array([[[0., 0., 0.], [0., 0.5, 1.], [1.5, 2., 2.5]],
                      [[0., 0., 0.], [4.5, 5., 5.5], [6., 6.5, 7.]]]),
            output_np)