Пример #1
0
def _insert_reduce_op(block,
                      insert_idx,
                      reduce_var,
                      ring_id,
                      root_id,
                      dist_context,
                      op_role=OpRole.Backward,
                      use_calc_stream=True):
    assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
        root_id)
    new_op = block._insert_op_without_sync(insert_idx,
                                           type='c_reduce_sum',
                                           inputs={'X': [reduce_var]},
                                           outputs={'Out': [reduce_var]},
                                           attrs={
                                               'ring_id': ring_id,
                                               'root_id': root_id,
                                               'use_calc_stream':
                                               use_calc_stream,
                                               OP_ROLE_KEY: op_role
                                           })

    dist_attr = dist_context.get_tensor_dist_attr_for_program(
        block.var(reduce_var))
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
        new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context)
Пример #2
0
    def _insert_optimizer_broadcasts(self, main_block, startup_block):

        if self.stage > 2:
            return

        for sharding_info in self.sharding_infos:
            for param in sharding_info.params:
                assert main_block.has_var(param.name)
                assert startup_block.has_var(param.name)

                new_op = main_block.append_op(type='c_broadcast',
                                              inputs={'X': param},
                                              outputs={'Out': param},
                                              attrs={
                                                  'ring_id':
                                                  sharding_info.group.id,
                                                  'root':
                                                  sharding_info.get_var_rank(
                                                      param.name),
                                                  'use_calc_stream':
                                                  True,
                                                  OP_ROLE_KEY:
                                                  OpRole.Optimize
                                              })
                param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                    param)
                assert param_dist_attr is not None
                naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                    new_op, param_dist_attr.process_mesh,
                    param_dist_attr.dims_mapping, self._dist_context)
        main_block._sync_with_cpp()
Пример #3
0
def _update_backward_cast_ops(params_grads, dist_context):
    """
    move param grad cast to the end of backward segment
    in order to enabel fp16 allreduce
    """
    # TODO filter optimize ops in future

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    for p, g in params_grads:
        op = g.op
        if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
            if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr(
                    'op_role_var'):
                op._remove_attr("op_role_var")

            post_ops = find_true_post_op(main_block.ops, op, g.name)
            if post_ops:
                raise ValueError("The cast op {0}'s output should not be"
                                 "used by a non-optimize op, however, it"
                                 "is used by {1}".format(op, post_ops[0]))

            if op == main_block.ops[-1]:
                continue

            # add new op in the python and cpp at the same time
            new_op_desc = main_block.desc.append_op()
            new_op_desc.copy_from(op.desc)
            new_op = paddle.fluid.framework.Operator(
                block=main_block,
                desc=new_op_desc,
                type=None,
                inputs=None,
                outputs=None,
                attrs=None)
            main_block.ops.append(new_op)

            # dist attr
            param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p)
            output_dist_attr = dist_context.get_tensor_dist_attr_for_program(
                main_block.var(op.output_arg_names[0]))
            assert param_dist_attr is not None
            assert output_dist_attr is not None
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                new_op, param_dist_attr.process_mesh,
                param_dist_attr.dims_mapping, dist_context)

            output_dist_attr.process_mesh = param_dist_attr.process_mesh
            output_dist_attr.dims_mapping = param_dist_attr.dims_mapping

            op_idx = find_op_index(main_block.desc, op.desc)
            if op_idx == -1:
                raise ValueError("The op {0} is not in program".format(op))
            main_block._remove_op(op_idx, sync=False)

    main_block._sync_with_cpp()
Пример #4
0
def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank,
                                  root_rank, ring_id, op_role, dist_context):
    """
    empty op for initialization
    """
    broadcast_var = block.var(varname)
    broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
        broadcast_var)

    new_op = block._insert_op_without_sync(insert_idx,
                                           type='c_broadcast',
                                           inputs={'X': varname},
                                           outputs={'Out': varname},
                                           attrs={
                                               'ring_id': ring_id,
                                               'root': root_rank,
                                               'use_calc_stream': True,
                                               OP_ROLE_KEY: op_role
                                           })
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
        new_op, broadcast_var_dist_attr.process_mesh,
        broadcast_var_dist_attr.dims_mapping, dist_context)
    if local_rank != root_rank:

        new_op = block._insert_op_without_sync(
            insert_idx,
            type="empty",
            outputs={"Out": broadcast_var.name},
            attrs={
                "shape": broadcast_var.shape,
                "dtype": broadcast_var.dtype,
                OP_ROLE_KEY: op_role
            })
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
            new_op, broadcast_var_dist_attr.process_mesh,
            broadcast_var_dist_attr.dims_mapping, dist_context)
    return
