def check_invalid_rsp_pull_list(kv, key):
     dns_val = [mx.nd.ones(shape) * 2] * len(key)
     assert_exception(kv.row_sparse_pull,
                      MXNetError,
                      key,
                      out=dns_val,
                      row_ids=[mx.nd.array([1])] * len(key))
 def check_invalid_rsp_pull_single(kv, key):
     dns_val = mx.nd.ones(shape) * 2
     assert_exception(kv.row_sparse_pull,
                      MXNetError,
                      key,
                      out=dns_val,
                      row_ids=mx.nd.array([1]))
Exemple #3
0
def test_hybrid_block_multiple_outputs():
    @use_np
    class TestAllNumpyOutputs(HybridBlock):
        def hybrid_forward(self, F, x, *args, **kwargs):
            return F.np.add(x, x), F.np.multiply(x, x)

    class TestAllClassicOutputs(HybridBlock):
        def hybrid_forward(self, F, x, *args, **kwargs):
            return x.as_nd_ndarray() + x.as_nd_ndarray(
            ), x.as_nd_ndarray() * x.as_nd_ndarray()

    data_np = np.ones((2, 3))
    for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray),
                                     (TestAllNumpyOutputs, np.ndarray)]:
        net = block()
        for hybridize in [True, False]:
            if hybridize:
                net.hybridize()
            out1, out2 = net(data_np)
            assert type(out1) is expected_out_type
            assert type(out2) is expected_out_type

    @use_np
    class TestMixedTypeOutputsFailure(HybridBlock):
        def hybrid_forward(self, F, x, *args, **kwargs):
            return x.as_nd_ndarray() + x.as_nd_ndarray(), F.np.multiply(x, x)

    net = TestMixedTypeOutputsFailure()
    assert_exception(net, TypeError, data_np)
    net.hybridize()
    assert_exception(net, TypeError, data_np)
Exemple #4
0
def test_np_empty():
    dtypes = [np.int8, np.int32, np.float16, np.float32, np.float64, None]
    expected_dtypes = [
        np.int8, np.int32, np.float16, np.float32, np.float64, np.float32
    ]
    orders = ['C', 'F', 'A']
    shapes = [
        (),
        0,
        (0, ),
        (0, 0),
        2,
        (2, ),
        (3, 0),
        (4, 5),
        (1, 1, 1, 1),
    ]
    ctxes = [npx.current_context(), None]
    for dtype, expected_dtype in zip(dtypes, expected_dtypes):
        for shape in shapes:
            for order in orders:
                for ctx in ctxes:
                    if order == 'C':
                        ret = np.empty(shape, dtype, order, ctx)
                        assert ret.dtype == expected_dtype
                        assert ret.shape == shape if isinstance(
                            shape, tuple) else (shape, )
                        assert ret.ctx == npx.current_context()
                    else:
                        assert_exception(np.empty, NotImplementedError, shape,
                                         dtype, order, ctx)
Exemple #5
0
def test_smooth_distribution():
    assert_exception(lambda: mx.contrib.quant._smooth_distribution(np.zeros((2,)), eps=1e-3), ValueError)
    dirac_delta = np.zeros((5,))
    dirac_delta[2] = 1
    smooth_dirac_delta = dirac_delta.copy()
    smooth_dirac_delta += 1e-3
    smooth_dirac_delta[2] -= 5e-3
    assert_almost_equal(mx.contrib.quant._smooth_distribution(dirac_delta, eps=1e-3), smooth_dirac_delta)
def test_smooth_distribution():
    assert_exception(lambda: mx.contrib.quant._smooth_distribution(np.zeros((2,)), eps=1e-3), ValueError)
    dirac_delta = np.zeros((5,))
    dirac_delta[2] = 1
    smooth_dirac_delta = dirac_delta.copy()
    smooth_dirac_delta += 1e-3
    smooth_dirac_delta[2] -= 5e-3
    assert_almost_equal(mx.contrib.quant._smooth_distribution(dirac_delta, eps=1e-3), smooth_dirac_delta)
 def check_invalid_pull():
     kv.init(keys_invalid[0], mx.nd.ones((2, 2)).tostype('row_sparse'))
     out = mx.nd.ones((2, 2)).tostype('row_sparse')
     assert_exception(kv.pull,
                      mx.MXNetError,
                      'invalid_key',
                      out=out,
                      ignore_sparse=False)
     print('worker ' + str(my_rank) + ' passed check_invalid_pull')
