Esempio n. 1
0
def test_data_dir():
    prev_data_dir = data_dir()
    system = platform.system()
    # Test that data_dir() returns the proper default value when MXNET_HOME is not set
    with environment('MXNET_HOME', None):
        if system == 'Windows':
            assert_equal(data_dir(), op.join(os.environ.get('APPDATA'),
                                             'mxnet'))
        else:
            assert_equal(data_dir(), op.join(op.expanduser('~'), '.mxnet'))
    # Test that data_dir() responds to an explicit setting of MXNET_HOME
    with environment('MXNET_HOME', '/tmp/mxnet_data'):
        assert_equal(data_dir(), '/tmp/mxnet_data')
    # Test that this test has not disturbed the MXNET_HOME value existing before the test
    assert_equal(data_dir(), prev_data_dir)
Esempio n. 2
0
def test_environment():
    name1 = 'MXNET_TEST_ENV_VAR_1'
    name2 = 'MXNET_TEST_ENV_VAR_2'

    # Test that a variable can be set in the python and backend environment
    with environment(name1, '42'):
        assert_equal(os.environ.get(name1), '42')
        assert_equal(getenv(name1), '42')

    # Test dict form of invocation
    env_var_dict = {name1: '1', name2: '2'}
    with environment(env_var_dict):
        for key, value in env_var_dict.items():
            assert_equal(os.environ.get(key), value)
            assert_equal(getenv(key), value)
Esempio n. 3
0
def test_engine_openmp_after_fork():
    """
    Test that the number of max threads in the child is 1. After forking we should not use a bigger
    OMP thread pool.

    With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP
    the child respects the number of max threads set in the parent.
    """
    with environment('OMP_NUM_THREADS', '42'):
        r, w = os.pipe()
        pid = os.fork()
        if pid:
            os.close(r)
            wfd = os.fdopen(w, 'w')
            wfd.write('a')
            omp_max_threads = mx.base._LIB.omp_get_max_threads()
            print("Parent omp max threads: {}".format(omp_max_threads))
            try:
                wfd.close()
            except:
                pass
            try:
                (cpid, status) = os.waitpid(pid, 0)
                assert cpid == pid
                exit_status = status >> 8
                assert exit_status == 0
            except:
                pass
        else:
            os.close(w)
            rfd = os.fdopen(r, 'r')
            rfd.read(1)
            omp_max_threads = mx.base._LIB.omp_get_max_threads()
            print("Child omp max threads: {}".format(omp_max_threads))
            assert omp_max_threads == 1
Esempio n. 4
0
def test_device_pushpull():
    def check_dense_pushpull(kv_type):
        for shape, key in zip(shapes, keys):
            for n_gpus in gpus:
                kv_device = mx.kv.create(kv_type)
                a = mx.nd.ones(shape, mx.gpu(0))
                cur_key = str(key * max(gpus) + n_gpus)
                kv_device.init(cur_key, a)
                arr_list = [
                    mx.nd.ones(shape, mx.gpu(x)) for x in range(n_gpus)
                ]
                res = [mx.nd.zeros(shape, mx.gpu(x)) for x in range(n_gpus)]
                kv_device.push(cur_key, arr_list)
                kv_device.pull(cur_key, res)
                for x in range(n_gpus):
                    assert (np.sum(np.abs((res[x] - n_gpus).asnumpy())) == 0)

    kvstore_tree_array_bound_values = [None, '1']
    kvstore_usetree_values = [None, '1']
    for y in kvstore_tree_array_bound_values:
        for x in kvstore_usetree_values:
            with environment({
                    'MXNET_KVSTORE_USETREE': x,
                    'MXNET_KVSTORE_TREE_ARRAY_BOUND': y
            }):
                check_dense_pushpull('local')
                check_dense_pushpull('device')
