def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # auto completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    partitioner = Partitioner(dist_context, rank_id)
    dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        dist_train_program, dist_startup_prog, dist_params_grads)

    resharder = Resharder(dist_train_program, dist_startup_prog, rank_id,
                          dist_context, dist_params_grads)
    resharder.reshard()
    return dist_train_program, dist_startup_prog
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    global _global_process_mesh
    dist_context.process_mesh = _global_process_mesh
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    params_grads = parallelizer._generate_backward(complete_train_program,
                                                   startup_program,
                                                   loss,
                                                   parameter_list=None,
                                                   no_grad_set=None,
                                                   callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
    return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
    def test_allgather(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        process_mesh = auto.ProcessMesh(mesh=[0, 3])
        with static.program_guard(train_program, startup_program):
            x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
            x = auto.shard_tensor(x,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [0, -1]
                                  })

            w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
            w = auto.shard_tensor(w,
                                  dist_attr={
                                      "process_mesh": process_mesh,
                                      "dims_mapping": [-1, -1]
                                  })

            # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
            #     x.name: [-1, -1],
            #     w.name: [-1, -1]
            # }, **{"x": x,
            #       "y": w})[0]

            y = paddle.distributed.shard_op(paddle.matmul,
                                            dist_attr={
                                                "process_mesh": process_mesh,
                                                x: {
                                                    "dims_mapping": [-1, -1]
                                                },
                                                w: {
                                                    "dims_mapping": [-1, -1]
                                                }
                                            })(x, w)[0]

        rank_id = 0
        dist_context = DistributedContext()
        dist_strategy = fleet.DistributedStrategy()
        partitioner = Partitioner(dist_context, rank_id)
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
        dist_context.block_state.parse_forward_blocks(complete_train_program)
        partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
            complete_train_program, startup_program, [])
        resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
                              rank_id, dist_context, partitioned_params_grads)
        resharder.reshard()
        # the x should not be slice
        self.assertTrue(check_allgather(partitioned_main_prog))
예제 #4
0
 def test_decoder_dp(self):
     global _global_parallel_strategy
     _global_parallel_strategy = "dp"
     global _global_process_mesh
     _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
     train_program = static.Program()
     start_program = static.Program()
     dist_context = DistributedContext()
     train_program, start_program = decoder_pretrain_forward(
         train_program, start_program)
     completer = Completer(dist_context)
     complete_train_program = completer.complete_forward_annotation(
         train_program)
     self.assertTrue(dist_context.validate_dist_attr_for_program())
예제 #5
0
def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.partitioner import Partitioner
    from paddle.distributed.auto_parallel.dist_context import DistributedContext

    main_program, start_program = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)

    dist_context.block_state.parse_forward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
    dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
                                                 [])

    return dist_main_prog, dist_context
예제 #6
0
    def test_loss_and_grad_allreduce(self):

        dist_context = DistributedContext(self.main_program,
                                          self.startup_program)
        completer = Completer(dist_context)
        completer.complete_prim_annotation(self.main_program)
        dist_context.block_state.parse_forward_blocks(self.main_program)
        dist_context.block_state.parse_backward_blocks(self.main_program)
        dist_context.grads_params = dict()
        dist_context.grads_params[self.w_grad.name] = self.w.name
        dist_context.synced_gradient = set()
        dist_context.data_parallel_group = list(range(nranks))
        partitioner = Partitioner(dist_context, rank)
        dist_main_prog, dist_startup_prog, _ = partitioner.partition(
            self.main_program, self.startup_program, [(self.w, self.w_grad)])
        ops = dist_main_prog.global_block().ops

        self.assertTrue(ops[1].type == "c_allreduce_sum")
        self.assertTrue(ops[3].type == "c_allreduce_sum")
def get_dist_prog(train_program,
                  startup_program,
                  dist_context,
                  rank_id,
                  change_process_mesh=False):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
    dist_context.block_state.parse_forward_blocks(complete_train_program)
    if change_process_mesh:
        global PP_MESH_1
        dist_context.get_tensor_dist_attr_for_program(
            train_program.global_block().vars[
                "gelu_0.tmp_0"]).process_mesh = PP_MESH_1

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

    # logical partition
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)

    return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
