Exemple #1
0
def test_train_4k_8p_gpu(batch_size=32, num_classes=4096):
    dev_num = 8
    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
    set_algo_parameters(elementwise_op_strategy_follow=True)
    resset_op_id()
    np.random.seed(6)
    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
    label_np = np.zeros([batch_size]).astype(np.int32)
    for i in range(0, batch_size):
        label_np[i] = i % num_classes
    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5, dataset, dataset_sink_mode=False)
    strategies = _executor._get_shard_strategy(model._train_network)
    for (k, v) in strategies.items():
        if re.search('Conv2D-op', k) is not None:
            assert v[0][0] == dev_num
        elif re.search('MatMul-op', k) is not None:
            assert v == [[dev_num, 1], [1, 1]]
        elif re.search('ReduceSum-op', k) is not None:
            assert v == [[dev_num, 1]]
def test_train_64k_8p(epoch_size=3,
                      batch_size=32,
                      num_classes=65536):  #1048576 #131072 #32768 #8192
    dev_num = 8
    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
                                      device_num=dev_num)
    cost_model_context.set_cost_model_context(costmodel_gamma=0.001,
                                              costmodel_beta=260.0)
    set_algo_parameters(elementwise_op_strategy_follow=True)
    resset_op_id()
    np.random.seed(6)
    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
    label_np = np.zeros([batch_size]).astype(np.int32)
    for i in range(0, batch_size):
        label_np[i] = i % num_classes
    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
                   0.01, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5, dataset, dataset_sink_mode=False)
    strategies = _executor._get_strategy(model._train_network)
    for (k, v) in strategies.items():
        if re.match(k, 'Conv2D-op') is not None:
            assert v[0][0] == dev_num
        elif re.match(k, 'MatMul-op') is not None:
            assert v == [[1, 1], [dev_num, 1]]
        elif re.match(k, 'ReduceSum-op') is not None:
            assert v == [[1, dev_num]]
Exemple #3
0
 def mindspore_auto_parallel_impl(self, dataset, epoch, device_num):
     parallel_mode_net = self.parallel_mode_net
     set_algo_parameters(fully_use_devices=False)
     context.reset_auto_parallel_context()
     context.set_auto_parallel_context(
         parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=device_num)
     self.parallel_ckpt = self._model_train_and_save_ckpt(
         net=parallel_mode_net, dataset=dataset, epoch=epoch)
     context.reset_auto_parallel_context()
Exemple #4
0
def test_train_32k_8p(batch_size=32, num_classes=32768):
    dev_num = 8
    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL,
                                      device_num=dev_num)
    set_algo_parameters(elementwise_op_strategy_follow=True)
    np.random.seed(6)
    input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32))
    net = resnet50(num_classes)
    model = Model(net)
    model.predict(input_np)
 def mindspore_optimizer_auto_parallel_impl(self, dataset, epoch,
                                            device_num):
     set_algo_parameters(fully_use_devices=False)
     context.reset_auto_parallel_context()
     context.set_auto_parallel_context(
         parallel_mode=ParallelMode.AUTO_PARALLEL,
         device_num=device_num,
         enable_parallel_optimizer=True)
     parallel_mode_net = self.net(self.strategy_dict)
     self.optimizer_parallel_ckpt = self._model_train_and_save_ckpt(
         net=parallel_mode_net, dataset=dataset, epoch=epoch)
     context.reset_auto_parallel_context()
def test_star_strategy_consistency2():
    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    set_algo_parameters(fully_use_devices=False)
    x = Tensor(np.ones([128, 1000]), dtype=ms.float32)
    strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)),
                     "relu2": ((2, 2),), "add": ((8, 1), (8, 1))}
    net = NetWithLoss(Net(strategy_dict))
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()
    net.set_train()
    _executor.compile(net, x, phase='train')
Exemple #7
0
def test_train_feed2(num_classes=1001):
    set_algo_parameters(elementwise_op_strategy_follow=True)
    parallel_callback = ModelCallback()
    dataGen = DataGenerator()
    input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
    label_full, label_part = dataGen.label_data((32 * 2,))
    dataset = Dataset(input_part, label_part)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097]
    assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
def test_train_feed(num_classes=65536):
    set_algo_parameters(elementwise_op_strategy_follow=True)
    parallel_callback = ModelCallback()
    data_gen = DataGenerator()
    _, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
    _, label_part = data_gen.label_data((32 * 8,))
    dataset = Dataset(input_part, label_part)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148]
    print(loss_value)
    assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
Exemple #9
0
def test_common_parameter():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.matmul3 = P.MatMul()
            self.weight1 = Parameter(Tensor(
                np.ones([64, 64]).astype(np.float16) * 0.01),
                                     "w",
                                     requires_grad=True)
            self.cast1 = P.Cast()
            self.cast2 = P.Cast()

        def construct(self, x, y, z, w):
            m1_result = self.matmul1(x, self.cast1(self.weight1,
                                                   mstype.float32))
            m2_result = self.matmul2(z, self.cast2(self.weight1,
                                                   mstype.float32))
            m3_result = self.matmul3(m2_result, m1_result)

            return m3_result

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    set_algo_parameters(elementwise_op_strategy_follow=True)
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    y = Tensor(np.ones([64, 64]), dtype=ms.float32)
    z = Tensor(np.ones([64, 64]), dtype=ms.float32)
    w = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, z, w, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {
        'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]],
        'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
        'Default/network-Net/Cast-op2': [[1, 1]],
        'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
        'Default/network-Net/Cast-op4': [[1, 1]]
    }
    assert strategies == expected_strategies
