def test_data_dir(): prev_data_dir = data_dir() system = platform.system() # Test that data_dir() returns the proper default value when MXNET_HOME is not set with environment('MXNET_HOME', None): if system == 'Windows': assert_equal(data_dir(), op.join(os.environ.get('APPDATA'), 'mxnet')) else: assert_equal(data_dir(), op.join(op.expanduser('~'), '.mxnet')) # Test that data_dir() responds to an explicit setting of MXNET_HOME with environment('MXNET_HOME', '/tmp/mxnet_data'): assert_equal(data_dir(), '/tmp/mxnet_data') # Test that this test has not disturbed the MXNET_HOME value existing before the test assert_equal(data_dir(), prev_data_dir)
def test_environment(): name1 = 'MXNET_TEST_ENV_VAR_1' name2 = 'MXNET_TEST_ENV_VAR_2' # Test that a variable can be set in the python and backend environment with environment(name1, '42'): assert_equal(os.environ.get(name1), '42') assert_equal(getenv(name1), '42') # Test dict form of invocation env_var_dict = {name1: '1', name2: '2'} with environment(env_var_dict): for key, value in env_var_dict.items(): assert_equal(os.environ.get(key), value) assert_equal(getenv(key), value)
def test_engine_openmp_after_fork(): """ Test that the number of max threads in the child is 1. After forking we should not use a bigger OMP thread pool. With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP the child respects the number of max threads set in the parent. """ with environment('OMP_NUM_THREADS', '42'): r, w = os.pipe() pid = os.fork() if pid: os.close(r) wfd = os.fdopen(w, 'w') wfd.write('a') omp_max_threads = mx.base._LIB.omp_get_max_threads() print("Parent omp max threads: {}".format(omp_max_threads)) try: wfd.close() except: pass try: (cpid, status) = os.waitpid(pid, 0) assert cpid == pid exit_status = status >> 8 assert exit_status == 0 except: pass else: os.close(w) rfd = os.fdopen(r, 'r') rfd.read(1) omp_max_threads = mx.base._LIB.omp_get_max_threads() print("Child omp max threads: {}".format(omp_max_threads)) assert omp_max_threads == 1
def test_device_pushpull(): def check_dense_pushpull(kv_type): for shape, key in zip(shapes, keys): for n_gpus in gpus: kv_device = mx.kv.create(kv_type) a = mx.nd.ones(shape, mx.gpu(0)) cur_key = str(key * max(gpus) + n_gpus) kv_device.init(cur_key, a) arr_list = [ mx.nd.ones(shape, mx.gpu(x)) for x in range(n_gpus) ] res = [mx.nd.zeros(shape, mx.gpu(x)) for x in range(n_gpus)] kv_device.push(cur_key, arr_list) kv_device.pull(cur_key, res) for x in range(n_gpus): assert (np.sum(np.abs((res[x] - n_gpus).asnumpy())) == 0) kvstore_tree_array_bound_values = [None, '1'] kvstore_usetree_values = [None, '1'] for y in kvstore_tree_array_bound_values: for x in kvstore_usetree_values: with environment({ 'MXNET_KVSTORE_USETREE': x, 'MXNET_KVSTORE_TREE_ARRAY_BOUND': y }): check_dense_pushpull('local') check_dense_pushpull('device')
def test_bind(): for enable_bulking in ['0', '1']: with environment({'MXNET_EXEC_BULK_EXEC_INFERENCE': enable_bulking, 'MXNET_EXEC_BULK_EXEC_TRAIN': enable_bulking}): nrepeat = 10 maxdim = 4 for _ in range(nrepeat): for dim in range(1, maxdim): check_bind_with_uniform(lambda x, y: x + y, lambda g, x, y: (g, g), dim) check_bind_with_uniform(lambda x, y: x - y, lambda g, x, y: (g, -g), dim) check_bind_with_uniform(lambda x, y: x * y, lambda g, x, y: (y * g, x * g), dim) check_bind_with_uniform(lambda x, y: x / y, lambda g, x, y: (g / y, -x * g/ (y**2)), dim) check_bind_with_uniform(lambda x, y: np.maximum(x, y), lambda g, x, y: (g * (x>=y), g * (y>x)), dim, sf=mx.symbol.maximum) check_bind_with_uniform(lambda x, y: np.minimum(x, y), lambda g, x, y: (g * (x<=y), g * (y<x)), dim, sf=mx.symbol.minimum)
def test_one_var(name, value, raise_exception=False): try: with environment(name, value): assert_equal(os.environ.get(name), value) assert_equal(getenv(name), value) if raise_exception: raise OnPurposeError except OnPurposeError: pass finally: check_background_values()
def test_engine_import(): import mxnet # Temporarily add an illegal entry (that is not caught) to show how the test needs improving engine_types = [ None, 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice', 'BogusEngine' ] for type in engine_types: with environment('MXNET_ENGINE_TYPE', type): reload(mxnet)
def test_gemms_true_fp16(): ctx = mx.gpu(0) input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx) weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx) net = nn.Dense(128, in_units=512, use_bias=False) net.cast('float16') net.initialize(ctx=ctx) net.weight.set_data(weights) with environment('MXNET_FC_TRUE_FP16', '0'): ref_results = net(input) with environment('MXNET_FC_TRUE_FP16', '1'): results_trueFP16 = net(input) atol = 1e-2 rtol = 1e-2 assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(), atol=atol, rtol=rtol)
def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs): inputs = sym.list_inputs() shapes = {inp : kwargs[inp].shape for inp in inputs} rtol = {'float16' : 1e-2, 'float32' : 1.5e-6, 'float64' : 1.5e-6, } atol = {'float16' : 1e-3, 'float32' : 1e-7, 'float64' : 1e-7, } for dtype in ['float16', 'float32', 'float64']: data = {inp : kwargs[inp].astype(dtype) for inp in inputs} for grad_req in ['write', 'add']: type_dict = {inp : dtype for inp in inputs} with environment({'MXNET_ELIMINATE_COMMON_EXPR': '0'}): orig_exec = sym._simple_bind(ctx=mx.cpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) with environment({'MXNET_ELIMINATE_COMMON_EXPR': '1'}): cse_exec = sym._simple_bind(ctx=mx.cpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) fwd_orig = orig_exec.forward(is_train=True, **data) out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] orig_exec.backward(out_grads=out_grads) fwd_cse = cse_exec.forward(is_train=True, **data) cse_exec.backward(out_grads=out_grads) if check_data: for orig, cse in zip(fwd_orig, fwd_cse): np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(), rtol=rtol[dtype], atol=atol[dtype]) for orig, cse in zip(orig_exec.grad_arrays, cse_exec.grad_arrays): if orig is None and cse is None: continue assert orig is not None assert cse is not None np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(), rtol=rtol[dtype], atol=atol[dtype]) orig_sym_internals = orig_exec.get_optimized_symbol().get_internals() cse_sym_internals = cse_exec.get_optimized_symbol().get_internals() # test that the graph has been simplified as expected assert (len(cse_sym_internals) + expected_savings) == len(orig_sym_internals)
def test_unary_func(): def check_unary_func(x): f_exp = lambda x: nd.exp(x) f_exp_grad = lambda x: [nd.exp(x)] autograd_assert(x, func=f_exp, grad_func=f_exp_grad) f_half = lambda x: x/2 f_half_grad = lambda x: [nd.ones(x.shape) * 0.5] autograd_assert(x, func=f_half, grad_func=f_half_grad) f_square = lambda x: x**2 f_square_grad = lambda x: [2*x] autograd_assert(x, func=f_square, grad_func=f_square_grad) uniform = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype in stypes: check_unary_func(uniform.tostype(stype))
def test_subgraph_exe4(sym, subgraph_backend, op_names): """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind and compare results of the partitioned sym and the original sym.""" def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): arg_shapes, _, aux_shapes = sym.infer_shape() if subgraph_backend is None: arg_array = [ mx.nd.random.uniform(shape=shape) for shape in arg_shapes ] aux_array = [ mx.nd.random.uniform(shape=shape) for shape in aux_shapes ] else: arg_array = None aux_array = None exe = sym._bind(ctx=mx.current_context(), args=arg_array if subgraph_backend is None else original_exec.arg_arrays, aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays, grad_req='null') exe.forward() return exe sym, _, _ = sym original_exec = get_executor(sym) with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): check_call( _LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), c_str_array(op_names))) partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec) check_call( _LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend))) outputs1 = original_exec.outputs outputs2 = partitioned_exec.outputs assert len(outputs1) == len(outputs2) for i in range(len(outputs1)): assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), onp.zeros(shape=(1, )))
def test_binary_func(): def check_binary_func(x, y): f_add = lambda x, y: x+y f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)] autograd_assert(x, y, func=f_add, grad_func=f_add_grad) f_mul = lambda x, y: x*y f_mul_grad = lambda x, y: [y, x] autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad) f_compose = lambda x, y: x+x*y f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x] autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad) uniform_x = nd.uniform(shape=(4, 5)) uniform_y = nd.uniform(shape=(4, 5)) stypes = ['default', 'row_sparse', 'csr'] with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): for stype_x in stypes: for stype_y in stypes: x = uniform_x.tostype(stype_x) y = uniform_y.tostype(stype_y) check_binary_func(x, y)
def run_in_spawned_process(func, env, *args): """ Helper function to run a test in its own process. Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend. Adds a seed as first arg to propagate determinism. Parameters ---------- func : function to run in a spawned process. env : dict of additional environment values to set temporarily in the environment before exec. args : args to pass to the function. Returns ------- Whether the python version supports running the function as a spawned process. This routine calculates a random seed and passes it into the test as a first argument. If the test uses random values, it should include an outer 'with random_seed(seed):'. If the test needs to return values to the caller, consider use of shared variable arguments. """ try: mpctx = mp.get_context('spawn') except: print( 'SKIP: python%s.%s lacks the required process fork-exec support ... ' % sys.version_info[0:2], file=sys.stderr, end='') return False else: seed = np.random.randint(0, 1024 * 1024 * 1024) with environment(env): # Prepend seed as first arg p = mpctx.Process(target=func, args=(seed, ) + args) p.start() p.join() assert p.exitcode == 0, "Non-zero exit code %d from %s()." % ( p.exitcode, func.__name__) return True
def test_subgraph_exe2(sym, subgraph_backend, op_names): """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in _simple_bind and compare results of the partitioned sym and the original sym.""" def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): exe = sym._simple_bind(ctx=mx.current_context(), grad_req='null') input_names = sym.list_inputs() for name in input_names: if name in exe.arg_dict: exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\ if original_exec is None else original_exec.arg_dict[name] else: assert name in exe.aux_dict exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\ if original_exec is None else original_exec.aux_dict[name] exe.forward() return exe sym, _, _ = sym original_exec = get_executor(sym) with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): check_call( _LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), c_str_array(op_names))) partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec) check_call( _LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend))) outputs1 = original_exec.outputs outputs2 = partitioned_exec.outputs assert len(outputs1) == len(outputs2) for i in range(len(outputs1)): assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1, )))
def test_new(*args, **kwargs): with environment(*args_): orig_test(*args, **kwargs)
def test_rsp_push_pull(): def check_rsp_push_pull(kv_type, sparse_pull, is_push_cpu=True): kv = init_kv_with_str('row_sparse', kv_type) kv.init('e', mx.nd.ones(shape).tostype('row_sparse')) push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)] kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs]) def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False): count = len(ctxs) num_rows = shape[0] row_ids = [] all_row_ids = np.arange(num_rows) vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)] if is_same_rowid: row_id = np.random.randint(num_rows, size=num_rows) row_ids = [mx.nd.array(row_id)] * count elif use_slice: total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows)) row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)] else: for i in range(count): row_id = np.random.randint(num_rows, size=num_rows) row_ids.append(mx.nd.array(row_id)) row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids vals_to_pull = vals[0] if len(vals) == 1 else vals kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull) for val, row_id in zip(vals, row_ids): retained = val.asnumpy() excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy()) for row in range(num_rows): expected_val = np.zeros_like(retained[row]) expected_val += 0 if row in excluded_row_ids else 2 assert_almost_equal(retained[row], expected_val) if sparse_pull is True: kv.pull('e', out=vals_to_pull, ignore_sparse=False) for val in vals: retained = val.asnumpy() expected_val = np.zeros_like(retained) expected_val[:] = 2 assert_almost_equal(retained, expected_val) check_rsp_pull(kv, [mx.gpu(0)], sparse_pull) check_rsp_pull(kv, [mx.cpu(0)], sparse_pull) check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull) check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, is_same_rowid=True) check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull) check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, is_same_rowid=True) check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True) check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True) envs = [None, '1'] key = 'MXNET_KVSTORE_USETREE' for val in envs: with environment(key, val): if val is '1': sparse_pull = False else: sparse_pull = True check_rsp_push_pull('local', sparse_pull) check_rsp_push_pull('device', sparse_pull) check_rsp_push_pull('device', sparse_pull, is_push_cpu=False)
def test_cuda_graphs(): class GraphTester(gluon.HybridBlock): def __init__(self, function_to_test, **kwargs): super(GraphTester, self).__init__(**kwargs) self.f = function_to_test() def forward(self, *args): # We need to isolate the operation to be fully inside the graph # in order for graphs usage to be possible copied_args = [mx.np.copy(a) for a in args] outputs = self.f(*copied_args) if isinstance(outputs, (list, tuple)): return [mx.np.copy(o) for o in outputs] else: return mx.np.copy(outputs) class TestDesc: def __init__(self, name, f, num_inputs=1, input_dim=4): self.name = name self.f = f self.num_inputs = num_inputs self.input_dim = input_dim def generate_inputs(self): shape = tuple(_np.random.randint(4, 11, size=self.input_dim)) ret = [mx.np.random.uniform(size=shape) for _ in range(self.num_inputs)] for r in ret: r.attach_grad() return ret tested_ops = [ TestDesc('add', lambda: (lambda x, y: x + y), num_inputs = 2), TestDesc('add_scalar', lambda: (lambda x: x + 0.5)), TestDesc('Conv', lambda: mx.gluon.nn.Conv2D(channels=32, kernel_size=(1,1))), TestDesc('ConvTranspose', lambda: mx.gluon.nn.Conv2DTranspose(channels=32, kernel_size=(1,1))), TestDesc('Dense', lambda: mx.gluon.nn.Dense(units=128)), TestDesc('Activation', lambda: mx.gluon.nn.Activation('tanh')), TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)), TestDesc('Flatten', lambda: mx.gluon.nn.Flatten()), TestDesc('MaxPool', lambda: mx.gluon.nn.MaxPool2D()), TestDesc('AvgPool', lambda: mx.gluon.nn.AvgPool2D()), TestDesc('GlobalMaxPool', lambda: mx.gluon.nn.GlobalMaxPool2D()), TestDesc('GlobalAvgPool', lambda: mx.gluon.nn.GlobalAvgPool2D()), TestDesc('ReflectionPad2D', lambda: mx.gluon.nn.ReflectionPad2D()), TestDesc('BatchNorm', lambda: mx.gluon.nn.BatchNorm()), TestDesc('InstanceNorm', lambda: mx.gluon.nn.InstanceNorm()), TestDesc('LayerNorm', lambda: mx.gluon.nn.LayerNorm()), TestDesc('LeakyReLU', lambda: mx.gluon.nn.LeakyReLU(0.1)), TestDesc('PReLU', lambda: mx.gluon.nn.PReLU()), TestDesc('ELU', lambda: mx.gluon.nn.ELU()), TestDesc('SELU', lambda: mx.gluon.nn.SELU()), TestDesc('Swish', lambda: mx.gluon.nn.Swish()), ] N = 10 with environment({'MXNET_ENABLE_CUDA_GRAPHS': '1', 'MXNET_USE_FUSION': '0'}): device = mx.gpu(0) for test_desc in tested_ops: print("Testing ", test_desc.name) inputs = test_desc.generate_inputs() inputsg = [i.copy() for i in inputs] for i in inputsg: i.attach_grad() seed = random.randint(0, 10000) net = GraphTester(test_desc.f) netg = GraphTester(test_desc.f) # initialize parameters net.initialize(device=device) netg.initialize(device=device) net(*inputs) for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()): p2.set_data(p1.data()) netg.hybridize(static_alloc=True, static_shape=True) print("Testing inference mode") with random_seed(seed): for _ in range(N): assert_almost_equal(net(*inputs), netg(*inputsg)) mx.npx.waitall() print("Testing training mode") for _ in range(N): with random_seed(seed): with mx.autograd.record(): out = net(*inputs) out.backward() with random_seed(seed): with mx.autograd.record(): outg = netg(*inputsg) outg.backward() assert_almost_equal(out, outg) for i, ig in zip(inputs, inputsg): assert_almost_equal(i.grad, ig.grad) for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()): assert_almost_equal(p1.data(), p2.data()) if p1.grad_req != 'null': assert_almost_equal(p1.grad(), p2.grad()) mx.npx.waitall()
def test_np_einsum(): class TestEinsum(HybridBlock): def __init__(self, subscripts, optimize): super(TestEinsum, self).__init__() self.subscripts = subscripts self.optimize = optimize def forward(self, *operands): return mx.np.einsum(self.subscripts, *operands, optimize=self.optimize) def dbg(name, data): print('type of {} = {}'.format(name, type(data))) print('shape of {} = {}'.format(name, data.shape)) print('{} = {}'.format(name, data)) configs = [ ('ii', [(5, 5)], lambda *args: (onp.eye(5), )), ('ii->i', [(5, 5)], lambda *args: (onp.eye(5), )), ('ij->i', [(5, 5)], lambda *args: (onp.ones((5, 5)), )), ('...j->...', [(5, 5)], lambda *args: (onp.ones((5, 5)), )), ('ji', [(2, 3)], lambda *args: (onp.ones((2, 3)), )), ('ij->ji', [(2, 3)], lambda *args: (onp.ones((2, 3)), )), ('ij, jk', [(5, 0), (0, 4)], lambda *args: (onp.empty( (5, 0)), onp.empty((0, 4)))), ('i, i', [(5, ), (5, )], lambda *args: (args[1], args[0])), ('ij, j', [(5, 5), (5, )], lambda *args: (onp.tile(args[1][None, :], [5, 1]), args[0].sum(axis=0))), ('...j, j', [(5, 5), (5, )], lambda *args: (onp.tile(args[1][None, :], [5, 1]), onp.sum(args[0], axis=0))), ('..., ...', [(), (2, 3)], lambda *args: (onp.sum(args[1], axis=None), args[0] * onp.ones((2, 3)))), (', ij', [(), (2, 3)], lambda *args: (onp.sum(args[1], axis=None), args[0] * onp.ones((2, 3)))), ('i, j', [(2, ), (5, )], lambda *args: (onp.sum(args[1], axis=None) * onp.ones(2), onp.sum( args[0], axis=None) * onp.ones(5))), ('ijk, jil->kl', [(3, 4, 5), (4, 3, 2)], lambda *args: (onp.tile( onp.transpose(onp.sum(args[1], axis=-1))[:, :, None], [1, 1, 5]), onp.tile( onp.transpose(onp.sum(args[0], axis=-1))[:, :, None], [1, 1, 2])) ), ('ijk, jil->kl', [(33, 44, 55), (44, 33, 22)], lambda *args: (onp.tile( onp.transpose(onp.sum(args[1], axis=-1))[:, :, None], [1, 1, 55]), onp.tile( onp.transpose(onp.sum(args[0], axis=-1))[:, :, None], [1, 1, 22]) )), ('ki, jk->ij', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]), onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))), ('ki, ...k->i...', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]), onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))), ('k..., jk', [(3, 2), (4, 3)], lambda *args: (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]), onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))), (('ij,jk'), [(2, 5), (5, 2)], lambda *args: (onp.dot(onp.ones( (2, 2)), args[1].T), onp.dot(args[0].T, onp.ones((2, 2))))), (('ij,jk,kl'), [(2, 2), (2, 5), (5, 2)], lambda *args: (onp.dot(onp.ones((2, 2)), onp.dot(args[1], args[2]).T), onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)), onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))), (('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)], lambda *args: (onp.dot(onp.ones((2, 2)), onp.dot(args[1], args[2]).T), onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)), onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))), (('ij,jk,kl->il'), [(67, 89), (89, 55), (55, 99)], lambda *args: (onp.dot(onp.ones((67, 99)), onp.dot(args[1], args[2]).T), onp.dot(args[0].T, onp.dot(onp.ones((67, 99)), args[2].T)), onp.dot(onp.dot(args[0], args[1]).T, onp.ones((67, 99))))), (('ij,jk,kl, lm->im'), [(12, 54), (54, 32), (32, 45), (45, 67)], lambda *args: (onp.dot(onp.ones((12, 67)), onp.dot(args[1], onp.dot(args[2], args[3])).T), onp.dot(args[0].T, onp.dot(onp.ones((12, 67)), onp.dot(args[2], args[3]).T)), onp.dot( onp.dot(args[0], args[1]).T, onp.dot(onp.ones((12, 67)), args[3].T)), onp.dot( onp.dot(args[0], onp.dot(args[1], args[2])).T, onp.ones( (12, 67))))), # broadcast axis ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (onp.sum(args[1], axis=0)[None, :], onp.tile(args[0], [2, 1]))), ('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args: (args[1].sum(axis=1)[None, :], onp.tile(args[0].sum(axis=0)[:, None], [1, 2]))), ('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (onp.tile(args[1].sum(axis=1)[None, :], [2, 1]), onp.tile(args[0].sum(axis=0)[:, None], [1, 2]))), ('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args: (args[1].sum(axis=3)[:, :, None, :], onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), ('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args: (onp.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]), onp.tile(args[0].sum(axis=3)[:, :, :, None], [1, 1, 1, 3]))), ('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args: (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), ('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args: (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), ('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args: (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), # test with cuTensor using workspace (('ij,jk,kl->il'), [(64, 200), (200, 64), (64, 64)], lambda *args: (onp.dot(onp.ones((64, 64)), onp.dot(args[1], args[2]).T), onp.dot(args[0].T, onp.dot(onp.ones((64, 64)), args[2].T)), onp.dot(onp.dot(args[0], args[1]).T, onp.ones((64, 64))))) ] dtypes = ['float16', 'float32', 'float64', 'int32'] for hybridize in [False, True]: for cache_setting in ['0', '1', None]: for dtype in dtypes: for config in configs: for optimize in [False, True]: with environment('MXNET_CUTENSOR_CACHEFILE', cache_setting): rtol = 1e-1 if dtype == 'float16' else 1e-3 atol = 1e-1 if dtype == 'float16' else 1e-4 (subscripts, operands, get_grad) = config test_einsum = TestEinsum(subscripts, optimize) if hybridize: test_einsum.hybridize() x = [] x_np = [] for shape in operands: tmp = onp.array(onp.random.uniform( -0.3, 0.3, shape), dtype=dtype) x_np.append(tmp) x.append(np.array(tmp, dtype=dtype)) x[-1].attach_grad() expected_np = onp.einsum(subscripts, *x_np, optimize=False, dtype=dtype).astype(dtype) with mx.autograd.record(): out_mx = test_einsum(*x) assert out_mx.shape == expected_np.shape assert_almost_equal(out_mx.asnumpy(), expected_np, rtol=rtol, atol=atol) out_mx.backward() for (iop, op) in enumerate(x): assert_almost_equal(op.grad.asnumpy(), get_grad(*x_np)[iop], rtol=rtol, atol=atol)