Exemplo n.º 1
0
 def test_input_invalid(self):
     set_default_distributed_context(None)
     global _global_parallel_strategy
     _global_parallel_strategy = "mp"
     global _global_process_mesh
     _global_process_mesh = auto.ProcessMesh([0, 1])
     dist_main_prog, _, _ = get_distributed_program()
     with self.assertRaises(TypeError):
         save_distributed_checkpoint(dist_main_prog, [""], [""],
                                     addition_info=[0])
     with self.assertRaises(ValueError):
         save_distributed_checkpoint(dist_main_prog, [""], [""],
                                     addition_info={"step": 0})
     with self.assertRaises(ValueError):
         save_distributed_checkpoint(dist_main_prog, [""], [""],
                                     addition_info={"batch": 0.0})
     with self.assertRaises(ValueError):
         load_checkpoint_into_program(["./model_state_rank.pdmodel"],
                                      ["./dist_attr_rank.pdattr"],
                                      dist_main_prog)
     with self.assertRaises(ValueError):
         load_distributed_checkpoint(["./model_state_rank.pdmodel"],
                                     ["./dist_attr_rank.pdattr"])
     with self.assertRaises(TypeError):
         load_distributed_checkpoint({"0": "./model_state_rank.pdmodel"},
                                     {"1": "./dist_attr_rank.pdattr"})
Exemplo n.º 2
0
    def test_mlp_mp2pp(self):
        set_default_distributed_context(None)
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh([0, 1])

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        for step in range(20):
            if step == 10:
                save_distributed_checkpoint(dist_main_prog,
                                            ".",
                                            dist_attr_path=".")

            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])
        last_res = res[0]

        set_default_distributed_context(None)
        _global_parallel_strategy = "pp"
        _global_process_mesh = auto.ProcessMesh([0, 1])
        global PP_MESH_0
        PP_MESH_0 = auto.ProcessMesh(mesh=[0])
        global PP_MESH_1
        PP_MESH_1 = auto.ProcessMesh(mesh=[1])

        dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
        )
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog_load)

        ckpt_path = [
            "./model_state_rank0.pdmodel", "./model_state_rank1.pdmodel"
        ]
        dist_attr_path = [
            "./dist_attr_rank0.pdattr", "./dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path,
                                     dist_main_prog_load)
        for step in range(10, 20):
            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog_load,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog_load,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss_load])
        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
Exemplo n.º 3
0
    def test_mlp_pp2mp(self):
        set_default_distributed_context(None)
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh([0, 1])
        global PP_MESH_0
        PP_MESH_0 = auto.ProcessMesh(mesh=[0])
        global PP_MESH_1
        PP_MESH_1 = auto.ProcessMesh(mesh=[1])
        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)
        for step in range(20):
            if step == 10:
                add_info = {"batch": step, "batch_size": 4}
                save_distributed_checkpoint(dist_main_prog, ".", ".", add_info)

            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss])
        if paddle.distributed.get_rank() in [1]:
            last_res = res[0]

        set_default_distributed_context(None)
        _global_parallel_strategy = "mp"
        _global_process_mesh = auto.ProcessMesh([0, 1])

        dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
        )
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog_load)
        ckpt_path = [
            "./model_state_rank0.pdmodel", "./model_state_rank1.pdmodel"
        ]
        dist_attr_path = [
            "./dist_attr_rank0.pdattr", "./dist_attr_rank1.pdattr"
        ]
        param_dict, pre_dist_attr, add_info = load_distributed_checkpoint(
            ckpt_path, dist_attr_path)
        batch = add_info["batch"]
        batch_size = add_info["batch_size"]
        start_index = batch * batch_size
        input = input[start_index:, :]
        label = label[start_index:, :]
        cur_dist_attr = get_dist_attr(dist_main_prog_load)
        sliced_param_dict = merge_and_slice_parameter(param_dict,
                                                      pre_dist_attr,
                                                      cur_dist_attr)
        load_parameter_into_program(sliced_param_dict, dist_main_prog_load)
        for step in range(10):
            res = exe.run(dist_main_prog_load,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss_load])
        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
    def test_gpt_dp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
        global _global_process_mesh

        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

        train_program = static.Program()
        startup_program = static.Program()
        parallelizer = AutoParallelizer(FakeFleet())
        dist_context = parallelizer._dist_context

        dist_context.process_mesh = _global_process_mesh
        train_program, startup_program, loss = gpt_pretrain_forward(
            train_program, startup_program)
        complete_train_program = auto.complete_annotation(
            train_program, dist_context)

        # serial backward pass
        params_grads = parallelizer._generate_backward(complete_train_program,
                                                       startup_program,
                                                       loss,
                                                       parameter_list=None,
                                                       no_grad_set=None,
                                                       callbacks=None)

        rank_id = 3
        partitioner = Partitioner(dist_context, rank_id)
        auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
            complete_train_program, startup_program, params_grads)

        with open("./test_auto_parallel_partitioner_serial_main_new.txt",
                  "w") as fw:
            fw.write(str(train_program))
        with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
                  "w") as fw:
            fw.write(str(startup_program))

        from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
        set_default_distributed_context(dist_context)
        with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
            fw.write(str(auto_parallel_main_prog))
        with open("./test_auto_parallel_partitioner_startup_new.txt1",
                  "w") as fw:
            fw.write(str(auto_parallel_startup_prog))
        # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
        #     from paddle.distributed.auto_parallel.completion import complete_backward_annotation
        #     complete_backward_annotation(auto_parallel_main_prog)
        #     fw.write(str(auto_parallel_main_prog))
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0',
            'linear_6.w_0',
            'linear_10.w_0',
        ]
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 1, nrank))

        # row parallel
        weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, nrank))

        weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, 1))

        all_params = sorted(
            [param.name for param in startup_program.all_parameters()])
        allreduce_grads = [
            'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
            'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
            'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
        ]
        process_mesh = _global_process_mesh
        mp_parallel_axis = 1
        dp_parallel_axis = 0

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
        mp_ring_id = new_process_group(group_ranks).id

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
        dp_ring_id = new_process_group(group_ranks).id

        tensor_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1
                and op.desc.attr("ring_id") == mp_ring_id)
        ])
        data_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum"
                and op.desc.attr("ring_id") == dp_ring_id)
        ])

        self.assertTrue(all_params == data_parallel_allreduce_vars)
        self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

        self.assertTrue(
            is_valid_completed_program(dist_context, auto_parallel_main_prog))