Esempio n. 5
0
def test_bind():
    for enable_bulking in ['0', '1']:
        with environment({'MXNET_EXEC_BULK_EXEC_INFERENCE': enable_bulking,
                          'MXNET_EXEC_BULK_EXEC_TRAIN': enable_bulking}):
            nrepeat = 10
            maxdim = 4
            for _ in range(nrepeat):
                for dim in range(1, maxdim):
                    check_bind_with_uniform(lambda x, y: x + y,
                                            lambda g, x, y: (g, g),
                                            dim)
                    check_bind_with_uniform(lambda x, y: x - y,
                                            lambda g, x, y: (g, -g),
                                            dim)
                    check_bind_with_uniform(lambda x, y: x * y,
                                            lambda g, x, y: (y * g, x * g),
                                            dim)
                    check_bind_with_uniform(lambda x, y: x / y,
                                            lambda g, x, y: (g / y, -x * g/ (y**2)),
                                            dim)

                    check_bind_with_uniform(lambda x, y: np.maximum(x, y),
                                            lambda g, x, y: (g * (x>=y), g * (y>x)),
                                            dim,
                                            sf=mx.symbol.maximum)
                    check_bind_with_uniform(lambda x, y: np.minimum(x, y),
                                            lambda g, x, y: (g * (x<=y), g * (y<x)),
                                            dim,
                                            sf=mx.symbol.minimum)
Esempio n. 6
0
 def test_one_var(name, value, raise_exception=False):
     try:
         with environment(name, value):
             assert_equal(os.environ.get(name), value)
             assert_equal(getenv(name), value)
             if raise_exception:
                 raise OnPurposeError
     except OnPurposeError:
         pass
     finally:
         check_background_values()
Esempio n. 7
0
def test_engine_import():
    import mxnet
    # Temporarily add an illegal entry (that is not caught) to show how the test needs improving
    engine_types = [
        None, 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice',
        'BogusEngine'
    ]

    for type in engine_types:
        with environment('MXNET_ENGINE_TYPE', type):
            reload(mxnet)
def test_gemms_true_fp16():
    ctx = mx.gpu(0)
    input = mx.nd.random.uniform(shape=(1, 512), dtype='float16', ctx=ctx)
    weights = mx.nd.random.uniform(shape=(128, 512), ctx=ctx)

    net = nn.Dense(128, in_units=512, use_bias=False)
    net.cast('float16')
    net.initialize(ctx=ctx)
    net.weight.set_data(weights)

    with environment('MXNET_FC_TRUE_FP16', '0'):
      ref_results = net(input)

    with environment('MXNET_FC_TRUE_FP16', '1'):
      results_trueFP16 = net(input)

    atol = 1e-2
    rtol = 1e-2
    assert_almost_equal(ref_results.asnumpy(), results_trueFP16.asnumpy(),
                        atol=atol, rtol=rtol)
 def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs):
     inputs = sym.list_inputs()
     shapes = {inp : kwargs[inp].shape for inp in inputs}
     rtol = {'float16' : 1e-2,
             'float32' : 1.5e-6,
             'float64' : 1.5e-6,
             }
     atol = {'float16' : 1e-3,
             'float32' : 1e-7,
             'float64' : 1e-7,
             }
     for dtype in ['float16', 'float32', 'float64']:
         data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
         for grad_req in ['write', 'add']:
             type_dict = {inp : dtype for inp in inputs}
             with environment({'MXNET_ELIMINATE_COMMON_EXPR': '0'}):
                 orig_exec = sym._simple_bind(ctx=mx.cpu(0), grad_req=grad_req,
                                             type_dict=type_dict, **shapes)
             with environment({'MXNET_ELIMINATE_COMMON_EXPR': '1'}):
                 cse_exec = sym._simple_bind(ctx=mx.cpu(0), grad_req=grad_req,
                                            type_dict=type_dict, **shapes)
             fwd_orig = orig_exec.forward(is_train=True, **data)
             out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
             orig_exec.backward(out_grads=out_grads)
             fwd_cse = cse_exec.forward(is_train=True, **data)
             cse_exec.backward(out_grads=out_grads)
             if check_data:
                 for orig, cse in zip(fwd_orig, fwd_cse):
                     np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(),
                                                rtol=rtol[dtype], atol=atol[dtype])
                 for orig, cse in zip(orig_exec.grad_arrays, cse_exec.grad_arrays):
                     if orig is None and cse is None:
                         continue
                     assert orig is not None
                     assert cse is not None
                     np.testing.assert_allclose(orig.asnumpy(), cse.asnumpy(),
                                                rtol=rtol[dtype], atol=atol[dtype])
             orig_sym_internals = orig_exec.get_optimized_symbol().get_internals()
             cse_sym_internals = cse_exec.get_optimized_symbol().get_internals()
             # test that the graph has been simplified as expected
             assert (len(cse_sym_internals) + expected_savings) == len(orig_sym_internals)