Пример #5
0
    def _shard_gradient_clip(self, main_block):

        if self.stage < 2:
            return

        # TODO (JZ-LIANG) support calculate global norm with tensor parallelism
        removed_op_type = [
            'elementwise_mul', 'squared_l2_norm', 'clip_by_norm'
        ]
        removed_op_idx = set()
        removed_tmp_var = set()

        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue

            if op.type in removed_op_type:
                input_name = op.input("X")[0]
                param_name = input_name[:input_name.find("@GRAD")]
                if not self._is_parameter_in_local_shard(param_name):
                    removed_op_idx.add(idx)
                    if op.type in ['squared_l2_norm', 'clip_by_norm']:
                        for output_name in op.output_arg_names:
                            removed_tmp_var.add(output_name)

        for idx, op in reversed(list(enumerate(main_block.ops))):
            if not _is_gradient_clip_op(op):
                continue
            if idx in removed_op_idx:
                main_block._remove_op(idx, sync=False)

        for varname in removed_tmp_var:
            main_block._remove_var(varname, sync=False)

        for idx, op in list(enumerate(main_block.ops)):
            if not _is_gradient_clip_op(op):
                continue
            if op.type == 'sum':
                reserved_vars = []
                for input_name in op.input_arg_names:
                    if input_name not in removed_tmp_var:
                        reserved_vars.append(input_name)
                op.desc.set_input("X", reserved_vars)

                sum_op_output = op.desc.output_arg_names()[0]
                for i, sharding_info in enumerate(self.sharding_infos):
                    new_op = main_block._insert_op(
                        idx + i + 1,
                        type='c_allreduce_sum',
                        inputs={'X': [sum_op_output]},
                        outputs={'Out': [sum_op_output]},
                        attrs={
                            'ring_id': sharding_info.group.id,
                            'op_namescope': "/gradient_clip_model_parallelism",
                            'use_calc_stream': True,
                            OP_ROLE_KEY: OpRole.Optimize,
                        })
                    dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                        main_block.var(sum_op_output))
                    assert dist_attr is not None
                    naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                        new_op, dist_attr.process_mesh, dist_attr.dims_mapping,
                        self._dist_context)
                break

        main_block._sync_with_cpp()
Пример #6
0
    def modify_forward_desc_for_recompute(self, dist_context):
        """
        If program's foward part has 'dropout' op, this function will insert 
        a seed op before it to guarantee that two dropout op have the same outputs.
        """
        op_types = [op.desc.type() for op in self._ops]
        if "dropout" not in op_types:
            return

        op_idx = 0
        while op_idx < len(self._ops):
            cur_op = self._ops[op_idx]
            if "grad" in cur_op.type:
                break
            if cur_op.type != "dropout":
                op_idx += 1
                continue
            if cur_op.input("Seed") is not None and len(cur_op.input("Seed")):
                op_idx += 1
                continue

            cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(
                cur_op)
            # insert seed op to guarantee that two dropout op have the same outputs
            op_unique_name = unique_name.generate("seed")
            var_unique_name = unique_name.generate_with_ignorable_key(".".join(
                [op_unique_name, 'tmp']))
            seed_var = self._block.create_var(
                name=var_unique_name,
                dtype='int32',
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=False)

            # set new seed_var's dist_attr
            ref_dims_mapping = [-1]
            ref_process_mesh = cur_op_dist_attr.process_mesh
            seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var,
                                                   ref_dims_mapping,
                                                   ref_process_mesh)

            seed = 0 if cur_op.attr("fix_seed") is False else int(
                cur_op.attr("seed"))
            seed_op = self._block._insert_op_without_sync(
                index=cur_op.idx,
                type="seed",
                inputs={},
                outputs={"Out": seed_var},
                attrs={
                    "seed": seed,
                    "force_cpu": True
                })
            # set new seed op's dist_attr
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                seed_op, ref_process_mesh, ref_dims_mapping, dist_context)

            # modify dropout op's desc
            self._ops.insert(op_idx, seed_op)
            cur_op.desc.set_input("Seed", [var_unique_name])
            cur_op.desc.remove_attr("fix_seed")
            cur_op.desc.remove_attr("seed")
            cur_op_dist_attr.set_input_dist_attr(seed_var.name,
                                                 seed_var_dist_attr)
            self._block._sync_with_cpp()
            op_idx += 2
