def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(ops.tuple_to_array((self.sens,)), mstype.float32)) grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) succ = self.optimizer(grads) return ops.depend(loss, succ)
def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return ops.depend( loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
def construct(self, input_ids, token_type_id, pad_mask, sens=None): """construct BertPoetryCell""" weights = self.weights loss = self.network(input_ids, token_type_id, pad_mask) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, token_type_id, pad_mask, self.cast(scaling_sens, mstype.float32)) grads = self.hyper_map(ops.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) cond = self.get_overflow_status(status, grads) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return ops.depend(ret, succ)
def construct(self, *inputs): """Defines the computation performed.""" weights = self.weights loss = self.network(*inputs) scaling_sens = self.scale_sense status, scaling_sens = self.start_overflow_check(loss, scaling_sens) scaling_sens_filled = ops.ones_like(loss) * ops.cast( scaling_sens, ops.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) # accumulate gradients if self.accumulation and self.accumulation_steps > 1: accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) loss = ops.depend(loss, accu_succ) overflow = self.get_overflow_status(status, grads) overflow = self.logical_or( self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) if self.accumulation: succ = False self.accu_overflow = accu_overflow else: self.accu_overflow = self.zero # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(ops.partial(_grad_scale, scaling_sens), grads) accu_overflow = self.allreduce(accu_overflow) overflow = self.less_equal(self.base, accu_overflow) accu_grads = ops.depend(self.accu_grads, grads) accu_succ = self.hyper_map(reset_accu_grads, accu_grads) overflow = ops.depend(overflow, accu_succ) overflow = self.reshape(overflow, (())) overflow = self.process_loss_scale(overflow) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, overflow, scaling_sens) return ops.depend(ret, succ)
def construct(self): seccess = self.hyper_map(ops.partial(_clear_op), self.grad_sum, self.zeros) return seccess
def construct(self, scale, grads): grads = self.hyper_map(ops.partial(grad_scale, scale), grads) return grads
def construct(self, grads): grads = self.hyper_map(ops.partial(clip_grad, self.clip_norm), grads) return grads