Esempio n. 1
0
def test_sub_scalar(scalar, device, shape, dtype):
    if dtype == 'bool_':
        # Boolean subtract is deprecated.
        return chainerx.testing.ignore()
    x_np = array_utils.create_dummy_ndarray(numpy, shape, dtype)
    # Implicit casting in NumPy's multiply depends on the 'casting' argument,
    # which is not yet supported (ChainerX always casts).
    # Therefore, we explicitly cast the scalar to the dtype of the ndarray
    # before the multiplication for NumPy.
    expected = x_np - numpy.dtype(dtype).type(scalar)
    expected_rev = numpy.dtype(dtype).type(scalar) - x_np

    x = chainerx.array(x_np)
    scalar_chx = chainerx.Scalar(scalar, dtype)
    chainerx.testing.assert_array_equal_ex(x - scalar, expected)
    chainerx.testing.assert_array_equal_ex(x - scalar_chx, expected)
    chainerx.testing.assert_array_equal_ex(scalar - x, expected_rev)
    chainerx.testing.assert_array_equal_ex(scalar_chx - x, expected_rev)
    chainerx.testing.assert_array_equal_ex(chainerx.subtract(x, scalar),
                                           expected)
    chainerx.testing.assert_array_equal_ex(chainerx.subtract(x, scalar_chx),
                                           expected)
    chainerx.testing.assert_array_equal_ex(chainerx.subtract(scalar, x),
                                           expected_rev)
    chainerx.testing.assert_array_equal_ex(chainerx.subtract(scalar_chx, x),
                                           expected_rev)
Esempio n. 2
0
def test_sub_scalar(scalar, device, shape, dtype):
    if dtype == 'bool_':
        # Boolean subtract is deprecated.
        return chainerx.testing.ignore()
    x_np = array_utils.create_dummy_ndarray(numpy, shape, dtype)
    # Implicit casting in NumPy's multiply depends on the 'casting' argument,
    # which is not yet supported (ChainerX always casts).
    # Therefore, we explicitly cast the scalar to the dtype of the ndarray
    # before the multiplication for NumPy.
    expected = x_np - numpy.dtype(dtype).type(scalar)
    expected_rev = numpy.dtype(dtype).type(scalar) - x_np

    x = chainerx.array(x_np)
    scalar_chx = chainerx.Scalar(scalar, dtype)
    chainerx.testing.assert_array_equal_ex(x - scalar, expected)
    chainerx.testing.assert_array_equal_ex(x - scalar_chx, expected)
    chainerx.testing.assert_array_equal_ex(scalar - x, expected_rev)
    chainerx.testing.assert_array_equal_ex(scalar_chx - x, expected_rev)
    chainerx.testing.assert_array_equal_ex(
        chainerx.subtract(x, scalar), expected)
    chainerx.testing.assert_array_equal_ex(
        chainerx.subtract(x, scalar_chx), expected)
    chainerx.testing.assert_array_equal_ex(
        chainerx.subtract(scalar, x), expected_rev)
    chainerx.testing.assert_array_equal_ex(
        chainerx.subtract(scalar_chx, x), expected_rev)
Esempio n. 3
0
def test_sub_invalid_dtypes(device, dtypes, is_module):
    (in_dtype1, in_dtype2), _ = dtypes
    shape = (2, 3)
    a = chainerx.array(array_utils.uniform(shape, in_dtype1))
    b = chainerx.array(array_utils.uniform(shape, in_dtype2))
    with pytest.raises(chainerx.DtypeError):
        if is_module:
            a - b
        else:
            chainerx.subtract(a, b)