Exemplo n.º 1
0
def append_backward(loss,
                    parameter_list=None,
                    no_grad_set=None,
                    callbacks=None,
                    checkpoints=None):
    """
    This function appends backward part to main_program.
    A complete neural network training is made up of forward and backward
    propagation. However, when we configure a network, we only need to
    specify its forward part. This function uses the chain rule to automatically
    generate the backward part according to the forward part.
    In most cases, users do not need to invoke this function manually.
    It will be automatically invoked by the optimizer's `minimize` function.
    Parameters:
        loss( :ref:`api_guide_Variable_en` ): The loss variable of the network.
        parameter_list(list of str, optional): Names of parameters that need
                                           to be updated by optimizers.
                                           If it is None, all parameters
                                           will be updated.
                                           Default: None.
        no_grad_set(set of str, optional): Variable names in the :ref:`api_guide_Block_en` 0 whose gradients
                               should be ignored. All variables with
                               `stop_gradient=True` from all blocks will
                               be automatically added into this set.
                               If this parameter is not None, the names in this set will be added to the default set.
                               Default: None.
       callbacks(list of callable object, optional): List of callback functions.
                                               The callbacks are used for
                                               doing some custom jobs during
                                               backward part building. All
                                               callable objects in it will
                                               be invoked once each time a
                                               new gradient operator is added
                                               into the program. The callable
                                               object must has two input
                                               parameters: 'block' and 'context'.
                                               The 'block' is the :ref:`api_guide_Block_en` which
                                               the new gradient operator will
                                               be added to. The 'context' is a
                                               map, whose keys are gradient
                                               variable names and values are
                                               corresponding original :ref:`api_guide_Variable_en` .
                                               In addition to this, the 'context'
                                               has another special key-value pair:
                                               the key is string '__current_op_desc__'
                                               and the value is the op_desc of the
                                               gradient operator who has just
                                               triggered the callable object.
                                               Default: None.
    Returns:
        list of tuple ( :ref:`api_guide_Variable_en` , :ref:`api_guide_Variable_en` ): Pairs of parameter and its corresponding gradients.
        The key is the parameter and the value is gradient variable.
    Raises:
        AssertionError: If `loss` is not an instance of Variable.
    Examples:
        .. code-block:: python
            import paddle.fluid as fluid
            x = fluid.data(name='x', shape=[None, 13], dtype='float32')
            y = fluid.data(name='y', shape=[None, 1], dtype='float32')
            y_predict = fluid.layers.fc(input=x, size=1, act=None)
            loss = fluid.layers.square_error_cost(input=y_predict, label=y)
            avg_loss = fluid.layers.mean(loss)
            param_grad_list = fluid.backward.append_backward(loss=avg_loss)
            p_g_list1 = fluid.backward.append_backward(loss=avg_loss)  # len(p_g_list1) == 2
            p_g_list2 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name])  # len(p_g_list1) == 1
            p_g_list3 = fluid.backward.append_backward(loss=avg_loss, no_grad_set=set([p_g_list1[0][0].name]))  # len(p_g_list1) == 1
            p_g_list4 = fluid.backward.append_backward(loss=avg_loss, parameter_list=[p_g_list1[0][0].name], no_grad_set=set([p_g_list1[0][0].name]))  # len(p_g_list1) == 0
    """

    assert isinstance(loss, framework.Variable)

    if loss.op is None:
        # the loss is from a cloned program. Find loss op manually.
        backward._find_loss_op_(loss)

    loss.op._set_attr(
        core.op_proto_and_checker_maker.kOpRoleAttrName(),
        int(core.op_proto_and_checker_maker.OpRole.Forward)
        | int(core.op_proto_and_checker_maker.OpRole.Loss))

    if callbacks is not None:
        isinstance(callbacks, list)

    program = loss.block.program
    program._appending_grad_times += 1

    if no_grad_set is None:
        no_grad_set = set()
    no_grad_set = copy.copy(no_grad_set)
    no_grad_dict = backward._get_stop_gradients_(program)
    no_grad_dict[0].update(
        list(map(backward._append_grad_suffix_, no_grad_set)))

    grad_info_map = dict()
    root_block = program.block(0)

    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = dict()

    op_desc = _create_loss_op_desc_(loss)
    root_block.desc.append_op().copy_from(op_desc)

    block_no_grad_set = set(map(backward._strip_grad_suffix_, no_grad_dict[0]))
    op_path = backward._find_op_path_(root_block, [loss], [],
                                      block_no_grad_set)
    no_grad_vars = backward._find_no_grad_vars(root_block, op_path, [loss],
                                               block_no_grad_set)
    block_no_grad_set.update(no_grad_vars)
    no_grad_dict[0].update(
        list(map(backward._append_grad_suffix_, block_no_grad_set)))

    input_grad_names_set = None
    # For double backward, input_grad_names is used for filter
    # some non-used gradients op.
    if program._appending_grad_times > 1:
        input_grad_names_set = set([backward._append_grad_suffix_(loss.name)])

    backward._append_backward_ops_(root_block,
                                   op_path,
                                   root_block,
                                   no_grad_dict,
                                   grad_to_var,
                                   callbacks,
                                   input_grad_names_set=input_grad_names_set)

    # Because calc_gradient may be called multiple times,
    # we need rename the internal gradient variables so that they have
    # different names.
    backward._rename_grad_(root_block, fwd_op_num, grad_to_var, {})

    backward._append_backward_vars_(root_block, fwd_op_num, grad_to_var,
                                    grad_info_map)

    program.current_block_idx = current_block_idx
    program._sync_with_cpp()

    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = list(filter(is_mpc_parameter, program.list_vars()))
        parameters = [param.name for param in params if param.trainable]

    params_and_grads = []
    for param in parameters:
        if cpt.to_text(param) not in grad_info_map:
            continue
        grad_info = grad_info_map[param]
        grad_block = grad_info[1]
        if not grad_block.has_var(grad_info[0]):
            raise ValueError(
                "grad block[{0}] did not have grad var {1}".format(
                    grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))

    op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName(
    )
    for p, g in params_and_grads:
        if g is None:
            continue
        for op in reversed(program.global_block().ops):
            assert isinstance(op, framework.Operator)
            if g.name in op.output_arg_names:
                g.op = op
                break

        if g.op is None:
            raise ValueError("Unexpected branch")
        attr_val = [p.name, g.name]
        if g.op.has_attr(op_role_var_attr_name):
            attr_val.extend(g.op.attr(op_role_var_attr_name))
        g.op._set_attr(op_role_var_attr_name, attr_val)

    return params_and_grads
