def get_dist_prog(train_program, startup_program, dist_context, rank_id): loss, train_program, startup_program = mlp_forward(train_program, startup_program) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() parallelizer = AutoParallelizer(fleet) parallelizer._dist_context = dist_context # auto completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) params_grads = parallelizer._generate_backward( complete_train_program, startup_program, loss, parameter_list=None, no_grad_set=None, callbacks=None) partitioner = Partitioner(dist_context, rank_id) dist_train_program, dist_startup_prog, dist_params_grads = partitioner.partition( complete_train_program, startup_program, params_grads) partitioned_optimize_ops = parallelizer._apply_optimize( dist_train_program, dist_startup_prog, dist_params_grads) resharder = Resharder(dist_train_program, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() return dist_train_program, dist_startup_prog
def test_mlp_pp(self): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global PP_MESH_0 PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 PP_MESH_1 = auto.ProcessMesh(mesh=[1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # parameter initialization of every rank should be different in the pipeline scene self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() process_mesh = auto.ProcessMesh(mesh=[0, 3]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = auto.shard_tensor(x, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [0, -1] }) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = auto.shard_tensor(w, dist_attr={ "process_mesh": process_mesh, "dims_mapping": [-1, -1] }) # y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { # x.name: [-1, -1], # w.name: [-1, -1] # }, **{"x": x, # "y": w})[0] y = paddle.distributed.shard_op(paddle.matmul, dist_attr={ "process_mesh": process_mesh, x: { "dims_mapping": [-1, -1] }, w: { "dims_mapping": [-1, -1] } })(x, w)[0] rank_id = 0 dist_context = DistributedContext() dist_strategy = fleet.DistributedStrategy() partitioner = Partitioner(dist_context, rank_id) completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) resharder = Resharder(partitioned_main_prog, partitioned_startup_prog, rank_id, dist_context, partitioned_params_grads) resharder.reshard() # the x should not be slice self.assertTrue(check_allgather(partitioned_main_prog))
def test_mlp_dpmppp(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # check parameter initialization self.assertTrue(check_initialization_for_dpmppp(dist_startup_prog))
def test_mlp_pp_diff_process_mesh(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) for key in list(_g_process_group_map.keys()): del _g_process_group_map[key] resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
def test_mlp_mppp(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 2 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) # parameter which not been sliced should be the same in the mp scene self.assertTrue( check_initialization_for_mppp(dist_startup_prog, rank_id))
def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) # all parameters should be initialized in dp scene self.assertTrue(check_initialization_for_dp(dist_startup_prog))
def test_auto_parallel_cost_model(self): standalone_cost_data = get_single_node_data() dist_program = [] for rank_id in range(NUM_RANKS): train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) resharder = Resharder(distributed_program, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() dist_program.append(distributed_program) cluster = None cost = estimate_cost(dist_program, cluster=cluster, pipeline_config=pp_cfg, standalone_cost_data=standalone_cost_data, batch_size=4) self.assertTrue(check_runtime_estimation(cost)) self.assertTrue(check_memory_estimation(cost))