def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # auto completion complete_train_program = auto.complete_annotation(train_program, dist_context) params_grads = parallelizer._generate_backward(complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) partitioner = Partitioner(dist_context, rank_id) dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) reshard(dist_train_program, dist_startup_prog, rank_id, dist_context) return dist_train_program, dist_startup_prog
def get_dist_prog(train_program, startup_program, dist_context, rank_id): global _global_process_mesh dist_context.process_mesh = _global_process_mesh loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion complete_train_program = auto.complete_annotation(train_program, dist_context) params_grads = parallelizer._generate_backward(complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog
def partition(train_program, start_program, dist_context): # optimizer = paddle.optimizer.SGD(learning_rate=0.00001) rank = paddle.distributed.get_rank() partitioner = Partitioner(dist_context, rank) dist_main_prog, dist_startup_prog, _ = partitioner.partition( train_program, start_program, []) return dist_main_prog, dist_startup_prog
def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = auto.shard_tensor(x, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [0, -1] }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = auto.shard_tensor(w, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [-1, -1] }) # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { # x.name: [-1, -1], # w.name: [-1, -1] # }, **{"x": x, # "y": w})[0] y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, x: { "dims_mapping": [-1, -1] }, w: { "dims_mapping": [-1, -1] } })(x, w)[0] rank_id = 0 dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) resharder = Resharder(partitioned_main_prog, partitioned_startup_prog, rank_id, dist_context, partitioned_params_grads) resharder.reshard() # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog))
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def test_loss_and_grad_allreduce(self): dist_context = DistributedContext(self.main_program, self.startup_program) completer = Completer(dist_context) completer.complete_prim_annotation(self.main_program) dist_context.block_state.parse_forward_blocks(self.main_program) dist_context.block_state.parse_backward_blocks(self.main_program) dist_context.grads_params = dict() dist_context.grads_params[self.w_grad.name] = self.w.name dist_context.synced_gradient = set() dist_context.data_parallel_group = list(range(nranks)) partitioner = Partitioner(dist_context, rank) dist_main_prog, dist_startup_prog, _ = partitioner.partition( self.main_program, self.startup_program, [(self.w, self.w_grad)]) ops = dist_main_prog.global_block().ops self.assertTrue(ops[1].type == "c_allreduce_sum") self.assertTrue(ops[3].type == "c_allreduce_sum")
def get_dist_prog(train_program, startup_program, dist_context, rank_id, change_process_mesh=False): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) if change_process_mesh: global PP_MESH_1 dist_context.get_tensor_dist_attr_for_program( train_program.global_block().vars[ "gelu_0.tmp_0"]).process_mesh = PP_MESH_1 params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program, loss = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) with program_guard(main_program, start_program): params_grads = append_backward( loss, distop_context=dist_context.dist_op_context) completer.complete_backward_annotation(main_program) dist_context.block_state.parse_backward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def get_dist_prog(train_program, startup_program, dist_context, rank_id, complete_train_program=None): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program ) if complete_train_program is None else complete_train_program dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program
def test_gpt_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() startup_program = static.Program() parallelizer = AutoParallelizer(FakeFleet()) dist_context = parallelizer._dist_context dist_context.process_mesh = _global_process_mesh train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program) complete_train_program = auto.complete_annotation( train_program, dist_context) # serial backward pass params_grads = parallelizer._generate_backward(complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) rank_id = 3 partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) with open("./test_auto_parallel_partitioner_serial_main_new.txt", "w") as fw: fw.write(str(train_program)) with open("./test_auto_parallel_partitioner_serial_startup_new.txt", "w") as fw: fw.write(str(startup_program)) from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context set_default_distributed_context(dist_context) with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: fw.write(str(auto_parallel_main_prog)) with open("./test_auto_parallel_partitioner_startup_new.txt1", "w") as fw: fw.write(str(auto_parallel_startup_prog)) # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: # from paddle.distributed.auto_parallel.completion import complete_backward_annotation # complete_backward_annotation(auto_parallel_main_prog) # fw.write(str(auto_parallel_main_prog)) nrank = 4 # col parallel weights = [ 'linear_0.w_0', 'linear_6.w_0', 'linear_10.w_0', ] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 1, nrank)) # row parallel weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0'] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 0, nrank)) weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0'] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 0, 1)) all_params = sorted( [param.name for param in startup_program.all_parameters()]) allreduce_grads = [ 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' ] process_mesh = _global_process_mesh mp_parallel_axis = 1 dp_parallel_axis = 0 group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id group_ranks = _get_comm_group(process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id tensor_parallel_allreduce_vars = sorted([ op.desc.output_arg_names()[0].split("@")[0] for op in auto_parallel_main_prog.global_block().ops if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and op.desc.attr("ring_id") == mp_ring_id) ]) data_parallel_allreduce_vars = sorted([ op.desc.output_arg_names()[0].split("@")[0] for op in auto_parallel_main_prog.global_block().ops if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") == dp_ring_id) ]) self.assertTrue(all_params == data_parallel_allreduce_vars) self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars) self.assertTrue( is_valid_completed_program(dist_context, auto_parallel_main_prog))