def test_unary_func():
    def check_unary_func(x):
        f_exp         = lambda x: nd.exp(x)
        f_exp_grad    = lambda x: [nd.exp(x)]
        autograd_assert(x, func=f_exp, grad_func=f_exp_grad)
        f_half        = lambda x: x/2
        f_half_grad   = lambda x: [nd.ones(x.shape) * 0.5]
        autograd_assert(x, func=f_half, grad_func=f_half_grad)
        f_square      = lambda x: x**2
        f_square_grad = lambda x: [2*x]
        autograd_assert(x, func=f_square, grad_func=f_square_grad)
    uniform = nd.uniform(shape=(4, 5))
    stypes = ['default', 'row_sparse', 'csr']
    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
        for stype in stypes:
            check_unary_func(uniform.tostype(stype))
Esempio n. 11
0
def test_subgraph_exe4(sym, subgraph_backend, op_names):
    """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind
    and compare results of the partitioned sym and the original sym."""
    def get_executor(sym,
                     subgraph_backend=None,
                     op_names=None,
                     original_exec=None):
        arg_shapes, _, aux_shapes = sym.infer_shape()
        if subgraph_backend is None:
            arg_array = [
                mx.nd.random.uniform(shape=shape) for shape in arg_shapes
            ]
            aux_array = [
                mx.nd.random.uniform(shape=shape) for shape in aux_shapes
            ]
        else:
            arg_array = None
            aux_array = None
        exe = sym._bind(ctx=mx.current_context(),
                        args=arg_array if subgraph_backend is None else
                        original_exec.arg_arrays,
                        aux_states=aux_array if subgraph_backend is None else
                        original_exec.aux_arrays,
                        grad_req='null')
        exe.forward()
        return exe

    sym, _, _ = sym
    original_exec = get_executor(sym)
    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
        check_call(
            _LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend),
                                              mx_uint(len(op_names)),
                                              c_str_array(op_names)))
        partitioned_exec = get_executor(sym, subgraph_backend, op_names,
                                        original_exec)
        check_call(
            _LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
    outputs1 = original_exec.outputs
    outputs2 = partitioned_exec.outputs
    assert len(outputs1) == len(outputs2)
    for i in range(len(outputs1)):
        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(),
                            onp.zeros(shape=(1, )))
def test_binary_func():
    def check_binary_func(x, y):
        f_add      = lambda x, y: x+y
        f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)]
        autograd_assert(x, y, func=f_add, grad_func=f_add_grad)
        f_mul      = lambda x, y: x*y
        f_mul_grad = lambda x, y: [y, x]
        autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad)
        f_compose  = lambda x, y: x+x*y
        f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x]
        autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad)
    uniform_x = nd.uniform(shape=(4, 5))
    uniform_y = nd.uniform(shape=(4, 5))
    stypes = ['default', 'row_sparse', 'csr']
    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
        for stype_x in stypes:
            for stype_y in stypes:
                x = uniform_x.tostype(stype_x)
                y = uniform_y.tostype(stype_y)
                check_binary_func(x, y)
def run_in_spawned_process(func, env, *args):
    """
    Helper function to run a test in its own process.

    Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend.
    Adds a seed as first arg to propagate determinism.

    Parameters
    ----------

    func : function to run in a spawned process.
    env : dict of additional environment values to set temporarily in the environment before exec.
    args : args to pass to the function.

    Returns
    -------
    Whether the python version supports running the function as a spawned process.

    This routine calculates a random seed and passes it into the test as a first argument.  If the
    test uses random values, it should include an outer 'with random_seed(seed):'.  If the
    test needs to return values to the caller, consider use of shared variable arguments.
    """
    try:
        mpctx = mp.get_context('spawn')
    except:
        print(
            'SKIP: python%s.%s lacks the required process fork-exec support ... '
            % sys.version_info[0:2],
            file=sys.stderr,
            end='')
        return False
    else:
        seed = np.random.randint(0, 1024 * 1024 * 1024)
        with environment(env):
            # Prepend seed as first arg
            p = mpctx.Process(target=func, args=(seed, ) + args)
            p.start()
            p.join()
            assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (
                p.exitcode, func.__name__)
    return True