Exemplo n.º 2
0
    def _apply_single_impl(self, main_programs, startup_programs, context):
        checkpoints = self.get_attr("checkpoints")
        loss = self.get_attr("loss")
        no_grad_set = self.get_attr("no_grad_set")
        self._dist_context = self.get_attr("dist_context")

        main_block = main_programs.global_block()
        no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set)
        # get op_path which is related to loss
        op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)

        # step 1: build recompute state
        rc_state = RecomputeState(main_block, op_path)
        rc_state.modify_forward_desc_for_recompute(self._dist_context)
        rc_state.build_stats()
        checkpoints = rc_state.sort_checkpoints(checkpoints)
        segments = rc_state.get_recompute_segments(checkpoints)
        if segments == []:
            return

        # step 2: get vars_should_be_hold
        vars_should_be_hold = []
        for segment in segments:
            vars_should_be_hold.extend(
                rc_state.get_out_of_subgraph_vars(segment[0], segment[1]))
        cross_vars = set(vars_should_be_hold) - set(checkpoints)
        logging.info(
            "found [{}] vars which cross recompute segment: [{}],"
            "better checkpoints might be set to reduce those vars".format(
                len(cross_vars), cross_vars))
        vars_should_be_hold.extend(rc_state.get_reserved_vars())
        vars_should_be_hold.extend(rc_state.get_input_nodes())
        vars_should_be_hold = list(set(vars_should_be_hold))
        vars_in_memory = vars_should_be_hold + checkpoints

        # step 3: get recomputed fwd ops desc
        var_name_dict = {}
        ckpt_ops_dict = {}
        buffer_block = main_block.program._create_block()
        for i, segment in enumerate(segments[::-1]):
            fwd_ops = op_path[segment[0]:segment[1]]
            var_suffix = ".subprog_%d" % i
            for op in fwd_ops:
                input_and_output_names = []
                input_and_output_names.extend(op.desc.input_arg_names())
                input_and_output_names.extend(op.desc.output_arg_names())
                cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
                    op)
                assert cur_op_dist_attr is not None
                for name in input_and_output_names:
                    if main_block.var(name).persistable or name in checkpoints:
                        continue
                    if name in vars_should_be_hold:
                        continue
                    if name not in var_name_dict:
                        ref_process_mesh = cur_op_dist_attr.process_mesh
                        if name in op.desc.input_arg_names():
                            ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping(
                                name)
                        else:
                            ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping(
                                name)
                        # record recomputed var's old_name and new_name (old_name.subprog_XXX)
                        # create new var with new name
                        var_name_dict[name] = name + var_suffix
                        ref_var = main_block.var(name)
                        rc_var = main_block.create_var(
                            name=var_name_dict[name],
                            shape=ref_var.shape,
                            dtype=ref_var.dtype,
                            type=ref_var.type,
                            persistable=ref_var.persistable,
                            stop_gradient=ref_var.stop_gradient)
                        # set new recomputed var's dist attr
                        set_var_dist_attr(self._dist_context, rc_var,
                                          ref_dims_mapping, ref_process_mesh)
            # get recomputed segment's descs
            segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block,
                                                       main_block,
                                                       vars_in_memory,
                                                       self._dist_context)
            # rename recomputed ops' input and output var name
            for key in var_name_dict:
                _rename_arg_(segment_descs, key, var_name_dict[key])

            # NOTE: one forward op could be correspond to multiple xxx_grad op.
            # When traversing all grad_ops in reverse, need to set a flag to indicate
            # whether the ckpt and its segment_descs can be used.
            ckpt_op = op_path[segment[1] - 1]
            ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs]

        # step 4: insert recomputed fwd ops
        ops = main_block.ops
        loss_op = get_loss_op(main_block)
        loss_op_idx = _find_op_index(main_block, loss_op)
        dist_op_context = self._dist_context.dist_op_context
        assert loss_op_idx != -1
        # Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints,
        # segments ops should be inserted.
        for i in range(len(ops) - 1, loss_op_idx, -1):
            grad_op = ops[i]
            # remove some attrs of dropout_grad op's desc
            if grad_op.type == "dropout_grad":
                grad_op.desc.remove_attr("fix_seed")
                grad_op.desc.remove_attr("seed")
                main_block._sync_with_cpp()

            # rename grad op's var_name which is not in 'vars_in_memory'
            for key in var_name_dict:
                self.reset_op_dist_attr(grad_op, var_name_dict)
                _rename_arg_([grad_op.desc], key, var_name_dict[key])

            # insert recomputed ops
            if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
                fwd_op_id = dist_op_context.grad_op_id_to_op_id[
                    grad_op.desc.id()]
                if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]:
                    idx = grad_op.idx
                    while idx - 1 >= 0 and ops[idx - 1].type == "sum":
                        idx -= 1
                    segment_descs = ckpt_ops_dict[fwd_op_id][1]
                    for _, op_desc in reversed(list(enumerate(segment_descs))):
                        rc_desc = main_block.desc._insert_op(idx)
                        rc_desc.copy_from(op_desc)
                        rc_op = Operator(main_block, rc_desc)
                        main_block.ops.insert(idx, rc_op)
                        # set recomputed ops' dist attr
                        fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
                            rc_desc.original_id())
                        assert fwd_op_dist_attr is not None
                        self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
                                              var_name_dict)

                    ckpt_ops_dict[fwd_op_id][0] = False
                    main_block._sync_with_cpp()

        main_programs._sync_with_cpp()