Esempio n. 1
0
def _get_dist_op_backward_implement(backward_op, dist_context,
                                    forward_op_id2forward_op):
    dist_op_context = dist_context.dist_op_context
    if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
        forward_op_id = dist_op_context.grad_op_id_to_op_id[
            backward_op.desc.original_id()]
        forward_op = forward_op_id2forward_op[forward_op_id]
        forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
            forward_op)
        dist_op_impl_container = get_distributed_operator_impl_container(
            forward_op_dist_attr.impl_type)
        dist_op_impl = dist_op_impl_container.get_impl(
            forward_op_dist_attr.impl_idx)
        return dist_op_impl

    # # NOTE trick for dist ops that only have backward implement
    if backward_op.type in BACKWARD_ONLY_DIST_OPS:
        op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
        assert op_dist_attr.impl_idx >= 0
        dist_op_impl = get_distributed_operator_impl_container(
            op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx)
        return dist_op_impl

    dist_op = get_distributed_operator_impl_container("default")
    return dist_op.get_impl(0)
    def test_softmax_compatible(self):
        valid_op_dist_attr_list = []
        program = paddle.static.Program()
        startup_program = paddle.static.Program()
        loss, program, start_program = mlp_forward(program, startup_program)
        ops = program.global_block().ops
        for idx, op in enumerate(ops):
            if op.type == 'softmax':
                dist_op_impl_container = get_distributed_operator_impl_container(
                    op.type)
                impls = dist_op_impl_container.impls
                op_dist_attr = OperatorDistributedAttribute()
                op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
                                                    [-1, -1])
                op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
                                                     [-1, -1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertTrue(impls[0].is_auto_compatible(dist_op))
                op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
                                                     [1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertFalse(impls[0].is_auto_compatible(dist_op))

                op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
                                                    [-1, 1])
                dist_op = DistributedOperator(op, op_dist_attr)
                self.assertFalse(impls[0].is_auto_compatible(dist_op))
                op.all_attrs()['axis'] = 2
                self.assertFalse(impls[0].is_auto_compatible(dist_op))
    def test_update(self):
        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        _, train_program, startup_program = mlp_forward(
            train_program, startup_program)
        global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
        dist_context = DistributedContext()
        set_default_dist_attr(train_program, dist_context, global_process_mesh)
        ops = train_program.global_block().ops
        vars = train_program.global_block().vars
        from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
        from paddle.distributed.auto_parallel.operators.common import is_elementwise_op
        from paddle.distributed.auto_parallel.dist_op import DistributedOperator

        for op in ops:
            dist_op_impl_container = get_distributed_operator_impl_container(
                op.type)
            if dist_op_impl_container is None:
                op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
                dist_op = DistributedOperator(op, op_dist_attr)
                if is_elementwise_op(op.type):
                    changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
                            -1 for i in range(
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)
                else:
                    changed = update_op_dims_mapping_by_default_dist_impl(
                        dist_op)
                    self.assertFalse(changed)

                    dist_op.dist_attr.set_output_dims_mapping(
                        op.output_arg_names[0], [0] + [
                            -1 for i in range(
                                1, len(vars[op.output_arg_names[0]].shape))
                        ])
                    try:
                        changed = update_op_dims_mapping_by_default_dist_impl(
                            dist_op)
                    except:
                        continue
                    self.assertTrue(changed)
 def test_matmulv2_matmul_0_compatible(self):
     valid_op_dist_attr_list = []
     program = paddle.static.Program()
     startup_program = paddle.static.Program()
     loss, program, start_program = mlp_forward(program, startup_program)
     with static.program_guard(program,
                               start_program), utils.unique_name.guard():
         matmulx3 = static.data(name="matmulx3",
                                shape=[6, 2, 6],
                                dtype='float32')
         matmuly3 = static.data(name="matmuly3",
                                shape=[6, 6],
                                dtype='float32')
         output1 = paddle.matmul(x=matmulx3, y=matmuly3)
         output_1 = layers.matmul(x=matmulx3, y=matmuly3)
         matmulx4 = static.data(name="matmulx4",
                                shape=[6, 6, 2, 6],
                                dtype='float32')
         matmuly4 = static.data(name="matmuly4",
                                shape=[6, 6, 6, 6],
                                dtype='float32')
         output2 = paddle.matmul(x=matmulx4, y=matmuly4)
         output_2 = layers.matmul(x=matmulx4, y=matmuly4)
     ops = program.global_block().ops
     vars = program.global_block().vars
     for idx, op in enumerate(ops):
         if op.type == 'matmul_v2' or op.type == 'matmul':
             dist_op_impl_container = get_distributed_operator_impl_container(
                 op.type)
             impls = dist_op_impl_container.impls
             op_dist_attr = OperatorDistributedAttribute()
             X = op.input_arg_names[0]
             Y = op.input_arg_names[1]
             out = op.output_arg_names[0]
             if len(vars[X].shape) == 2 and len(vars[Y].shape) == 2:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [0, 0])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [0, -1])
                 op_dist_attr.set_output_dims_mapping(out, [1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
             if len(vars[X].shape) == 3 and len(vars[Y].shape) == 2:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 0, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
             if len(vars[X].shape) == 4 and len(vars[Y].shape) == 4:
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
                 self.assertTrue(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, 1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(X, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, 1, 1, 1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_output_dims_mapping(out, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
                 op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
                 self.assertFalse(impls[0].is_auto_compatible(
                     DistributedOperator(op, op_dist_attr)))
Esempio n. 5
0
def _get_dist_op_forward_implement(forward_op, dist_context):
    dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
    dist_op_impl_container = get_distributed_operator_impl_container(
        dist_attr.impl_type)
    dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
    return dist_op_impl
Esempio n. 6
0
    def partition_block(self, ref_block, target_block):

        dist_op_context = self._dist_context.dist_op_context
        serial_ops = ref_block.ops

        last_fwd_op_idx = -1
        for idx, op in enumerate(ref_block.ops):
            if is_loss_op(op):
                last_fwd_op_idx = idx
                break

        if last_fwd_op_idx == -1:
            last_fwd_op_idx = len(ref_block.ops)

        # init mapping
        forward_op_id2forward_op = {}
        for idx in range(len(serial_ops)):
            if idx <= last_fwd_op_idx:
                forward_op_id2forward_op[
                    serial_ops[idx].desc.original_id()] = serial_ops[idx]

        appended_grad_times = 0
        # partiiton
        for idx, op in enumerate(serial_ops):

            if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1])
                                       or is_loss_op(serial_ops[idx - 1])):
                appended_grad_times += 1

            # partititon input variables
            for serial_input_varname in op.desc.input_arg_names():
                if serial_input_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_input_varname + self._dist_varname_suffix
                    if ref_block.has_var(serial_input_varname):
                        _partition_var(self._dist_context, ref_block,
                                       target_block, serial_input_varname,
                                       new_varname)
                    else:
                        assert serial_input_varname in __varname_not_in_block__

                    self._serial2dist_varname_mapping[
                        serial_input_varname] = new_varname

            # partition output vars
            for serial_output_varname in op.desc.output_arg_names():
                if serial_output_varname not in self._serial2dist_varname_mapping:
                    new_varname = serial_output_varname + self._dist_varname_suffix
                    _partition_var(self._dist_context, ref_block, target_block,
                                   serial_output_varname, new_varname)
                    self._serial2dist_varname_mapping[
                        serial_output_varname] = new_varname

            # partition op
            op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
            if is_forward_op(op) or op_dist_attr.is_recompute:
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_forward_impl = _get_dist_op_forward_implement(
                    op, self._dist_context)
                dist_op_forward_impl.forward(self._dist_context, **kinputs,
                                             **koutputs)

            elif is_backward_op(op):
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_backward_impl = _get_dist_op_backward_implement(
                    op, self._dist_context, forward_op_id2forward_op)
                grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[
                    appended_grad_times]
                dist_op_backward_impl.backward(
                    self._dist_context, **kinputs, **koutputs,
                    **{"grad_var_to_var": grad_var_to_var})
            elif is_optimize_op(op):
                kinputs, koutputs = dist_op_context.prepare_context(op)
                dist_op_impl = get_distributed_operator_impl_container(
                    "default").get_impl(0)
                dist_op_impl.backward(self._dist_context, **kinputs,
                                      **koutputs)
            else:
                raise NotImplementedError(
                    "partitioner only support forward and backward, optimize ops, but got {}"
                    .format(str(op)))