Exemple #10
0
def test_flatten_reshape3(parallel_mode="auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
    set_algo_parameters(fully_use_devices=False)
    net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),))
    loss = CrossEntropyLoss()
    predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(epoch_size, dataset, dataset_sink_mode=False)
Exemple #11
0
def test_flatten_reshape4(parallel_mode="semi_auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
    set_algo_parameters(fully_use_devices=False)
    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True,
                                strategy=((4, 1, 1, 1),))
    loss = CrossEntropyLoss2()
    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_train_feed(num_classes=8192):
    set_algo_parameters(elementwise_op_strategy_follow=True)
    parallel_callback = ModelCallback()
    dataGen = DataGenerator()
    input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
    label_full, label_part = dataGen.label_data((32 * 2, ))
    dataset = Dataset(input_part, label_part)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
                   10.0, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5,
                dataset,
                dataset_sink_mode=False,
                callbacks=parallel_callback)
    loss_value = np.array(parallel_callback.loss_list)
    expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489]
    assert allclose(loss_value, expect_out, 0.0001, 0.0001)
Exemple #13
0
def test_common_parameter():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()
            self.matmul3 = P.MatMul()
            self.weight1 = Parameter(Tensor(
                np.ones([64, 64]).astype(np.float16) * 0.01),
                                     "w",
                                     requires_grad=True)
            self.cast1 = P.Cast()
            self.cast2 = P.Cast()

        def construct(self, x, y):
            m1_result = self.matmul1(x, self.cast1(self.weight1,
                                                   mstype.float32))
            m2_result = self.matmul2(y, self.cast2(self.weight1,
                                                   mstype.float32))
            m3_result = self.matmul3(m2_result, m1_result)

            return m3_result

    size = 8
    context.set_auto_parallel_context(device_num=size, global_rank=0)

    set_algo_parameters(elementwise_op_strategy_follow=True)
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    y = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('MatMul-op', k) is not None:
            assert v == [[8, 1], [1, 1]]
        elif re.search('Cast-op', k) is not None:
            assert v == [[1, 1]]
Exemple #14
0
def test_two_bn():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.block1 = get_block()
            self.block2 = get_block()
            self.relu = P.ReLU()
            self.add = P.Add()
            self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32)

        def construct(self, x):
            out = self.block1(x)
            out = self.relu(out)
            out = self.add(out, self.bias)
            out = self.block2(out)
            return out

    context.set_context(save_graphs=False)
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net = NetWithLoss(Net())
    x = Tensor(np.ones([64, 64]), dtype=ms.float32)
    net.set_auto_parallel()
    net.set_train()
    set_algo_parameters(elementwise_op_strategy_follow=True)
    reset_op_id()

    _executor.compile(net, x, phase='train')
    strategies = _executor._get_shard_strategy(net)
    assert len(strategies) == 4

    for (k, v) in strategies.items():
        if re.search('BatchNorm-op', k) is not None:
            assert v == [[8, 1], [1], [1], [1], [1]]
        elif re.search('TensorAdd-op', k) is not None:
            assert v == [[8, 1], [8, 1]]
        elif re.search('ReLU-op', k) is not None:
            assert v == [[8, 1]]
def test_set_auto_parallel_context():
    context.set_auto_parallel_context(device_num=4,
                                      global_rank=3,
                                      gradients_mean=True,
                                      gradient_fp32_sync=False,
                                      parallel_mode="auto_parallel",
                                      parameter_broadcast=False)
    device_num = context.get_auto_parallel_context("device_num")
    global_rank = context.get_auto_parallel_context("global_rank")
    gradients_mean = context.get_auto_parallel_context("gradients_mean")
    gradient_fp32_sync = context.get_auto_parallel_context(
        "gradient_fp32_sync")
    parallel_mode = context.get_auto_parallel_context("parallel_mode")
    parameter_broadcast = context.get_auto_parallel_context(
        "parameter_broadcast")
    assert device_num == 4
    assert global_rank == 3
    assert gradients_mean
    assert not gradient_fp32_sync
    assert parallel_mode == "auto_parallel"
    assert not parameter_broadcast

    auto_parallel_context().set_device_num(4)
    device_num = auto_parallel_context().get_device_num()
    device_num_is_set = auto_parallel_context().get_device_num_is_set()
    assert device_num == 4
    assert device_num_is_set

    auto_parallel_context().set_global_rank(4)
    global_rank = auto_parallel_context().get_global_rank()
    assert global_rank == 4

    auto_parallel_context().set_gradients_mean(True)
    gradients_mean = auto_parallel_context().get_gradients_mean()
    assert gradients_mean

    auto_parallel_context().set_gradient_fp32_sync(False)
    gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync()
    assert not gradient_fp32_sync

    parameter_broadcast_is_set = auto_parallel_context(
    ).get_parameter_broadcast_is_set()
    assert parameter_broadcast_is_set

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(device_num=0)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(device_num=4097)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(global_rank=-1)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(parallel_mode="wrong_mode")

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(global_rank=4096)

    with pytest.raises(ValueError):
        set_algo_parameters(tensor_slice_align_size=0)

    with pytest.raises(ValueError):
        set_algo_parameters(tensor_slice_align_size=1025)

    context.set_auto_parallel_context(enable_parallel_optimizer=True)
    assert context.get_auto_parallel_context("enable_parallel_optimizer")
    assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