def _get_gm_cond_var(main_program, k_steps, dist_context):
    main_block = main_program.global_block()
    # Add const var
    k_step_var = layers.create_global_var(name="gradient_merge_k",
                                          shape=[1],
                                          value=int(k_steps),
                                          dtype='int32',
                                          persistable=True,
                                          force_cpu=True)
    set_var_dist_attr(dist_context, k_step_var, [-1],
                      world_process_group.ranks)

    zero_var = layers.create_global_var(name="gradient_merge_zero",
                                        shape=[1],
                                        value=int(0),
                                        dtype='int32',
                                        persistable=True,
                                        force_cpu=True)
    set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks)

    # Add step var & cond var
    step_var = layers.create_global_var(name="gradient_merge_step",
                                        shape=[1],
                                        value=int(0),
                                        dtype='int32',
                                        persistable=True,
                                        force_cpu=True)
    set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)

    cond_var = main_block.create_var(name="gradient_merge_cond",
                                     shape=[1],
                                     dtype='bool')
    set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)

    with device_guard("cpu"):
        # step_var = (step_var + 1) % k_step
        layers.increment(x=step_var, value=1.0, in_place=True)
        elementwise_mod_op = main_block.append_op(type='elementwise_mod',
                                                  inputs={
                                                      'X': step_var,
                                                      'Y': k_step_var
                                                  },
                                                  outputs={'Out': step_var},
                                                  attrs={
                                                      'axis': -1,
                                                      'use_mkldnn': False
                                                  })
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
            elementwise_mod_op, world_process_group.ranks, [-1], dist_context)

        # cond_var = (step_var == 0)
        equal_op = main_block.append_op(type='equal',
                                        inputs={
                                            'X': step_var,
                                            'Y': zero_var
                                        },
                                        outputs={'Out': cond_var})
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
            equal_op, world_process_group.ranks, [-1], dist_context)

    return cond_var
def _append_gradient_merge_backward_op(
        main_program, startup_program, params_grads: List[Tuple[Any, Any]],
        cond_var_name: str,
        dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
    main_block = main_program.global_block()
    startup_block = startup_program.global_block()

    # step1: remove grad.op's op_role_var
    for param, grad in params_grads:
        assert (
            param.type != core.VarDesc.VarType.SELECTED_ROWS
        ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"

        _remove_op_role_var(param, grad)

    param_to_gradient_merge = {}
    new_params_to_grads = []
    # step2: create gradient_merge var and init with 0
    for param, grad in params_grads:
        param_name = param.name
        param_var = main_block.var(param_name)
        assert (param_var is not None)
        ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(
            param_var)
        assert ref_dist_attr is not None
        gradient_merge_var = main_block.create_var(name=param_name +
                                                   "@GRAD@GradientMerge",
                                                   shape=param_var.shape,
                                                   dtype=param_var.dtype,
                                                   persistable=True)
        param_to_gradient_merge[param_name] = gradient_merge_var
        ref_process_mesh = ref_dist_attr.process_mesh
        ref_dims_mapping = ref_dist_attr.dims_mapping

        set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping,
                          ref_process_mesh)

        startup_gradient_merge_var = startup_block.create_var(
            name=param_name + "@GRAD@GradientMerge",
            shape=param_var.shape,
            dtype=param_var.dtype,
            persistable=True)
        startup_block.append_op(type="fill_constant",
                                outputs={"Out": startup_gradient_merge_var},
                                attrs={
                                    "shape": param_var.shape,
                                    "dtype": param_var.dtype,
                                    "value": float(0),
                                })

        # grad_merge += grad
        new_grad_op = main_block.append_op(type="elementwise_add",
                                           inputs={
                                               'X': grad,
                                               'Y': gradient_merge_var
                                           },
                                           outputs={'Out': gradient_merge_var},
                                           attrs={
                                               'axis': -1,
                                               'use_mkldnn': False
                                           })
        new_params_to_grads.append([param, gradient_merge_var])
        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
            new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context)
    return new_params_to_grads, param_to_gradient_merge
