예제 #1
0
def test_linear(device, x_shape, w_shape, b_shape, n_batch_axes, dtype):
    # TODO(imanishi): Remove the skip after supporting non-float dot on CUDA
    if device.name == 'cuda:0' and numpy.dtype(dtype).kind != 'f':
        return chainerx.testing.ignore()
    x = array_utils.create_dummy_ndarray(numpy, x_shape, dtype)
    w = array_utils.create_dummy_ndarray(numpy, w_shape, dtype)
    b = (None if b_shape in (None, Unspecified)
         else array_utils.create_dummy_ndarray(numpy, b_shape, dtype))

    # Calculate chainerx_out
    chainerx_x = chainerx.array(x)
    chainerx_w = chainerx.array(w)
    chainerx_b = chainerx.array(b) if b is not None else None
    if b_shape is Unspecified:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w)
    elif n_batch_axes is Unspecified:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w, chainerx_b)
    else:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w, chainerx_b,
                                       n_batch_axes)

    # Calculate numpy_out
    if n_batch_axes is Unspecified:
        n_batch_axes = 1
    out_shape = x_shape[:n_batch_axes] + (w_shape[0],)
    x = x.reshape(numpy.prod(x_shape[:n_batch_axes]),
                  numpy.prod(x_shape[n_batch_axes:]))
    numpy_out = x.dot(w.T).reshape(out_shape)
    if b is not None:
        numpy_out += b

    chainerx.testing.assert_allclose_ex(
        chainerx_out, numpy_out,
        float16_rtol=1e-2, float16_atol=1e-2, strides_check=False)
예제 #2
0
def test_linear(device, x_shape, w_shape, b_shape, n_batch_axes, dtype):
    # TODO(imanishi): Remove the skip after supporting non-float dot on CUDA
    if device.name == 'cuda:0' and numpy.dtype(dtype).kind != 'f':
        return chainerx.testing.ignore()
    x = array_utils.create_dummy_ndarray(numpy, x_shape, dtype)
    w = array_utils.create_dummy_ndarray(numpy, w_shape, dtype)
    b = (None if b_shape in (None, Unspecified)
         else array_utils.create_dummy_ndarray(numpy, b_shape, dtype))

    # Calculate chainerx_out
    chainerx_x = chainerx.array(x)
    chainerx_w = chainerx.array(w)
    chainerx_b = chainerx.array(b) if b is not None else None
    if b_shape is Unspecified:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w)
    elif n_batch_axes is Unspecified:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w, chainerx_b)
    else:
        chainerx_out = chainerx.linear(chainerx_x, chainerx_w, chainerx_b,
                                       n_batch_axes)

    # Calculate numpy_out
    if n_batch_axes is Unspecified:
        n_batch_axes = 1
    out_shape = x_shape[:n_batch_axes] + (w_shape[0],)
    x = x.reshape(numpy.prod(x_shape[:n_batch_axes]),
                  numpy.prod(x_shape[n_batch_axes:]))
    numpy_out = x.dot(w.T).reshape(out_shape)
    if b is not None:
        numpy_out += b

    chainerx.testing.assert_array_equal(chainerx_out, numpy_out)
예제 #3
0
    def forward_chainerx(self, inputs):
        # TODO(niboshi): Support dtype casting in ChainerX
        if inputs[0].dtype != inputs[1].dtype:
            return chainer.Fallback

        # Generic implementation
        if len(inputs) == 3:
            x, W, b = inputs
            if x.dtype != b.dtype:
                return chainer.Fallback
            return chainerx.linear(x, W, b),
        else:
            x, W = inputs
            return chainerx.linear(x, W),
예제 #4
0
파일: linear.py 프로젝트: jnishi/chainer
    def forward_chainerx(self, inputs):
        # TODO(niboshi): Support dtype casting in ChainerX
        if inputs[0].dtype != inputs[1].dtype:
            return chainer.Fallback

        # Generic implementation
        if len(inputs) == 3:
            x, W, b = inputs
            if x.dtype != b.dtype:
                return chainer.Fallback
            return chainerx.linear(x, W, b),
        else:
            x, W = inputs
            return chainerx.linear(x, W),
예제 #5
0
    def forward_chainerx(self, inputs):
        if len(inputs) == 3:
            x, w, b = inputs
        else:
            (x, w), b = inputs, None

        n_batch_axes = self.n_batch_axes

        if b is Unspecified:
            y = chainerx.linear(x, w)
        elif n_batch_axes is Unspecified:
            y = chainerx.linear(x, w, b)
        else:
            y = chainerx.linear(x, w, b, n_batch_axes)
        return y,
예제 #6
0
    def forward_chainerx(self, inputs):
        if len(inputs) == 3:
            x, w, b = inputs
        else:
            (x, w), b = inputs, None

        n_batch_axes = self.n_batch_axes

        if b is Unspecified:
            y = chainerx.linear(x, w)
        elif n_batch_axes is Unspecified:
            y = chainerx.linear(x, w, b)
        else:
            y = chainerx.linear(x, w, b, n_batch_axes)
        return y,
예제 #7
0
 def forward(self, x):
     h = chx.relu(chx.linear(x, self.W1, self.b1))
     h = chx.relu(chx.linear(h, self.W2, self.b2))
     return chx.linear(h, self.W3, self.b3)
예제 #8
0
 def forward(self, x):
     h = chx.relu(chx.linear(x, self.W1, self.b1))
     h = chx.relu(chx.linear(h, self.W2, self.b2))
     return chx.linear(h, self.W3, self.b3)