Esempio n. 14
0
def test_subgraph_exe2(sym, subgraph_backend, op_names):
    """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in _simple_bind
    and compare results of the partitioned sym and the original sym."""
    def get_executor(sym,
                     subgraph_backend=None,
                     op_names=None,
                     original_exec=None):
        exe = sym._simple_bind(ctx=mx.current_context(), grad_req='null')
        input_names = sym.list_inputs()
        for name in input_names:
            if name in exe.arg_dict:
                exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
                    if original_exec is None else original_exec.arg_dict[name]
            else:
                assert name in exe.aux_dict
                exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
                    if original_exec is None else original_exec.aux_dict[name]
        exe.forward()
        return exe

    sym, _, _ = sym

    original_exec = get_executor(sym)
    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
        check_call(
            _LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend),
                                              mx_uint(len(op_names)),
                                              c_str_array(op_names)))
        partitioned_exec = get_executor(sym, subgraph_backend, op_names,
                                        original_exec)
        check_call(
            _LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
    outputs1 = original_exec.outputs
    outputs2 = partitioned_exec.outputs
    assert len(outputs1) == len(outputs2)
    for i in range(len(outputs1)):
        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(),
                            np.zeros(shape=(1, )))
Esempio n. 15
0
File: common.py Progetto: tlby/mxnet
 def test_new(*args, **kwargs):
     with environment(*args_):
         orig_test(*args, **kwargs)
