def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    global _global_process_mesh
    dist_context.process_mesh = _global_process_mesh
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
    return auto_parallel_main_prog, auto_parallel_startup_prog
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # auto completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    partitioner = Partitioner(dist_context, rank_id)
    dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        dist_train_program, dist_startup_prog, dist_params_grads)

    reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
    return dist_train_program, dist_startup_prog
def get_dist_prog(train_program,
                  startup_program,
                  dist_context,
                  rank_id,
                  change_process_mesh=False):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    if change_process_mesh:
        global PP_MESH_1
        dist_context.get_tensor_dist_attr_for_program(
            train_program.global_block().vars[
                "gelu_0.tmp_0"]).process_mesh = PP_MESH_1

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)

    return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
Exemple #4
0
def get_dist_prog(train_program,
                  startup_program,
                  dist_context,
                  rank_id,
                  complete_train_program=None):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program
    ) if complete_train_program is None else complete_train_program
    dist_context.block_state.parse_forward_blocks(complete_train_program)

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)

    return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program
    def test_gpt_dp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
        global _global_process_mesh

        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

        train_program = static.Program()
        startup_program = static.Program()
        parallelizer = AutoParallelizer(FakeFleet())
        dist_context = parallelizer._dist_context

        dist_context.process_mesh = _global_process_mesh
        train_program, startup_program, loss = gpt_pretrain_forward(
            train_program, startup_program)
        complete_train_program = auto.complete_annotation(
            train_program, dist_context)

        # serial backward pass
        params_grads = parallelizer._generate_backward(complete_train_program,
                                                       startup_program,
                                                       loss,
                                                       parameter_list=None,
                                                       no_grad_set=None,
                                                       callbacks=None)

        rank_id = 3
        partitioner = Partitioner(dist_context, rank_id)
        auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
            complete_train_program, startup_program, params_grads)

        with open("./test_auto_parallel_partitioner_serial_main_new.txt",
                  "w") as fw:
            fw.write(str(train_program))
        with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
                  "w") as fw:
            fw.write(str(startup_program))

        from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
        set_default_distributed_context(dist_context)
        with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
            fw.write(str(auto_parallel_main_prog))
        with open("./test_auto_parallel_partitioner_startup_new.txt1",
                  "w") as fw:
            fw.write(str(auto_parallel_startup_prog))
        # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
        #     from paddle.distributed.auto_parallel.completion import complete_backward_annotation
        #     complete_backward_annotation(auto_parallel_main_prog)
        #     fw.write(str(auto_parallel_main_prog))
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0',
            'linear_6.w_0',
            'linear_10.w_0',
        ]
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 1, nrank))

        # row parallel
        weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, nrank))

        weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, 1))

        all_params = sorted(
            [param.name for param in startup_program.all_parameters()])
        allreduce_grads = [
            'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
            'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
            'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
        ]
        process_mesh = _global_process_mesh
        mp_parallel_axis = 1
        dp_parallel_axis = 0

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
        mp_ring_id = new_process_group(group_ranks).id

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
        dp_ring_id = new_process_group(group_ranks).id

        tensor_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1
                and op.desc.attr("ring_id") == mp_ring_id)
        ])
        data_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum"
                and op.desc.attr("ring_id") == dp_ring_id)
        ])

        self.assertTrue(all_params == data_parallel_allreduce_vars)
        self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

        self.assertTrue(
            is_valid_completed_program(dist_context, auto_parallel_main_prog))