def test_shift_right(self): # Test shifts right on axis=1 layer = attention.ShiftRight() input_np = np.arange(2 * 3 * 3).reshape(2, 3, 3) output_np = layer(input_np) self.assertEqual(input_np.shape, output_np.shape) self.assertAllEqual( np.array([[[0, 0, 0], [0, 1, 2], [3, 4, 5]], [[0, 0, 0], [9, 10, 11], [12, 13, 14]]]), output_np)
def test_shift_right_float(self): layer = attention.ShiftRight() input_np = np.arange(2 * 3 * 3).reshape(2, 3, 3).astype(np.float32) # Test on a float array. input_np = input_np.astype(np.float32) input_np /= 2.0 self.assertEqual(input_np.dtype, np.float32) output_np = layer(input_np) self.assertEqual(input_np.shape, output_np.shape) self.assertEqual(output_np.dtype, np.float32) self.assertAllEqual( np.array([[[0., 0., 0.], [0., 0.5, 1.], [1.5, 2., 2.5]], [[0., 0., 0.], [4.5, 5., 5.5], [6., 6.5, 7.]]]), output_np)