Пример #9
0
    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
        OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()

        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
            loss_op)

        if loss.dtype != core.VarDesc.VarType.FP32:
            # cast loss here will change the effective loss tensor for the computation graph
            # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
            # so we it is not allowed by now. fixed it in future.
            raise NotImplementedError(
                "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
            )

            tmp_name = unique_name.generate(loss.name + ".cast_fp32")
            cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
            loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
                loss)
            ref_mesh = loss_op_dist_attr.process_mesh
            self.dist_context.set_tensor_dist_attr_for_program(
                cast_loss, loss_dist_attr)

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
            cast_op = main_block._insert_op(
                loss_op_idx + 1,
                type='cast',
                inputs={'X': [loss]},
                outputs={'Out': [cast_loss]},
                attrs={
                    "in_dtype": loss.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
                })

            loss_op._set_attr(OP_ROLE_KEY,
                              core.op_proto_and_checker_maker.OpRole.Forward)
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                cast_op, ref_mesh, [-1], self.dist_context)
            loss = loss.astype('float32')

        if self.get_attr("use_dynamic_loss_scaling"
                         ) or self.get_attr("init_loss_scaling") != 1.0:

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
            self._scaled_loss = main_block.create_var(
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
                persistable=loss.persistable)
            set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
                              ref_mesh)

            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
                inputs={
                    'X': [loss],
                    'Y': [self._loss_scaling]
                },
                outputs={'Out': [self._scaled_loss]},
                attrs={
                    'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
                })
            loss_op._set_attr(OP_ROLE_KEY,
                              core.op_proto_and_checker_maker.OpRole.Forward)
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                elementwise_mul_op, ref_mesh, [-1], self.dist_context)

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
            assert first_backward_op.type == "fill_constant" and int(
                first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            self._scaled_loss_grad = main_block.create_var(
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
                persistable=loss.persistable)
            set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1],
                              ref_mesh)
            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name,
                                             self._scaled_loss_grad.name)
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
                loss_op_idx + 3)
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
                'Out@GRAD', [self._scaled_loss_grad.name])
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
            elementwise_mul_grad_op_desc.set_input('Y',
                                                   [self._loss_scaling.name])
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
            elementwise_mul_grad_op_desc._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward)
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
            elementwise_mul_grad_op = paddle.fluid.framework.Operator(
                main_block, elementwise_mul_grad_op_desc)
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context)

        else:
            self._scaled_loss = loss

        main_block._sync_with_cpp()
