def set_default_dist_attr(program, dist_context, process_mesh): ops = program.global_block().ops vars = program.global_block().vars for op in ops: op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = process_mesh for var_name in op.input_arg_names: tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] dist_context.set_tensor_dist_attr_for_program( vars[var_name], tensor_dist_attr) op_dist_attr.set_input_dims_mapping(var_name, tensor_dist_attr.dims_mapping) for var_name in op.output_arg_names: tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.process_mesh = process_mesh tensor_dist_attr.dims_mapping = [-1 for i in vars[var_name].shape] dist_context.set_tensor_dist_attr_for_program( vars[var_name], tensor_dist_attr) op_dist_attr.set_output_dims_mapping(var_name, tensor_dist_attr.dims_mapping) dist_context.set_op_dist_attr_for_program(op, op_dist_attr) dist_context.add_process_mesh(process_mesh)
def _update_loss_scaling(self, grads, found_inf): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() check_variable_and_dtype(self._loss_scaling, "prev_loss_scaling", ['float32', 'float64'], "update_loss_scaling") check_type(grads, 'x', (tuple, list), 'update_loss_scaling') for e in grads: check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling') assert self._loss_scaling.dtype == e.dtype, \ "The dtype of prev_loss_scaling should be equal to the dtype of x." inputs = { 'X': grads, 'FoundInfinite': found_inf, 'PrevLossScaling': self._loss_scaling, 'InGoodSteps': self._num_good_steps, 'InBadSteps': self._num_bad_steps } outputs = { 'Out': grads, 'LossScaling': self._loss_scaling, 'OutGoodSteps': self._num_good_steps, 'OutBadSteps': self._num_bad_steps } attrs = { 'incr_every_n_steps': self.get_attr("incr_every_n_steps"), 'decr_every_n_nan_or_inf': self.get_attr("decr_every_n_nan_or_inf"), 'incr_ratio': self.get_attr("incr_ratio"), 'decr_ratio': self.get_attr("decr_ratio"), 'stop_update': self.get_attr("stop_update"), 'op_role': OpRole.Backward } new_op = main_block.append_op( type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = global_process_mesh if len(global_process_mesh) > 1: new_op_dist_attr.impl_idx = 0 for g in grads: g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None new_op_dist_attr.set_input_dims_mapping(g.name, g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(g.name, g_dist_attr.dims_mapping) self.dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) main_block._sync_with_cpp()
def _check_and_update_gradient(params_grads, loss_scaling, dist_context): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() grads = [g for _, g in params_grads] check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') for e in grads: check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'check_finite_and_unscale') found_inf = main_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( ['find_infinite_scale', 'tmp'])), shape=[1], dtype='bool', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} attrs = {'op_role': OpRole.Backward} new_op = main_block.append_op(type='check_finite_and_unscale', inputs=inputs, outputs=outputs, attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = world_process_group.ranks new_op_dist_attr.impl_idx = 0 if len(world_process_group.ranks) > 1: new_op_dist_attr.impl_type = "check_finite_and_unscale" for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None new_op_dist_attr.set_input_dims_mapping(g.name, g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(g.name, g_dist_attr.dims_mapping) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) return grads, found_inf
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): new_dist_attr = OperatorDistributedAttribute() new_dist_attr.is_recompute = True new_dist_attr.impl_idx = old_dist_attr.impl_idx new_dist_attr.process_mesh = old_dist_attr.process_mesh for input in old_dist_attr.inputs_dist_attrs.keys(): if input in var_name_dict.keys(): in_dist_attr = old_dist_attr.inputs_dist_attrs[input] new_dist_attr.set_input_dist_attr(var_name_dict[input], in_dist_attr) else: in_dist_attr = old_dist_attr.inputs_dist_attrs[input] new_dist_attr.set_input_dist_attr(input, in_dist_attr) for output in old_dist_attr.outputs_dist_attrs.keys(): if output in var_name_dict.keys(): out_dist_attr = old_dist_attr.outputs_dist_attrs[output] new_dist_attr.set_output_dist_attr(var_name_dict[output], out_dist_attr) else: out_dist_attr = old_dist_attr.outputs_dist_attrs[output] new_dist_attr.set_output_dist_attr(output, out_dist_attr) self._dist_context.set_op_dist_attr_for_program(op, new_dist_attr)