def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    expected_strategies = {
        'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
        'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]
    }
    assert strategies == expected_strategies
Esempio n. 2
0
def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False,
                        enable_algo_approxi=True,
                        algo_approxi_epsilon=0.001)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.001

    expecte_single_loop = True
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop
    expecte_single_loop = False
    _set_algo_single_loop(expecte_single_loop)
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert not enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.1

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('MatMul-op', k) is not None:
            assert v == [[16, 1], [1, 1]]