Пример #10
0
    def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype,
                                 dist_context):
        """ only for backward cast """
        def _keep_fp32_input(op, in_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return in_name not in {'X', 'Y@GRAD'}
            return False

        def _keep_fp32_output(op, out_name):
            op_type = op.type
            if op_type in ['layer_norm_grad']:
                return out_name != 'X@GRAD'
            return False

        num_cast_ops = 0
        original_id = grad_op.desc.original_id()
        dist_op_context = dist_context.dist_op_context
        fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]

        for in_name in grad_op.input_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
                    grad_op, in_name):
                for in_var_name in grad_op.input(in_name):
                    in_var = self._block._find_var_recursive(in_var_name)
                    assert in_var.dtype == core.VarDesc.VarType.FP32
                continue

            for in_var_name in grad_op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.dtype == src_dtype:
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
                        grad_op)
                    if in_var_name in self._var_name_dict[fwd_op_id]:
                        # NOTE: if in_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr.
                        cast_name = self._var_name_dict[fwd_op_id][in_var_name]
                        grad_op.desc._rename_input(in_var_name, cast_name)
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
                            in_var_name)
                        consume_op_attr.set_input_dist_attr(
                            cast_name, in_var_dist_attr)
                    else:
                        assert in_var.dtype == dst_dtype

        for out_name in grad_op.output_names:
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
                    grad_op, out_name):
                for out_var_name in grad_op.output(out_name):
                    out_var = self._block._find_var_recursive(out_var_name)
                    assert out_var.dtype == core.VarDesc.VarType.FP32
                continue

            for out_var_name in grad_op.output(out_name):
                out_var = self._block._find_var_recursive(out_var_name)
                out_var_name_prefix = out_var_name[:out_var_name.find("@")]
                fwd_var = self._block._find_var_recursive(out_var_name_prefix)
                # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype
                if out_var.dtype != fwd_var.dtype:
                    out_var.desc.set_dtype(fwd_var.dtype)

                if out_var.dtype == src_dtype:
                    if out_var_name_prefix in self._var_name_dict[fwd_op_id]:
                        # NOTE: if out_var of consume grad_op has been casted before,
                        # it should be renamed and reset dist_attr, then we insert cast op to
                        # convert the cast_var to original dtype
                        consume_op_attr = dist_context.get_op_dist_attr_for_program(
                            grad_op)
                        fwd_cast_name = self._var_name_dict[fwd_op_id][
                            out_var_name_prefix]
                        cast_name = fwd_cast_name + "@GRAD"
                        cast_var = self._block.vars.get(cast_name)
                        if cast_var is None or cast_var.dtype != dst_dtype:
                            grad_op.desc._rename_output(
                                out_var_name, cast_name)
                            out_var_dist_attr = consume_op_attr.get_output_dist_attr(
                                out_var_name)
                            ref_mesh = out_var_dist_attr.process_mesh
                            ref_mapping = out_var_dist_attr.dims_mapping
                            consume_op_attr.set_output_dist_attr(
                                cast_name, out_var_dist_attr)
                            assert ref_mapping is not None
                            cast_var = self._block.create_var(
                                name=cast_name,
                                shape=out_var.shape,
                                dtype=dst_dtype,
                                persistable=False,
                                stop_gradient=out_var.stop_gradient)
                            set_var_dist_attr(dist_context, cast_var,
                                              ref_mapping, ref_mesh)

                            cast_op = self._block._insert_op(
                                idx + 1,
                                type="cast",
                                inputs={"X": cast_var},
                                outputs={"Out": out_var},
                                attrs={
                                    "in_dtype": cast_var.dtype,
                                    "out_dtype": out_var.dtype,
                                    "op_role": OpRole.Backward
                                })
                            cast_op._remove_attr("op_role_var")
                            cast_op._remove_attr("op_namescope")
                            cast_op._remove_attr("with_quant_attr")
                            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                                cast_op, ref_mesh, ref_mapping, dist_context)
                            num_cast_ops += 1
                else:
                    assert out_var.dtype == dst_dtype

        return num_cast_ops
Пример #11
0
    def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype,
                                dist_context):
        """
        only for forward cast
        modified from paddle.fluid.contrib.mixed_precision
        """
        num_cast_ops = 0

        for in_name in op.input_names:
            var_name_dict = {}
            if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
                    op, in_name):
                continue
            for in_var_name in op.input(in_name):
                in_var = self._block._find_var_recursive(in_var_name)
                if in_var.type not in _valid_types or in_var.dtype == dst_dtype:
                    continue
                if in_var.dtype == src_dtype:
                    cast_name = in_var.name + '.cast_' + _dtype_to_str(
                        dst_dtype)
                    out_var = self._block.vars.get(cast_name)
                    var_name_dict[in_var.name] = cast_name
                    consume_op_attr = dist_context.get_op_dist_attr_for_program(
                        op)
                    assert consume_op_attr is not None
                    if out_var is None or out_var.dtype != dst_dtype:
                        # NOTE we make the cast op and var's dist attr as the op that consume the
                        # cast var instead of the op which generates the var
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
                            in_var.name)
                        assert in_var_dist_attr is not None
                        ref_mesh = in_var_dist_attr.process_mesh
                        ref_mapping = in_var_dist_attr.dims_mapping
                        consume_op_attr.set_input_dist_attr(
                            cast_name, in_var_dist_attr)

                        out_var = self._block.create_var(
                            name=cast_name,
                            dtype=dst_dtype,
                            persistable=False,
                            stop_gradient=in_var.stop_gradient)
                        set_var_dist_attr(dist_context, out_var, ref_mapping,
                                          ref_mesh)

                        cast_op = self._block._insert_op_without_sync(
                            idx,
                            type="cast",
                            inputs={"X": in_var},
                            outputs={"Out": out_var},
                            attrs={
                                "in_dtype": in_var.dtype,
                                "out_dtype": out_var.dtype,
                            })
                        naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                            cast_op, ref_mesh, ref_mapping, dist_context)
                        num_cast_ops += 1
                    else:
                        in_var_dist_attr = consume_op_attr.get_input_dist_attr(
                            in_var.name)
                        consume_op_attr.set_input_dist_attr(
                            cast_name, in_var_dist_attr)
                    _rename_arg(op, in_var.name, cast_name)
                else:
                    if op.has_attr('in_dtype'):
                        op._set_attr('in_dtype', dst_dtype)
        self._var_name_dict[op.desc.original_id()] = var_name_dict

        if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16:
            for out_name in op.output_names:
                if _keep_fp32_output(op, out_name):
                    continue
                for out_var_name in op.output(out_name):
                    out_var = self._block.var(out_var_name)
                    if out_var.type not in _valid_types:
                        continue
                    if out_var.dtype == core.VarDesc.VarType.FP32:
                        out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
                        if op.has_attr('out_dtype'):
                            op._set_attr('out_dtype',
                                         core.VarDesc.VarType.FP16)
        return num_cast_ops
