def _insert_reduce_op(block, insert_idx, reduce_var, ring_id, root_id, dist_context, op_role=OpRole.Backward, use_calc_stream=True): assert root_id >= 0, "root id should be a positive int, but now root id is {}".format( root_id) new_op = block._insert_op_without_sync(insert_idx, type='c_reduce_sum', inputs={'X': [reduce_var]}, outputs={'Out': [reduce_var]}, attrs={ 'ring_id': ring_id, 'root_id': root_id, 'use_calc_stream': use_calc_stream, OP_ROLE_KEY: op_role }) dist_attr = dist_context.get_tensor_dist_attr_for_program( block.var(reduce_var)) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, dist_attr.process_mesh, dist_attr.dims_mapping, dist_context)
def _insert_optimizer_broadcasts(self, main_block, startup_block): if self.stage > 2: return for sharding_info in self.sharding_infos: for param in sharding_info.params: assert main_block.has_var(param.name) assert startup_block.has_var(param.name) new_op = main_block.append_op(type='c_broadcast', inputs={'X': param}, outputs={'Out': param}, attrs={ 'ring_id': sharding_info.group.id, 'root': sharding_info.get_var_rank( param.name), 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize }) param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( param) assert param_dist_attr is not None naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, param_dist_attr.process_mesh, param_dist_attr.dims_mapping, self._dist_context) main_block._sync_with_cpp()
def _update_backward_cast_ops(params_grads, dist_context): """ move param grad cast to the end of backward segment in order to enabel fp16 allreduce """ # TODO filter optimize ops in future main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() for p, g in params_grads: op = g.op if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr( 'op_role_var'): op._remove_attr("op_role_var") post_ops = find_true_post_op(main_block.ops, op, g.name) if post_ops: raise ValueError("The cast op {0}'s output should not be" "used by a non-optimize op, however, it" "is used by {1}".format(op, post_ops[0])) if op == main_block.ops[-1]: continue # add new op in the python and cpp at the same time new_op_desc = main_block.desc.append_op() new_op_desc.copy_from(op.desc) new_op = paddle.fluid.framework.Operator( block=main_block, desc=new_op_desc, type=None, inputs=None, outputs=None, attrs=None) main_block.ops.append(new_op) # dist attr param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) output_dist_attr = dist_context.get_tensor_dist_attr_for_program( main_block.var(op.output_arg_names[0])) assert param_dist_attr is not None assert output_dist_attr is not None naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, param_dist_attr.process_mesh, param_dist_attr.dims_mapping, dist_context) output_dist_attr.process_mesh = param_dist_attr.process_mesh output_dist_attr.dims_mapping = param_dist_attr.dims_mapping op_idx = find_op_index(main_block.desc, op.desc) if op_idx == -1: raise ValueError("The op {0} is not in program".format(op)) main_block._remove_op(op_idx, sync=False) main_block._sync_with_cpp()
def _insert_init_and_broadcast_op(block, insert_idx, varname, local_rank, root_rank, ring_id, op_role, dist_context): """ empty op for initialization """ broadcast_var = block.var(varname) broadcast_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( broadcast_var) new_op = block._insert_op_without_sync(insert_idx, type='c_broadcast', inputs={'X': varname}, outputs={'Out': varname}, attrs={ 'ring_id': ring_id, 'root': root_rank, 'use_calc_stream': True, OP_ROLE_KEY: op_role }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, broadcast_var_dist_attr.process_mesh, broadcast_var_dist_attr.dims_mapping, dist_context) if local_rank != root_rank: new_op = block._insert_op_without_sync( insert_idx, type="empty", outputs={"Out": broadcast_var.name}, attrs={ "shape": broadcast_var.shape, "dtype": broadcast_var.dtype, OP_ROLE_KEY: op_role }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, broadcast_var_dist_attr.process_mesh, broadcast_var_dist_attr.dims_mapping, dist_context) return
def _shard_gradient_clip(self, main_block): if self.stage < 2: return # TODO (JZ-LIANG) support calculate global norm with tensor parallelism removed_op_type = [ 'elementwise_mul', 'squared_l2_norm', 'clip_by_norm' ] removed_op_idx = set() removed_tmp_var = set() for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue if op.type in removed_op_type: input_name = op.input("X")[0] param_name = input_name[:input_name.find("@GRAD")] if not self._is_parameter_in_local_shard(param_name): removed_op_idx.add(idx) if op.type in ['squared_l2_norm', 'clip_by_norm']: for output_name in op.output_arg_names: removed_tmp_var.add(output_name) for idx, op in reversed(list(enumerate(main_block.ops))): if not _is_gradient_clip_op(op): continue if idx in removed_op_idx: main_block._remove_op(idx, sync=False) for varname in removed_tmp_var: main_block._remove_var(varname, sync=False) for idx, op in list(enumerate(main_block.ops)): if not _is_gradient_clip_op(op): continue if op.type == 'sum': reserved_vars = [] for input_name in op.input_arg_names: if input_name not in removed_tmp_var: reserved_vars.append(input_name) op.desc.set_input("X", reserved_vars) sum_op_output = op.desc.output_arg_names()[0] for i, sharding_info in enumerate(self.sharding_infos): new_op = main_block._insert_op( idx + i + 1, type='c_allreduce_sum', inputs={'X': [sum_op_output]}, outputs={'Out': [sum_op_output]}, attrs={ 'ring_id': sharding_info.group.id, 'op_namescope': "/gradient_clip_model_parallelism", 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize, }) dist_attr = self._dist_context.get_tensor_dist_attr_for_program( main_block.var(sum_op_output)) assert dist_attr is not None naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, dist_attr.process_mesh, dist_attr.dims_mapping, self._dist_context) break main_block._sync_with_cpp()
def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ op_types = [op.desc.type() for op in self._ops] if "dropout" not in op_types: return op_idx = 0 while op_idx < len(self._ops): cur_op = self._ops[op_idx] if "grad" in cur_op.type: break if cur_op.type != "dropout": op_idx += 1 continue if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): op_idx += 1 continue cur_op_dist_attr = dist_context.get_op_dist_attr_for_program( cur_op) # insert seed op to guarantee that two dropout op have the same outputs op_unique_name = unique_name.generate("seed") var_unique_name = unique_name.generate_with_ignorable_key(".".join( [op_unique_name, 'tmp'])) seed_var = self._block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) # set new seed_var's dist_attr ref_dims_mapping = [-1] ref_process_mesh = cur_op_dist_attr.process_mesh seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var, ref_dims_mapping, ref_process_mesh) seed = 0 if cur_op.attr("fix_seed") is False else int( cur_op.attr("seed")) seed_op = self._block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, outputs={"Out": seed_var}, attrs={ "seed": seed, "force_cpu": True }) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( seed_op, ref_process_mesh, ref_dims_mapping, dist_context) # modify dropout op's desc self._ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.remove_attr("fix_seed") cur_op.desc.remove_attr("seed") cur_op_dist_attr.set_input_dist_attr(seed_var.name, seed_var_dist_attr) self._block._sync_with_cpp() op_idx += 2
def _get_gm_cond_var(main_program, k_steps, dist_context): main_block = main_program.global_block() # Add const var k_step_var = layers.create_global_var(name="gradient_merge_k", shape=[1], value=int(k_steps), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks) zero_var = layers.create_global_var(name="gradient_merge_zero", shape=[1], value=int(0), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks) # Add step var & cond var step_var = layers.create_global_var(name="gradient_merge_step", shape=[1], value=int(0), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) cond_var = main_block.create_var(name="gradient_merge_cond", shape=[1], dtype='bool') set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) with device_guard("cpu"): # step_var = (step_var + 1) % k_step layers.increment(x=step_var, value=1.0, in_place=True) elementwise_mod_op = main_block.append_op(type='elementwise_mod', inputs={ 'X': step_var, 'Y': k_step_var }, outputs={'Out': step_var}, attrs={ 'axis': -1, 'use_mkldnn': False }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mod_op, world_process_group.ranks, [-1], dist_context) # cond_var = (step_var == 0) equal_op = main_block.append_op(type='equal', inputs={ 'X': step_var, 'Y': zero_var }, outputs={'Out': cond_var}) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( equal_op, world_process_group.ranks, [-1], dist_context) return cond_var
def _append_gradient_merge_backward_op( main_program, startup_program, params_grads: List[Tuple[Any, Any]], cond_var_name: str, dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: main_block = main_program.global_block() startup_block = startup_program.global_block() # step1: remove grad.op's op_role_var for param, grad in params_grads: assert ( param.type != core.VarDesc.VarType.SELECTED_ROWS ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" _remove_op_role_var(param, grad) param_to_gradient_merge = {} new_params_to_grads = [] # step2: create gradient_merge var and init with 0 for param, grad in params_grads: param_name = param.name param_var = main_block.var(param_name) assert (param_var is not None) ref_dist_attr = dist_context.get_tensor_dist_attr_for_program( param_var) assert ref_dist_attr is not None gradient_merge_var = main_block.create_var(name=param_name + "@GRAD@GradientMerge", shape=param_var.shape, dtype=param_var.dtype, persistable=True) param_to_gradient_merge[param_name] = gradient_merge_var ref_process_mesh = ref_dist_attr.process_mesh ref_dims_mapping = ref_dist_attr.dims_mapping set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh) startup_gradient_merge_var = startup_block.create_var( name=param_name + "@GRAD@GradientMerge", shape=param_var.shape, dtype=param_var.dtype, persistable=True) startup_block.append_op(type="fill_constant", outputs={"Out": startup_gradient_merge_var}, attrs={ "shape": param_var.shape, "dtype": param_var.dtype, "value": float(0), }) # grad_merge += grad new_grad_op = main_block.append_op(type="elementwise_add", inputs={ 'X': grad, 'Y': gradient_merge_var }, outputs={'Out': gradient_merge_var}, attrs={ 'axis': -1, 'use_mkldnn': False }) new_params_to_grads.append([param, gradient_merge_var]) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) return new_params_to_grads, param_to_gradient_merge
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: # cast loss here will change the effective loss tensor for the computation graph # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge), # so we it is not allowed by now. fixed it in future. raise NotImplementedError( "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list." ) tmp_name = unique_name.generate(loss.name + ".cast_fp32") cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( loss) ref_mesh = loss_op_dist_attr.process_mesh self.dist_context.set_tensor_dist_attr_for_program( cast_loss, loss_dist_attr) loss_op_idx = find_op_index(main_block.desc, loss_op.desc) cast_op = main_block._insert_op( loss_op_idx + 1, type='cast', inputs={'X': [loss]}, outputs={'Out': [cast_loss]}, attrs={ "in_dtype": loss.dtype, "out_dtype": core.VarDesc.VarType.FP32, 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, [-1], self.dist_context) loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling" ) or self.get_attr("init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={ 'X': [loss], 'Y': [self._loss_scaling] }, outputs={'Out': [self._scaled_loss]}, attrs={ 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, dist_context): """ only for backward cast """ def _keep_fp32_input(op, in_name): op_type = op.type if op_type in ['layer_norm_grad']: return in_name not in {'X', 'Y@GRAD'} return False def _keep_fp32_output(op, out_name): op_type = op.type if op_type in ['layer_norm_grad']: return out_name != 'X@GRAD' return False num_cast_ops = 0 original_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] for in_name in grad_op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( grad_op, in_name): for in_var_name in grad_op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) assert in_var.dtype == core.VarDesc.VarType.FP32 continue for in_var_name in grad_op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) if in_var.dtype == src_dtype: consume_op_attr = dist_context.get_op_dist_attr_for_program( grad_op) if in_var_name in self._var_name_dict[fwd_op_id]: # NOTE: if in_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr. cast_name = self._var_name_dict[fwd_op_id][in_var_name] grad_op.desc._rename_input(in_var_name, cast_name) in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_name) consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) else: assert in_var.dtype == dst_dtype for out_name in grad_op.output_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( grad_op, out_name): for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) assert out_var.dtype == core.VarDesc.VarType.FP32 continue for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) out_var_name_prefix = out_var_name[:out_var_name.find("@")] fwd_var = self._block._find_var_recursive(out_var_name_prefix) # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype if out_var.dtype != fwd_var.dtype: out_var.desc.set_dtype(fwd_var.dtype) if out_var.dtype == src_dtype: if out_var_name_prefix in self._var_name_dict[fwd_op_id]: # NOTE: if out_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr, then we insert cast op to # convert the cast_var to original dtype consume_op_attr = dist_context.get_op_dist_attr_for_program( grad_op) fwd_cast_name = self._var_name_dict[fwd_op_id][ out_var_name_prefix] cast_name = fwd_cast_name + "@GRAD" cast_var = self._block.vars.get(cast_name) if cast_var is None or cast_var.dtype != dst_dtype: grad_op.desc._rename_output( out_var_name, cast_name) out_var_dist_attr = consume_op_attr.get_output_dist_attr( out_var_name) ref_mesh = out_var_dist_attr.process_mesh ref_mapping = out_var_dist_attr.dims_mapping consume_op_attr.set_output_dist_attr( cast_name, out_var_dist_attr) assert ref_mapping is not None cast_var = self._block.create_var( name=cast_name, shape=out_var.shape, dtype=dst_dtype, persistable=False, stop_gradient=out_var.stop_gradient) set_var_dist_attr(dist_context, cast_var, ref_mapping, ref_mesh) cast_op = self._block._insert_op( idx + 1, type="cast", inputs={"X": cast_var}, outputs={"Out": out_var}, attrs={ "in_dtype": cast_var.dtype, "out_dtype": out_var.dtype, "op_role": OpRole.Backward }) cast_op._remove_attr("op_role_var") cast_op._remove_attr("op_namescope") cast_op._remove_attr("with_quant_attr") naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context) num_cast_ops += 1 else: assert out_var.dtype == dst_dtype return num_cast_ops
def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype, dist_context): """ only for forward cast modified from paddle.fluid.contrib.mixed_precision """ num_cast_ops = 0 for in_name in op.input_names: var_name_dict = {} if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( op, in_name): continue for in_var_name in op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dst_dtype: continue if in_var.dtype == src_dtype: cast_name = in_var.name + '.cast_' + _dtype_to_str( dst_dtype) out_var = self._block.vars.get(cast_name) var_name_dict[in_var.name] = cast_name consume_op_attr = dist_context.get_op_dist_attr_for_program( op) assert consume_op_attr is not None if out_var is None or out_var.dtype != dst_dtype: # NOTE we make the cast op and var's dist attr as the op that consume the # cast var instead of the op which generates the var in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var.name) assert in_var_dist_attr is not None ref_mesh = in_var_dist_attr.process_mesh ref_mapping = in_var_dist_attr.dims_mapping consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) out_var = self._block.create_var( name=cast_name, dtype=dst_dtype, persistable=False, stop_gradient=in_var.stop_gradient) set_var_dist_attr(dist_context, out_var, ref_mapping, ref_mesh) cast_op = self._block._insert_op_without_sync( idx, type="cast", inputs={"X": in_var}, outputs={"Out": out_var}, attrs={ "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context) num_cast_ops += 1 else: in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var.name) consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) _rename_arg(op, in_var.name, cast_name) else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dst_dtype) self._var_name_dict[op.desc.original_id()] = var_name_dict if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = self._block.var(out_var_name) if out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: out_var.desc.set_dtype(core.VarDesc.VarType.FP16) if op.has_attr('out_dtype'): op._set_attr('out_dtype', core.VarDesc.VarType.FP16) return num_cast_ops
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling") or self.get_attr( "init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={'X': [loss], 'Y': [self._loss_scaling]}, outputs={'Out': [self._scaled_loss]}, attrs={'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()