Exemple #16
0
def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False,
                        enable_algo_approxi=True,
                        algo_approxi_epsilon=0.001)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.001

    expecte_single_loop = True
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop
    expecte_single_loop = False
    _set_algo_single_loop(expecte_single_loop)
    signle_loop = _get_algo_single_loop()
    assert expecte_single_loop == signle_loop

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow
    enable_approxi = get_algo_parameters("enable_algo_approxi")
    assert not enable_approxi
    algo_epsilon = get_algo_parameters("algo_approxi_epsilon")
    assert algo_epsilon == 0.1

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    net.set_train()
    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    for (k, v) in strategies.items():
        if re.search('MatMul-op', k) is not None:
            assert v == [[16, 1], [1, 1]]
def run_predict_pipeline(args_opt):
    device_id = int(os.getenv("DEVICE_ID"))
    rank_id_str = os.getenv('RANK_ID', '0')
    rank_id = int(rank_id_str[rank_id_str.rfind('-') + 1:])
    print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
    device_id = int(os.getenv('DEVICE_ID'))
    local_rank = rank_id
    print('local_rank:{}, device id:{} start to run...'.format(
        local_rank, device_id),
          flush=True)
    context.set_context(save_graphs=False,
                        mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    context.set_context(variable_memory_max_size="30GB")
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        print("device_id is {}, rank_id is {}, device_num is {}".format(
            device_id, rank, device_num))
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
            gradients_mean=False,
            device_num=device_num,
            full_batch=True,
            loss_repeated_mean=True,
            enable_parallel_optimizer=False,
            pipeline_stages=args_opt.stage_num)
        set_algo_parameters(elementwise_op_strategy_follow=True)
        _set_multi_subgraphs()

    else:
        rank = 0
        device_num = 1

    model_parallel_num = args_opt.tensor_model_parallel_num
    stage_device_num = int(device_num / args_opt.stage_num)
    data_parallel_num = int(stage_device_num / model_parallel_num)
    per_batch_size = args_opt.per_batch_size
    batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
    config = PANGUALPHAConfig(data_parallel_num=data_parallel_num,
                              model_parallel_num=model_parallel_num,
                              batch_size=batch_size,
                              seq_length=args_opt.seq_length,
                              vocab_size=args_opt.vocab_size,
                              embedding_size=args_opt.embedding_size,
                              num_layers=args_opt.num_layers,
                              num_heads=args_opt.num_heads,
                              expand_ratio=4,
                              post_layernorm_residual=False,
                              dropout_rate=0.0,
                              compute_dtype=mstype.float16,
                              use_past=False,
                              self_layernorm=True,
                              forward_reduce_scatter=True,
                              stage_num=args_opt.stage_num,
                              micro_size=args_opt.micro_size,
                              word_emb_dp=False)
    print("===config is: ", config, flush=True)
    print("=====args_opt is: ", args_opt, flush=True)

    per_stage_layers = config.num_layers // config.stage_num
    per_stage_devices = device_num // config.stage_num
    self_stage = rank_id // per_stage_devices

    # all cards will save ckpt
    train_stage_num = 16
    train_device_num = 1024
    train_mp = 16
    ckpt_name = args_opt.load_ckpt_name
    train_per_stage_num = train_device_num // train_stage_num
    if config.mp != train_mp:
        raise ValueError(
            "the model parallel num is not equal to training model parallel num"
        )
    concat_stage_num = train_stage_num // config.stage_num
    pangu_alpha = PANGUALPHAPipeline(config)
    eval_net = EvalNet(pangu_alpha)
    eval_net.set_train(False)
    model_predict = Model(eval_net)
    inputs_np = Tensor(np.ones(shape=(1, config.seq_length)), mstype.int32)
    model_predict.infer_predict_layout(inputs_np)
    print("======start load_distributed checkpoint", flush=True)
    for i in range(self_stage * concat_stage_num,
                   (self_stage + 1) * concat_stage_num):
        stage_position = local_rank % (config.mp * config.dp)
        ckpt_rank = i * train_per_stage_num + stage_position  # 訓練時候的rank號
        ckpt_dir = os.path.join(args_opt.load_ckpt_path,
                                f"rank_{(ckpt_rank)}")  # 命名還是以訓練時候的rank號命名
        local_ckpt_file = os.path.join(ckpt_dir, ckpt_name)
        if not os.path.exists(local_ckpt_file):
            raise ValueError("Ckpt file not exits,", local_ckpt_file)
        params_dict = load_checkpoint(local_ckpt_file, filter_prefix="adam")
        load_param_into_net(eval_net, params_dict)
    print("================load param ok=================", flush=True)
    # here predict with fake input
    model_predict.predict(inputs_np)
