예제 #1
0
def test_add_invalid_dtypes(device, dtypes, is_module):
    (in_dtype1, in_dtype2), _ = dtypes
    shape = (2, 3)
    a = chainerx.array(array_utils.uniform(shape, in_dtype1))
    b = chainerx.array(array_utils.uniform(shape, in_dtype2))
    with pytest.raises(chainerx.DtypeError):
        if is_module:
            a + b
        else:
            chainerx.add(a, b)
예제 #2
0
def test_add_scalar(scalar, device, shape, dtype):
    x_np = array_utils.create_dummy_ndarray(numpy, shape, dtype)
    # Implicit casting in NumPy's multiply depends on the 'casting' argument,
    # which is not yet supported (ChainerX always casts).
    # Therefore, we explicitly cast the scalar to the dtype of the ndarray
    # before the multiplication for NumPy.
    expected = x_np + numpy.dtype(dtype).type(scalar)

    x = chainerx.array(x_np)
    chainerx.testing.assert_array_equal_ex(x + scalar, expected)
    chainerx.testing.assert_array_equal_ex(scalar + x, expected)
    chainerx.testing.assert_array_equal_ex(chainerx.add(x, scalar), expected)
    chainerx.testing.assert_array_equal_ex(chainerx.add(scalar, x), expected)
예제 #3
0
파일: test_math.py 프로젝트: jnishi/chainer
def test_add_scalar(scalar, device, shape, dtype):
    x_np = array_utils.create_dummy_ndarray(numpy, shape, dtype)
    # Implicit casting in NumPy's multiply depends on the 'casting' argument,
    # which is not yet supported (ChainerX always casts).
    # Therefore, we explicitly cast the scalar to the dtype of the ndarray
    # before the multiplication for NumPy.
    expected = x_np + numpy.dtype(dtype).type(scalar)

    x = chainerx.array(x_np)
    scalar_chx = chainerx.Scalar(scalar, dtype)
    chainerx.testing.assert_array_equal_ex(x + scalar, expected)
    chainerx.testing.assert_array_equal_ex(x + scalar_chx, expected)
    chainerx.testing.assert_array_equal_ex(scalar + x, expected)
    chainerx.testing.assert_array_equal_ex(scalar_chx + x, expected)
    chainerx.testing.assert_array_equal_ex(chainerx.add(x, scalar), expected)
    chainerx.testing.assert_array_equal_ex(
        chainerx.add(x, scalar_chx), expected)
    chainerx.testing.assert_array_equal_ex(chainerx.add(scalar, x), expected)
    chainerx.testing.assert_array_equal_ex(
        chainerx.add(scalar_chx, x), expected)