def update_loss_scale(grads): state = mixed_precision_global_state() if state is None or not state.dynamic_scaling: return per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) grad_valid = layers.isfinite(per_grad_check) layers.cond(grad_valid, lambda: state.increment(), lambda: state.decrement()) return grad_valid
def maybe_update(): new_scale = self.scale * self.factor scale_valid = layers.isfinite(new_scale) def update_scale_and_step(): layers.assign(new_scale, self.scale) layers.assign( layers.zeros_like(self.good_steps), self.good_steps) layers.cond(scale_valid, update_scale_and_step)
def update_loss_scale(grads): state = mixed_precision_global_state() if state is None or not state.dynamic_scaling: return per_grad_check = layers.stack([layers.reduce_sum(g) for g in grads]) grad_valid = layers.isfinite(per_grad_check) with layers.Switch() as switch: with switch.case(grad_valid): state.increment() with switch.default(): state.decrement() return grad_valid
def increment(self): enough_steps = layers.less_than(self.increment_every, self.good_steps + 1) with layers.Switch() as switch: with switch.case(enough_steps): new_scale = self.scale * self.factor scale_valid = layers.isfinite(new_scale) with layers.Switch() as switch2: with switch2.case(scale_valid): layers.assign(new_scale, self.scale) layers.assign(layers.zeros_like(self.good_steps), self.good_steps) with switch2.default(): layers.increment(self.good_steps) with switch.default(): layers.increment(self.good_steps)
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps, num_bad_steps, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio): """ Update loss scaling according to overall gradients. If all gradients is finite after incr_every_n_steps, loss scaling will increase by incr_ratio. Otherwise, loss scaling will decrease by decr_ratio after decr_every_n_nan_or_inf steps and each step some gradients are infinite. Args: is_overall_finite (Variable): A boolean variable indicates whether all gradients are finite. prev_loss_scaling (Variable): Previous loss scaling. num_good_steps (Variable): A variable accumulates good steps in which all gradients are finite. num_bad_steps (Variable): A variable accumulates bad steps in which some gradients are infinite. incr_every_n_steps (Variable): A variable represents increasing loss scaling every n consecutive steps with finite gradients. decr_every_n_nan_or_inf (Variable): A variable represents decreasing loss scaling every n accumulated steps with nan or inf gradients. incr_ratio(float): The multiplier to use when increasing the loss scaling. decr_ratio(float): The less-than-one-multiplier to use when decreasing loss scaling. """ zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0) with layers.Switch() as switch: with switch.case(is_overall_finite): should_incr_loss_scaling = layers.less_than( incr_every_n_steps, num_good_steps + 1) with layers.Switch() as switch1: with switch1.case(should_incr_loss_scaling): new_loss_scaling = prev_loss_scaling * incr_ratio loss_scaling_is_finite = layers.isfinite(new_loss_scaling) with layers.Switch() as switch2: with switch2.case(loss_scaling_is_finite): layers.assign(new_loss_scaling, prev_loss_scaling) with switch2.default(): pass layers.assign(zero_steps, num_good_steps) layers.assign(zero_steps, num_bad_steps) with switch1.default(): layers.increment(num_good_steps) layers.assign(zero_steps, num_bad_steps) with switch.default(): should_decr_loss_scaling = layers.less_than( decr_every_n_nan_or_inf, num_bad_steps + 1) with layers.Switch() as switch3: with switch3.case(should_decr_loss_scaling): new_loss_scaling = prev_loss_scaling * decr_ratio static_loss_scaling = \ layers.fill_constant(shape=[1], dtype='float32', value=1.0) less_than_one = layers.less_than(new_loss_scaling, static_loss_scaling) with layers.Switch() as switch4: with switch4.case(less_than_one): layers.assign(static_loss_scaling, prev_loss_scaling) with switch4.default(): layers.assign(new_loss_scaling, prev_loss_scaling) layers.assign(zero_steps, num_good_steps) layers.assign(zero_steps, num_bad_steps) with switch3.default(): layers.assign(zero_steps, num_good_steps) layers.increment(num_bad_steps)
def minimize(self, loss, startup_program=None, parameter_list=None, no_grad_set=None, callbacks=None): assert loss._get_info('shard_logit') shard_logit = loss._get_info('shard_logit') shard_prob = loss._get_info('shard_prob') shard_label = loss._get_info('shard_label') shard_dim = loss._get_info('shard_dim') op_maker = fluid.core.op_proto_and_checker_maker op_role_key = op_maker.kOpRoleAttrName() op_role_var_key = op_maker.kOpRoleVarAttrName() backward_role = int(op_maker.OpRole.Backward) loss_backward_role = int(op_maker.OpRole.Loss) | int( op_maker.OpRole.Backward) # minimize a scalar of reduce_sum to generate the backward network scalar = fluid.layers.reduce_sum(shard_logit) block = loss.block if not self._use_fp16: ret = self._optimizer.minimize(scalar) # remove the unnecessary ops index = 0 for i, op in enumerate(block.ops): if op.all_attrs()[op_role_key] == loss_backward_role: index = i break assert block.ops[index - 1].type == 'reduce_sum' assert block.ops[index].type == 'fill_constant' assert block.ops[index + 1].type == 'reduce_sum_grad' block._remove_op(index + 1) block._remove_op(index) block._remove_op(index - 1) self.insert_commom_backward_op(block, index, shard_logit, shard_prob, shard_label, shard_dim, op_role_key, backward_role, loss_backward_role) return ret else: scaled_params_grads = self.fp16_backward(block, scalar, startup_program, parameter_list, no_grad_set, callbacks) index = 0 for i, op in enumerate(block.ops): if op.all_attrs()[op_role_key] == loss_backward_role: index = i break if self._loss_type == 'dist_arcface': assert block.ops[index - 2].type == 'fill_constant' assert block.ops[index - 1].type == 'reduce_sum' assert block.ops[index].type == 'fill_constant' assert block.ops[index + 1].type == 'reduce_sum_grad' assert block.ops[index + 2].type == 'scale' assert block.ops[index + 3].type == 'elementwise_add_grad' block._remove_op(index + 2) block._remove_op(index + 1) block._remove_op(index) block._remove_op(index - 1) self.insert_dist_arcface_backward_op(block, index, shard_logit, shard_prob, shard_label, shard_dim, op_role_key, backward_role, loss_backward_role) elif self._loss_type == 'dist_softmax': assert block.ops[index - 1].type == 'reduce_sum' assert block.ops[index].type == 'fill_constant' assert block.ops[index + 1].type == 'reduce_sum_grad' assert block.ops[index + 2].type == 'cast' assert block.ops[index + 3].type == 'elementwise_add_grad' block._remove_op(index + 1) block._remove_op(index) block._remove_op(index - 1) self.insert_dist_softmax_backward_op(block, index, shard_logit, shard_prob, shard_label, shard_dim, op_role_key, backward_role, loss_backward_role) if self._use_dynamic_loss_scaling: grads = [ layers.reduce_sum(g) for [_, g] in scaled_params_grads ] all_grads = layers.concat(grads) all_grads_sum = layers.reduce_sum(all_grads) is_overall_finite = layers.isfinite(all_grads_sum) update_loss_scaling(is_overall_finite, self._loss_scaling, self._num_good_steps, self._num_bad_steps, self._incr_every_n_steps, self._decr_every_n_nan_or_inf, self._incr_ratio, self._decr_ratio) with layers.Switch() as switch: with switch.case(is_overall_finite): pass with switch.default(): for _, g in scaled_params_grads: layers.assign(layers.zeros_like(g), g) optimize_ops = self._optimizer.apply_gradients(scaled_params_grads) ret = optimize_ops, scaled_params_grads return ret