def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
    dev_num = 8
    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
    cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
    cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
    set_algo_parameters(elementwise_op_strategy_follow=True)
    resset_op_id()
    np.random.seed(6)
    input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32)
    label_np = np.zeros([batch_size]).astype(np.int32)
    for i in range(0, batch_size):
        label_np[i] = i % num_classes
    dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1)
    net = resnet50(num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(5, dataset, dataset_sink_mode=False)
    strategies = _executor._get_strategy(model._train_network)
    for (k, v) in strategies.items():
        if re.search('Conv2D-op', k) is not None:
            assert v[0][0] == dev_num
        elif re.search('MatMul-op', k) is not None:
            assert v == [[dev_num, 1], [1, 1]]
        elif re.search('ReduceSum-op', k) is not None:
            assert v == [[dev_num, 1]]

    allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network)

    print(allreduce_fusion_dict)
    expect_dict = {'end_point.bias': 2,
                   'end_point.weight': 2,
                   'layer4.2.bn3.beta': 2,
                   'layer4.2.bn3.gamma': 2,
                   'layer4.2.conv3.weight': 2,
                   'layer4.2.bn2.beta': 2,
                   'layer4.2.bn2.gamma': 2,
                   'layer4.2.conv2.weight': 2,
                   'layer4.2.bn1.beta': 2,
                   'layer4.2.bn1.gamma': 2,
                   'layer4.2.conv1.weight': 2,
                   'layer4.1.bn3.beta': 2,
                   'layer4.1.bn3.gamma': 2,
                   'layer4.1.conv3.weight': 2,
                   'layer4.1.bn2.beta': 2,
                   'layer4.1.bn2.gamma': 2,
                   'layer4.1.conv2.weight': 2,
                   'layer4.1.bn1.beta': 2,
                   'layer4.1.bn1.gamma': 2,
                   'layer4.1.conv1.weight': 2,
                   'layer4.0.bn_down_sample.beta': 2,
                   'layer4.0.bn_down_sample.gamma': 2,
                   'layer4.0.conv_down_sample.weight': 2,
                   'layer4.0.bn3.beta': 2,
                   'layer4.0.bn3.gamma': 2,
                   'layer4.0.conv3.weight': 2,
                   'layer4.0.bn2.beta': 2,
                   'layer4.0.bn2.gamma': 2,
                   'layer4.0.conv2.weight': 2,
                   'layer4.0.bn1.beta': 2,
                   'layer4.0.bn1.gamma': 2,
                   'layer4.0.conv1.weight': 2,
                   'layer3.5.bn3.beta': 2,
                   'layer3.5.bn3.gamma': 2,
                   'layer3.5.conv3.weight': 2,
                   'layer3.5.bn2.beta': 2,
                   'layer3.5.bn2.gamma': 2,
                   'layer3.5.conv2.weight': 2,
                   'layer3.5.bn1.beta': 2,
                   'layer3.5.bn1.gamma': 2,
                   'layer3.5.conv1.weight': 2,
                   'layer3.4.bn3.beta': 2,
                   'layer3.4.bn3.gamma': 2,
                   'layer3.4.conv3.weight': 2,
                   'layer3.4.bn2.beta': 2,
                   'layer3.4.bn2.gamma': 2,
                   'layer3.4.conv2.weight': 2,
                   'layer3.4.bn1.beta': 2,
                   'layer3.4.bn1.gamma': 2,
                   'layer3.4.conv1.weight': 2,
                   'layer3.3.bn3.beta': 2,
                   'layer3.3.bn3.gamma': 2,
                   'layer3.3.conv3.weight': 2,
                   'layer3.3.bn2.beta': 2,
                   'layer3.3.bn2.gamma': 2,
                   'layer3.3.conv2.weight': 2,
                   'layer3.3.bn1.beta': 2,
                   'layer3.3.bn1.gamma': 2,
                   'layer3.3.conv1.weight': 2,
                   'layer3.2.bn3.beta': 2,
                   'layer3.2.bn3.gamma': 2,
                   'layer3.2.conv3.weight': 2,
                   'layer3.2.bn2.beta': 2,
                   'layer3.2.bn2.gamma': 2,
                   'layer3.2.conv2.weight': 2,
                   'layer3.2.bn1.beta': 2,
                   'layer3.2.bn1.gamma': 2,
                   'layer3.2.conv1.weight': 2,
                   'layer3.1.bn3.beta': 2,
                   'layer3.1.bn3.gamma': 2,
                   'layer3.1.conv3.weight': 2,
                   'layer3.1.bn2.beta': 2,
                   'layer3.1.bn2.gamma': 2,
                   'layer3.1.conv2.weight': 2,
                   'layer3.1.bn1.beta': 2,
                   'layer3.1.bn1.gamma': 2,
                   'layer3.1.conv1.weight': 2,
                   'layer3.0.bn_down_sample.beta': 1,
                   'layer3.0.bn_down_sample.gamma': 1,
                   'layer3.0.conv_down_sample.weight': 2,
                   'layer3.0.bn3.beta': 1,
                   'layer3.0.bn3.gamma': 1,
                   'layer3.0.conv3.weight': 2,
                   'layer3.0.bn2.beta': 2,
                   'layer3.0.bn2.gamma': 2,
                   'layer3.0.conv2.weight': 2,
                   'layer3.0.bn1.beta': 2,
                   'layer3.0.bn1.gamma': 2,
                   'layer3.0.conv1.weight': 2,
                   'layer2.3.bn3.beta': 2,
                   'layer2.3.bn3.gamma': 2,
                   'layer2.3.conv3.weight': 2,
                   'layer2.3.bn2.beta': 2,
                   'layer2.3.bn2.gamma': 2,
                   'layer2.3.conv2.weight': 2,
                   'layer2.3.bn1.beta': 2,
                   'layer2.3.bn1.gamma': 2,
                   'layer2.3.conv1.weight': 2,
                   'layer2.2.bn3.beta': 2,
                   'layer2.2.bn3.gamma': 2,
                   'layer2.2.conv3.weight': 2,
                   'layer2.2.bn2.beta': 2,
                   'layer2.2.bn2.gamma': 2,
                   'layer2.2.conv2.weight': 2,
                   'layer2.2.bn1.beta': 2,
                   'layer2.2.bn1.gamma': 2,
                   'layer2.2.conv1.weight': 2,
                   'layer2.1.bn3.beta': 1,
                   'layer2.1.bn3.gamma': 1,
                   'layer2.1.conv3.weight': 2,
                   'layer2.1.bn2.beta': 2,
                   'layer2.1.bn2.gamma': 2,
                   'layer2.1.conv2.weight': 2,
                   'layer2.1.bn1.beta': 2,
                   'layer2.1.bn1.gamma': 2,
                   'layer2.1.conv1.weight': 2,
                   'layer2.0.bn_down_sample.beta': 1,
                   'layer2.0.bn_down_sample.gamma': 1,
                   'layer2.0.conv_down_sample.weight': 2,
                   'layer2.0.bn3.beta': 1,
                   'layer2.0.bn3.gamma': 1,
                   'layer2.0.conv3.weight': 2,
                   'layer2.0.bn2.beta': 2,
                   'layer2.0.bn2.gamma': 2,
                   'layer2.0.conv2.weight': 2,
                   'layer2.0.bn1.beta': 2,
                   'layer2.0.bn1.gamma': 2,
                   'layer2.0.conv1.weight': 2,
                   'layer1.2.bn3.beta': 2,
                   'layer1.2.bn3.gamma': 2,
                   'layer1.2.conv3.weight': 2,
                   'layer1.2.bn2.beta': 2,
                   'layer1.2.bn2.gamma': 2,
                   'layer1.2.conv2.weight': 2,
                   'layer1.2.bn1.beta': 2,
                   'layer1.2.bn1.gamma': 2,
                   'layer1.2.conv1.weight': 2,
                   'layer1.1.bn3.beta': 1,
                   'layer1.1.bn3.gamma': 1,
                   'layer1.1.conv3.weight': 2,
                   'layer1.1.bn2.beta': 2,
                   'layer1.1.bn2.gamma': 2,
                   'layer1.1.conv2.weight': 2,
                   'layer1.1.bn1.beta': 2,
                   'layer1.1.bn1.gamma': 2,
                   'layer1.1.conv1.weight': 2,
                   'layer1.0.bn_down_sample.beta': 1,
                   'layer1.0.bn_down_sample.gamma': 1,
                   'layer1.0.conv_down_sample.weight': 2,
                   'layer1.0.bn3.beta': 1,
                   'layer1.0.bn3.gamma': 1,
                   'layer1.0.conv3.weight': 2,
                   'layer1.0.bn2.beta': 2,
                   'layer1.0.bn2.gamma': 2,
                   'layer1.0.conv2.weight': 2,
                   'layer1.0.bn1.beta': 2,
                   'layer1.0.bn1.gamma': 2,
                   'layer1.0.conv1.weight': 2,
                   'bn1.beta': 1,
                   'bn1.gamma': 1,
                   'conv1.weight': 2}

    assert (allreduce_fusion_dict == expect_dict)
    cost_model_context.reset_cost_model_context()