예제 #8
0
def parallelizer(program_func, rank):
    from paddle.distributed.auto_parallel.completion import Completer
    from paddle.distributed.auto_parallel.partitioner import Partitioner
    from paddle.distributed.auto_parallel.dist_context import DistributedContext

    main_program, start_program, loss = program_func()

    dist_context = DistributedContext()
    completer = Completer(dist_context)
    completer.complete_forward_annotation(main_program)
    dist_context.block_state.parse_forward_blocks(main_program)

    with program_guard(main_program, start_program):
        params_grads = append_backward(
            loss, distop_context=dist_context.dist_op_context)
    completer.complete_backward_annotation(main_program)

    dist_context.block_state.parse_backward_blocks(main_program)
    partitioner = Partitioner(dist_context, rank)
    dist_main_prog, _, _ = partitioner.partition(main_program, start_program,
                                                 [])

    return dist_main_prog, dist_context
예제 #9
0
def completion(train_program, start_program, dist_context):
    # blocks = train_program.blocks
    # # completion tensors
    # for block in blocks:
    #     for op in block.ops:
    #         if op.type == "layer_norm":
    #             for out_name in op.output_arg_names:
    #                 out_var = block.vars[out_name]
    #                 tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     out_var)
    #                 if tensor_dist_attr:
    #                     continue
    #                 tensor_dist_attr = TensorDistributedAttribute()
    #                 tensor_dist_attr.process_mesh = _g_process_mesh
    #                 tensor_dist_attr.dims_mapping = [-1]
    #                 dist_context.set_tensor_dist_attr_for_program(
    #                     out_var, tensor_dist_attr)

    #         elif op.type == "elementwise_sub":
    #             for out_name in op.output_arg_names:
    #                 out_var = block.vars[out_name]
    #                 tensor_dist_attr = TensorDistributedAttribute()
    #                 tensor_dist_attr.process_mesh = _g_process_mesh
    #                 tensor_dist_attr.dims_mapping = [-1, -1, -1]
    #                 dist_context.set_tensor_dist_attr_for_program(
    #                     out_var, tensor_dist_attr)

    #         elif op.type == "matmul_v2":
    #             col = False
    #             for in_name in op.input_arg_names:
    #                 if ".w_" not in in_name:
    #                     continue
    #                 if in_name not in block.vars:
    #                     in_var = blocks[0].vars[in_name]
    #                 else:
    #                     in_var = block.vars[in_name]
    #                 tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     in_var)
    #                 assert tensor_dist_attr is not None
    #                 if tensor_dist_attr.dims_mapping == [-1, 0]:
    #                     col = True
    #             for out_name in op.output_arg_names:
    #                 out_var = block.vars[out_name]
    #                 tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     out_var)
    #                 if tensor_dist_attr:
    #                     continue
    #                 tensor_dist_attr = TensorDistributedAttribute()
    #                 tensor_dist_attr.process_mesh = _g_process_mesh
    #                 if col:
    #                     tensor_dist_attr.dims_mapping = [-1, -1, 0]
    #                 else:
    #                     tensor_dist_attr.dims_mapping = [-1, -1, -1]
    #                 dist_context.set_tensor_dist_attr_for_program(
    #                     out_var, tensor_dist_attr)
    #         elif op.type == "while":
    #             out_name = op.desc.output("StepScopes")[0]
    #             out_var = block.vars[out_name]
    #             tensor_dist_attr = TensorDistributedAttribute()
    #             tensor_dist_attr.process_mesh = _g_process_mesh
    #             tensor_dist_attr.dims_mapping = [-1]
    #             dist_context.set_tensor_dist_attr_for_program(out_var,
    #                                                           tensor_dist_attr)

    # # completion ops
    # for block in blocks:
    #     for op in block.ops:
    #         op_dist_attr = OperatorDistributedAttribute()
    #         op_dist_attr.process_mesh = _g_process_mesh
    #         if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
    #             for in_name in op.input_arg_names:
    #                 op_dist_attr.set_input_dims_mapping(in_name, [])
    #             for out_name in op.output_arg_names:
    #                 op_dist_attr.set_output_dims_mapping(out_name, [])
    #         elif op.type == "read":
    #             for in_name in op.input_arg_names:
    #                 op_dist_attr.set_output_dims_mapping(in_name, [])
    #             for out_name in op.output_arg_names:
    #                 out_var = block.vars[out_name]
    #                 out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     out_var)
    #                 op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
    #         elif op.type == "while":
    #             for in_name in op.input_arg_names:
    #                 in_var = block.vars[in_name]
    #                 in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     in_var)
    #                 op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
    #             for out_name in op.output_arg_names:
    #                 if out_name == op.desc.output("StepScopes")[0]:
    #                     op_dist_attr.set_output_dims_mapping(out_name, [])
    #                 else:
    #                     out_var = block.vars[out_name]
    #                     out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                         out_var)
    #                     op_dist_attr.set_output_dist_attr(out_name,
    #                                                       out_dist_attr)
    #         else:
    #             for in_name in op.input_arg_names:
    #                 if in_name == "lod_tensor_blocking_queue_0":
    #                     continue
    #                 if in_name not in block.vars:
    #                     in_var = blocks[0].vars[in_name]
    #                 else:
    #                     in_var = block.vars[in_name]
    #                 in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     in_var)
    #                 op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
    #             for out_name in op.output_arg_names:
    #                 if out_name not in block.vars:
    #                     out_var = blocks[0].vars[out_name]
    #                 else:
    #                     out_var = block.vars[out_name]
    #                 out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
    #                     out_var)
    #                 op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)

    #         if op.type == "matmul_v2":
    #             op_dist_attr.impl_type = "matmul_v2"
    #             for in_name in op_dist_attr.inputs_dist_attrs.keys():
    #                 in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
    #                 if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
    #                     op_dist_attr.impl_idx = 0
    #                 else:
    #                     op_dist_attr.impl_idx = 1
    #         elif op.type == "fill_constant_batch_size_like":
    #             op_dist_attr.impl_type = "fill_constant_batch_size_like"
    #             op_dist_attr.impl_idx = 0
    #         else:
    #             op_dist_attr.impl_type = "default"
    #             op_dist_attr.impl_idx = 0

    #         dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
    #         make_data_unshard(train_program, start_program, dist_context)

    completer = Completer(dist_context)
    train_program = completer.complete_forward_annotation(train_program)
    make_data_unshard(train_program, start_program, dist_context)

    return train_program, start_program
