示例#1
0
 def check_general_stride(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = Variable(x)
     y = as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y_expected = _stride_array(xp.arange(8, dtype=self.dtype),
                                (3, 3), (1, 2), 0)
     assert (y.array == y_expected).all()
示例#2
0
 def check_general_stride(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = chainer.Variable(x)
     y = F.as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y_expected = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3),
                                (1, 2), 0)
     assert (y.array == y_expected).all()
示例#3
0
 def check_general_stride_backward(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = chainer.Variable(x)
     y = F.as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y.grad = xp.ones(y.shape, dtype=self.dtype)
     with self.assertRaises(TypeError):
         gx, = chainer.grad((y, ), (v, ))
示例#4
0
 def check_general_stride_backward(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = Variable(x)
     y = as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y.grad = xp.ones(y.shape, dtype=self.dtype)
     with self.assertRaises(TypeError):
         gx, = grad((y,), (v,))
示例#5
0
 def check_general_stride(self, xp):
     x = xp.arange(8, dtype=self.dtype)
     y = _stride_array(x, (3, 3), (-1, 2), 3)
     y_expected = xp.array(
         [[3, 5, 7],
          [2, 4, 6],
          [1, 3, 5]],
         dtype=self.dtype
     )
     testing.assert_allclose(y, y_expected)
示例#6
0
 def check_general_stride_backward(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = chainer.Variable(x)
     y = F.as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y.grad = xp.ones(y.shape, dtype=self.dtype)
     gx, = chainer.grad((y, ), (v, ))
     testing.assert_allclose(
         gx.array,
         xp.array([[0.5, 0.5, 0.], [2., 2., 1.], [1., 0.5, 0.5]],
                  dtype=self.dtype))
示例#7
0
 def check_general_stride_backward(self, xp):
     x = _stride_array(xp.arange(8, dtype=self.dtype), (3, 3), (-1, 2), 3)
     # [[3., 5., 7.], [2., 4., 6.], [1., 3., 5.]]
     v = Variable(x)
     y = as_strided(v, (3, 3), (1, 2), 0)
     # [[0., 2., 4.], [1., 3., 5.,], [2., 4., 6.]]
     y.grad = xp.ones(y.shape, dtype=self.dtype)
     gx, = grad((y,), (v,))
     testing.assert_allclose(gx.array,
                             xp.array([
                                 [0.5, 0.5, 0.],
                                 [2., 2., 1.],
                                 [1., 0.5, 0.5]
                             ], dtype=self.dtype)
                             )
示例#8
0
 def check_invalid_negative_index(self, xp):
     x = xp.arange(8, dtype=self.dtype)
     with self.assertRaises(ValueError):
         _stride_array(x, (3, 3), (-1, 2), 1)
示例#9
0
 def check_unstride(self, xp):
     x = xp.arange(12, dtype=self.dtype).reshape((3, 4))[::-1]
     y = _stride_array(x, (12, ), (1, ), 0)
     y_expected = xp.arange(12, dtype=self.dtype)
     testing.assert_allclose(y, y_expected)
示例#10
0
 def check_broadcast(self, xp):
     x = xp.arange(12, dtype=self.dtype).reshape((3, 4)).copy()
     # [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
     y = _stride_array(x, (2, 3, 4), (0, 4, 1), 0)
     y_expected = _broadcast_to(xp, x, (2, 3, 4))
     testing.assert_allclose(y, y_expected)
示例#11
0
 def check_flip(self, xp):
     x = xp.arange(4, dtype=self.dtype)
     y = _stride_array(x, (4, ), (-1, ), 3)  # [3, 2, 1, 0]
     y_expected = x[::-1]
     testing.assert_allclose(y, y_expected)
示例#12
0
 def check_invalid_negative_index(self, xp):
     x = xp.arange(8, dtype=self.dtype)
     with self.assertRaises(ValueError):
         _stride_array(x, (3, 3), (-1, 2), 1)
示例#13
0
 def check_unstride(self, xp):
     x = xp.arange(12, dtype=self.dtype).reshape((3, 4))[::-1]
     y = _stride_array(x, (12,), (1,), 0)
     y_expected = xp.arange(12, dtype=self.dtype)
     testing.assert_allclose(y, y_expected)
示例#14
0
 def check_broadcast(self, xp):
     x = xp.arange(12, dtype=self.dtype).reshape((3, 4)).copy()
     # [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
     y = _stride_array(x, (2, 3, 4), (0, 4, 1), 0)
     y_expected = _broadcast_to(xp, x, (2, 3, 4))
     testing.assert_allclose(y, y_expected)
示例#15
0
 def check_flip(self, xp):
     x = xp.arange(4, dtype=self.dtype)
     y = _stride_array(x, (4,), (-1,), 3)  # [3, 2, 1, 0]
     y_expected = x[::-1]
     testing.assert_allclose(y, y_expected)