def test_zero_prop2(): x = mx.sym.Variable('x') idx = mx.sym.Variable('idx') y = mx.sym.batch_take(x, idx) z = mx.sym.stop_gradient(y) exe = z.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10, ), type_dict={ 'x': np.float32, 'idx': np.int32 }) exe.forward() exe.backward() # The following bind() should throw an exception. We discard the expected stderr # output for this operation only in order to keep the test logs clean. with discard_stderr(): try: y.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10, ), type_dict={ 'x': np.float32, 'idx': np.int32 }) except: return assert False
def test_infershape_happens_for_all_ops_in_graph(): v = mx.sym.Variable('V') s = mx.sym.transpose(v) x = mx.sym.Variable('x') s2 = x + v s3 = s + s2 with discard_stderr(): try: # This should throw an exception as you cannot add arrays # with shapes [2,3] and [3,2] e = s3.simple_bind(ctx=mx.cpu(), x=(2, 3), grad_req='null') except: return assert False
def test_zero_prop2(): x = mx.sym.Variable('x') idx = mx.sym.Variable('idx') y = mx.sym.batch_take(x, idx) z = mx.sym.stop_gradient(y) exe = z.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,), type_dict={'x': np.float32, 'idx': np.int32}) exe.forward() exe.backward() # The following bind() should throw an exception. We discard the expected stderr # output for this operation only in order to keep the test logs clean. with discard_stderr(): try: y.simple_bind(ctx=mx.cpu(), x=(10, 10), idx=(10,), type_dict={'x': np.float32, 'idx': np.int32}) except: return assert False
def test_invalid_pull(): def check_ignored_pull_single(kv, key): dns_val = (mx.nd.ones(shape) * 2) rsp_val = dns_val.tostype('row_sparse') kv.pull(key, out=rsp_val) check_diff_to_scalar(rsp_val, 2) def check_ignored_pull_list(kv, key): dns_val = [mx.nd.ones(shape) * 2] * len(key) rsp_val = [val.tostype('row_sparse') for val in dns_val] kv.pull(key, out=rsp_val) for v in rsp_val: check_diff_to_scalar(v, 2) def check_invalid_rsp_pull_single(kv, key): dns_val = mx.nd.ones(shape) * 2 assertRaises(MXNetError, kv.row_sparse_pull, key, out=dns_val, row_ids=mx.nd.array([1])) def check_invalid_rsp_pull_list(kv, key): dns_val = [mx.nd.ones(shape) * 2] * len(key) assertRaises(MXNetError, kv.row_sparse_pull, key, out=dns_val, row_ids=[mx.nd.array([1])] * len(key)) def check_invalid_key_types_single(kv, key): dns_val = mx.nd.ones(shape) * 2 rsp_val = dns_val.tostype('row_sparse') assertRaises(MXNetError, kv.init, key, dns_val) assertRaises(MXNetError, kv.push, key, dns_val) assertRaises(MXNetError, kv.pull, key, dns_val) assertRaises(MXNetError, kv.row_sparse_pull, key, rsp_val, row_ids=mx.nd.array([1])) def check_invalid_key_types_list(kv, key): dns_val = [mx.nd.ones(shape) * 2] * len(key) rsp_val = [val.tostype('row_sparse') for val in dns_val] assertRaises(MXNetError, kv.init, key, dns_val) assertRaises(MXNetError, kv.push, key, dns_val) assertRaises(MXNetError, kv.pull, key, dns_val) assertRaises(MXNetError, kv.row_sparse_pull, key, rsp_val, row_ids=[mx.nd.array([1])] * len(key)) int_kv = init_kv() str_kv = init_kv_with_str() kvs = [int_kv, str_kv] single_keys = [3, 'a'] list_keys = [keys, str_keys] # Keep this test from adding stack backtraces to the log file. with discard_stderr(): for i in range(2): # pull with rsp outputs should be ignored with no values updated check_ignored_pull_single(kvs[i], single_keys[i]) check_ignored_pull_list(kvs[i], list_keys[i]) # row_sparse_pull should be aborted when vals.stype != row_sparse check_invalid_rsp_pull_single(kvs[i], single_keys[i]) check_invalid_rsp_pull_list(kvs[i], list_keys[i]) # kvstore should be restricted to only accept either int or str keys check_invalid_key_types_single(kvs[i], single_keys[1 - i]) check_invalid_key_types_list(kvs[i], list_keys[1 - i])