def test_rsp_push_pull():
    def check_rsp_push_pull(kv_type, sparse_pull, is_push_cpu=True):
        kv = init_kv_with_str('row_sparse', kv_type)
        kv.init('e', mx.nd.ones(shape).tostype('row_sparse'))
        push_ctxs = [mx.cpu(i) if is_push_cpu else mx.gpu(i) for i in range(2)]
        kv.push('e', [mx.nd.ones(shape, ctx=context).tostype('row_sparse') for context in push_ctxs])

        def check_rsp_pull(kv, ctxs, sparse_pull, is_same_rowid=False, use_slice=False):
            count = len(ctxs)
            num_rows = shape[0]
            row_ids = []
            all_row_ids = np.arange(num_rows)
            vals = [mx.nd.sparse.zeros(shape=shape, ctx=ctxs[i], stype='row_sparse') for i in range(count)]
            if is_same_rowid:
                row_id = np.random.randint(num_rows, size=num_rows)
                row_ids = [mx.nd.array(row_id)] * count
            elif use_slice:
                total_row_ids = mx.nd.array(np.random.randint(num_rows, size=count*num_rows))
                row_ids = [total_row_ids[i*num_rows : (i+1)*num_rows] for i in range(count)]
            else:
                for i in range(count):
                    row_id = np.random.randint(num_rows, size=num_rows)
                    row_ids.append(mx.nd.array(row_id))
            row_ids_to_pull = row_ids[0] if (len(row_ids) == 1 or is_same_rowid) else row_ids
            vals_to_pull = vals[0] if len(vals) == 1 else vals

            kv.row_sparse_pull('e', out=vals_to_pull, row_ids=row_ids_to_pull)
            for val, row_id in zip(vals, row_ids):
                retained = val.asnumpy()
                excluded_row_ids = np.setdiff1d(all_row_ids, row_id.asnumpy())
                for row in range(num_rows):
                    expected_val = np.zeros_like(retained[row])
                    expected_val += 0 if row in excluded_row_ids else 2
                    assert_almost_equal(retained[row], expected_val)

            if sparse_pull is True:
                kv.pull('e', out=vals_to_pull, ignore_sparse=False)
                for val in vals:
                    retained = val.asnumpy()
                    expected_val = np.zeros_like(retained)
                    expected_val[:] = 2
                    assert_almost_equal(retained, expected_val)

        check_rsp_pull(kv, [mx.gpu(0)], sparse_pull)
        check_rsp_pull(kv, [mx.cpu(0)], sparse_pull)
        check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull)
        check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, is_same_rowid=True)
        check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull)
        check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, is_same_rowid=True)
        check_rsp_pull(kv, [mx.gpu(i//2) for i in range(4)], sparse_pull, use_slice=True)
        check_rsp_pull(kv, [mx.cpu(i) for i in range(4)], sparse_pull, use_slice=True)

    envs = [None, '1']
    key  = 'MXNET_KVSTORE_USETREE'
    for val in envs:
        with environment(key, val):
            if val is '1':
                sparse_pull = False
            else:
                sparse_pull = True
            check_rsp_push_pull('local', sparse_pull)
            check_rsp_push_pull('device', sparse_pull)
            check_rsp_push_pull('device', sparse_pull, is_push_cpu=False)
Esempio n. 17
0
def test_cuda_graphs():
    class GraphTester(gluon.HybridBlock):
        def __init__(self, function_to_test, **kwargs):
            super(GraphTester, self).__init__(**kwargs)
            self.f = function_to_test()

        def forward(self, *args):
            # We need to isolate the operation to be fully inside the graph
            # in order for graphs usage to be possible
            copied_args = [mx.np.copy(a) for a in args]
            outputs = self.f(*copied_args)
            if isinstance(outputs, (list, tuple)):
                return [mx.np.copy(o) for o in outputs]
            else:
                return mx.np.copy(outputs)

    class TestDesc:
        def __init__(self, name, f, num_inputs=1, input_dim=4):
            self.name = name
            self.f = f
            self.num_inputs = num_inputs
            self.input_dim = input_dim

        def generate_inputs(self):
            shape = tuple(_np.random.randint(4, 11, size=self.input_dim))
            ret = [mx.np.random.uniform(size=shape) for _ in range(self.num_inputs)]
            for r in ret:
                r.attach_grad()
            return ret

    tested_ops = [
            TestDesc('add', lambda: (lambda x, y: x + y), num_inputs = 2),
            TestDesc('add_scalar', lambda: (lambda x: x + 0.5)),
            TestDesc('Conv', lambda: mx.gluon.nn.Conv2D(channels=32, kernel_size=(1,1))),
            TestDesc('ConvTranspose', lambda: mx.gluon.nn.Conv2DTranspose(channels=32, kernel_size=(1,1))),
            TestDesc('Dense', lambda: mx.gluon.nn.Dense(units=128)),
            TestDesc('Activation', lambda: mx.gluon.nn.Activation('tanh')),
            TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)),
            TestDesc('Flatten', lambda: mx.gluon.nn.Flatten()),
            TestDesc('MaxPool', lambda: mx.gluon.nn.MaxPool2D()),
            TestDesc('AvgPool', lambda: mx.gluon.nn.AvgPool2D()),
            TestDesc('GlobalMaxPool', lambda: mx.gluon.nn.GlobalMaxPool2D()),
            TestDesc('GlobalAvgPool', lambda: mx.gluon.nn.GlobalAvgPool2D()),
            TestDesc('ReflectionPad2D', lambda: mx.gluon.nn.ReflectionPad2D()),
            TestDesc('BatchNorm', lambda: mx.gluon.nn.BatchNorm()),
            TestDesc('InstanceNorm', lambda: mx.gluon.nn.InstanceNorm()),
            TestDesc('LayerNorm', lambda: mx.gluon.nn.LayerNorm()),
            TestDesc('LeakyReLU', lambda: mx.gluon.nn.LeakyReLU(0.1)),
            TestDesc('PReLU', lambda: mx.gluon.nn.PReLU()),
            TestDesc('ELU', lambda: mx.gluon.nn.ELU()),
            TestDesc('SELU', lambda: mx.gluon.nn.SELU()),
            TestDesc('Swish', lambda: mx.gluon.nn.Swish()),
        ]

    N = 10

    with environment({'MXNET_ENABLE_CUDA_GRAPHS': '1',
                      'MXNET_USE_FUSION': '0'}):
        device = mx.gpu(0)
        for test_desc in tested_ops:
            print("Testing ", test_desc.name)
            inputs = test_desc.generate_inputs()
            inputsg = [i.copy() for i in inputs]
            for i in inputsg:
                i.attach_grad()
            seed = random.randint(0, 10000)
            net = GraphTester(test_desc.f)
            netg = GraphTester(test_desc.f)

            # initialize parameters
            net.initialize(device=device)
            netg.initialize(device=device)

            net(*inputs)

            for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()):
                p2.set_data(p1.data())

            netg.hybridize(static_alloc=True, static_shape=True)

            print("Testing inference mode")
            with random_seed(seed):
                for _ in range(N):
                    assert_almost_equal(net(*inputs), netg(*inputsg))

            mx.npx.waitall()
            print("Testing training mode")
            for _ in range(N):
                with random_seed(seed):
                    with mx.autograd.record():
                        out = net(*inputs)
                    out.backward()

                with random_seed(seed):
                    with mx.autograd.record():
                        outg = netg(*inputsg)
                    outg.backward()

                assert_almost_equal(out, outg)
                for i, ig in zip(inputs, inputsg):
                    assert_almost_equal(i.grad, ig.grad)

                for p1, p2 in zip(net.collect_params().values(), netg.collect_params().values()):
                    assert_almost_equal(p1.data(), p2.data())
                    if p1.grad_req != 'null':
                        assert_almost_equal(p1.grad(), p2.grad())
            mx.npx.waitall()
