def construct(self, *args): weights = self.weights loss = self.network(*args) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*args, sens) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) if self.use_global_norm: grads = self.hyper_map(F.partial(grad_scale, F.scalar_to_array(self.sens)), grads) grads = C.clip_by_global_norm(grads) return F.depend(loss, self.optimizer(grads))
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # update accumulation parameters is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) mean_loss = self.accu_loss / self.local_step is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(scaling_sens, mstype.float32)) accu_grads = self.hyper_map(add_grads, self.accu_grads, grads) scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads) grads = self.grad_reducer(grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) flag_reduce = self.overflow_reducer(flag_sum) overflow = self.less_equal(self.base, flag_reduce) overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) overflow = self.reshape(overflow, (())) if is_accu_step: succ = False accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) succ = F.depend(succ, accu_succ) else: if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, overflow) if overflow: succ = False else: if self.enable_global_norm: grads = C.clip_by_global_norm(grads, 1.0, None) else: grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) succ = self.optimizer(grads) accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) succ = F.depend(succ, accu_succ) ret = (mean_loss, overflow, scaling_sens) return F.depend(ret, succ)