def test_gluon_trainer_reset():
    params = mx.gluon.ParameterDict()
    x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse')
    params.initialize(ctx=mx.cpu(0), init='zeros')
    trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
    params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params')
    row_id = mx.nd.arange(0, 4)
    w = x.row_sparse_data(row_id)
    assert trainer._kv_initialized and trainer._update_on_kvstore
    # load would fail to reset kvstore since update_on_kvstore is True
    assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params')
    print('worker ' + str(my_rank) + ' passed test_gluon_trainer_reset')
 def check_invalid_gluon_trainer_reset():
     params = mx.gluon.ParameterDict()
     x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse')
     params.initialize(ctx=mx.cpu(0), init='zeros')
     trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
     params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params')
     row_id = mx.nd.arange(0, 4)
     w = x.row_sparse_data(row_id)
     assert trainer._kv_initialized and trainer._update_on_kvstore
     mx.nd.waitall()
     # load would fail to reset kvstore since update_on_kvstore is True
     assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params')
     print('worker ' + str(my_rank) + ' passed check_invalid_gluon_trainer_reset')
 def check_invalid_key_types_list(kv, key):
     dns_val = [mx.nd.ones(shape) * 2] * len(key)
     rsp_val = [val.tostype('row_sparse') for val in dns_val]
     assert_exception(kv.init, MXNetError, key, dns_val)
     assert_exception(kv.push, MXNetError, key, dns_val)
     assert_exception(kv.pull, MXNetError, key, dns_val)
     assert_exception(kv.row_sparse_pull, MXNetError, key, rsp_val,
                      row_ids=[mx.nd.array([1])] * len(key))
 def check_invalid_key_types_single(kv, key):
     dns_val = mx.nd.ones(shape) * 2
     rsp_val = dns_val.tostype('row_sparse')
     assert_exception(kv.init, MXNetError, key, dns_val)
     assert_exception(kv.push, MXNetError, key, dns_val)
     assert_exception(kv.pull, MXNetError, key, dns_val)
     assert_exception(kv.row_sparse_pull, MXNetError, key, rsp_val,
                      row_ids=mx.nd.array([1]))
def test_np_empty():
    # (input dtype, expected output dtype)
    dtype_pairs = [
        (np.int8, np.int8),
        (np.int32, np.int32),
        (np.float16, np.float16),
        (np.float32, np.float32),
        (np.float64, np.float64),
        (np.bool_, np.bool_),
        (np.bool, np.bool_),
        ('int8', np.int8),
        ('int32', np.int32),
        ('float16', np.float16),
        ('float32', np.float32),
        ('float64', np.float64),
        ('bool', np.bool_),
        (None, np.float32),
    ]
    orders = ['C', 'F', 'A']
    shapes = [
        (),
        0,
        (0, ),
        (0, 0),
        2,
        (2, ),
        (3, 0),
        (4, 5),
        (1, 1, 1, 1),
    ]
    ctxes = [npx.current_context(), None]
    for dtype, expected_dtype in dtype_pairs:
        for shape in shapes:
            for order in orders:
                for ctx in ctxes:
                    if order == 'C':
                        ret = np.empty(shape, dtype, order, ctx)
                        assert ret.dtype == expected_dtype
                        assert ret.shape == shape if isinstance(
                            shape, tuple) else (shape, )
                        assert ret.ctx == npx.current_context()
                    else:
                        assert_exception(np.empty, NotImplementedError, shape,
                                         dtype, order, ctx)
Exemple #13
0
    def test_boolean_catch_exception():
        # adapted from numpy's test_indexing.py
        arr = np.ones((5, 4, 3))

        index = np.array([True], dtype=np.bool_)
        assert_exception(arr.__getitem__, IndexError, index)

        index = np.array([False] * 6, dtype=np.bool_)
        assert_exception(arr.__getitem__, IndexError, index)

        index = np.zeros((4, 4), dtype=bool)
        assert_exception(arr.__getitem__, IndexError, index)

        assert_exception(arr.__getitem__, TypeError, (slice(None), index))
Exemple #14
0
def test_np_ndarray_copy():
    mx_data = np.array([2, 3, 4, 5], dtype=_np.int32)
    assert_exception(mx_data.copy, NotImplementedError, order='F')
    mx_ret = mx_data.copy()
    np_ret = mx_data.asnumpy().copy()
    assert same(mx_ret.asnumpy(), np_ret)
 def check_invalid_rsp_pull_list(kv, key):
     dns_val = [mx.nd.ones(shape) * 2] * len(key)
     assert_exception(kv.row_sparse_pull, MXNetError, key, out=dns_val,
                      row_ids=[mx.nd.array([1])] * len(key))
 def check_invalid_rsp_pull_single(kv, key):
     dns_val = mx.nd.ones(shape) * 2
     assert_exception(kv.row_sparse_pull, MXNetError,
                      key, out=dns_val, row_ids=mx.nd.array([1]))
Exemple #17
0
 def check_invalid_pull():
     kv.init(keys_invalid[0], mx.nd.ones((2,2)).tostype('row_sparse'))
     out = mx.nd.ones((2,2)).tostype('row_sparse')
     assert_exception(kv.pull, mx.MXNetError, 'invalid_key', out=out, ignore_sparse=False)
     print('worker ' + str(my_rank) + ' passed check_invalid_pull')