def _update_loss_scaling(self, grads, found_inf):

        main_block = paddle.static.default_main_program().global_block()
        main_block._sync_with_cpp()

        check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling",
                                 ['float32', 'float64'], "update_loss_scaling")
        check_type(grads, 'x', (tuple, list), 'update_loss_scaling')
        for e in grads:
            check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
                                     'update_loss_scaling')
            assert self._loss_scaling.dtype == e.dtype, \
                "The dtype of prev_loss_scaling should be equal to the dtype of x."

        inputs = {
            'X': grads,
            'FoundInfinite': found_inf,
            'PrevLossScaling': self._loss_scaling,
            'InGoodSteps': self._num_good_steps,
            'InBadSteps': self._num_bad_steps
        }

        outputs = {
            'Out': grads,
            'LossScaling': self._loss_scaling,
            'OutGoodSteps': self._num_good_steps,
            'OutBadSteps': self._num_bad_steps
        }

        attrs = {
            'incr_every_n_steps': self.get_attr("incr_every_n_steps"),
            'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"),
            'incr_ratio': self.get_attr("incr_ratio"),
            'decr_ratio': self.get_attr("decr_ratio"),
            'stop_update': self.get_attr("stop_update"),
            'op_role': OpRole.Backward
        }

        new_op = main_block.append_op(
            type='update_loss_scaling',
            inputs=inputs,
            outputs=outputs,
            attrs=attrs)

        new_op_dist_attr = OperatorDistributedAttribute()
        new_op_dist_attr.process_mesh = global_process_mesh
        if len(global_process_mesh) > 1:
            new_op_dist_attr.impl_idx = 0
        for g in grads:
            g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
            assert g_dist_attr is not None
            new_op_dist_attr.set_input_dims_mapping(g.name,
                                                    g_dist_attr.dims_mapping)
            new_op_dist_attr.set_output_dims_mapping(g.name,
                                                     g_dist_attr.dims_mapping)
        self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)

        main_block._sync_with_cpp()
Esempio n. 2
0
def _check_and_update_gradient(params_grads, loss_scaling, dist_context):

    main_block = paddle.static.default_main_program().global_block()
    main_block._sync_with_cpp()

    grads = [g for _, g in params_grads]
    check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale')
    for e in grads:
        check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
                                 'check_finite_and_unscale')

    found_inf = main_block.create_var(
        name=unique_name.generate_with_ignorable_key(".".join(
            ['find_infinite_scale', 'tmp'])),
        shape=[1],
        dtype='bool',
        type=core.VarDesc.VarType.LOD_TENSOR,
        persistable=False,
        stop_gradient=False)
    set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)

    inputs = {'X': grads, 'Scale': loss_scaling}
    outputs = {'Out': grads, 'FoundInfinite': found_inf}
    attrs = {'op_role': OpRole.Backward}
    new_op = main_block.append_op(type='check_finite_and_unscale',
                                  inputs=inputs,
                                  outputs=outputs,
                                  attrs=attrs)

    new_op_dist_attr = OperatorDistributedAttribute()
    new_op_dist_attr.process_mesh = world_process_group.ranks
    new_op_dist_attr.impl_idx = 0
    if len(world_process_group.ranks) > 1:
        new_op_dist_attr.impl_type = "check_finite_and_unscale"
    for g in grads:
        g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
        assert g_dist_attr is not None
        new_op_dist_attr.set_input_dims_mapping(g.name,
                                                g_dist_attr.dims_mapping)
        new_op_dist_attr.set_output_dims_mapping(g.name,
                                                 g_dist_attr.dims_mapping)
    dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
    return grads, found_inf
Esempio n. 3
0
 def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
     new_dist_attr = OperatorDistributedAttribute()
     new_dist_attr.is_recompute = True
     new_dist_attr.impl_idx = old_dist_attr.impl_idx
     new_dist_attr.process_mesh = old_dist_attr.process_mesh
     for input in old_dist_attr.inputs_dist_attrs.keys():
         if input in var_name_dict.keys():
             in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
             new_dist_attr.set_input_dist_attr(var_name_dict[input],
                                               in_dist_attr)
         else:
             in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
             new_dist_attr.set_input_dist_attr(input, in_dist_attr)
     for output in old_dist_attr.outputs_dist_attrs.keys():
         if output in var_name_dict.keys():
             out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
             new_dist_attr.set_output_dist_attr(var_name_dict[output],
                                                out_dist_attr)
         else:
             out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
             new_dist_attr.set_output_dist_attr(output, out_dist_attr)
     self._dist_context.set_op_dist_attr_for_program(op, new_dist_attr)