def test_softmax_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() startup_program = paddle.static.Program() loss, program, start_program = mlp_forward(program, startup_program) ops = program.global_block().ops for idx, op in enumerate(ops): if op.type == 'softmax': dist_op_impl_container = get_distributed_operator_impl_container( op.type) impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], [-1, -1]) dist_op = DistributedOperator(op, op_dist_attr) self.assertTrue(impls[0].is_auto_compatible(dist_op)) op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], [1]) dist_op = DistributedOperator(op, op_dist_attr) self.assertFalse(impls[0].is_auto_compatible(dist_op)) op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, 1]) dist_op = DistributedOperator(op, op_dist_attr) self.assertFalse(impls[0].is_auto_compatible(dist_op)) op.all_attrs()['axis'] = 2 self.assertFalse(impls[0].is_auto_compatible(dist_op))
def test_update(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() _, train_program, startup_program = mlp_forward( train_program, startup_program) global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) dist_context = DistributedContext() set_default_dist_attr(train_program, dist_context, global_process_mesh) ops = train_program.global_block().ops vars = train_program.global_block().vars from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container from paddle.distributed.auto_parallel.operators.common import is_elementwise_op from paddle.distributed.auto_parallel.dist_op import DistributedOperator for op in ops: dist_op_impl_container = get_distributed_operator_impl_container( op.type) if dist_op_impl_container is None: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) dist_op = DistributedOperator(op, op_dist_attr) if is_elementwise_op(op.type): changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) self.assertFalse(changed) dist_op.dist_attr.set_output_dims_mapping( op.output_arg_names[0], [0] + [ -1 for i in range( 1, len(vars[op.output_arg_names[0]].shape)) ]) try: changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dist_op) except: continue self.assertTrue(changed) else: changed = update_op_dims_mapping_by_default_dist_impl( dist_op) self.assertFalse(changed) dist_op.dist_attr.set_output_dims_mapping( op.output_arg_names[0], [0] + [ -1 for i in range( 1, len(vars[op.output_arg_names[0]].shape)) ]) try: changed = update_op_dims_mapping_by_default_dist_impl( dist_op) except: continue self.assertTrue(changed)
def test_matmulv2_matmul_0_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() startup_program = paddle.static.Program() loss, program, start_program = mlp_forward(program, startup_program) with static.program_guard(program, start_program), utils.unique_name.guard(): matmulx3 = static.data(name="matmulx3", shape=[6, 2, 6], dtype='float32') matmuly3 = static.data(name="matmuly3", shape=[6, 6], dtype='float32') output1 = paddle.matmul(x=matmulx3, y=matmuly3) output_1 = layers.matmul(x=matmulx3, y=matmuly3) matmulx4 = static.data(name="matmulx4", shape=[6, 6, 2, 6], dtype='float32') matmuly4 = static.data(name="matmuly4", shape=[6, 6, 6, 6], dtype='float32') output2 = paddle.matmul(x=matmulx4, y=matmuly4) output_2 = layers.matmul(x=matmulx4, y=matmuly4) ops = program.global_block().ops vars = program.global_block().vars for idx, op in enumerate(ops): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] out = op.output_arg_names[0] if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2: op_dist_attr.set_input_dims_mapping(X, [-1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, 1]) self.assertTrue(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [0, 0]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [0, -1]) op_dist_attr.set_output_dims_mapping(out, [1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2: op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1]) self.assertTrue(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, 0, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [-1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [-1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4: op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) self.assertTrue(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, 1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [-1, 1, 1, 1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) self.assertFalse(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr)))