コード例 #1
0
def test_auto_parallel():
    context.set_context(save_graphs=True)
    context.set_auto_parallel_context(parallel_mode="auto_parallel",
                                      device_num=16,
                                      global_rank=0)
    net = Full(_w1, 3)
    net.set_auto_parallel()
    net.set_train()
    _executor.compile(net, _x, phase='train')
    num_ops = _executor._get_num_parallel_ops(net)
    expected_num = 16
    assert num_ops == expected_num
コード例 #2
0
def test_double_subgraphs():
    context.set_context(save_graphs=False)
    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
    net = TrainStepWarp(NetWithLoss(Net()))
    _set_multi_subgraphs()
    net.set_auto_parallel()

    x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
    reset_op_id()
    net.set_train()
    _executor.compile(net, x, phase='train')
    num_ops = _executor._get_num_parallel_ops(net)
    expected_num = 7
    assert expected_num == num_ops