Exemplo n.º 1
0
def test_take_non_contiguous(device, shape, indices, axis):
    a = numpy.random.uniform(-1, 1, shape).astype('float32')
    indices = numpy.array(indices).astype(numpy.int32)
    chx_a = chainerx.array(a).astype('float32')
    a = numpy.transpose(a, axes=range(chx_a.ndim)[::-1])
    chx_a = chainerx.transpose(chx_a, axes=range(chx_a.ndim)[::-1])
    assert (not chx_a.is_contiguous)
    chx_indices = chainerx.array(indices).astype(numpy.int32)
    chx_out = chainerx.take(chx_a, chx_indices, axis)
    np_out = numpy.take(a, indices, axis)
    numpy.testing.assert_array_equal(chx_out, np_out)
Exemplo n.º 2
0
def test_transpose_invalid_axes(shape, axes):
    a = array_utils.create_dummy_ndarray(chainerx, shape, 'float32')
    with pytest.raises(chainerx.DimensionError):
        chainerx.transpose(a, axes)
    with pytest.raises(chainerx.DimensionError):
        a.transpose(axes)