Esempio n. 18
0
def test_np_einsum():
    class TestEinsum(HybridBlock):
        def __init__(self, subscripts, optimize):
            super(TestEinsum, self).__init__()
            self.subscripts = subscripts
            self.optimize = optimize

        def forward(self, *operands):
            return mx.np.einsum(self.subscripts,
                                *operands,
                                optimize=self.optimize)

    def dbg(name, data):
        print('type of {} = {}'.format(name, type(data)))
        print('shape of {} = {}'.format(name, data.shape))
        print('{} = {}'.format(name, data))

    configs = [
        ('ii', [(5, 5)], lambda *args: (onp.eye(5), )),
        ('ii->i', [(5, 5)], lambda *args: (onp.eye(5), )),
        ('ij->i', [(5, 5)], lambda *args: (onp.ones((5, 5)), )),
        ('...j->...', [(5, 5)], lambda *args: (onp.ones((5, 5)), )),
        ('ji', [(2, 3)], lambda *args: (onp.ones((2, 3)), )),
        ('ij->ji', [(2, 3)], lambda *args: (onp.ones((2, 3)), )),
        ('ij, jk', [(5, 0), (0, 4)], lambda *args: (onp.empty(
            (5, 0)), onp.empty((0, 4)))),
        ('i, i', [(5, ), (5, )], lambda *args: (args[1], args[0])),
        ('ij, j', [(5, 5), (5, )], lambda *args:
         (onp.tile(args[1][None, :], [5, 1]), args[0].sum(axis=0))),
        ('...j, j', [(5, 5), (5, )], lambda *args:
         (onp.tile(args[1][None, :], [5, 1]), onp.sum(args[0], axis=0))),
        ('..., ...', [(), (2, 3)], lambda *args:
         (onp.sum(args[1], axis=None), args[0] * onp.ones((2, 3)))),
        (', ij', [(), (2, 3)], lambda *args:
         (onp.sum(args[1], axis=None), args[0] * onp.ones((2, 3)))),
        ('i, j', [(2, ), (5, )], lambda *args:
         (onp.sum(args[1], axis=None) * onp.ones(2), onp.sum(
             args[0], axis=None) * onp.ones(5))),
        ('ijk, jil->kl', [(3, 4, 5), (4, 3, 2)], lambda *args:
         (onp.tile(
             onp.transpose(onp.sum(args[1], axis=-1))[:, :, None], [1, 1, 5]),
          onp.tile(
              onp.transpose(onp.sum(args[0], axis=-1))[:, :, None], [1, 1, 2]))
         ),
        ('ijk, jil->kl', [(33, 44, 55), (44, 33, 22)], lambda *args:
         (onp.tile(
             onp.transpose(onp.sum(args[1], axis=-1))[:, :, None], [1, 1, 55]),
          onp.tile(
              onp.transpose(onp.sum(args[0], axis=-1))[:, :, None], [1, 1, 22])
          )),
        ('ki, jk->ij', [(3, 2), (4, 3)], lambda *args:
         (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
          onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
        ('ki, ...k->i...', [(3, 2), (4, 3)], lambda *args:
         (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
          onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
        ('k..., jk', [(3, 2), (4, 3)], lambda *args:
         (onp.tile(args[1].sum(axis=0)[:, None], [1, 2]),
          onp.tile(args[0].sum(axis=1)[None, :], [4, 1]))),
        (('ij,jk'), [(2, 5), (5, 2)], lambda *args:
         (onp.dot(onp.ones(
             (2, 2)), args[1].T), onp.dot(args[0].T, onp.ones((2, 2))))),
        (('ij,jk,kl'), [(2, 2), (2, 5), (5, 2)], lambda *args:
         (onp.dot(onp.ones((2, 2)),
                  onp.dot(args[1], args[2]).T),
          onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)),
          onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))),
        (('ij,jk,kl->il'), [(2, 2), (2, 5), (5, 2)], lambda *args:
         (onp.dot(onp.ones((2, 2)),
                  onp.dot(args[1], args[2]).T),
          onp.dot(args[0].T, onp.dot(onp.ones((2, 2)), args[2].T)),
          onp.dot(onp.dot(args[0], args[1]).T, onp.ones((2, 2))))),
        (('ij,jk,kl->il'), [(67, 89), (89, 55), (55, 99)], lambda *args:
         (onp.dot(onp.ones((67, 99)),
                  onp.dot(args[1], args[2]).T),
          onp.dot(args[0].T, onp.dot(onp.ones((67, 99)), args[2].T)),
          onp.dot(onp.dot(args[0], args[1]).T, onp.ones((67, 99))))),
        (('ij,jk,kl, lm->im'), [(12, 54), (54, 32), (32, 45),
                                (45, 67)], lambda *args:
         (onp.dot(onp.ones((12, 67)),
                  onp.dot(args[1], onp.dot(args[2], args[3])).T),
          onp.dot(args[0].T,
                  onp.dot(onp.ones((12, 67)),
                          onp.dot(args[2], args[3]).T)),
          onp.dot(
              onp.dot(args[0], args[1]).T,
              onp.dot(onp.ones((12, 67)), args[3].T)),
          onp.dot(
              onp.dot(args[0], onp.dot(args[1], args[2])).T, onp.ones(
                  (12, 67))))),

        # broadcast axis
        ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args:
         (onp.sum(args[1], axis=0)[None, :], onp.tile(args[0], [2, 1]))),
        ('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args:
         (args[1].sum(axis=1)[None, :],
          onp.tile(args[0].sum(axis=0)[:, None], [1, 2]))),
        ('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args:
         (onp.tile(args[1].sum(axis=1)[None, :], [2, 1]),
          onp.tile(args[0].sum(axis=0)[:, None], [1, 2]))),
        ('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args:
         (args[1].sum(axis=3)[:, :, None, :],
          onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
        ('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args:
         (onp.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]),
          onp.tile(args[0].sum(axis=3)[:, :, :, None], [1, 1, 1, 3]))),
        ('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args:
         (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
          onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
        ('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args:
         (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
          onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
        ('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args:
         (onp.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]),
          onp.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))),
        # test with cuTensor using workspace
        (('ij,jk,kl->il'), [(64, 200), (200, 64), (64, 64)], lambda *args:
         (onp.dot(onp.ones((64, 64)),
                  onp.dot(args[1], args[2]).T),
          onp.dot(args[0].T, onp.dot(onp.ones((64, 64)), args[2].T)),
          onp.dot(onp.dot(args[0], args[1]).T, onp.ones((64, 64)))))
    ]

    dtypes = ['float16', 'float32', 'float64', 'int32']
    for hybridize in [False, True]:
        for cache_setting in ['0', '1', None]:
            for dtype in dtypes:
                for config in configs:
                    for optimize in [False, True]:
                        with environment('MXNET_CUTENSOR_CACHEFILE',
                                         cache_setting):
                            rtol = 1e-1 if dtype == 'float16' else 1e-3
                            atol = 1e-1 if dtype == 'float16' else 1e-4
                            (subscripts, operands, get_grad) = config
                            test_einsum = TestEinsum(subscripts, optimize)
                            if hybridize:
                                test_einsum.hybridize()
                            x = []
                            x_np = []
                            for shape in operands:
                                tmp = onp.array(onp.random.uniform(
                                    -0.3, 0.3, shape),
                                                dtype=dtype)
                                x_np.append(tmp)
                                x.append(np.array(tmp, dtype=dtype))
                                x[-1].attach_grad()
                            expected_np = onp.einsum(subscripts,
                                                     *x_np,
                                                     optimize=False,
                                                     dtype=dtype).astype(dtype)
                            with mx.autograd.record():
                                out_mx = test_einsum(*x)
                            assert out_mx.shape == expected_np.shape
                            assert_almost_equal(out_mx.asnumpy(),
                                                expected_np,
                                                rtol=rtol,
                                                atol=atol)
                            out_mx.backward()
                            for (iop, op) in enumerate(x):
                                assert_almost_equal(op.grad.asnumpy(),
                                                    get_grad(*x_np)[iop],
                                                    rtol=rtol,
                                                    atol=atol)