コード例 #1
0
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    global _global_process_mesh
    dist_context.process_mesh = _global_process_mesh
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
    return auto_parallel_main_prog, auto_parallel_startup_prog
コード例 #2
0
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # auto completion
    complete_train_program = auto.complete_annotation(train_program,
                                                      dist_context)

    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    partitioner = Partitioner(dist_context, rank_id)
    dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        dist_train_program, dist_startup_prog, dist_params_grads)

    reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
    return dist_train_program, dist_startup_prog
コード例 #3
0
 def test_decoder_dp(self):
     global _global_parallel_strategy
     _global_parallel_strategy = "dp"
     global _global_process_mesh
     _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
     train_program = static.Program()
     start_program = static.Program()
     dist_context = DistributedContext()
     train_program, start_program = decoder_pretrain_forward(
         train_program, start_program)
     complete_train_program = auto.complete_annotation(
         train_program, dist_context)
     # print_program_with_dist_attr(complete_train_program,
     #                                     dist_context)
     self.assertTrue(dist_context.validate_dist_attr_for_program())
コード例 #4
0
    def test_allgather(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        process_mesh = auto.ProcessMesh(mesh=[0, 3])
        with static.program_guard(train_program, startup_program):
            x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
            x = auto.shard_tensor(x,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [0, -1]
                                  })

            w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
            w = auto.shard_tensor(w,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [-1, -1]
                                  })

            # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
            #     x.name: [-1, -1],
            #     w.name: [-1, -1]
            # }, **{"x": x,
            #       "y": w})[0]

            y = paddle.distributed.shard_op(paddle.matmul,
                                            dist_attr={
                                                "process_mesh": process_mesh,
                                                x: {
                                                    "dims_mapping": [-1, -1]
                                                },
                                                w: {
                                                    "dims_mapping": [-1, -1]
                                                }
                                            })(x, w)[0]

        rank_id = 0
        dist_context = DistributedContext()
        dist_strategy = fleet.DistributedStrategy()
        partitioner = Partitioner(dist_context, rank_id)
        complete_train_program = auto.complete_annotation(
            train_program, dist_context)
        partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
            complete_train_program, startup_program, [])
        reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
                dist_context)
        # the x should not be slice
        self.assertTrue(check_allgather(partitioned_main_prog))
コード例 #5
0
    def test_gpt_dp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
        global _global_process_mesh

        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

        train_program = static.Program()
        startup_program = static.Program()
        parallelizer = AutoParallelizer(FakeFleet())
        dist_context = parallelizer._dist_context

        dist_context.process_mesh = _global_process_mesh
        train_program, startup_program, loss = gpt_pretrain_forward(
            train_program, startup_program)
        complete_train_program = auto.complete_annotation(
            train_program, dist_context)

        # serial backward pass
        params_grads = parallelizer._generate_backward(complete_train_program,
                                                       startup_program,
                                                       loss,
                                                       parameter_list=None,
                                                       no_grad_set=None,
                                                       callbacks=None)

        rank_id = 3
        partitioner = Partitioner(dist_context, rank_id)
        auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
            complete_train_program, startup_program, params_grads)

        with open("./test_auto_parallel_partitioner_serial_main_new.txt",
                  "w") as fw:
            fw.write(str(train_program))
        with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
                  "w") as fw:
            fw.write(str(startup_program))

        from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
        set_default_distributed_context(dist_context)
        with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
            fw.write(str(auto_parallel_main_prog))
        with open("./test_auto_parallel_partitioner_startup_new.txt1",
                  "w") as fw:
            fw.write(str(auto_parallel_startup_prog))
        # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
        #     from paddle.distributed.auto_parallel.completion import complete_backward_annotation
        #     complete_backward_annotation(auto_parallel_main_prog)
        #     fw.write(str(auto_parallel_main_prog))
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0',
            'linear_6.w_0',
            'linear_10.w_0',
        ]
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 1, nrank))

        # row parallel
        weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, nrank))

        weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, 1))

        all_params = sorted(
            [param.name for param in startup_program.all_parameters()])
        allreduce_grads = [
            'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
            'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
            'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
        ]
        process_mesh = _global_process_mesh
        mp_parallel_axis = 1
        dp_parallel_axis = 0

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, mp_parallel_axis,
                                      3)
        mp_ring_id = new_process_group(group_ranks).id

        group_ranks = _get_comm_group(process_mesh.processes,
                                      process_mesh.topology, dp_parallel_axis,
                                      3)
        dp_ring_id = new_process_group(group_ranks).id

        tensor_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1
                and op.desc.attr("ring_id") == mp_ring_id)
        ])
        data_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum"
                and op.desc.attr("ring_id") == dp_ring_id)
        ])

        self.assertTrue(all_params == data_parallel_allreduce_vars)
        self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

        self.assertTrue(
            is_valid_completed_program(dist_context, auto_parallel_main_prog))