def _check_and_update_gradient(params_grads, loss_scaling, dist_context): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() grads = [g for _, g in params_grads] check_type(grads, 'x', (tuple, list), 'check_finite_and_unscale') for e in grads: check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'check_finite_and_unscale') found_inf = main_block.create_var( name=unique_name.generate_with_ignorable_key(".".join( ['find_infinite_scale', 'tmp'])), shape=[1], dtype='bool', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks) inputs = {'X': grads, 'Scale': loss_scaling} outputs = {'Out': grads, 'FoundInfinite': found_inf} attrs = {'op_role': OpRole.Backward} new_op = main_block.append_op(type='check_finite_and_unscale', inputs=inputs, outputs=outputs, attrs=attrs) new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr.process_mesh = world_process_group.ranks new_op_dist_attr.impl_idx = 0 if len(world_process_group.ranks) > 1: new_op_dist_attr.impl_type = "check_finite_and_unscale" for g in grads: g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) assert g_dist_attr is not None new_op_dist_attr.set_input_dims_mapping(g.name, g_dist_attr.dims_mapping) new_op_dist_attr.set_output_dims_mapping(g.name, g_dist_attr.dims_mapping) dist_context.set_op_dist_attr_for_program(new_op, new_op_dist_attr) return grads, found_inf
def _init_amp_var(self): self._loss_scaling = paddle.static.create_global_var( name=unique_name.generate("loss_scaling"), shape=[1], value=self.get_attr("init_loss_scaling"), dtype='float32', persistable=True) set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], world_process_group.ranks) if self.get_attr("use_dynamic_loss_scaling"): self._num_good_steps = paddle.static.create_global_var( name=unique_name.generate("num_good_steps"), shape=[1], value=0, dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], world_process_group.ranks) self._num_bad_steps = paddle.static.create_global_var( name=unique_name.generate("num_bad_steps"), shape=[1], value=0, dtype='int32', persistable=True) set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], world_process_group.ranks)
def init_prog(self): # block = self.main_program.global_block() # block = self.main_program.global_block() self.w = self.layer_help.create_parameter( dtype="float", shape=[20], attr=None) self.w_grad = paddle.static.data( name='w_grad', shape=[20], dtype='float') self.tmp1 = paddle.static.data(name='tmp1', shape=[20], dtype='float') self.tmp2 = paddle.static.data(name='tmp2', shape=[20], dtype='float') self.batch_reduced = paddle.static.data( name='batch_reduced', shape=[1], dtype='float') self.attrs = {} default_dist_context = get_default_distributed_context() _global_process_mesh = auto.ProcessMesh(list(range(nranks))) tensor_dist_attr = set_var_dist_attr( default_dist_context, self.tmp1, [-1], _global_process_mesh, mark_annotated=True) tensor_dist_attr = set_var_dist_attr( default_dist_context, self.tmp1, [-1], _global_process_mesh, mark_annotated=True) op = self.layer_help.append_op( type="add_p", inputs={'X': self.tmp1, 'Y': self.w}, outputs={'Z': self.w_grad}, attrs=self.attrs) op = self.layer_help.append_op( type="reduce_p", inputs={'X': self.tmp2}, outputs={'Y': self.batch_reduced}, attrs={"axis": [0]})
def _shard_parameter(self, main_block, startup_block): if self.stage < 3: return dp_ring_ids = [group.id for group in self.dp_groups] for sharding_info in self.sharding_infos: need_broadcast_vars, param_usage = sharding_info.get_broadcast_vars_and_param_usage( main_block) not_used_param_nane = [] for param_name in param_usage: if param_usage[param_name] == 0 and sharding_info.get_var_rank( param_name) != sharding_info.local_rank: not_used_param_nane.append(param_name) for idx, op in reversed(list(enumerate(main_block.ops))): if is_optimizer_op(op): continue for input_name in op.desc.input_arg_names(): if op.type == "cast": continue if input_name not in need_broadcast_vars: continue root_rank = sharding_info.get_var_rank(input_name) if root_rank == sharding_info.local_rank: broadcast_varname = input_name else: broadcast_varname = unique_name.generate(input_name + "@BroadCast") input_var = main_block.var(input_name) new_var = main_block.create_var(name=broadcast_varname, shape=input_var.shape, dtype=input_var.dtype, persistable=False) ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( input_var) out_var_dist_attr = set_var_dist_attr( self._dist_context, new_var, ref_dist_attr.dims_mapping, ref_dist_attr.process_mesh) op._rename_input(input_name, broadcast_varname) _insert_init_and_broadcast_op(main_block, idx, broadcast_varname, sharding_info.local_rank, root_rank, sharding_info.group.id, op.attr('op_role'), self._dist_context) for idx, op in reversed(list(enumerate(main_block.ops))): if op.type != "cast": continue input_name = op.input_arg_names[0] output_name = op.output_arg_names[0] if input_name in not_used_param_nane: main_block._remove_op(idx, sync=False) main_block._remove_var(output_name, sync=False) for idx, op in reversed(list(enumerate(startup_block.ops))): assert len(op.output_arg_names) == 1 output_name = op.output_arg_names[0] if op.type == "c_broadcast" and op.attr( "ring_id") in dp_ring_ids: if self.outer_dp_group and sharding_info.get_var_rank( output_name) == sharding_info.local_rank: op._set_attr("ring_id", self.outer_dp_group.id) else: startup_block._remove_op(idx, sync=False) continue if op.type != "c_broadcast" and output_name in param_usage and sharding_info.get_var_rank( output_name) != sharding_info.local_rank: startup_block._remove_op(idx, sync=False) for param_name in param_usage: if sharding_info.get_var_rank( param_name) != sharding_info.local_rank: main_block._remove_var(param_name, sync=False) startup_block._remove_var(param_name, sync=False) main_block._sync_with_cpp() startup_block._sync_with_cpp()
def modify_forward_desc_for_recompute(self, dist_context): """ If program's foward part has 'dropout' op, this function will insert a seed op before it to guarantee that two dropout op have the same outputs. """ op_types = [op.desc.type() for op in self._ops] if "dropout" not in op_types: return op_idx = 0 while op_idx < len(self._ops): cur_op = self._ops[op_idx] if "grad" in cur_op.type: break if cur_op.type != "dropout": op_idx += 1 continue if cur_op.input("Seed") is not None and len(cur_op.input("Seed")): op_idx += 1 continue cur_op_dist_attr = dist_context.get_op_dist_attr_for_program( cur_op) # insert seed op to guarantee that two dropout op have the same outputs op_unique_name = unique_name.generate("seed") var_unique_name = unique_name.generate_with_ignorable_key(".".join( [op_unique_name, 'tmp'])) seed_var = self._block.create_var( name=var_unique_name, dtype='int32', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) # set new seed_var's dist_attr ref_dims_mapping = [-1] ref_process_mesh = cur_op_dist_attr.process_mesh seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var, ref_dims_mapping, ref_process_mesh) seed = 0 if cur_op.attr("fix_seed") is False else int( cur_op.attr("seed")) seed_op = self._block._insert_op_without_sync( index=cur_op.idx, type="seed", inputs={}, outputs={"Out": seed_var}, attrs={ "seed": seed, "force_cpu": True }) # set new seed op's dist_attr naive_set_dist_op_attr_for_program_by_mesh_and_mapping( seed_op, ref_process_mesh, ref_dims_mapping, dist_context) # modify dropout op's desc self._ops.insert(op_idx, seed_op) cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.remove_attr("fix_seed") cur_op.desc.remove_attr("seed") cur_op_dist_attr.set_input_dist_attr(seed_var.name, seed_var_dist_attr) self._block._sync_with_cpp() op_idx += 2
def _apply_single_impl(self, main_programs, startup_programs, context): checkpoints = self.get_attr("checkpoints") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") self._dist_context = self.get_attr("dist_context") main_block = main_programs.global_block() no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set) # get op_path which is related to loss op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) # step 1: build recompute state rc_state = RecomputeState(main_block, op_path) rc_state.modify_forward_desc_for_recompute(self._dist_context) rc_state.build_stats() checkpoints = rc_state.sort_checkpoints(checkpoints) segments = rc_state.get_recompute_segments(checkpoints) if segments == []: return # step 2: get vars_should_be_hold vars_should_be_hold = [] for segment in segments: vars_should_be_hold.extend( rc_state.get_out_of_subgraph_vars(segment[0], segment[1])) cross_vars = set(vars_should_be_hold) - set(checkpoints) logging.info( "found [{}] vars which cross recompute segment: [{}]," "better checkpoints might be set to reduce those vars".format( len(cross_vars), cross_vars)) vars_should_be_hold.extend(rc_state.get_reserved_vars()) vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold = list(set(vars_should_be_hold)) vars_in_memory = vars_should_be_hold + checkpoints # step 3: get recomputed fwd ops desc var_name_dict = {} ckpt_ops_dict = {} buffer_block = main_block.program._create_block() for i, segment in enumerate(segments[::-1]): fwd_ops = op_path[segment[0]:segment[1]] var_suffix = ".subprog_%d" % i for op in fwd_ops: input_and_output_names = [] input_and_output_names.extend(op.desc.input_arg_names()) input_and_output_names.extend(op.desc.output_arg_names()) cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( op) assert cur_op_dist_attr is not None for name in input_and_output_names: if main_block.var(name).persistable or name in checkpoints: continue if name in vars_should_be_hold: continue if name not in var_name_dict: ref_process_mesh = cur_op_dist_attr.process_mesh if name in op.desc.input_arg_names(): ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping( name) else: ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping( name) # record recomputed var's old_name and new_name (old_name.subprog_XXX) # create new var with new name var_name_dict[name] = name + var_suffix ref_var = main_block.var(name) rc_var = main_block.create_var( name=var_name_dict[name], shape=ref_var.shape, dtype=ref_var.dtype, type=ref_var.type, persistable=ref_var.persistable, stop_gradient=ref_var.stop_gradient) # set new recomputed var's dist attr set_var_dist_attr(self._dist_context, rc_var, ref_dims_mapping, ref_process_mesh) # get recomputed segment's descs segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block, main_block, vars_in_memory, self._dist_context) # rename recomputed ops' input and output var name for key in var_name_dict: _rename_arg_(segment_descs, key, var_name_dict[key]) # NOTE: one forward op could be correspond to multiple xxx_grad op. # When traversing all grad_ops in reverse, need to set a flag to indicate # whether the ckpt and its segment_descs can be used. ckpt_op = op_path[segment[1] - 1] ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs] # step 4: insert recomputed fwd ops ops = main_block.ops loss_op = get_loss_op(main_block) loss_op_idx = _find_op_index(main_block, loss_op) dist_op_context = self._dist_context.dist_op_context assert loss_op_idx != -1 # Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints, # segments ops should be inserted. for i in range(len(ops) - 1, loss_op_idx, -1): grad_op = ops[i] # remove some attrs of dropout_grad op's desc if grad_op.type == "dropout_grad": grad_op.desc.remove_attr("fix_seed") grad_op.desc.remove_attr("seed") main_block._sync_with_cpp() # rename grad op's var_name which is not in 'vars_in_memory' for key in var_name_dict: self.reset_op_dist_attr(grad_op, var_name_dict) _rename_arg_([grad_op.desc], key, var_name_dict[key]) # insert recomputed ops if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id: fwd_op_id = dist_op_context.grad_op_id_to_op_id[ grad_op.desc.id()] if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]: idx = grad_op.idx while idx - 1 >= 0 and ops[idx - 1].type == "sum": idx -= 1 segment_descs = ckpt_ops_dict[fwd_op_id][1] for _, op_desc in reversed(list(enumerate(segment_descs))): rc_desc = main_block.desc._insert_op(idx) rc_desc.copy_from(op_desc) rc_op = Operator(main_block, rc_desc) main_block.ops.insert(idx, rc_op) # set recomputed ops' dist attr fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( rc_desc.original_id()) assert fwd_op_dist_attr is not None self.set_op_dist_attr(rc_op, fwd_op_dist_attr, var_name_dict) ckpt_ops_dict[fwd_op_id][0] = False main_block._sync_with_cpp() main_programs._sync_with_cpp()
def _get_gm_cond_var(main_program, k_steps, dist_context): main_block = main_program.global_block() # Add const var k_step_var = layers.create_global_var(name="gradient_merge_k", shape=[1], value=int(k_steps), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks) zero_var = layers.create_global_var(name="gradient_merge_zero", shape=[1], value=int(0), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks) # Add step var & cond var step_var = layers.create_global_var(name="gradient_merge_step", shape=[1], value=int(0), dtype='int32', persistable=True, force_cpu=True) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) cond_var = main_block.create_var(name="gradient_merge_cond", shape=[1], dtype='bool') set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) with device_guard("cpu"): # step_var = (step_var + 1) % k_step layers.increment(x=step_var, value=1.0, in_place=True) elementwise_mod_op = main_block.append_op(type='elementwise_mod', inputs={ 'X': step_var, 'Y': k_step_var }, outputs={'Out': step_var}, attrs={ 'axis': -1, 'use_mkldnn': False }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mod_op, world_process_group.ranks, [-1], dist_context) # cond_var = (step_var == 0) equal_op = main_block.append_op(type='equal', inputs={ 'X': step_var, 'Y': zero_var }, outputs={'Out': cond_var}) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( equal_op, world_process_group.ranks, [-1], dist_context) return cond_var
def _append_gradient_merge_backward_op( main_program, startup_program, params_grads: List[Tuple[Any, Any]], cond_var_name: str, dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: main_block = main_program.global_block() startup_block = startup_program.global_block() # step1: remove grad.op's op_role_var for param, grad in params_grads: assert ( param.type != core.VarDesc.VarType.SELECTED_ROWS ), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now" _remove_op_role_var(param, grad) param_to_gradient_merge = {} new_params_to_grads = [] # step2: create gradient_merge var and init with 0 for param, grad in params_grads: param_name = param.name param_var = main_block.var(param_name) assert (param_var is not None) ref_dist_attr = dist_context.get_tensor_dist_attr_for_program( param_var) assert ref_dist_attr is not None gradient_merge_var = main_block.create_var(name=param_name + "@GRAD@GradientMerge", shape=param_var.shape, dtype=param_var.dtype, persistable=True) param_to_gradient_merge[param_name] = gradient_merge_var ref_process_mesh = ref_dist_attr.process_mesh ref_dims_mapping = ref_dist_attr.dims_mapping set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh) startup_gradient_merge_var = startup_block.create_var( name=param_name + "@GRAD@GradientMerge", shape=param_var.shape, dtype=param_var.dtype, persistable=True) startup_block.append_op(type="fill_constant", outputs={"Out": startup_gradient_merge_var}, attrs={ "shape": param_var.shape, "dtype": param_var.dtype, "value": float(0), }) # grad_merge += grad new_grad_op = main_block.append_op(type="elementwise_add", inputs={ 'X': grad, 'Y': gradient_merge_var }, outputs={'Out': gradient_merge_var}, attrs={ 'axis': -1, 'use_mkldnn': False }) new_params_to_grads.append([param, gradient_merge_var]) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) return new_params_to_grads, param_to_gradient_merge
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: # cast loss here will change the effective loss tensor for the computation graph # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge), # so we it is not allowed by now. fixed it in future. raise NotImplementedError( "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list." ) tmp_name = unique_name.generate(loss.name + ".cast_fp32") cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( loss) ref_mesh = loss_op_dist_attr.process_mesh self.dist_context.set_tensor_dist_attr_for_program( cast_loss, loss_dist_attr) loss_op_idx = find_op_index(main_block.desc, loss_op.desc) cast_op = main_block._insert_op( loss_op_idx + 1, type='cast', inputs={'X': [loss]}, outputs={'Out': [cast_loss]}, attrs={ "in_dtype": loss.dtype, "out_dtype": core.VarDesc.VarType.FP32, 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, [-1], self.dist_context) loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling" ) or self.get_attr("init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={ 'X': [loss], 'Y': [self._loss_scaling] }, outputs={'Out': [self._scaled_loss]}, attrs={ 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype, dist_context): """ only for backward cast """ def _keep_fp32_input(op, in_name): op_type = op.type if op_type in ['layer_norm_grad']: return in_name not in {'X', 'Y@GRAD'} return False def _keep_fp32_output(op, out_name): op_type = op.type if op_type in ['layer_norm_grad']: return out_name != 'X@GRAD' return False num_cast_ops = 0 original_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id] for in_name in grad_op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( grad_op, in_name): for in_var_name in grad_op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) assert in_var.dtype == core.VarDesc.VarType.FP32 continue for in_var_name in grad_op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) if in_var.dtype == src_dtype: consume_op_attr = dist_context.get_op_dist_attr_for_program( grad_op) if in_var_name in self._var_name_dict[fwd_op_id]: # NOTE: if in_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr. cast_name = self._var_name_dict[fwd_op_id][in_var_name] grad_op.desc._rename_input(in_var_name, cast_name) in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_name) consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) else: assert in_var.dtype == dst_dtype for out_name in grad_op.output_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output( grad_op, out_name): for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) assert out_var.dtype == core.VarDesc.VarType.FP32 continue for out_var_name in grad_op.output(out_name): out_var = self._block._find_var_recursive(out_var_name) out_var_name_prefix = out_var_name[:out_var_name.find("@")] fwd_var = self._block._find_var_recursive(out_var_name_prefix) # NOTE: the out_var's dtype of consume grad_op should equal to the fwd_var's dtype if out_var.dtype != fwd_var.dtype: out_var.desc.set_dtype(fwd_var.dtype) if out_var.dtype == src_dtype: if out_var_name_prefix in self._var_name_dict[fwd_op_id]: # NOTE: if out_var of consume grad_op has been casted before, # it should be renamed and reset dist_attr, then we insert cast op to # convert the cast_var to original dtype consume_op_attr = dist_context.get_op_dist_attr_for_program( grad_op) fwd_cast_name = self._var_name_dict[fwd_op_id][ out_var_name_prefix] cast_name = fwd_cast_name + "@GRAD" cast_var = self._block.vars.get(cast_name) if cast_var is None or cast_var.dtype != dst_dtype: grad_op.desc._rename_output( out_var_name, cast_name) out_var_dist_attr = consume_op_attr.get_output_dist_attr( out_var_name) ref_mesh = out_var_dist_attr.process_mesh ref_mapping = out_var_dist_attr.dims_mapping consume_op_attr.set_output_dist_attr( cast_name, out_var_dist_attr) assert ref_mapping is not None cast_var = self._block.create_var( name=cast_name, shape=out_var.shape, dtype=dst_dtype, persistable=False, stop_gradient=out_var.stop_gradient) set_var_dist_attr(dist_context, cast_var, ref_mapping, ref_mesh) cast_op = self._block._insert_op( idx + 1, type="cast", inputs={"X": cast_var}, outputs={"Out": out_var}, attrs={ "in_dtype": cast_var.dtype, "out_dtype": out_var.dtype, "op_role": OpRole.Backward }) cast_op._remove_attr("op_role_var") cast_op._remove_attr("op_namescope") cast_op._remove_attr("with_quant_attr") naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context) num_cast_ops += 1 else: assert out_var.dtype == dst_dtype return num_cast_ops
def _insert_cast_op_forward(self, op, idx, src_dtype, dst_dtype, dist_context): """ only for forward cast modified from paddle.fluid.contrib.mixed_precision """ num_cast_ops = 0 for in_name in op.input_names: var_name_dict = {} if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( op, in_name): continue for in_var_name in op.input(in_name): in_var = self._block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dst_dtype: continue if in_var.dtype == src_dtype: cast_name = in_var.name + '.cast_' + _dtype_to_str( dst_dtype) out_var = self._block.vars.get(cast_name) var_name_dict[in_var.name] = cast_name consume_op_attr = dist_context.get_op_dist_attr_for_program( op) assert consume_op_attr is not None if out_var is None or out_var.dtype != dst_dtype: # NOTE we make the cast op and var's dist attr as the op that consume the # cast var instead of the op which generates the var in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var.name) assert in_var_dist_attr is not None ref_mesh = in_var_dist_attr.process_mesh ref_mapping = in_var_dist_attr.dims_mapping consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) out_var = self._block.create_var( name=cast_name, dtype=dst_dtype, persistable=False, stop_gradient=in_var.stop_gradient) set_var_dist_attr(dist_context, out_var, ref_mapping, ref_mesh) cast_op = self._block._insert_op_without_sync( idx, type="cast", inputs={"X": in_var}, outputs={"Out": out_var}, attrs={ "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, }) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, ref_mapping, dist_context) num_cast_ops += 1 else: in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var.name) consume_op_attr.set_input_dist_attr( cast_name, in_var_dist_attr) _rename_arg(op, in_var.name, cast_name) else: if op.has_attr('in_dtype'): op._set_attr('in_dtype', dst_dtype) self._var_name_dict[op.desc.original_id()] = var_name_dict if src_dtype == core.VarDesc.VarType.FP32 and dst_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = self._block.var(out_var_name) if out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: out_var.desc.set_dtype(core.VarDesc.VarType.FP16) if op.has_attr('out_dtype'): op._set_attr('out_dtype', core.VarDesc.VarType.FP16) return num_cast_ops
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling") or self.get_attr( "init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={'X': [loss], 'Y': [self._loss_scaling]}, outputs={'Out': [self._scaled_loss]}, attrs={'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()