Пример #12
0
    def _scale_loss(self):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()
        loss = self.get_attr("loss")
        assert loss is not None
        loss_op = loss.op
        loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
            loss_op)

        if loss.dtype != core.VarDesc.VarType.FP32:
            loss = loss.astype('float32')

        if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
                "init_loss_scaling") != 1.0:

            loss_op_idx = find_op_index(main_block.desc, loss_op.desc)

            # forward
            ref_mesh = loss_op_dist_attr.process_mesh
            self._scaled_loss = main_block.create_var(
                name=unique_name.generate("scaled_loss"),
                shape=loss.shape,
                dtype=loss.dtype,
                persistable=loss.persistable)
            set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
                              ref_mesh)

            OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
            elementwise_mul_op = main_block._insert_op(
                loss_op_idx + 1,
                type='elementwise_mul',
                inputs={'X': [loss],
                        'Y': [self._loss_scaling]},
                outputs={'Out': [self._scaled_loss]},
                attrs={'op_role': loss_op.all_attrs()[OP_ROLE_KEY], })
            loss_op._set_attr(OP_ROLE_KEY,
                              core.op_proto_and_checker_maker.OpRole.Forward)
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                elementwise_mul_op, ref_mesh, [-1], self.dist_context)

            # backward
            first_backward_op = main_block.ops[loss_op_idx + 2]
            assert first_backward_op.type == "fill_constant" and int(
                first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257
            self._scaled_loss_grad = main_block.create_var(
                name=unique_name.generate("scaled_loss") + "@GRAD",
                shape=loss.shape,
                dtype=loss.dtype,
                persistable=loss.persistable)
            set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1],
                              ref_mesh)
            pre_grad_name = first_backward_op.output_arg_names[0]
            first_backward_op._rename_output(pre_grad_name,
                                             self._scaled_loss_grad.name)
            # FIXME(JZ-LIANG) a trick to insert backward op
            main_block._sync_with_cpp()
            elementwise_mul_grad_op_desc = main_block.desc._insert_op(
                loss_op_idx + 3)
            elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad")
            elementwise_mul_grad_op_desc.set_input(
                'Out@GRAD', [self._scaled_loss_grad.name])
            elementwise_mul_grad_op_desc.set_input('X', [loss.name])
            elementwise_mul_grad_op_desc.set_input('Y',
                                                   [self._loss_scaling.name])
            elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name])
            elementwise_mul_grad_op_desc.set_output('Y@GRAD', [])
            elementwise_mul_grad_op_desc._set_attr(
                OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward)
            elementwise_mul_grad_op_desc._set_attr('axis', -1)
            elementwise_mul_grad_op = paddle.fluid.framework.Operator(
                main_block, elementwise_mul_grad_op_desc)
            main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op)
            main_block._sync_with_cpp()
            elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
            assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
                elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context)

        else:
            self._scaled_loss = loss

        main_block._sync_with_cpp()