def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # auto completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    partitioner = Partitioner(dist_context, rank_id)
    dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        dist_train_program, dist_startup_prog, dist_params_grads)

    resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
                          dist_context, dist_params_grads)
    resharder.reshard()
    return dist_train_program, dist_startup_prog
    def test_mlp_pp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
        global PP_MESH_0
        PP_MESH_0 = auto.ProcessMesh(mesh=[0])
        global PP_MESH_1
        PP_MESH_1 = auto.ProcessMesh(mesh=[1])

        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 1
        dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
            train_program, startup_program, dist_context, rank_id)
        for key in list(_g_process_group_map.keys()):
            del _g_process_group_map[key]
        resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
                              dist_context, dist_params_grads)
        resharder.reshard()

        # check send and recv result
        self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))

        # parameter initialization of every rank should be different in the pipeline scene
        self.assertTrue(check_initialization(dist_startup_prog, rank_id))
    def test_allgather(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        process_mesh = auto.ProcessMesh(mesh=[0, 3])
        with static.program_guard(train_program, startup_program):
            x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
            x = auto.shard_tensor(x,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [0, -1]
                                  })

            w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
            w = auto.shard_tensor(w,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [-1, -1]
                                  })

            # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
            #     x.name: [-1, -1],
            #     w.name: [-1, -1]
            # }, **{"x": x,
            #       "y": w})[0]

            y = paddle.distributed.shard_op(paddle.matmul,
                                            dist_attr={
                                                "process_mesh": process_mesh,
                                                x: {
                                                    "dims_mapping": [-1, -1]
                                                },
                                                w: {
                                                    "dims_mapping": [-1, -1]
                                                }
                                            })(x, w)[0]

        rank_id = 0
        dist_context = DistributedContext()
        dist_strategy = fleet.DistributedStrategy()
        partitioner = Partitioner(dist_context, rank_id)
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
        dist_context.block_state.parse_forward_blocks(complete_train_program)
        partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
            complete_train_program, startup_program, [])
        resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
                              rank_id, dist_context, partitioned_params_grads)
        resharder.reshard()
        # the x should not be slice
        self.assertTrue(check_allgather(partitioned_main_prog))
    def test_mlp_dpmppp(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 2
        dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
            train_program, startup_program, dist_context, rank_id)
        resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
                              dist_context, dist_params_grads)
        resharder.reshard()
        # print_program_with_dist_attr(dist_main_prog, dist_context)
        # check send and recv result
        self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))

        # check parameter initialization
        self.assertTrue(check_initialization_for_dpmppp(dist_startup_prog))
    def test_mlp_pp_diff_process_mesh(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 1
        dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
            train_program, startup_program, dist_context, rank_id, True)
        for key in list(_g_process_group_map.keys()):
            del _g_process_group_map[key]
        resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
                              dist_context, dist_params_grads)
        resharder.reshard()

        # check send and recv result
        self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
        self.assertTrue(check_initialization(dist_startup_prog, rank_id))
    def test_mlp_mppp(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 2
        dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
            train_program, startup_program, dist_context, rank_id)
        resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
                              dist_context, dist_params_grads)
        resharder.reshard()

        # check send and recv result
        self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))

        # parameter which not been sliced should be the same in the mp scene
        self.assertTrue(
            check_initialization_for_mppp(dist_startup_prog, rank_id))
    def test_mlp_dp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
        global _global_process_mesh
        _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])

        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 0
        dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
            train_program, startup_program, dist_context, rank_id)
        resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
                              dist_context, dist_params_grads)
        resharder.reshard()
        # send and recv should not exist in dp scene.
        self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))

        # all parameters should be initialized in dp scene
        self.assertTrue(check_initialization_for_dp(dist_startup_prog))
示例#8
0
 def test_auto_parallel_cost_model(self):
     standalone_cost_data = get_single_node_data()
     dist_program = []
     for rank_id in range(NUM_RANKS):
         train_program = paddle.static.Program()
         startup_program = paddle.static.Program()
         dist_context = DistributedContext()
         distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
             train_program, startup_program, dist_context, rank_id)
         resharder = Resharder(distributed_program, dist_startup_prog,
                               rank_id, dist_context, dist_params_grads)
         resharder.reshard()
         dist_program.append(distributed_program)
     cluster = None
     cost = estimate_cost(dist_program,
                          cluster=cluster,
                          pipeline_config=pp_cfg,
                          standalone_cost_data=standalone_cost_data,
                          batch_size=4)
     self.assertTrue(check_runtime_estimation(cost))
     self.assertTrue(check_memory_estimation(cost))