def test_two_matmul():
    class Net(nn.Cell):
        def __init__(self):
            super().__init__()
            self.matmul1 = P.MatMul()
            self.matmul2 = P.MatMul()

        def construct(self, x, y, b):
            out = self.matmul1(x, y)
            out = self.matmul2(out, b)
            return out

    size = 16
    context.set_auto_parallel_context(device_num=size, global_rank=0)
    cost_model_context.set_cost_model_context(
        device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
        costmodel_alpha=1.0,
        costmodel_beta=60.0,
        costmodel_gamma=0.1,
        costmodel_communi_threshold=1024.0,
        costmodel_communi_const=2222.0,
        costmodel_communi_bias=1111.0)
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 60.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.1
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 1024.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 2222.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1111.0

    cost_model_context.reset_cost_model_context()
    dev_mem_cap = cost_model_context.get_cost_model_context(
        "device_memory_capacity")
    assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0
    costmodel_alpha = cost_model_context.get_cost_model_context(
        "costmodel_alpha")
    assert costmodel_alpha == 1.0
    costmodel_beta = cost_model_context.get_cost_model_context(
        "costmodel_beta")
    assert costmodel_beta == 400.0
    costmodel_gamma = cost_model_context.get_cost_model_context(
        "costmodel_gamma")
    assert costmodel_gamma == 0.001
    costmodel_communi_threshold = cost_model_context.get_cost_model_context(
        "costmodel_communi_threshold")
    assert costmodel_communi_threshold == 2048.0
    costmodel_communi_const = cost_model_context.get_cost_model_context(
        "costmodel_communi_const")
    assert costmodel_communi_const == 3072.0
    costmodel_communi_bias = cost_model_context.get_cost_model_context(
        "costmodel_communi_bias")
    assert costmodel_communi_bias == 1024.0

    set_algo_parameters(tensor_slice_align_enable=False,
                        tensor_slice_align_size=32,
                        fully_use_devices=False,
                        elementwise_op_strategy_follow=False)
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 32
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert not fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    reset_algo_parameters()
    para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable")
    assert not para_slice_align_enable
    para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
    assert para_slice_align_size == 16
    fully_use_devices = get_algo_parameters("fully_use_devices")
    assert fully_use_devices
    elementwise_op_strategy_follow = get_algo_parameters(
        "elementwise_op_strategy_follow")
    assert not elementwise_op_strategy_follow

    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = NetWithLoss(Net())
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net.set_auto_parallel()
    reset_op_id()

    _executor.compile(net, x, y, b, phase='train')
    strategies = _executor._get_shard_strategy(net)
    expected_strategies = {
        'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]],
        'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]
    }
    assert strategies == expected_strategies
