def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # auto completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) partitioner = Partitioner(dist_context, rank_id) dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) resharder = Resharder(dist_train_program, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() return dist_train_program, dist_startup_prog
def get_dist_prog(train_program, startup_program, dist_context, rank_id): global _global_process_mesh dist_context.process_mesh = _global_process_mesh loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward(complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = auto.shard_tensor(x, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [0, -1] }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = auto.shard_tensor(w, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [-1, -1] }) # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { # x.name: [-1, -1], # w.name: [-1, -1] # }, **{"x": x, # "y": w})[0] y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, x: { "dims_mapping": [-1, -1] }, w: { "dims_mapping": [-1, -1] } })(x, w)[0] rank_id = 0 dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) resharder = Resharder(partitioned_main_prog, partitioned_startup_prog, rank_id, dist_context, partitioned_params_grads) resharder.reshard() # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog))
def test_decoder_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward( train_program, start_program) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) self.assertTrue(dist_context.validate_dist_attr_for_program())
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def test_loss_and_grad_allreduce(self): dist_context = DistributedContext(self.main_program, self.startup_program) completer = Completer(dist_context) completer.complete_prim_annotation(self.main_program) dist_context.block_state.parse_forward_blocks(self.main_program) dist_context.block_state.parse_backward_blocks(self.main_program) dist_context.grads_params = dict() dist_context.grads_params[self.w_grad.name] = self.w.name dist_context.synced_gradient = set() dist_context.data_parallel_group = list(range(nranks)) partitioner = Partitioner(dist_context, rank) dist_main_prog, dist_startup_prog, _ = partitioner.partition( self.main_program, self.startup_program, [(self.w, self.w_grad)]) ops = dist_main_prog.global_block().ops self.assertTrue(ops[1].type == "c_allreduce_sum") self.assertTrue(ops[3].type == "c_allreduce_sum")
def get_dist_prog(train_program, startup_program, dist_context, rank_id, change_process_mesh=False): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) if change_process_mesh: global PP_MESH_1 dist_context.get_tensor_dist_attr_for_program( train_program.global_block().vars[ "gelu_0.tmp_0"]).process_mesh = PP_MESH_1 params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) # logical partition partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program, loss = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) with program_guard(main_program, start_program): params_grads = append_backward( loss, distop_context=dist_context.dist_op_context) completer.complete_backward_annotation(main_program) dist_context.block_state.parse_backward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def completion(train_program, start_program, dist_context): # blocks = train_program.blocks # # completion tensors # for block in blocks: # for op in block.ops: # if op.type == "layer_norm": # for out_name in op.output_arg_names: # out_var = block.vars[out_name] # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_var) # if tensor_dist_attr: # continue # tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.dims_mapping = [-1] # dist_context.set_tensor_dist_attr_for_program( # out_var, tensor_dist_attr) # elif op.type == "elementwise_sub": # for out_name in op.output_arg_names: # out_var = block.vars[out_name] # tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.dims_mapping = [-1, -1, -1] # dist_context.set_tensor_dist_attr_for_program( # out_var, tensor_dist_attr) # elif op.type == "matmul_v2": # col = False # for in_name in op.input_arg_names: # if ".w_" not in in_name: # continue # if in_name not in block.vars: # in_var = blocks[0].vars[in_name] # else: # in_var = block.vars[in_name] # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_var) # assert tensor_dist_attr is not None # if tensor_dist_attr.dims_mapping == [-1, 0]: # col = True # for out_name in op.output_arg_names: # out_var = block.vars[out_name] # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_var) # if tensor_dist_attr: # continue # tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr.process_mesh = _g_process_mesh # if col: # tensor_dist_attr.dims_mapping = [-1, -1, 0] # else: # tensor_dist_attr.dims_mapping = [-1, -1, -1] # dist_context.set_tensor_dist_attr_for_program( # out_var, tensor_dist_attr) # elif op.type == "while": # out_name = op.desc.output("StepScopes")[0] # out_var = block.vars[out_name] # tensor_dist_attr = TensorDistributedAttribute() # tensor_dist_attr.process_mesh = _g_process_mesh # tensor_dist_attr.dims_mapping = [-1] # dist_context.set_tensor_dist_attr_for_program(out_var, # tensor_dist_attr) # # completion ops # for block in blocks: # for op in block.ops: # op_dist_attr = OperatorDistributedAttribute() # op_dist_attr.process_mesh = _g_process_mesh # if op.type == "create_by_read" or op.type == "create_double_buffer_reader": # for in_name in op.input_arg_names: # op_dist_attr.set_input_dims_mapping(in_name, []) # for out_name in op.output_arg_names: # op_dist_attr.set_output_dims_mapping(out_name, []) # elif op.type == "read": # for in_name in op.input_arg_names: # op_dist_attr.set_output_dims_mapping(in_name, []) # for out_name in op.output_arg_names: # out_var = block.vars[out_name] # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_var) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # elif op.type == "while": # for in_name in op.input_arg_names: # in_var = block.vars[in_name] # in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_var) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # for out_name in op.output_arg_names: # if out_name == op.desc.output("StepScopes")[0]: # op_dist_attr.set_output_dims_mapping(out_name, []) # else: # out_var = block.vars[out_name] # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_var) # op_dist_attr.set_output_dist_attr(out_name, # out_dist_attr) # else: # for in_name in op.input_arg_names: # if in_name == "lod_tensor_blocking_queue_0": # continue # if in_name not in block.vars: # in_var = blocks[0].vars[in_name] # else: # in_var = block.vars[in_name] # in_dist_attr = dist_context.get_tensor_dist_attr_for_program( # in_var) # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) # for out_name in op.output_arg_names: # if out_name not in block.vars: # out_var = blocks[0].vars[out_name] # else: # out_var = block.vars[out_name] # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( # out_var) # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) # if op.type == "matmul_v2": # op_dist_attr.impl_type = "matmul_v2" # for in_name in op_dist_attr.inputs_dist_attrs.keys(): # in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] # if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: # op_dist_attr.impl_idx = 0 # else: # op_dist_attr.impl_idx = 1 # elif op.type == "fill_constant_batch_size_like": # op_dist_attr.impl_type = "fill_constant_batch_size_like" # op_dist_attr.impl_idx = 0 # else: # op_dist_attr.impl_type = "default" # op_dist_attr.impl_idx = 0 # dist_context.set_op_dist_attr_for_program(op, op_dist_attr) # make_data_unshard(train_program, start_program, dist_context) completer = Completer(dist_context) train_program = completer.complete_forward_annotation(train_program) make_data_unshard(train_program, start_program, dist_context) return train_program, start_program
def test_completer(self): train_program, start_program, dataloader, i, loss = get_program() dist_context = DistributedContext() completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program)
def test_gpt_dp_mp(self): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh( mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) train_program = static.Program() startup_program = static.Program() parallelizer = AutoParallelizer(FakeFleet()) dist_context = parallelizer._dist_context dist_context.process_mesh = _global_process_mesh train_program, startup_program, loss = gpt_pretrain_forward( train_program, startup_program) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) # serial backward pass params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) rank_id = 3 partitioner = Partitioner(dist_context, rank_id) auto_parallel_main_prog, auto_parallel_startup_prog, params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) with open("./test_auto_parallel_partitioner_serial_main_new.txt", "w") as fw: fw.write(str(train_program)) with open("./test_auto_parallel_partitioner_serial_startup_new.txt", "w") as fw: fw.write(str(startup_program)) from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context set_default_distributed_context(dist_context) with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: fw.write(str(auto_parallel_main_prog)) with open("./test_auto_parallel_partitioner_startup_new.txt1", "w") as fw: fw.write(str(auto_parallel_startup_prog)) # with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw: # from paddle.distributed.auto_parallel.completion import Completer # completer = Completer() # completer.complete_forward_annotation(auto_parallel_main_prog) # fw.write(str(auto_parallel_main_prog)) nrank = 4 # col parallel weights = [ 'linear_0.w_0', 'linear_6.w_0', 'linear_10.w_0', ] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 1, nrank)) # row parallel weights = ['word_embeddings', 'linear_9.w_0', 'linear_11.w_0'] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 0, nrank)) weights = ['pos_embeddings', 'layer_norm_0.b_0', 'layer_norm_4.w_0'] self.assertTrue( check_tensor_split(auto_parallel_main_prog, weights, complete_train_program, weights, 0, 1)) all_params = sorted( [param.name for param in startup_program.all_parameters()]) allreduce_grads = [ 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_5.tmp_2', 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' ] process_mesh = _global_process_mesh mp_parallel_axis = 1 dp_parallel_axis = 0 group_ranks = _get_comm_group( process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3) mp_ring_id = new_process_group(group_ranks).id group_ranks = _get_comm_group( process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3) dp_ring_id = new_process_group(group_ranks).id tensor_parallel_allreduce_vars = sorted([ op.desc.output_arg_names()[0].split("@")[0] for op in auto_parallel_main_prog.global_block().ops if (op.type == "c_allreduce_sum" and op.attr('op_role') == 1 and op.desc.attr("ring_id") == mp_ring_id) ]) data_parallel_allreduce_vars = sorted([ op.desc.output_arg_names()[0].split("@")[0] for op in auto_parallel_main_prog.global_block().ops if (op.type == "c_allreduce_sum" and op.desc.attr("ring_id") == dp_ring_id) ]) self.assertTrue(all_params == data_parallel_allreduce_vars) self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars) self.assertTrue( is_valid_completed_program(dist_context, auto_parallel_main_prog))