def test_train_4k_8p_gpu(batch_size=32, num_classes=4096): dev_num = 8 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_shard_strategy(model._train_network) for (k, v) in strategies.items(): if re.search('Conv2D-op', k) is not None: assert v[0][0] == dev_num elif re.search('MatMul-op', k) is not None: assert v == [[dev_num, 1], [1, 1]] elif re.search('ReduceSum-op', k) is not None: assert v == [[dev_num, 1]]
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): if re.match(k, 'Conv2D-op') is not None: assert v[0][0] == dev_num elif re.match(k, 'MatMul-op') is not None: assert v == [[1, 1], [dev_num, 1]] elif re.match(k, 'ReduceSum-op') is not None: assert v == [[1, dev_num]]
def mindspore_auto_parallel_impl(self, dataset, epoch, device_num): parallel_mode_net = self.parallel_mode_net set_algo_parameters(fully_use_devices=False) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=device_num) self.parallel_ckpt = self._model_train_and_save_ckpt( net=parallel_mode_net, dataset=dataset, epoch=epoch) context.reset_auto_parallel_context()
def test_train_32k_8p(batch_size=32, num_classes=32768): dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) set_algo_parameters(elementwise_op_strategy_follow=True) np.random.seed(6) input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32)) net = resnet50(num_classes) model = Model(net) model.predict(input_np)
def mindspore_optimizer_auto_parallel_impl(self, dataset, epoch, device_num): set_algo_parameters(fully_use_devices=False) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=device_num, enable_parallel_optimizer=True) parallel_mode_net = self.net(self.strategy_dict) self.optimizer_parallel_ckpt = self._model_train_and_save_ckpt( net=parallel_mode_net, dataset=dataset, epoch=epoch) context.reset_auto_parallel_context()
def test_star_strategy_consistency2(): size = 8 context.set_auto_parallel_context(device_num=size, global_rank=0) set_algo_parameters(fully_use_devices=False) x = Tensor(np.ones([128, 1000]), dtype=ms.float32) strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)), "relu2": ((2, 2),), "add": ((8, 1), (8, 1))} net = NetWithLoss(Net(strategy_dict)) context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() reset_op_id() net.set_train() _executor.compile(net, x, phase='train')
def test_train_feed2(num_classes=1001): set_algo_parameters(elementwise_op_strategy_follow=True) parallel_callback = ModelCallback() dataGen = DataGenerator() input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224)) label_full, label_part = dataGen.label_data((32 * 2,)) dataset = Dataset(input_part, label_part) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097] assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
def test_train_feed(num_classes=65536): set_algo_parameters(elementwise_op_strategy_follow=True) parallel_callback = ModelCallback() data_gen = DataGenerator() _, input_part = data_gen.input_data((32 * 8, 3, 224, 224)) _, label_part = data_gen.label_data((32 * 8,)) dataset = Dataset(input_part, label_part) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148] print(loss_value) assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)
def test_common_parameter(): class Net(nn.Cell): def __init__(self): super().__init__() self.matmul1 = P.MatMul() self.matmul2 = P.MatMul() self.matmul3 = P.MatMul() self.weight1 = Parameter(Tensor( np.ones([64, 64]).astype(np.float16) * 0.01), "w", requires_grad=True) self.cast1 = P.Cast() self.cast2 = P.Cast() def construct(self, x, y, z, w): m1_result = self.matmul1(x, self.cast1(self.weight1, mstype.float32)) m2_result = self.matmul2(z, self.cast2(self.weight1, mstype.float32)) m3_result = self.matmul3(m2_result, m1_result) return m3_result size = 8 context.set_auto_parallel_context(device_num=size, global_rank=0) set_algo_parameters(elementwise_op_strategy_follow=True) x = Tensor(np.ones([64, 64]), dtype=ms.float32) y = Tensor(np.ones([64, 64]), dtype=ms.float32) z = Tensor(np.ones([64, 64]), dtype=ms.float32) w = Tensor(np.ones([64, 64]), dtype=ms.float32) net = NetWithLoss(Net()) context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() reset_op_id() _executor.compile(net, x, y, z, w, phase='train') strategies = _executor._get_strategy(net) expected_strategies = { 'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]], 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], 'Default/network-Net/Cast-op2': [[1, 1]], 'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]], 'Default/network-Net/Cast-op4': [[1, 1]] } assert strategies == expected_strategies
def test_flatten_reshape3(parallel_mode="auto_parallel"): batch_size = 16 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) set_algo_parameters(fully_use_devices=False) net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),)) loss = CrossEntropyLoss() predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32) dataset = Dataset(predict, label, 2, input_num=2) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss_fn=loss, optimizer=opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_flatten_reshape4(parallel_mode="semi_auto_parallel"): batch_size = 16 learning_rate = 0.1 momentum = 0.9 epoch_size = 2 context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) set_algo_parameters(fully_use_devices=False) net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),)) loss = CrossEntropyLoss2() predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32) dataset = Dataset(predict, label, 2, input_num=2) opt = Momentum(net.trainable_params(), learning_rate, momentum) model = Model(net, loss_fn=loss, optimizer=opt) model.train(epoch_size, dataset, dataset_sink_mode=False)
def test_train_feed(num_classes=8192): set_algo_parameters(elementwise_op_strategy_follow=True) parallel_callback = ModelCallback() dataGen = DataGenerator() input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224)) label_full, label_part = dataGen.label_data((32 * 2, )) dataset = Dataset(input_part, label_part) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 10.0, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489] assert allclose(loss_value, expect_out, 0.0001, 0.0001)
def test_common_parameter(): class Net(nn.Cell): def __init__(self): super().__init__() self.matmul1 = P.MatMul() self.matmul2 = P.MatMul() self.matmul3 = P.MatMul() self.weight1 = Parameter(Tensor( np.ones([64, 64]).astype(np.float16) * 0.01), "w", requires_grad=True) self.cast1 = P.Cast() self.cast2 = P.Cast() def construct(self, x, y): m1_result = self.matmul1(x, self.cast1(self.weight1, mstype.float32)) m2_result = self.matmul2(y, self.cast2(self.weight1, mstype.float32)) m3_result = self.matmul3(m2_result, m1_result) return m3_result size = 8 context.set_auto_parallel_context(device_num=size, global_rank=0) set_algo_parameters(elementwise_op_strategy_follow=True) x = Tensor(np.ones([64, 64]), dtype=ms.float32) y = Tensor(np.ones([64, 64]), dtype=ms.float32) net = NetWithLoss(Net()) context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() reset_op_id() net.set_train() _executor.compile(net, x, y, phase='train') strategies = _executor._get_shard_strategy(net) for (k, v) in strategies.items(): if re.search('MatMul-op', k) is not None: assert v == [[8, 1], [1, 1]] elif re.search('Cast-op', k) is not None: assert v == [[1, 1]]
def test_two_bn(): class Net(nn.Cell): def __init__(self): super().__init__() self.block1 = get_block() self.block2 = get_block() self.relu = P.ReLU() self.add = P.Add() self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32) def construct(self, x): out = self.block1(x) out = self.relu(out) out = self.add(out, self.bias) out = self.block2(out) return out context.set_context(save_graphs=False) context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(parallel_mode="auto_parallel") net = NetWithLoss(Net()) x = Tensor(np.ones([64, 64]), dtype=ms.float32) net.set_auto_parallel() net.set_train() set_algo_parameters(elementwise_op_strategy_follow=True) reset_op_id() _executor.compile(net, x, phase='train') strategies = _executor._get_shard_strategy(net) assert len(strategies) == 4 for (k, v) in strategies.items(): if re.search('BatchNorm-op', k) is not None: assert v == [[8, 1], [1], [1], [1], [1]] elif re.search('TensorAdd-op', k) is not None: assert v == [[8, 1], [8, 1]] elif re.search('ReLU-op', k) is not None: assert v == [[8, 1]]
def test_set_auto_parallel_context(): context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False, parallel_mode="auto_parallel", parameter_broadcast=False) device_num = context.get_auto_parallel_context("device_num") global_rank = context.get_auto_parallel_context("global_rank") gradients_mean = context.get_auto_parallel_context("gradients_mean") gradient_fp32_sync = context.get_auto_parallel_context( "gradient_fp32_sync") parallel_mode = context.get_auto_parallel_context("parallel_mode") parameter_broadcast = context.get_auto_parallel_context( "parameter_broadcast") assert device_num == 4 assert global_rank == 3 assert gradients_mean assert not gradient_fp32_sync assert parallel_mode == "auto_parallel" assert not parameter_broadcast auto_parallel_context().set_device_num(4) device_num = auto_parallel_context().get_device_num() device_num_is_set = auto_parallel_context().get_device_num_is_set() assert device_num == 4 assert device_num_is_set auto_parallel_context().set_global_rank(4) global_rank = auto_parallel_context().get_global_rank() assert global_rank == 4 auto_parallel_context().set_gradients_mean(True) gradients_mean = auto_parallel_context().get_gradients_mean() assert gradients_mean auto_parallel_context().set_gradient_fp32_sync(False) gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync() assert not gradient_fp32_sync parameter_broadcast_is_set = auto_parallel_context( ).get_parameter_broadcast_is_set() assert parameter_broadcast_is_set with pytest.raises(ValueError): context.set_auto_parallel_context(device_num=0) with pytest.raises(ValueError): context.set_auto_parallel_context(device_num=4097) with pytest.raises(ValueError): context.set_auto_parallel_context(global_rank=-1) with pytest.raises(ValueError): context.set_auto_parallel_context(parallel_mode="wrong_mode") with pytest.raises(ValueError): context.set_auto_parallel_context(global_rank=4096) with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=0) with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) context.set_auto_parallel_context(enable_parallel_optimizer=True) assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def test_two_matmul(): class Net(nn.Cell): def __init__(self): super().__init__() self.matmul1 = P.MatMul() self.matmul2 = P.MatMul() def construct(self, x, y, b): out = self.matmul1(x, y) out = self.matmul2(out, b) return out size = 16 context.set_auto_parallel_context(device_num=size, global_rank=0) cost_model_context.set_cost_model_context( device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0, costmodel_alpha=1.0, costmodel_beta=60.0, costmodel_gamma=0.1, costmodel_communi_threshold=1024.0, costmodel_communi_const=2222.0, costmodel_communi_bias=1111.0) dev_mem_cap = cost_model_context.get_cost_model_context( "device_memory_capacity") assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0 costmodel_alpha = cost_model_context.get_cost_model_context( "costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context( "costmodel_beta") assert costmodel_beta == 60.0 costmodel_gamma = cost_model_context.get_cost_model_context( "costmodel_gamma") assert costmodel_gamma == 0.1 costmodel_communi_threshold = cost_model_context.get_cost_model_context( "costmodel_communi_threshold") assert costmodel_communi_threshold == 1024.0 costmodel_communi_const = cost_model_context.get_cost_model_context( "costmodel_communi_const") assert costmodel_communi_const == 2222.0 costmodel_communi_bias = cost_model_context.get_cost_model_context( "costmodel_communi_bias") assert costmodel_communi_bias == 1111.0 cost_model_context.reset_cost_model_context() dev_mem_cap = cost_model_context.get_cost_model_context( "device_memory_capacity") assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0 costmodel_alpha = cost_model_context.get_cost_model_context( "costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context( "costmodel_beta") assert costmodel_beta == 400.0 costmodel_gamma = cost_model_context.get_cost_model_context( "costmodel_gamma") assert costmodel_gamma == 0.001 costmodel_communi_threshold = cost_model_context.get_cost_model_context( "costmodel_communi_threshold") assert costmodel_communi_threshold == 2048.0 costmodel_communi_const = cost_model_context.get_cost_model_context( "costmodel_communi_const") assert costmodel_communi_const == 3072.0 costmodel_communi_bias = cost_model_context.get_cost_model_context( "costmodel_communi_bias") assert costmodel_communi_bias == 1024.0 set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32, fully_use_devices=False, elementwise_op_strategy_follow=False, enable_algo_approxi=True, algo_approxi_epsilon=0.001) para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") assert not para_slice_align_enable para_slice_align_size = get_algo_parameters("tensor_slice_align_size") assert para_slice_align_size == 32 fully_use_devices = get_algo_parameters("fully_use_devices") assert not fully_use_devices elementwise_op_strategy_follow = get_algo_parameters( "elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow enable_approxi = get_algo_parameters("enable_algo_approxi") assert enable_approxi algo_epsilon = get_algo_parameters("algo_approxi_epsilon") assert algo_epsilon == 0.001 expecte_single_loop = True signle_loop = _get_algo_single_loop() assert expecte_single_loop == signle_loop expecte_single_loop = False _set_algo_single_loop(expecte_single_loop) signle_loop = _get_algo_single_loop() assert expecte_single_loop == signle_loop reset_algo_parameters() para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") assert not para_slice_align_enable para_slice_align_size = get_algo_parameters("tensor_slice_align_size") assert para_slice_align_size == 16 fully_use_devices = get_algo_parameters("fully_use_devices") assert fully_use_devices elementwise_op_strategy_follow = get_algo_parameters( "elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow enable_approxi = get_algo_parameters("enable_algo_approxi") assert not enable_approxi algo_epsilon = get_algo_parameters("algo_approxi_epsilon") assert algo_epsilon == 0.1 x = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32) net = NetWithLoss(Net()) context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() reset_op_id() net.set_train() _executor.compile(net, x, y, b, phase='train') strategies = _executor._get_shard_strategy(net) for (k, v) in strategies.items(): if re.search('MatMul-op', k) is not None: assert v == [[16, 1], [1, 1]]
def run_predict_pipeline(args_opt): device_id = int(os.getenv("DEVICE_ID")) rank_id_str = os.getenv('RANK_ID', '0') rank_id = int(rank_id_str[rank_id_str.rfind('-') + 1:]) print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str)) device_id = int(os.getenv('DEVICE_ID')) local_rank = rank_id print('local_rank:{}, device id:{} start to run...'.format( local_rank, device_id), flush=True) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="30GB") if args_opt.distribute == "true": D.init() device_num = D.get_group_size() rank = D.get_rank() print("device_id is {}, rank_id is {}, device_num is {}".format( device_id, rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, device_num=device_num, full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, pipeline_stages=args_opt.stage_num) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() else: rank = 0 device_num = 1 model_parallel_num = args_opt.tensor_model_parallel_num stage_device_num = int(device_num / args_opt.stage_num) data_parallel_num = int(stage_device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size config = PANGUALPHAConfig(data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size, num_layers=args_opt.num_layers, num_heads=args_opt.num_heads, expand_ratio=4, post_layernorm_residual=False, dropout_rate=0.0, compute_dtype=mstype.float16, use_past=False, self_layernorm=True, forward_reduce_scatter=True, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size, word_emb_dp=False) print("===config is: ", config, flush=True) print("=====args_opt is: ", args_opt, flush=True) per_stage_layers = config.num_layers // config.stage_num per_stage_devices = device_num // config.stage_num self_stage = rank_id // per_stage_devices # all cards will save ckpt train_stage_num = 16 train_device_num = 1024 train_mp = 16 ckpt_name = args_opt.load_ckpt_name train_per_stage_num = train_device_num // train_stage_num if config.mp != train_mp: raise ValueError( "the model parallel num is not equal to training model parallel num" ) concat_stage_num = train_stage_num // config.stage_num pangu_alpha = PANGUALPHAPipeline(config) eval_net = EvalNet(pangu_alpha) eval_net.set_train(False) model_predict = Model(eval_net) inputs_np = Tensor(np.ones(shape=(1, config.seq_length)), mstype.int32) model_predict.infer_predict_layout(inputs_np) print("======start load_distributed checkpoint", flush=True) for i in range(self_stage * concat_stage_num, (self_stage + 1) * concat_stage_num): stage_position = local_rank % (config.mp * config.dp) ckpt_rank = i * train_per_stage_num + stage_position # 訓練時候的rank號 ckpt_dir = os.path.join(args_opt.load_ckpt_path, f"rank_{(ckpt_rank)}") # 命名還是以訓練時候的rank號命名 local_ckpt_file = os.path.join(ckpt_dir, ckpt_name) if not os.path.exists(local_ckpt_file): raise ValueError("Ckpt file not exits,", local_ckpt_file) params_dict = load_checkpoint(local_ckpt_file, filter_prefix="adam") load_param_into_net(eval_net, params_dict) print("================load param ok=================", flush=True) # here predict with fake input model_predict.predict(inputs_np)
def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) label_np = np.zeros([batch_size]).astype(np.int32) for i in range(0, batch_size): label_np[i] = i % num_classes dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) net = resnet50(num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) model = Model(net, loss_fn=loss, optimizer=opt) model.train(5, dataset, dataset_sink_mode=False) strategies = _executor._get_strategy(model._train_network) for (k, v) in strategies.items(): if re.search('Conv2D-op', k) is not None: assert v[0][0] == dev_num elif re.search('MatMul-op', k) is not None: assert v == [[dev_num, 1], [1, 1]] elif re.search('ReduceSum-op', k) is not None: assert v == [[dev_num, 1]] allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) print(allreduce_fusion_dict) expect_dict = {'end_point.bias': 2, 'end_point.weight': 2, 'layer4.2.bn3.beta': 2, 'layer4.2.bn3.gamma': 2, 'layer4.2.conv3.weight': 2, 'layer4.2.bn2.beta': 2, 'layer4.2.bn2.gamma': 2, 'layer4.2.conv2.weight': 2, 'layer4.2.bn1.beta': 2, 'layer4.2.bn1.gamma': 2, 'layer4.2.conv1.weight': 2, 'layer4.1.bn3.beta': 2, 'layer4.1.bn3.gamma': 2, 'layer4.1.conv3.weight': 2, 'layer4.1.bn2.beta': 2, 'layer4.1.bn2.gamma': 2, 'layer4.1.conv2.weight': 2, 'layer4.1.bn1.beta': 2, 'layer4.1.bn1.gamma': 2, 'layer4.1.conv1.weight': 2, 'layer4.0.bn_down_sample.beta': 2, 'layer4.0.bn_down_sample.gamma': 2, 'layer4.0.conv_down_sample.weight': 2, 'layer4.0.bn3.beta': 2, 'layer4.0.bn3.gamma': 2, 'layer4.0.conv3.weight': 2, 'layer4.0.bn2.beta': 2, 'layer4.0.bn2.gamma': 2, 'layer4.0.conv2.weight': 2, 'layer4.0.bn1.beta': 2, 'layer4.0.bn1.gamma': 2, 'layer4.0.conv1.weight': 2, 'layer3.5.bn3.beta': 2, 'layer3.5.bn3.gamma': 2, 'layer3.5.conv3.weight': 2, 'layer3.5.bn2.beta': 2, 'layer3.5.bn2.gamma': 2, 'layer3.5.conv2.weight': 2, 'layer3.5.bn1.beta': 2, 'layer3.5.bn1.gamma': 2, 'layer3.5.conv1.weight': 2, 'layer3.4.bn3.beta': 2, 'layer3.4.bn3.gamma': 2, 'layer3.4.conv3.weight': 2, 'layer3.4.bn2.beta': 2, 'layer3.4.bn2.gamma': 2, 'layer3.4.conv2.weight': 2, 'layer3.4.bn1.beta': 2, 'layer3.4.bn1.gamma': 2, 'layer3.4.conv1.weight': 2, 'layer3.3.bn3.beta': 2, 'layer3.3.bn3.gamma': 2, 'layer3.3.conv3.weight': 2, 'layer3.3.bn2.beta': 2, 'layer3.3.bn2.gamma': 2, 'layer3.3.conv2.weight': 2, 'layer3.3.bn1.beta': 2, 'layer3.3.bn1.gamma': 2, 'layer3.3.conv1.weight': 2, 'layer3.2.bn3.beta': 2, 'layer3.2.bn3.gamma': 2, 'layer3.2.conv3.weight': 2, 'layer3.2.bn2.beta': 2, 'layer3.2.bn2.gamma': 2, 'layer3.2.conv2.weight': 2, 'layer3.2.bn1.beta': 2, 'layer3.2.bn1.gamma': 2, 'layer3.2.conv1.weight': 2, 'layer3.1.bn3.beta': 2, 'layer3.1.bn3.gamma': 2, 'layer3.1.conv3.weight': 2, 'layer3.1.bn2.beta': 2, 'layer3.1.bn2.gamma': 2, 'layer3.1.conv2.weight': 2, 'layer3.1.bn1.beta': 2, 'layer3.1.bn1.gamma': 2, 'layer3.1.conv1.weight': 2, 'layer3.0.bn_down_sample.beta': 1, 'layer3.0.bn_down_sample.gamma': 1, 'layer3.0.conv_down_sample.weight': 2, 'layer3.0.bn3.beta': 1, 'layer3.0.bn3.gamma': 1, 'layer3.0.conv3.weight': 2, 'layer3.0.bn2.beta': 2, 'layer3.0.bn2.gamma': 2, 'layer3.0.conv2.weight': 2, 'layer3.0.bn1.beta': 2, 'layer3.0.bn1.gamma': 2, 'layer3.0.conv1.weight': 2, 'layer2.3.bn3.beta': 2, 'layer2.3.bn3.gamma': 2, 'layer2.3.conv3.weight': 2, 'layer2.3.bn2.beta': 2, 'layer2.3.bn2.gamma': 2, 'layer2.3.conv2.weight': 2, 'layer2.3.bn1.beta': 2, 'layer2.3.bn1.gamma': 2, 'layer2.3.conv1.weight': 2, 'layer2.2.bn3.beta': 2, 'layer2.2.bn3.gamma': 2, 'layer2.2.conv3.weight': 2, 'layer2.2.bn2.beta': 2, 'layer2.2.bn2.gamma': 2, 'layer2.2.conv2.weight': 2, 'layer2.2.bn1.beta': 2, 'layer2.2.bn1.gamma': 2, 'layer2.2.conv1.weight': 2, 'layer2.1.bn3.beta': 1, 'layer2.1.bn3.gamma': 1, 'layer2.1.conv3.weight': 2, 'layer2.1.bn2.beta': 2, 'layer2.1.bn2.gamma': 2, 'layer2.1.conv2.weight': 2, 'layer2.1.bn1.beta': 2, 'layer2.1.bn1.gamma': 2, 'layer2.1.conv1.weight': 2, 'layer2.0.bn_down_sample.beta': 1, 'layer2.0.bn_down_sample.gamma': 1, 'layer2.0.conv_down_sample.weight': 2, 'layer2.0.bn3.beta': 1, 'layer2.0.bn3.gamma': 1, 'layer2.0.conv3.weight': 2, 'layer2.0.bn2.beta': 2, 'layer2.0.bn2.gamma': 2, 'layer2.0.conv2.weight': 2, 'layer2.0.bn1.beta': 2, 'layer2.0.bn1.gamma': 2, 'layer2.0.conv1.weight': 2, 'layer1.2.bn3.beta': 2, 'layer1.2.bn3.gamma': 2, 'layer1.2.conv3.weight': 2, 'layer1.2.bn2.beta': 2, 'layer1.2.bn2.gamma': 2, 'layer1.2.conv2.weight': 2, 'layer1.2.bn1.beta': 2, 'layer1.2.bn1.gamma': 2, 'layer1.2.conv1.weight': 2, 'layer1.1.bn3.beta': 1, 'layer1.1.bn3.gamma': 1, 'layer1.1.conv3.weight': 2, 'layer1.1.bn2.beta': 2, 'layer1.1.bn2.gamma': 2, 'layer1.1.conv2.weight': 2, 'layer1.1.bn1.beta': 2, 'layer1.1.bn1.gamma': 2, 'layer1.1.conv1.weight': 2, 'layer1.0.bn_down_sample.beta': 1, 'layer1.0.bn_down_sample.gamma': 1, 'layer1.0.conv_down_sample.weight': 2, 'layer1.0.bn3.beta': 1, 'layer1.0.bn3.gamma': 1, 'layer1.0.conv3.weight': 2, 'layer1.0.bn2.beta': 2, 'layer1.0.bn2.gamma': 2, 'layer1.0.conv2.weight': 2, 'layer1.0.bn1.beta': 2, 'layer1.0.bn1.gamma': 2, 'layer1.0.conv1.weight': 2, 'bn1.beta': 1, 'bn1.gamma': 1, 'conv1.weight': 2} assert (allreduce_fusion_dict == expect_dict) cost_model_context.reset_cost_model_context()
def test_two_matmul(): class Net(nn.Cell): def __init__(self): super().__init__() self.matmul1 = P.MatMul() self.matmul2 = P.MatMul() def construct(self, x, y, b): out = self.matmul1(x, y) out = self.matmul2(out, b) return out size = 16 context.set_auto_parallel_context(device_num=size, global_rank=0) cost_model_context.set_cost_model_context( device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0, costmodel_alpha=1.0, costmodel_beta=60.0, costmodel_gamma=0.1, costmodel_communi_threshold=1024.0, costmodel_communi_const=2222.0, costmodel_communi_bias=1111.0) dev_mem_cap = cost_model_context.get_cost_model_context( "device_memory_capacity") assert dev_mem_cap == 32.0 * 1024.0 * 1024.0 * 1024.0 costmodel_alpha = cost_model_context.get_cost_model_context( "costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context( "costmodel_beta") assert costmodel_beta == 60.0 costmodel_gamma = cost_model_context.get_cost_model_context( "costmodel_gamma") assert costmodel_gamma == 0.1 costmodel_communi_threshold = cost_model_context.get_cost_model_context( "costmodel_communi_threshold") assert costmodel_communi_threshold == 1024.0 costmodel_communi_const = cost_model_context.get_cost_model_context( "costmodel_communi_const") assert costmodel_communi_const == 2222.0 costmodel_communi_bias = cost_model_context.get_cost_model_context( "costmodel_communi_bias") assert costmodel_communi_bias == 1111.0 cost_model_context.reset_cost_model_context() dev_mem_cap = cost_model_context.get_cost_model_context( "device_memory_capacity") assert dev_mem_cap == 16.0 * 1024.0 * 1024.0 * 1024.0 costmodel_alpha = cost_model_context.get_cost_model_context( "costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context( "costmodel_beta") assert costmodel_beta == 400.0 costmodel_gamma = cost_model_context.get_cost_model_context( "costmodel_gamma") assert costmodel_gamma == 0.001 costmodel_communi_threshold = cost_model_context.get_cost_model_context( "costmodel_communi_threshold") assert costmodel_communi_threshold == 2048.0 costmodel_communi_const = cost_model_context.get_cost_model_context( "costmodel_communi_const") assert costmodel_communi_const == 3072.0 costmodel_communi_bias = cost_model_context.get_cost_model_context( "costmodel_communi_bias") assert costmodel_communi_bias == 1024.0 set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32, fully_use_devices=False, elementwise_op_strategy_follow=False) para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") assert not para_slice_align_enable para_slice_align_size = get_algo_parameters("tensor_slice_align_size") assert para_slice_align_size == 32 fully_use_devices = get_algo_parameters("fully_use_devices") assert not fully_use_devices elementwise_op_strategy_follow = get_algo_parameters( "elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow reset_algo_parameters() para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") assert not para_slice_align_enable para_slice_align_size = get_algo_parameters("tensor_slice_align_size") assert para_slice_align_size == 16 fully_use_devices = get_algo_parameters("fully_use_devices") assert fully_use_devices elementwise_op_strategy_follow = get_algo_parameters( "elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow x = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32) net = NetWithLoss(Net()) context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() reset_op_id() _executor.compile(net, x, y, b, phase='train') strategies = _executor._get_shard_strategy(net) expected_strategies = { 'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]], 'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]] } assert strategies == expected_strategies
# init context context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) if args_opt.parameter_server: context.set_ps_context(enable_ps=True) if args_opt.run_distribute: if target == "Ascend": device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context( device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) set_algo_parameters(elementwise_op_strategy_follow=True) if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": context.set_auto_parallel_context( all_reduce_fusion_config=[85, 160]) else: context.set_auto_parallel_context( all_reduce_fusion_config=[180, 313]) init() # GPU target else: init() context.set_auto_parallel_context( device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) if args_opt.net == "resnet50":
def run_train_pipeline(args_opt): device_id = int(os.getenv("DEVICE_ID")) rank_id = int(os.getenv("RANK_ID")) local_rank = rank_id print('local_rank:{}, device id:{} start to run...'.format( local_rank, device_id), flush=True) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="31GB") strategy_ckpt_save_file = "/cache/" + "strategy" + str( local_rank) + ".ckpt" if args_opt.distribute == "true": D.init() device_num = D.get_group_size() rank = D.get_rank() print("device_id is {}, rank_id is {}, device_num is {}".format( device_id, rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, device_num=device_num, full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=bool(args_opt.optimizer_shard), pipeline_stages=args_opt.stage_num, strategy_ckpt_save_file=strategy_ckpt_save_file) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() else: rank = 0 device_num = 1 model_parallel_num = args_opt.tensor_model_parallel_num stage_device_num = int(device_num / args_opt.stage_num) data_parallel_num = int(stage_device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size config = PANGUALPHAConfig(data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size, num_layers=args_opt.num_layers, num_heads=args_opt.num_heads, expand_ratio=4, post_layernorm_residual=False, dropout_rate=0.1, compute_dtype=mstype.float16, use_past=False, self_layernorm=True, forward_reduce_scatter=True, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size, word_emb_dp=False) print("===config is: ", config, flush=True) pangu_alpha = PANGUALPHAPipeline(config) loss = CrossEntropyLoss(config) pangu_alpha_with_loss = PANGUALPHAWithLossPipeline(config, pangu_alpha, loss) pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) print("=====args_opt is: ", args_opt, flush=True) lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) per_stage_layers = config.num_layers // config.stage_num per_stage_devices = device_num // config.stage_num self_stage = rank_id // per_stage_devices range_min = self_stage * per_stage_layers range_max = range_min + per_stage_layers if self_stage == 0: params = [pangu_alpha.embedding_table] params.extend(pangu_alpha.backbone.pangu_alpha_embedding. position_embedding.trainable_params()) elif self_stage == config.stage_num - 1: params = [pangu_alpha.embedding_table] params.extend(pangu_alpha.backbone.layernorm.trainable_params()) params.extend( pangu_alpha.backbone.top_query_embedding.trainable_params()) else: params = [] for i in range(range_min, range_max): params.extend(pangu_alpha.backbone.blocks[i].trainable_params()) decay_filter = lambda x: 'layernorm' not in x.name.lower( ) and "bias" not in x.name.lower() decay_params = list(filter(decay_filter, params)) other_params = list(filter(lambda x: not decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': args_opt.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) save_steps = args_opt.save_steps ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}") if not os.path.exists(ckpt_dir): Path(ckpt_dir).mkdir(parents=True, exist_ok=True) ds = create_dataset(config.batch_size, data_path=args_opt.data_url, data_start_index=0) epoch_num = args_opt.epoch_size step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) callback = [ TimeMonitor(callback_size), LossCallBack(callback_size, local_rank, config.stage_num) ] config_ck = CheckpointConfig(save_checkpoint_steps=save_steps, keep_checkpoint_max=1, integrated_save=False, filter_prefix="accu_grads") ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha", directory=ckpt_dir, config=config_ck) callback.append(ckpoint_cb) loss_scale_value = math.pow(2, 32) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) pangu_alpha_with_grads = PANGUALPHATrainPipelineWithLossScaleCell( pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) model = Model(pangu_alpha_with_grads) de.config.set_sending_batches(2 * args_opt.sink_size) model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
def test_set_auto_parallel_context(): context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, cast_before_mirror=False, parallel_mode="auto_parallel", parameter_broadcast=False) device_num = context.get_auto_parallel_context("device_num") global_rank = context.get_auto_parallel_context("global_rank") mirror_mean = context.get_auto_parallel_context("mirror_mean") cast_before_mirror = context.get_auto_parallel_context( "cast_before_mirror") parallel_mode = context.get_auto_parallel_context("parallel_mode") parameter_broadcast = context.get_auto_parallel_context( "parameter_broadcast") assert device_num == 4 assert global_rank == 3 assert mirror_mean assert not cast_before_mirror assert parallel_mode == "auto_parallel" assert not parameter_broadcast auto_parallel_context().set_communication_backend("hccl") backend = auto_parallel_context().get_communication_backend() assert backend == "hccl" auto_parallel_context().set_device_num(4) device_num = auto_parallel_context().get_device_num() device_num_is_set = auto_parallel_context().get_device_num_is_set() assert device_num == 4 assert device_num_is_set auto_parallel_context().set_global_rank(4) global_rank = auto_parallel_context().get_global_rank() assert global_rank == 4 auto_parallel_context().set_mirror_mean(True) mirror_mean = auto_parallel_context().get_mirror_mean() assert mirror_mean auto_parallel_context().set_cast_before_mirror(False) cast_before_mirror = auto_parallel_context().get_cast_before_mirror() assert not cast_before_mirror parameter_broadcast_is_set = auto_parallel_context( ).get_parameter_broadcast_is_set() assert parameter_broadcast_is_set with pytest.raises(ValueError): context.set_auto_parallel_context(device_num=0) with pytest.raises(ValueError): context.set_auto_parallel_context(device_num=4097) with pytest.raises(ValueError): context.set_auto_parallel_context(global_rank=-1) with pytest.raises(ValueError): context.set_auto_parallel_context(parallel_mode="wrong_mode") with pytest.raises(ValueError): context.set_auto_parallel_context(global_rank=4096) with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=0) with pytest.raises(ValueError): set_algo_parameters(tensor_slice_align_size=1025) context.set_auto_parallel_context(enable_parallel_optimizer=True) assert context.get_auto_parallel_context("enable_parallel_optimizer") assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def run_predict_no_pipeline(args_opt): device_id = int(os.getenv("DEVICE_ID")) rank_id_str = os.getenv('RANK_ID', '0') rank_id = int(rank_id_str[rank_id_str.rfind('-') + 1:]) print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str)) device_id = int(os.getenv('DEVICE_ID')) local_rank = rank_id print('local_rank:{}, device id:{} start to run...'.format( local_rank, device_id), flush=True) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="30GB") if args_opt.distribute == "true": D.init() device_num = D.get_group_size() rank = D.get_rank() print("device_id is {}, rank_id is {}, device_num is {}".format( device_id, rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, device_num=device_num, full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=False, strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, pipeline_stages=args_opt.stage_num) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() else: rank = 0 device_num = 1 model_parallel_num = args_opt.tensor_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num config = PANGUALPHAConfig(data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size, num_layers=args_opt.num_layers, num_heads=args_opt.num_heads, expand_ratio=4, post_layernorm_residual=False, dropout_rate=0.0, compute_dtype=mstype.float16, use_past=False, self_layernorm=True, forward_reduce_scatter=True, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size, eod_reset=False, word_emb_dp=True, load_ckpt_path=args_opt.load_ckpt_path) print("===config is: ", config, flush=True) print("=====args_opt is: ", args_opt, flush=True) ckpt_name = args_opt.load_ckpt_name pangu_alpha = PANGUALPHA(config) eval_net = EvalNet(pangu_alpha) eval_net.set_train(False) model_predict = Model(eval_net) inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) predict_layout = model_predict.infer_predict_layout(inputs_np) print("======start load_distributed checkpoint", flush=True) # For 2.6B and 13B models, the number of ckpt files is 512. ckpt_name = 'filerted' ckpt_file_list = [ os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in range(0, 512) ] print(f"Loading from path {ckpt_file_list[0]}", flush=True) load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout) print("================load param ok=================", flush=True) from tokenization_jieba import JIEBATokenizer from generate import generate tokenizer = JIEBATokenizer( os.path.join(args_opt.tokenizer_path, 'vocab.vocab'), os.path.join(args_opt.tokenizer_path, 'vocab.model')) sample = "今天是一个好天气" tokenized_token = tokenizer.tokenize(sample) start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) input_ids = np.array(start_sentence).reshape(1, -1) output_ids = generate(model_predict, input_ids, config.seq_length, 9) output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) print('Output is:', output_samples, flush=True)