コード例 #1
0
ファイル: test_array.py プロジェクト: vermashresth/chainer
def _check_array(
        array, expected_dtype, expected_shape, expected_data_list=None,
        device=None):
    expected_dtype = chainerx.dtype(expected_dtype)

    assert isinstance(array.dtype, chainerx.dtype)
    assert isinstance(array.shape, tuple)
    assert array.dtype == expected_dtype
    assert array.shape == expected_shape
    assert array.itemsize == expected_dtype.itemsize
    assert array.size == array_utils.total_size(expected_shape)
    assert array.nbytes == expected_dtype.itemsize * \
        array_utils.total_size(expected_shape)
    if expected_data_list is not None:
        assert array._debug_flat_data == expected_data_list

    assert array.is_contiguous

    array_utils.check_device(array, device)
コード例 #2
0
ファイル: test_array.py プロジェクト: hvy/chainer
def _check_array(
        array, expected_dtype, expected_shape, expected_data_list=None,
        device=None):
    expected_dtype = chainerx.dtype(expected_dtype)

    assert isinstance(array.dtype, chainerx.dtype)
    assert isinstance(array.shape, tuple)
    assert array.dtype == expected_dtype
    assert array.shape == expected_shape
    assert array.itemsize == expected_dtype.itemsize
    assert array.size == array_utils.total_size(expected_shape)
    assert array.nbytes == expected_dtype.itemsize * \
        array_utils.total_size(expected_shape)
    if expected_data_list is not None:
        assert array._debug_flat_data == expected_data_list

    assert array.is_contiguous

    array_utils.check_device(array, device)
コード例 #3
0
ファイル: test_manipulation.py プロジェクト: pfnet/chainer
def _make_inputs(shapes, dtypes):
    # Generates input ndarrays.
    assert isinstance(shapes, (list, tuple))
    assert isinstance(dtypes, (list, tuple))
    assert len(shapes) == len(dtypes)

    inputs = []
    for i, (shape, dtype) in enumerate(zip(shapes, dtypes)):
        size = array_utils.total_size(shape)
        a = numpy.arange(i * 100, i * 100 + size)
        a = a.reshape(shape)
        a = a.astype(dtype)
        inputs.append(a)

    assert len(inputs) > 0
    return tuple(inputs)
コード例 #4
0
def _make_inputs(shapes, dtypes):
    # Generates input ndarrays.
    assert isinstance(shapes, (list, tuple))
    assert isinstance(dtypes, (list, tuple))
    assert len(shapes) == len(dtypes)

    inputs = []
    for i, (shape, dtype) in enumerate(zip(shapes, dtypes)):
        size = array_utils.total_size(shape)
        a = numpy.arange(i * 100, i * 100 + size)
        a = a.reshape(shape)
        a = a.astype(dtype)
        inputs.append(a)

    assert len(inputs) > 0
    return tuple(inputs)
コード例 #5
0
ファイル: test_creation.py プロジェクト: jnishi/chainer
def test_diagflat_invalid_ndim(xp, k, shape, device):
    v = xp.arange(array_utils.total_size(shape)).reshape(shape)
    return xp.diagflat(v, k)
コード例 #6
0
ファイル: test_creation.py プロジェクト: jnishi/chainer
def test_diag(xp, k, shape, transpose, device):
    v = xp.arange(array_utils.total_size(shape)).reshape(shape)
    if transpose:  # Test non-contiguous inputs for multi-dimensional shapes.
        v = v.T
    return xp.diag(v, k)
コード例 #7
0
ファイル: test_creation.py プロジェクト: dselivanov/chainer
def test_diagflat_invalid_ndim(xp, k, shape, device):
    v = xp.arange(array_utils.total_size(shape), dtype='int32').reshape(shape)
    return xp.diagflat(v, k)
コード例 #8
0
ファイル: test_creation.py プロジェクト: dselivanov/chainer
def test_diag(xp, k, shape, transpose, device):
    v = xp.arange(array_utils.total_size(shape), dtype='int32').reshape(shape)
    if transpose:  # Test non-contiguous inputs for multi-dimensional shapes.
        v = v.T
    return xp.diag(v, k)
コード例 #9
0
def test_total_size(expected, shape):
    assert expected == array_utils.total_size(shape)
コード例 #10
0
ファイル: test_creation.py プロジェクト: hvy/chainer
def test_diagflat(xp, k, shape, device):
    v = xp.arange(array_utils.total_size(shape), dtype='int32').reshape(shape)
    return xp.diagflat(v, k)
コード例 #11
0
def test_diagflat(xp, k, shape, device):
    v = xp.arange(array_utils.total_size(shape)).reshape(shape)
    return xp.diagflat(v, k)
コード例 #12
0
def test_diag_invalid_ndim(xp, k, shape, device):
    v = xp.arange(array_utils.total_size(shape)).reshape(shape)
    return xp.diag(v, k)
コード例 #13
0
ファイル: test_array_utils.py プロジェクト: asi1024/chainer
def test_total_size(expected, shape):
    assert expected == array_utils.total_size(shape)