def test_update(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() _, train_program, startup_program = mlp_forward( train_program, startup_program) global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) dist_context = DistributedContext() set_default_dist_attr(train_program, dist_context, global_process_mesh) ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.distributed.auto_parallel.operators.common import is_elementwise_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: dist_op_impl_container = get_distributed_operator_impl_container( op.type) if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) if is_elementwise_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed) dist_op.dist_attr.set_output_dims_mapping( op.output_arg_names[0], [0] + [ -1 for i in range( 1, len(vars[op.output_arg_names[0]].shape)) ]) try: changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) except: continue self.assertTrue(changed) else: changed = update_op_dims_mapping_by_default_dist_impl( dist_op) self.assertFalse(changed) dist_op.dist_attr.set_output_dims_mapping( op.output_arg_names[0], [0] + [ -1 for i in range( 1, len(vars[op.output_arg_names[0]].shape)) ]) try: changed = update_op_dims_mapping_by_default_dist_impl( dist_op) except: continue self.assertTrue(changed)
def test_decoder_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() train_program, start_program = decoder_pretrain_forward( train_program, start_program) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_new_local_tensor(self): test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( mesh=[0, 1]) test_auto_parallel_reshard._global_parallel_strategy = "dp" train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog( train_program, startup_program, dist_context, rank_id) dist_context.dist_main_programs[rank_id] = dist_main_prog dist_context.dist_startup_programs[rank_id] = dist_startup_prog name = "layer_norm_1.tmp_2" dist_tensor = dist_context.get_dist_tensor_for_program( complete_train_program.global_block().vars[name]) dist_tensor._dist_context = dist_context intermediate_var_0 = dist_tensor.new_local_tensor( name="intermediate_var_0") self.assertEqual(intermediate_var_0.shape, (2, 1024)) self.assertEqual(intermediate_var_0.name, "intermediate_var_0") rank_id = 1 train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_main_prog, dist_startup_prog, _ = get_dist_prog( train_program, startup_program, dist_context, rank_id, complete_train_program) dist_context.dist_main_programs[rank_id] = dist_main_prog dist_context.dist_startup_programs[rank_id] = dist_startup_prog name = "layer_norm_1.tmp_2" dist_tensor = dist_context.get_dist_tensor_for_program( complete_train_program.global_block().vars[name]) dist_tensor._dist_context = dist_context intermediate_var_1 = dist_tensor.new_local_tensor( rank=rank_id, name="intermediate_var_1") self.assertEqual(intermediate_var_0.shape, (2, 1024)) self.assertEqual(intermediate_var_1.name, "intermediate_var_1") name = "linear_0.w_0" dist_tensor = dist_context.get_dist_tensor_for_program( complete_train_program.global_block().vars[name]) dist_tensor._dist_context = dist_context intermediate_var_1 = dist_tensor.new_local_tensor( rank=rank_id, name="linear_0.w_0_intermediate") self.assertEqual(intermediate_var_1.shape, (1024, 4096)) self.assertEqual(intermediate_var_1.name, "linear_0.w_0_intermediate") copied_dist_context = copy.deepcopy(dist_context) self.assertIsNotNone(copied_dist_context) self.assertEqual( id(copied_dist_context), id( copied_dist_context.get_dist_tensor_for_program( dist_tensor.serial_tensor).dist_context))
def test_attn_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() train_program, start_program = attn_pretrain_forward( train_program, start_program) complete_train_program = auto.complete_annotation( train_program, dist_context) # print_program_with_dist_attr(complete_train_program, # dist_context) self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_pp(self): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global PP_MESH_0 PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 PP_MESH_1 = auto.ProcessMesh(mesh=[1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # parameter initialization of every rank should be different in the pipeline scene self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_loss_and_grad_allreduce(self): dist_context = DistributedContext(self.main_program, self.startup_program) completer = Completer(dist_context) completer.complete_prim_annotation(self.main_program) dist_context.block_state.parse_forward_blocks(self.main_program) dist_context.block_state.parse_backward_blocks(self.main_program) dist_context.grads_params = dict() dist_context.grads_params[self.w_grad.name] = self.w.name dist_context.synced_gradient = set() dist_context.data_parallel_group = list(range(nranks)) partitioner = Partitioner(dist_context, rank) dist_main_prog, dist_startup_prog, _ = partitioner.partition( self.main_program, self.startup_program, [(self.w, self.w_grad)]) ops = dist_main_prog.global_block().ops self.assertTrue(ops[1].type == "c_allreduce_sum") self.assertTrue(ops[3].type == "c_allreduce_sum")
def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = auto.shard_tensor(x, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [0, -1] }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = auto.shard_tensor(w, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [-1, -1] }) # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { # x.name: [-1, -1], # w.name: [-1, -1] # }, **{"x": x, # "y": w})[0] y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, x: { "dims_mapping": [-1, -1] }, w: { "dims_mapping": [-1, -1] } })(x, w)[0] rank_id = 0 dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) resharder = Resharder(partitioned_main_prog, partitioned_startup_prog, rank_id, dist_context, partitioned_params_grads) resharder.reshard() # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog))
def test_mlp_dpmppp(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # check parameter initialization self.assertTrue(check_initialization_for_dpmppp(dist_startup_prog))
def test_mlp_mppp(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # parameter which not been sliced should be the same in the mp scene self.assertTrue( check_initialization_for_mppp(dist_startup_prog, rank_id))
def test_mlp_pp_diff_process_mesh(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def test_mapper_dp_mp_pp(self): cluster_json_file = "" cluster_json_object = json.loads(cluster_json) with open("./auto_parallel_cluster.json", "w") as cluster_json_file: json.dump(cluster_json_object, cluster_json_file) cluster = Cluster() cluster.build_from_file("./auto_parallel_cluster.json") os.remove("./auto_parallel_cluster.json") global _global_parallel_strategy _global_parallel_strategy = "dp_mp_pp" global _global_num_stages _global_num_stages = 2 global _global_process_mesh _global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] processes = [0, 1, 2, 3, 4, 5, 6, 7] dist_programs = {} for rank_id in processes: train_program = static.Program() startup_program = static.Program() dist_context = DistributedContext() dist_train_program, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) # if rank_id == 0: # print_program_with_dist_attr(dist_train_program, dist_context) dist_programs[rank_id] = [dist_train_program, None] rank_mapping = mapping(dist_programs, cluster) all_mapped_ranks = set() for machine_id, machine_mapping in rank_mapping.items(): machine = cluster.machines[machine_id] machine_mapped_ranks = set() machine_mapped_device_local_ids = set() for rank, device_ids in machine_mapping["ranks"].items(): # Only allow one process to one device mapping self.assertEqual(len(device_ids), 1) self.assertTrue(is_in_machine(device_ids[0], machine)) machine_mapped_ranks.add(rank) machine_mapped_device_local_ids.add(device_ids[0]) self.assertEqual(len(machine_mapped_ranks), len(machine_mapped_device_local_ids)) all_mapped_ranks.update(machine_mapped_ranks) self.assertEqual(set(processes), all_mapped_ranks)
def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) # all parameters should be initialized in dp scene self.assertTrue(check_initialization_for_dp(dist_startup_prog))
def test_auto_parallel_cost_model(self): standalone_cost_data = get_single_node_data() dist_program = [] for rank_id in range(NUM_RANKS): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() distributed_program, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, rank_id) reshard(distributed_program, dist_startup_prog, rank_id, dist_context) dist_program.append(distributed_program) cluster = None cost = estimate_cost(dist_program, cluster=cluster, pipeline_config=pp_cfg, standalone_cost_data=standalone_cost_data, batch_size=4) self.assertTrue(check_runtime_estimation(cost)) self.assertTrue(check_memory_estimation(cost))
def test_complete_backward_annotation(self): global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog = get_dist_prog( train_program, startup_program, dist_context, 0) op_need_check = None for op in dist_main_prog.global_block().ops: if op.type == "gelu_grad": op_need_check = op break # print_program_with_dist_attr(dist_main_prog, dist_context) # grad op should have dist attr self.assertTrue( check_backward_dist_attr(dist_context, dist_main_prog, op_need_check))
def parallelizer(program_func, rank): from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.dist_context import DistributedContext main_program, start_program, loss = program_func() dist_context = DistributedContext() completer = Completer(dist_context) completer.complete_forward_annotation(main_program) dist_context.block_state.parse_forward_blocks(main_program) with program_guard(main_program, start_program): params_grads = append_backward( loss, distop_context=dist_context.dist_op_context) completer.complete_backward_annotation(main_program) dist_context.block_state.parse_backward_blocks(main_program) partitioner = Partitioner(dist_context, rank) dist_main_prog, _, _ = partitioner.partition(main_program, start_program, []) return dist_main_prog, dist_context
def test_completer(self): train_program, start_program, dataloader, i, loss = get_program() dist_context = DistributedContext() completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program)
def test_backup_restore(self): train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program( ) dist_context = DistributedContext(train_program, start_program, optimizer, loss, feed_vars, fetch_vars) dist_context.initialize() dist_context._backup(serial=True, dist=True) dist_context._restore( serial=True, serial_mode="to_backup", dist=True, dist_mode="to_backup") dist_context._backup(serial=True, dist=True) dist_context._restore( serial=True, serial_mode="to_original", dist=True, dist_mode="to_original") dist_context._backup(serial=True, dist=True) dist_context._restore(serial=True, dist=True, dist_mode="to_default") dist_context._backup(serial=True, dist=True) dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")