Exemple #20
0
 # init context
 context.set_context(mode=context.GRAPH_MODE,
                     device_target=target,
                     save_graphs=False)
 if args_opt.parameter_server:
     context.set_ps_context(enable_ps=True)
 if args_opt.run_distribute:
     if target == "Ascend":
         device_id = int(os.getenv('DEVICE_ID'))
         context.set_context(device_id=device_id,
                             enable_auto_mixed_precision=True)
         context.set_auto_parallel_context(
             device_num=args_opt.device_num,
             parallel_mode=ParallelMode.DATA_PARALLEL,
             gradients_mean=True)
         set_algo_parameters(elementwise_op_strategy_follow=True)
         if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
             context.set_auto_parallel_context(
                 all_reduce_fusion_config=[85, 160])
         else:
             context.set_auto_parallel_context(
                 all_reduce_fusion_config=[180, 313])
         init()
     # GPU target
     else:
         init()
         context.set_auto_parallel_context(
             device_num=get_group_size(),
             parallel_mode=ParallelMode.DATA_PARALLEL,
             gradients_mean=True)
         if args_opt.net == "resnet50":
def run_train_pipeline(args_opt):
    device_id = int(os.getenv("DEVICE_ID"))
    rank_id = int(os.getenv("RANK_ID"))
    local_rank = rank_id
    print('local_rank:{}, device id:{} start to run...'.format(
        local_rank, device_id),
          flush=True)
    context.set_context(save_graphs=False,
                        mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    context.set_context(variable_memory_max_size="31GB")
    strategy_ckpt_save_file = "/cache/" + "strategy" + str(
        local_rank) + ".ckpt"
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        print("device_id is {}, rank_id is {}, device_num is {}".format(
            device_id, rank, device_num))
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
            gradients_mean=False,
            device_num=device_num,
            full_batch=True,
            loss_repeated_mean=True,
            enable_parallel_optimizer=bool(args_opt.optimizer_shard),
            pipeline_stages=args_opt.stage_num,
            strategy_ckpt_save_file=strategy_ckpt_save_file)
        set_algo_parameters(elementwise_op_strategy_follow=True)
        _set_multi_subgraphs()
    else:
        rank = 0
        device_num = 1

    model_parallel_num = args_opt.tensor_model_parallel_num
    stage_device_num = int(device_num / args_opt.stage_num)
    data_parallel_num = int(stage_device_num / model_parallel_num)
    per_batch_size = args_opt.per_batch_size
    batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
    config = PANGUALPHAConfig(data_parallel_num=data_parallel_num,
                              model_parallel_num=model_parallel_num,
                              batch_size=batch_size,
                              seq_length=args_opt.seq_length,
                              vocab_size=args_opt.vocab_size,
                              embedding_size=args_opt.embedding_size,
                              num_layers=args_opt.num_layers,
                              num_heads=args_opt.num_heads,
                              expand_ratio=4,
                              post_layernorm_residual=False,
                              dropout_rate=0.1,
                              compute_dtype=mstype.float16,
                              use_past=False,
                              self_layernorm=True,
                              forward_reduce_scatter=True,
                              stage_num=args_opt.stage_num,
                              micro_size=args_opt.micro_size,
                              word_emb_dp=False)
    print("===config is: ", config, flush=True)
    pangu_alpha = PANGUALPHAPipeline(config)
    loss = CrossEntropyLoss(config)
    pangu_alpha_with_loss = PANGUALPHAWithLossPipeline(config, pangu_alpha,
                                                       loss)
    pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)

    print("=====args_opt is: ", args_opt, flush=True)
    lr = LearningRate(learning_rate=args_opt.start_lr,
                      end_learning_rate=args_opt.end_lr,
                      warmup_steps=args_opt.warmup_step,
                      decay_steps=args_opt.decay_steps)

    per_stage_layers = config.num_layers // config.stage_num
    per_stage_devices = device_num // config.stage_num
    self_stage = rank_id // per_stage_devices
    range_min = self_stage * per_stage_layers
    range_max = range_min + per_stage_layers
    if self_stage == 0:
        params = [pangu_alpha.embedding_table]
        params.extend(pangu_alpha.backbone.pangu_alpha_embedding.
                      position_embedding.trainable_params())
    elif self_stage == config.stage_num - 1:
        params = [pangu_alpha.embedding_table]
        params.extend(pangu_alpha.backbone.layernorm.trainable_params())
        params.extend(
            pangu_alpha.backbone.top_query_embedding.trainable_params())
    else:
        params = []
    for i in range(range_min, range_max):
        params.extend(pangu_alpha.backbone.blocks[i].trainable_params())

    decay_filter = lambda x: 'layernorm' not in x.name.lower(
    ) and "bias" not in x.name.lower()

    decay_params = list(filter(decay_filter, params))
    other_params = list(filter(lambda x: not decay_filter(x), params))
    group_params = [{
        'params': decay_params,
        'weight_decay': args_opt.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]
    if args_opt.optimizer == "lamb":
        optimizer = nn.Lamb(group_params, learning_rate=lr)
    else:
        optimizer = nn.AdamWeightDecay(group_params,
                                       learning_rate=lr,
                                       beta1=0.9,
                                       beta2=0.95,
                                       eps=1e-8)

    save_steps = args_opt.save_steps
    ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}")
    if not os.path.exists(ckpt_dir):
        Path(ckpt_dir).mkdir(parents=True, exist_ok=True)

    ds = create_dataset(config.batch_size,
                        data_path=args_opt.data_url,
                        data_start_index=0)

    epoch_num = args_opt.epoch_size
    step_per_epoch = ds.get_dataset_size()
    callback_size = args_opt.sink_size
    actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
    callback = [
        TimeMonitor(callback_size),
        LossCallBack(callback_size, local_rank, config.stage_num)
    ]
    config_ck = CheckpointConfig(save_checkpoint_steps=save_steps,
                                 keep_checkpoint_max=1,
                                 integrated_save=False,
                                 filter_prefix="accu_grads")
    ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha",
                                 directory=ckpt_dir,
                                 config=config_ck)
    callback.append(ckpoint_cb)
    loss_scale_value = math.pow(2, 32)
    update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
                                             scale_factor=2,
                                             scale_window=1000)

    pangu_alpha_with_grads = PANGUALPHATrainPipelineWithLossScaleCell(
        pangu_alpha_with_loss,
        optimizer=optimizer,
        config=config,
        scale_update_cell=update_cell)

    model = Model(pangu_alpha_with_grads)
    de.config.set_sending_batches(2 * args_opt.sink_size)
    model.train(actual_epoch_num,
                ds,
                callbacks=callback,
                sink_size=callback_size,
                dataset_sink_mode=True)