예제 #10
0
 def test_completer(self):
     train_program, start_program, dataloader, i, loss = get_program()
     dist_context = DistributedContext()
     completer = Completer(dist_context)
     complete_train_program = completer.complete_forward_annotation(
         train_program)
    def test_gpt_dp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp_mp"
        global _global_process_mesh

        _global_process_mesh = auto.ProcessMesh(
            mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])

        train_program = static.Program()
        startup_program = static.Program()
        parallelizer = AutoParallelizer(FakeFleet())
        dist_context = parallelizer._dist_context

        dist_context.process_mesh = _global_process_mesh
        train_program, startup_program, loss = gpt_pretrain_forward(
            train_program, startup_program)
        completer = Completer(dist_context)
        complete_train_program = completer.complete_forward_annotation(
            train_program)
        dist_context.block_state.parse_forward_blocks(complete_train_program)

        # serial backward pass
        params_grads = parallelizer._generate_backward(
            complete_train_program,
            startup_program,
            loss,
            parameter_list=None,
            no_grad_set=None,
            callbacks=None)

        rank_id = 3
        partitioner = Partitioner(dist_context, rank_id)
        auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition(
            complete_train_program, startup_program, params_grads)

        with open("./test_auto_parallel_partitioner_serial_main_new.txt",
                  "w") as fw:
            fw.write(str(train_program))
        with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
                  "w") as fw:
            fw.write(str(startup_program))

        from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
        set_default_distributed_context(dist_context)
        with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
            fw.write(str(auto_parallel_main_prog))
        with open("./test_auto_parallel_partitioner_startup_new.txt1",
                  "w") as fw:
            fw.write(str(auto_parallel_startup_prog))
        # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
        #     from paddle.distributed.auto_parallel.completion import Completer
        #     completer = Completer()
        #     completer.complete_forward_annotation(auto_parallel_main_prog)
        #     fw.write(str(auto_parallel_main_prog))       
        nrank = 4
        # col parallel
        weights = [
            'linear_0.w_0',
            'linear_6.w_0',
            'linear_10.w_0',
        ]
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 1, nrank))

        # row parallel
        weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, nrank))

        weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0']
        self.assertTrue(
            check_tensor_split(auto_parallel_main_prog, weights,
                               complete_train_program, weights, 0, 1))

        all_params = sorted(
            [param.name for param in startup_program.all_parameters()])
        allreduce_grads = [
            'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2',
            'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
            'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
        ]
        process_mesh = _global_process_mesh
        mp_parallel_axis = 1
        dp_parallel_axis = 0

        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
        mp_ring_id = new_process_group(group_ranks).id

        group_ranks = _get_comm_group(
            process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
        dp_ring_id = new_process_group(group_ranks).id

        tensor_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and
                op.desc.attr("ring_id") == mp_ring_id)
        ])
        data_parallel_allreduce_vars = sorted([
            op.desc.output_arg_names()[0].split("@")[0]
            for op in auto_parallel_main_prog.global_block().ops
            if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") ==
                dp_ring_id)
        ])

        self.assertTrue(all_params == data_parallel_allreduce_vars)
        self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

        self.assertTrue(
            is_valid_completed_program(dist_context, auto_parallel_main_prog))