예제 #1
0
def test_auto_parallel_l2normalize():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.norm1 = P.L2Normalize()
            self.norm2 = P.L2Normalize()
            self.mul1 = P.Mul()
            self.mul2 = P.Mul()

        def construct(self, x, y, b):
            y = self.norm1(y)
            x = self.norm2(x)
            out = self.mul1(x, y)
            out = self.mul2(out, b)
            return out

    context.set_auto_parallel_context(device_num=8, global_rank=0)
    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
    y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
    b = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
    _executor.compile(net, x, y, b, phase='train')
def test_auto_parallel_arithmetic_broadcast_both():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul = P.MatMul()
            self.floordiv = P.FloorDiv()

        def construct(self, x, y, b):
            out = self.matmul(x, y)
            out = self.floordiv(out, b)
            return out

    context.set_auto_parallel_context(device_num=8, global_rank=0)
    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    reset_op_id()

    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
    compile_net(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('FloorDiv-op', k) is not None:
            assert v == [[8, 1], [1, 1]]
        elif re.search('MatMul-op', k) is not None:
            assert v == [[8, 1], [1, 1]]
def test_auto_parallel_arithmetic_broadcast_both():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul = P.MatMul()
            self.floordiv = P.FloorDiv()

        def construct(self, x, y, b):
            out = self.matmul(x, y)
            out = self.floordiv(out, b)
            return out

    context.set_auto_parallel_context(device_num=8, global_rank=0)
    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    reset_op_id()

    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 1]), dtype=ms.float32)
    b = Tensor(np.ones([1, 64]), dtype=ms.float32)
    compile_net(net, x, y, b, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {
        'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
        'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]
    }
    assert strategies == expected_strategies
def test_matmul_prelu():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.mul1 = P.Mul()
            self.prelu = P.PReLU()

        def construct(self, x, y, b):
            out = self.mul1(x, y)
            out = self.prelu(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    x = Tensor(np.ones([16, 3, 128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([16, 3, 128, 32]), dtype=ms.float32)
    b = Tensor(np.array([0.01, 0.02, 0.03]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_strategy(net)
    assert strategies['Default/network-Net/PReLU-op2'] == [[16, 1, 1, 1], [1]]
    assert strategies['Default/network-Net/Mul-op3'] == [[16, 1, 1, 1],
                                                         [16, 1, 1, 1]]
예제 #5
0
def test_two_matmul_transpose():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.transpose1 = P.Transpose()
            self.transpose2 = P.Transpose()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            out = self.transpose1(out, (1, 0))
            out = self.transpose2(out, (1, 0))
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]],
                           'Default/network-Net/Transpose-op1': [[16, 1]],
                           'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
                           'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
    assert strategies == expected_strategies
예제 #6
0
def test_triangle_strategy_consistency():
    class Net(nn.Cell):
        def __init__(self):
            super(Net, self).__init__()
            self.mul1 = P.Mul().shard(((2, 4), (2, 4)))
            self.mul2 = P.Mul()
            self.ba1 = P.BiasAdd()
            self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight")
            self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias")
            self.add = P.TensorAdd().shard(((1, 8), (1, 8)))
            self.relu = P.ReLU()

        def construct(self, x):
            out = self.mul1(x, self.weight)
            mul_out = self.mul2(out, self.weight)
            ba_out = self.ba1(out, self.bias)
            ta_out = self.add(mul_out, ba_out)
            out = self.relu(ta_out)
            return out

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, phase='train')
예제 #7
0
def test_matmul_prelu():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.mul1 = P.Mul()
            self.prelu = P.PReLU()

        def construct(self, x, y, b):
            out = self.mul1(x, y)
            out = self.prelu(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    x = Tensor(np.ones([16, 3, 128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([16, 3, 128, 32]), dtype=ms.float32)
    b = Tensor(np.array([0.01, 0.02, 0.03]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_strategy(net)
    for (k, v) in strategies.items():
        if re.search('PReLU-op', k) is not None:
            assert v == [[16, 1, 1, 1], [1]]
        elif re.search('Mul-op', k) is not None:
            assert v == [[16, 1, 1, 1], [16, 1, 1, 1]]
def test_star_strategy_consistency2():
    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    set_algo_parameters(fully_use_devices=False)
    x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
    strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)),
                     "relu2": ((2, 2),), "add": ((8, 1), (8, 1))}
    net = NetWithLoss(Net(strategy_dict))
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()
    net.set_train()
    _executor.compile(net, x, phase='train')
예제 #9
0
def test_double_subgraphs():
    context.set_context(save_graphs=False)
    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
    net = TrainStepWarp(NetWithLoss(Net()))
    _set_multi_subgraphs()
    net.set_auto_parallel()

    x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
    reset_op_id()
    net.set_train()
    _executor.compile(net, x, phase='train')
    num_ops = _executor._get_num_parallel_ops(net)
    expected_num = 7
    assert expected_num == num_ops
예제 #10
0
def test_common_parameter():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.matmul3 = P.MatMul()
            self.weight1 = Parameter(Tensor(
                np.ones([64, 64]).astype(np.float16) * 0.01),
                                     "w",
                                     requires_grad=True)
            self.cast1 = P.Cast()
            self.cast2 = P.Cast()

        def construct(self, x, y, z, w):
            m1_result = self.matmul1(x, self.cast1(self.weight1,
                                                   mstype.float32))
            m2_result = self.matmul2(z, self.cast2(self.weight1,
                                                   mstype.float32))
            m3_result = self.matmul3(m2_result, m1_result)

            return m3_result

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    set_algo_parameters(elementwise_op_strategy_follow=True)
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    y = Tensor(np.ones([64, 64]), dtype=ms.float32)
    z = Tensor(np.ones([64, 64]), dtype=ms.float32)
    w = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, z, w, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {
        'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]],
        'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
        'Default/network-Net/Cast-op2': [[1, 1]],
        'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
        'Default/network-Net/Cast-op4': [[1, 1]]
    }
    assert strategies == expected_strategies
예제 #11
0
def test_double_subgraphs():
    cost_model_context.set_cost_model_context(multi_subgraphs=True)
    context.set_context(save_graphs=True)
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    net = TrainStepWarp(NetWithLoss(Net()))
    context.set_auto_parallel_context(parallel_mode="auto_parallel")

    x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
    reset_op_id()
    _executor.compile(net, x, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
                           'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
    assert strategies == expected_strategies
예제 #12
0
def test_auto_parallel_assign_sub_with_ref_key():
    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    x = Tensor(np.random.rand(4, 4, 32, 64), dtype=ms.float32)

    net = NetWithLoss(nn.PReLU(4))
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    reset_op_id()

    _executor.compile(net, x, phase="train")
    strategies = _executor._get_strategy(net)
    for (k, v) in strategies.items():
        if re.search('PReLU-op', k) is not None:
            assert v == [[1, 1, 1, 8], [1]]
        elif re.search('ReLU-op', k) is not None:
            assert v == [[1]]
예제 #13
0
def test_auto_parallel_assign_sub_with_ref_key():
    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    x = Tensor(np.random.rand(4, 4, 32, 64), dtype=ms.float32)

    net = NetWithLoss(nn.PReLU(4))
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    reset_op_id()

    _executor.compile(net, x, phase="train")
    strategies = _executor._get_strategy(net)
    expected_strategies = {
        'Default/network-PReLU/PReLU-op2': [[1, 1, 1, 8], [1]],
        'Default/network-PReLU/ReLU-op3': [[1]]
    }
    assert strategies == expected_strategies
예제 #14
0
def test_common_parameter():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.matmul3 = P.MatMul()
            self.weight1 = Parameter(Tensor(
                np.ones([64, 64]).astype(np.float16) * 0.01),
                                     "w",
                                     requires_grad=True)
            self.cast1 = P.Cast()
            self.cast2 = P.Cast()

        def construct(self, x, y):
            m1_result = self.matmul1(x, self.cast1(self.weight1,
                                                   mstype.float32))
            m2_result = self.matmul2(y, self.cast2(self.weight1,
                                                   mstype.float32))
            m3_result = self.matmul3(m2_result, m1_result)

            return m3_result

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    set_algo_parameters(elementwise_op_strategy_follow=True)
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    y = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('MatMul-op', k) is not None:
            assert v == [[8, 1], [1, 1]]
        elif re.search('Cast-op', k) is not None:
            assert v == [[1, 1]]
예제 #15
0
def test_double_star_graph():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.matmul3 = P.MatMul()
            self.cast1 = P.Cast()
            self.cast2 = P.Cast()

        def construct(self, x, y, z, w):
            m1_result = self.matmul1(x, y)
            m2_result = self.matmul2(z, w)
            m3_result = self.matmul3(self.cast1(m2_result, mstype.float16),
                                     self.cast2(m1_result, mstype.float16))

            return m3_result

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    x = Tensor(np.ones([32, 8]), dtype=ms.float32)
    y = Tensor(np.ones([8, 16]), dtype=ms.float32)
    z = Tensor(np.ones([8, 16]), dtype=ms.float32)
    w = Tensor(np.ones([16, 32]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, z, w, phase='train')
    strategies = _executor._get_shard_strategy(net)
    expected_strategies = {
        'Default/network-Net/Cast-op2': [[8, 1]],
        'Default/network-Net/Cast-op4': [[1, 8]],
        'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
        'Default/network-Net/MatMul-op5': [[1, 1], [1, 8]],
        'Default/network-Net/MatMul-op1': [[1, 8], [8, 1]]
    }
    assert strategies == expected_strategies
def test_double_subgraphs():
    _set_multi_subgraphs()
    context.set_context(save_graphs=True)
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net = TrainStepWarp(NetWithLoss(Net()))
    net.set_auto_parallel()

    x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
    reset_op_id()
    net.set_train()
    _executor.compile(net, x, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('ReduceMean-op', k) is not None:
            assert v == [[8, 1, 1, 1]]
        elif re.search('ReLU-op', k) is not None:
            assert v == [[8, 1, 1, 1]]
        elif re.search('Mul-op', k) is not None:
            assert v == [[8, 1, 1, 1], [8, 1, 1, 1]]
        elif re.search('ReduceSum-op', k) is not None:
            assert v == [[8, 1, 1, 1]]
예제 #17
0
def test_two_bn():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.block1 = get_block()
            self.block2 = get_block()
            self.relu = P.ReLU()
            self.add = P.Add()
            self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32)

        def construct(self, x):
            out = self.block1(x)
            out = self.relu(out)
            out = self.add(out, self.bias)
            out = self.block2(out)
            return out

    context.set_context(save_graphs=False)
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net = NetWithLoss(Net())
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    net.set_auto_parallel()
    net.set_train()
    set_algo_parameters(elementwise_op_strategy_follow=True)
    reset_op_id()

    _executor.compile(net, x, phase='train')
    strategies = _executor._get_shard_strategy(net)
    assert len(strategies) == 4

    for (k, v) in strategies.items():
        if re.search('BatchNorm-op', k) is not None:
            assert v == [[8, 1], [1], [1], [1], [1]]
        elif re.search('TensorAdd-op', k) is not None:
            assert v == [[8, 1], [8, 1]]
        elif re.search('ReLU-op', k) is not None:
            assert v == [[8, 1]]
def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    expected_strategies = {
        'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
        'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]
    }
    assert strategies == expected_strategies
예제 #19
0
def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False,
                        enable_algo_approxi=True,
                        algo_approxi_epsilon=0.001)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.001

    expecte_single_loop = True
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop
    expecte_single_loop = False
    _set_algo_single_loop(expecte_single_loop)
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert not enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.1

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('MatMul-op', k) is not None:
            assert v == [[16, 1], [1, 1]]