Exemple #22
0
def test_set_auto_parallel_context():
    context.set_auto_parallel_context(device_num=4,
                                      global_rank=3,
                                      mirror_mean=True,
                                      cast_before_mirror=False,
                                      parallel_mode="auto_parallel",
                                      parameter_broadcast=False)
    device_num = context.get_auto_parallel_context("device_num")
    global_rank = context.get_auto_parallel_context("global_rank")
    mirror_mean = context.get_auto_parallel_context("mirror_mean")
    cast_before_mirror = context.get_auto_parallel_context(
        "cast_before_mirror")
    parallel_mode = context.get_auto_parallel_context("parallel_mode")
    parameter_broadcast = context.get_auto_parallel_context(
        "parameter_broadcast")
    assert device_num == 4
    assert global_rank == 3
    assert mirror_mean
    assert not cast_before_mirror
    assert parallel_mode == "auto_parallel"
    assert not parameter_broadcast

    auto_parallel_context().set_communication_backend("hccl")
    backend = auto_parallel_context().get_communication_backend()
    assert backend == "hccl"

    auto_parallel_context().set_device_num(4)
    device_num = auto_parallel_context().get_device_num()
    device_num_is_set = auto_parallel_context().get_device_num_is_set()
    assert device_num == 4
    assert device_num_is_set

    auto_parallel_context().set_global_rank(4)
    global_rank = auto_parallel_context().get_global_rank()
    assert global_rank == 4

    auto_parallel_context().set_mirror_mean(True)
    mirror_mean = auto_parallel_context().get_mirror_mean()
    assert mirror_mean

    auto_parallel_context().set_cast_before_mirror(False)
    cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
    assert not cast_before_mirror

    parameter_broadcast_is_set = auto_parallel_context(
    ).get_parameter_broadcast_is_set()
    assert parameter_broadcast_is_set

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(device_num=0)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(device_num=4097)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(global_rank=-1)

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(parallel_mode="wrong_mode")

    with pytest.raises(ValueError):
        context.set_auto_parallel_context(global_rank=4096)

    with pytest.raises(ValueError):
        set_algo_parameters(tensor_slice_align_size=0)

    with pytest.raises(ValueError):
        set_algo_parameters(tensor_slice_align_size=1025)

    context.set_auto_parallel_context(enable_parallel_optimizer=True)
    assert context.get_auto_parallel_context("enable_parallel_optimizer")
    assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def run_predict_no_pipeline(args_opt):
    device_id = int(os.getenv("DEVICE_ID"))
    rank_id_str = os.getenv('RANK_ID', '0')
    rank_id = int(rank_id_str[rank_id_str.rfind('-') + 1:])
    print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
    device_id = int(os.getenv('DEVICE_ID'))
    local_rank = rank_id
    print('local_rank:{}, device id:{} start to run...'.format(
        local_rank, device_id),
          flush=True)
    context.set_context(save_graphs=False,
                        mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=device_id)
    context.set_context(variable_memory_max_size="30GB")
    if args_opt.distribute == "true":
        D.init()
        device_num = D.get_group_size()
        rank = D.get_rank()
        print("device_id is {}, rank_id is {}, device_num is {}".format(
            device_id, rank, device_num))
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
            gradients_mean=False,
            device_num=device_num,
            full_batch=True,
            loss_repeated_mean=True,
            enable_parallel_optimizer=False,
            strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
            pipeline_stages=args_opt.stage_num)
        set_algo_parameters(elementwise_op_strategy_follow=True)
        _set_multi_subgraphs()

    else:
        rank = 0
        device_num = 1

    model_parallel_num = args_opt.tensor_model_parallel_num
    data_parallel_num = int(device_num / model_parallel_num)
    per_batch_size = args_opt.per_batch_size
    batch_size = per_batch_size * data_parallel_num
    config = PANGUALPHAConfig(data_parallel_num=data_parallel_num,
                              model_parallel_num=model_parallel_num,
                              batch_size=batch_size,
                              seq_length=args_opt.seq_length,
                              vocab_size=args_opt.vocab_size,
                              embedding_size=args_opt.embedding_size,
                              num_layers=args_opt.num_layers,
                              num_heads=args_opt.num_heads,
                              expand_ratio=4,
                              post_layernorm_residual=False,
                              dropout_rate=0.0,
                              compute_dtype=mstype.float16,
                              use_past=False,
                              self_layernorm=True,
                              forward_reduce_scatter=True,
                              stage_num=args_opt.stage_num,
                              micro_size=args_opt.micro_size,
                              eod_reset=False,
                              word_emb_dp=True,
                              load_ckpt_path=args_opt.load_ckpt_path)
    print("===config is: ", config, flush=True)
    print("=====args_opt is: ", args_opt, flush=True)

    ckpt_name = args_opt.load_ckpt_name
    pangu_alpha = PANGUALPHA(config)
    eval_net = EvalNet(pangu_alpha)
    eval_net.set_train(False)
    model_predict = Model(eval_net)
    inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)),
                       mstype.int32)
    predict_layout = model_predict.infer_predict_layout(inputs_np)
    print("======start load_distributed checkpoint", flush=True)
    # For 2.6B and 13B models, the number of ckpt files is 512.

    ckpt_name = 'filerted'
    ckpt_file_list = [
        os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt")
        for ckpt_rank in range(0, 512)
    ]
    print(f"Loading from path {ckpt_file_list[0]}", flush=True)
    load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout)
    print("================load param ok=================", flush=True)

    from tokenization_jieba import JIEBATokenizer
    from generate import generate
    tokenizer = JIEBATokenizer(
        os.path.join(args_opt.tokenizer_path, 'vocab.vocab'),
        os.path.join(args_opt.tokenizer_path, 'vocab.model'))

    sample = "今天是一个好天气"
    tokenized_token = tokenizer.tokenize(sample)
    start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token)
    input_ids = np.array(start_sentence).reshape(1, -1)
    output_ids = generate(model_predict, input_ids, config.seq_length, 9)
    output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist())
    print('Output is:', output_samples, flush=True)