コード例 #1
0
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # auto completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    partitioner = Partitioner(dist_context, rank_id)
    dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        dist_train_program, dist_startup_prog, dist_params_grads)

    reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
    return dist_train_program, dist_startup_prog
コード例 #2
0
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    global _global_process_mesh
    dist_context.process_mesh = _global_process_mesh
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
    return auto_parallel_main_prog, auto_parallel_startup_prog
コード例 #3
0
def partition(train_program, start_program, dist_context):

    # optimizer = paddle.optimizer.SGD(learning_rate=0.00001)
    rank = paddle.distributed.get_rank()
    partitioner = Partitioner(dist_context, rank)
    dist_main_prog, dist_startup_prog, _ = partitioner.partition(
        train_program, start_program, [])

    return dist_main_prog, dist_startup_prog
コード例 #4
0
    def test_allgather(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        process_mesh = auto.ProcessMesh(mesh=[0, 3])
        with static.program_guard(train_program, startup_program):
            x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
            x = auto.shard_tensor(x,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [0, -1]
                                  })

            w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
            w = auto.shard_tensor(w,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [-1, -1]
                                  })

            # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
            #     x.name: [-1, -1],
            #     w.name: [-1, -1]
            # }, **{"x": x,
            #       "y": w})[0]

            y = paddle.distributed.shard_op(paddle.matmul,
                                            dist_attr={
                                                "process_mesh": process_mesh,
                                                x: {
                                                    "dims_mapping": [-1, -1]
                                                },
                                                w: {
                                                    "dims_mapping": [-1, -1]
                                                }
                                            })(x, w)[0]

        rank_id = 0
        dist_context = DistributedContext()
        dist_strategy = fleet.DistributedStrategy()
        partitioner = Partitioner(dist_context, rank_id)
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
        dist_context.block_state.parse_forward_blocks(complete_train_program)
        partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
            complete_train_program, startup_program, [])
        resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
                              rank_id, dist_context, partitioned_params_grads)
        resharder.reshard()
        # the x should not be slice
        self.assertTrue(check_allgather(partitioned_main_prog))
コード例 #5
0
ファイル: test_dist_slice.py プロジェクト: sandyhouse/Paddle
def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.partitioner import Partitioner
    from paddle.distributed.auto_parallel.dist_context import DistributedContext

    main_program, start_program = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)

    dist_context.block_state.parse_forward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
    dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
                                                 [])

    return dist_main_prog, dist_context
コード例 #6
0
    def test_loss_and_grad_allreduce(self):

        dist_context = DistributedContext(self.main_program,
                                          self.startup_program)
        completer = Completer(dist_context)
        completer.complete_prim_annotation(self.main_program)
        dist_context.block_state.parse_forward_blocks(self.main_program)
        dist_context.block_state.parse_backward_blocks(self.main_program)
        dist_context.grads_params = dict()
        dist_context.grads_params[self.w_grad.name] = self.w.name
        dist_context.synced_gradient = set()
        dist_context.data_parallel_group = list(range(nranks))
        partitioner = Partitioner(dist_context, rank)
        dist_main_prog, dist_startup_prog, _ = partitioner.partition(
            self.main_program, self.startup_program, [(self.w, self.w_grad)])
        ops = dist_main_prog.global_block().ops

        self.assertTrue(ops[1].type == "c_allreduce_sum")
        self.assertTrue(ops[3].type == "c_allreduce_sum")
コード例 #7
0
def get_dist_prog(train_program,
                  startup_program,
                  dist_context,
                  rank_id,
                  change_process_mesh=False):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    if change_process_mesh:
        global PP_MESH_1
        dist_context.get_tensor_dist_attr_for_program(
            train_program.global_block().vars[
                "gelu_0.tmp_0"]).process_mesh = PP_MESH_1

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)

    return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
コード例 #8
0
def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.partitioner import Partitioner
    from paddle.distributed.auto_parallel.dist_context import DistributedContext

    main_program, start_program, loss = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    with program_guard(main_program, start_program):
        params_grads = append_backward(
            loss, distop_context=dist_context.dist_op_context)
    completer.complete_backward_annotation(main_program)

    dist_context.block_state.parse_backward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
    dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
                                                 [])

    return dist_main_prog, dist_context
コード例 #9
0
def get_dist_prog(train_program,
                  startup_program,
                  dist_context,
                  rank_id,
                  complete_train_program=None):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program
    ) if complete_train_program is None else complete_train_program
    dist_context.block_state.parse_forward_blocks(complete_train_program)

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)

    return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program
コード例 #10
0
    def test_gpt_dp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
        global _global_process_mesh

        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

        train_program = static.Program()
        startup_program = static.Program()
        parallelizer = AutoParallelizer(FakeFleet())
        dist_context = parallelizer._dist_context

        dist_context.process_mesh = _global_process_mesh
        train_program, startup_program, loss = gpt_pretrain_forward(
            train_program, startup_program)
        complete_train_program = auto.complete_annotation(
            train_program, dist_context)

        # serial backward pass
        params_grads = parallelizer._generate_backward(complete_train_program,
                                                       startup_program,
                                                       loss,
                                                       parameter_list=None,
                                                       no_grad_set=None,
                                                       callbacks=None)

        rank_id = 3
        partitioner = Partitioner(dist_context, rank_id)
        auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
            complete_train_program, startup_program, params_grads)

        with open("./test_auto_parallel_partitioner_serial_main_new.txt",
                  "w") as fw:
            fw.write(str(train_program))
        with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
                  "w") as fw:
            fw.write(str(startup_program))

        from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
        set_default_distributed_context(dist_context)
        with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
            fw.write(str(auto_parallel_main_prog))
        with open("./test_auto_parallel_partitioner_startup_new.txt1",
                  "w") as fw:
            fw.write(str(auto_parallel_startup_prog))
        # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
        #     from paddle.distributed.auto_parallel.completion import complete_backward_annotation
        #     complete_backward_annotation(auto_parallel_main_prog)
        #     fw.write(str(auto_parallel_main_prog))
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0',
            'linear_6.w_0',
            'linear_10.w_0',
        ]
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 1, nrank))

        # row parallel
        weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, nrank))

        weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, 1))

        all_params = sorted(
            [param.name for param in startup_program.all_parameters()])
        allreduce_grads = [
            'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
            'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
            'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
        ]
        process_mesh = _global_process_mesh
        mp_parallel_axis = 1
        dp_parallel_axis = 0

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
        mp_ring_id = new_process_group(group_ranks).id

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
        dp_ring_id = new_process_group(group_ranks).id

        tensor_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1
                and op.desc.attr("ring_id") == mp_ring_id)
        ])
        data_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum"
                and op.desc.attr("ring_id") == dp_ring_id)
        ])

        self.assertTrue(all_params == data_parallel_allreduce_vars)
        self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

        self.assertTrue(
            is_valid_completed_program(dist_context, auto_parallel_main_prog))