Esempio n. 1
0
def test_zero_prop2():
    x = mx.sym.Variable('x')
    idx = mx.sym.Variable('idx')
    y = mx.sym.batch_take(x, idx)
    z = mx.sym.stop_gradient(y)
    exe = z.simple_bind(ctx=mx.cpu(),
                        x=(10, 10),
                        idx=(10, ),
                        type_dict={
                            'x': np.float32,
                            'idx': np.int32
                        })
    exe.forward()
    exe.backward()

    # The following bind() should throw an exception. We discard the expected stderr
    # output for this operation only in order to keep the test logs clean.
    with discard_stderr():
        try:
            y.simple_bind(ctx=mx.cpu(),
                          x=(10, 10),
                          idx=(10, ),
                          type_dict={
                              'x': np.float32,
                              'idx': np.int32
                          })
        except:
            return

    assert False
Esempio n. 2
0
def test_infershape_happens_for_all_ops_in_graph():
    v = mx.sym.Variable('V')
    s = mx.sym.transpose(v)
    x = mx.sym.Variable('x')
    s2 = x + v
    s3 = s + s2
    with discard_stderr():
        try:
            # This should throw an exception as you cannot add arrays
            # with shapes [2,3] and [3,2]
            e = s3.simple_bind(ctx=mx.cpu(), x=(2, 3), grad_req='null')
        except:
            return

    assert False
Esempio n. 3
0
def test_zero_prop2():
    x = mx.sym.Variable('x')
    idx = mx.sym.Variable('idx')
    y = mx.sym.batch_take(x, idx)
    z = mx.sym.stop_gradient(y)
    exe = z.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,),
                        type_dict={'x': np.float32, 'idx': np.int32})
    exe.forward()
    exe.backward()

    # The following bind() should throw an exception. We discard the expected stderr
    # output for this operation only in order to keep the test logs clean.
    with discard_stderr():
        try:
            y.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,),
                          type_dict={'x': np.float32, 'idx': np.int32})
        except:
            return

    assert False
def test_invalid_pull():
    def check_ignored_pull_single(kv, key):
        dns_val = (mx.nd.ones(shape) * 2)
        rsp_val = dns_val.tostype('row_sparse')
        kv.pull(key, out=rsp_val)
        check_diff_to_scalar(rsp_val, 2)

    def check_ignored_pull_list(kv, key):
        dns_val = [mx.nd.ones(shape) * 2] * len(key)
        rsp_val = [val.tostype('row_sparse') for val in dns_val]
        kv.pull(key, out=rsp_val)
        for v in rsp_val:
            check_diff_to_scalar(v, 2)

    def check_invalid_rsp_pull_single(kv, key):
        dns_val = mx.nd.ones(shape) * 2
        assertRaises(MXNetError,
                     kv.row_sparse_pull,
                     key,
                     out=dns_val,
                     row_ids=mx.nd.array([1]))

    def check_invalid_rsp_pull_list(kv, key):
        dns_val = [mx.nd.ones(shape) * 2] * len(key)
        assertRaises(MXNetError,
                     kv.row_sparse_pull,
                     key,
                     out=dns_val,
                     row_ids=[mx.nd.array([1])] * len(key))

    def check_invalid_key_types_single(kv, key):
        dns_val = mx.nd.ones(shape) * 2
        rsp_val = dns_val.tostype('row_sparse')
        assertRaises(MXNetError, kv.init, key, dns_val)
        assertRaises(MXNetError, kv.push, key, dns_val)
        assertRaises(MXNetError, kv.pull, key, dns_val)
        assertRaises(MXNetError,
                     kv.row_sparse_pull,
                     key,
                     rsp_val,
                     row_ids=mx.nd.array([1]))

    def check_invalid_key_types_list(kv, key):
        dns_val = [mx.nd.ones(shape) * 2] * len(key)
        rsp_val = [val.tostype('row_sparse') for val in dns_val]
        assertRaises(MXNetError, kv.init, key, dns_val)
        assertRaises(MXNetError, kv.push, key, dns_val)
        assertRaises(MXNetError, kv.pull, key, dns_val)
        assertRaises(MXNetError,
                     kv.row_sparse_pull,
                     key,
                     rsp_val,
                     row_ids=[mx.nd.array([1])] * len(key))

    int_kv = init_kv()
    str_kv = init_kv_with_str()

    kvs = [int_kv, str_kv]
    single_keys = [3, 'a']
    list_keys = [keys, str_keys]
    # Keep this test from adding stack backtraces to the log file.
    with discard_stderr():
        for i in range(2):
            # pull with rsp outputs should be ignored with no values updated
            check_ignored_pull_single(kvs[i], single_keys[i])
            check_ignored_pull_list(kvs[i], list_keys[i])
            # row_sparse_pull should be aborted when vals.stype != row_sparse
            check_invalid_rsp_pull_single(kvs[i], single_keys[i])
            check_invalid_rsp_pull_list(kvs[i], list_keys[i])
            # kvstore should be restricted to only accept either int or str keys
            check_invalid_key_types_single(kvs[i], single_keys[1 - i])
            check_invalid_key_types_list(kvs[i], list_keys[1 - i])