def _update_backward_cast_ops(params_grads, dist_context): """ move param grad cast to the end of backward segment in order to enabel fp16 allreduce """ # TODO filter optimize ops in future main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() for p, g in params_grads: op = g.op if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': if int(op.attr('op_role')) == int(OpRole.Backward) and op.has_attr( 'op_role_var'): op._remove_attr("op_role_var") post_ops = find_true_post_op(main_block.ops, op, g.name) if post_ops: raise ValueError("The cast op {0}'s output should not be" "used by a non-optimize op, however, it" "is used by {1}".format(op, post_ops[0])) if op == main_block.ops[-1]: continue # add new op in the python and cpp at the same time new_op_desc = main_block.desc.append_op() new_op_desc.copy_from(op.desc) new_op = paddle.fluid.framework.Operator( block=main_block, desc=new_op_desc, type=None, inputs=None, outputs=None, attrs=None) main_block.ops.append(new_op) # dist attr param_dist_attr = dist_context.get_tensor_dist_attr_for_program(p) output_dist_attr = dist_context.get_tensor_dist_attr_for_program( main_block.var(op.output_arg_names[0])) assert param_dist_attr is not None assert output_dist_attr is not None naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, param_dist_attr.process_mesh, param_dist_attr.dims_mapping, dist_context) output_dist_attr.process_mesh = param_dist_attr.process_mesh output_dist_attr.dims_mapping = param_dist_attr.dims_mapping op_idx = find_op_index(main_block.desc, op.desc) if op_idx == -1: raise ValueError("The op {0} is not in program".format(op)) main_block._remove_op(op_idx, sync=False) main_block._sync_with_cpp()
def cast_backward_program(self, params_grads, dist_context): self._block._sync_with_cpp() ops = self._block.ops loss_op = get_loss_op(self._block) loss_op_index = find_op_index(self._block.desc, loss_op.desc) idx = loss_op_index + 1 while idx < len(ops): num_cast_ops = 0 grad_op = ops[idx] grad_op_orig_id = grad_op.desc.original_id() dist_op_context = dist_context.dist_op_context if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id: if self._is_fp16_op(grad_op_orig_id) == False: # fp32 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32, dist_context) elif self._is_fp16_op(grad_op_orig_id) == True: # fp16 num_cast_ops = self._insert_cast_op_backward( grad_op, idx, core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP16, dist_context) elif grad_op.type == "sum": in_var_name = grad_op.desc.input_arg_names()[0] src_dtype = self._block.var(in_var_name).dtype for in_var_name in grad_op.desc.input_arg_names(): assert src_dtype == self._block.var(in_var_name).dtype out_var_name = grad_op.desc.output_arg_names()[0] out_var = self._block.var(out_var_name) if out_var.dtype != src_dtype: out_var.desc.set_dtype(src_dtype) elif int(grad_op.attr('op_role')) == 257: pass else: raise ValueError( "'{}' op is not supported in the complete amp pass.". format(grad_op.type)) idx += num_cast_ops + 1 self._block._sync_with_cpp() _update_backward_cast_ops(params_grads, dist_context)
def test_find_op_index(self): block = fluid.default_main_program().global_block() op_desc = core.OpDesc() idx = fp16_utils.find_op_index(block.desc, op_desc) assert (idx == -1)
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: # cast loss here will change the effective loss tensor for the computation graph # and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge), # so we it is not allowed by now. fixed it in future. raise NotImplementedError( "Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list." ) tmp_name = unique_name.generate(loss.name + ".cast_fp32") cast_loss = main_block.create_var(name=tmp_name, dtype=dtype) loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program( loss) ref_mesh = loss_op_dist_attr.process_mesh self.dist_context.set_tensor_dist_attr_for_program( cast_loss, loss_dist_attr) loss_op_idx = find_op_index(main_block.desc, loss_op.desc) cast_op = main_block._insert_op( loss_op_idx + 1, type='cast', inputs={'X': [loss]}, outputs={'Out': [cast_loss]}, attrs={ "in_dtype": loss.dtype, "out_dtype": core.VarDesc.VarType.FP32, 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( cast_op, ref_mesh, [-1], self.dist_context) loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling" ) or self.get_attr("init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={ 'X': [loss], 'Y': [self._loss_scaling] }, outputs={'Out': [self._scaled_loss]}, attrs={ 'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()
def _scale_loss(self): main_block = paddle.static.default_main_program().global_block() main_block._sync_with_cpp() loss = self.get_attr("loss") assert loss is not None loss_op = loss.op loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program( loss_op) if loss.dtype != core.VarDesc.VarType.FP32: loss = loss.astype('float32') if self.get_attr("use_dynamic_loss_scaling") or self.get_attr( "init_loss_scaling") != 1.0: loss_op_idx = find_op_index(main_block.desc, loss_op.desc) # forward ref_mesh = loss_op_dist_attr.process_mesh self._scaled_loss = main_block.create_var( name=unique_name.generate("scaled_loss"), shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss, [-1], ref_mesh) OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() elementwise_mul_op = main_block._insert_op( loss_op_idx + 1, type='elementwise_mul', inputs={'X': [loss], 'Y': [self._loss_scaling]}, outputs={'Out': [self._scaled_loss]}, attrs={'op_role': loss_op.all_attrs()[OP_ROLE_KEY], }) loss_op._set_attr(OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Forward) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_op, ref_mesh, [-1], self.dist_context) # backward first_backward_op = main_block.ops[loss_op_idx + 2] assert first_backward_op.type == "fill_constant" and int( first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 self._scaled_loss_grad = main_block.create_var( name=unique_name.generate("scaled_loss") + "@GRAD", shape=loss.shape, dtype=loss.dtype, persistable=loss.persistable) set_var_dist_attr(self.dist_context, self._scaled_loss_grad, [-1], ref_mesh) pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, self._scaled_loss_grad.name) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op( loss_op_idx + 3) elementwise_mul_grad_op_desc.set_type("elementwise_mul_grad") elementwise_mul_grad_op_desc.set_input( 'Out@GRAD', [self._scaled_loss_grad.name]) elementwise_mul_grad_op_desc.set_input('X', [loss.name]) elementwise_mul_grad_op_desc.set_input('Y', [self._loss_scaling.name]) elementwise_mul_grad_op_desc.set_output('X@GRAD', [pre_grad_name]) elementwise_mul_grad_op_desc.set_output('Y@GRAD', []) elementwise_mul_grad_op_desc._set_attr( OP_ROLE_KEY, core.op_proto_and_checker_maker.OpRole.Backward) elementwise_mul_grad_op_desc._set_attr('axis', -1) elementwise_mul_grad_op = paddle.fluid.framework.Operator( main_block, elementwise_mul_grad_op_desc) main_block.ops.insert(loss_op_idx + 3, elementwise_mul_grad_op) main_block._sync_with_cpp() elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] assert elementwise_mul_grad_op.type == "elementwise_mul_grad" naive_set_dist_op_attr_for_program_by_mesh_and_mapping( elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context) else: self._scaled_loss = loss main_block._sync_with_cpp()