def test_softmax_compatible(self):
        valid_op_dist_attr_list = []
        program = paddle.static.Program()
        startup_program = paddle.static.Program()
        loss, program, start_program = mlp_forward(program, startup_program)
        ops = program.global_block().ops
        for idx, op in enumerate(ops):
            if op.type == 'softmax':
                dist_op_impl_container = get_distributed_operator_impl_container(
                    op.type)
                impls = dist_op_impl_container.impls
                op_dist_attr = OperatorDistributedAttribute()
                op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
                                                    [-1, -1])
                op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
                                                     [-1, -1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertTrue(impls[0].is_auto_compatible(dist_op))
                op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
                                                     [1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertFalse(impls[0].is_auto_compatible(dist_op))

                op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
                                                    [-1, 1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertFalse(impls[0].is_auto_compatible(dist_op))
                op.all_attrs()['axis'] = 2
                self.assertFalse(impls[0].is_auto_compatible(dist_op))
    def test_update(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        _, train_program, startup_program = mlp_forward(
            train_program, startup_program)
        global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
        dist_context = DistributedContext()
        set_default_dist_attr(train_program, dist_context, global_process_mesh)
        ops = train_program.global_block().ops
        vars = train_program.global_block().vars
        from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
        from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
        from paddle.distributed.auto_parallel.dist_op import DistributedOperator

        for op in ops:
            dist_op_impl_container = get_distributed_operator_impl_container(
                op.type)
            if dist_op_impl_container is None:
                op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
                dist_op = DistributedOperator(op, op_dist_attr)
                if is_elementwise_op(op.type):
                    changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
                            -1 for i in range(
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)
                else:
                    changed = update_op_dims_mapping_by_default_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
                            -1 for i in range(
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_default_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)
 def test_matmulv2_matmul_0_compatible(self):
     valid_op_dist_attr_list = []
     program = paddle.static.Program()
     startup_program = paddle.static.Program()
     loss, program, start_program = mlp_forward(program, startup_program)
     with static.program_guard(program,
                               start_program), utils.unique_name.guard():
         matmulx3 = static.data(name="matmulx3",
                                shape=[6, 2, 6],
                                dtype='float32')
         matmuly3 = static.data(name="matmuly3",
                                shape=[6, 6],
                                dtype='float32')
         output1 = paddle.matmul(x=matmulx3, y=matmuly3)
         output_1 = layers.matmul(x=matmulx3, y=matmuly3)
         matmulx4 = static.data(name="matmulx4",
                                shape=[6, 6, 2, 6],
                                dtype='float32')
         matmuly4 = static.data(name="matmuly4",
                                shape=[6, 6, 6, 6],
                                dtype='float32')
         output2 = paddle.matmul(x=matmulx4, y=matmuly4)
         output_2 = layers.matmul(x=matmulx4, y=matmuly4)
     ops = program.global_block().ops
     vars = program.global_block().vars
     for idx, op in enumerate(ops):
         if op.type == 'matmul_v2' or op.type == 'matmul':
             dist_op_impl_container = get_distributed_operator_impl_container(
                 op.type)
             impls = dist_op_impl_container.impls
             op_dist_attr = OperatorDistributedAttribute()
             X = op.input_arg_names[0]
             Y = op.input_arg_names[1]
             out = op.output_arg_names[0]
             if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [0, 0])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [0, -1])
                 op_dist_attr.set_output_dims_mapping(out, [1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
             if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 0, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
             if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1, 1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))