def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg, hm_hp, hp_offset, hp_ind, hp_mask): """Defines the computation performed.""" image = self.image(image) weights = self.weights loss = self.network(image, hm, reg_mask, ind, wh, kps, kps_mask, reg, hm_hp, hp_offset, hp_ind, hp_mask) scaling_sens = self.cast(self.loss_scale, mstype.float32) * 2.0 / 2.0 # alloc status and clear should be right before gradoperation init = self.alloc_status() init = ops.depend(init, scaling_sens) clear_status = self.clear_status(init) scaling_sens = ops.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, kps, kps_mask, reg, hm_hp, hp_offset, hp_ind, hp_mask, scaling_sens) grads = self.grad_reducer(grads) grads = self.grad_scale(scaling_sens * self.degree, grads) init = ops.depend(init, grads) get_status = self.get_status(init) init = ops.depend(init, get_status) flag_sum = self.reduce_sum(init, (0, )) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return ops.depend(ret, succ)
def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return ops.depend( loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(ops.tuple_to_array((self.sens,)), mstype.float32)) grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) succ = self.optimizer(grads) return ops.depend(loss, succ)
def construct(self, image, hm, reg_mask, ind, wh, reg): """Defines the computation performed.""" image = self.image(image) weights = self.weights loss = self.network(image, hm, reg_mask, ind, wh, reg) grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, reg) succ = self.optimizer(grads) ret = loss return ops.depend(ret, succ)
def construct(self, realA, realB): """ Define TrainOneStepCell. """ d_loss = self.loss_netD(realA, realB) g_loss = self.loss_netG(realA, realB) d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss), self.sens) d_grads = self.grad(self.loss_netD, self.weights_D)(realA, realB, d_sens) d_res = ops.depend(d_loss, self.optimizerD(d_grads)) g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss), self.sens) g_grads = self.grad(self.loss_netG, self.weights_G)(realA, realB, g_sens) g_res = ops.depend(g_loss, self.optimizerG(g_grads)) return d_res, g_res
def construct(self, *inputs): """Defines the computation performed.""" weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) if self.accumulation and self.accumulation_steps > 1: accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) loss = ops.depend(loss, accu_succ) if self.accumulation: succ = False else: grads = self.grad_reducer(grads) accu_grads = ops.depend(self.accu_grads, grads) accu_succ = self.hyper_map(reset_accu_grads, accu_grads) loss = ops.depend(loss, accu_succ) succ = self.optimizer(grads) return ops.depend(loss, succ)
def construct(self, img_A, img_B, fake_A, fake_B): weights = self.weights ld = self.D(img_A, img_B, fake_A, fake_B) sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens) grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d) if self.reducer_flag: # apply grad reducer on grads grads_d = self.grad_reducer(grads_d) return ops.depend(ld, self.optimizer(grads_d))
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(scaling_sens, mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return ops.depend(ret, succ)
def construct(self, *inputs): """Defines the computation performed.""" weights = self.weights loss = self.network(*inputs) scaling_sens = self.scale_sense status, scaling_sens = self.start_overflow_check(loss, scaling_sens) scaling_sens_filled = ops.ones_like(loss) * ops.cast( scaling_sens, ops.dtype(loss)) grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) # accumulate gradients if self.accumulation and self.accumulation_steps > 1: accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) loss = ops.depend(loss, accu_succ) overflow = self.get_overflow_status(status, grads) overflow = self.logical_or( self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) if self.accumulation: succ = False self.accu_overflow = accu_overflow else: self.accu_overflow = self.zero # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(ops.partial(_grad_scale, scaling_sens), grads) accu_overflow = self.allreduce(accu_overflow) overflow = self.less_equal(self.base, accu_overflow) accu_grads = ops.depend(self.accu_grads, grads) accu_succ = self.hyper_map(reset_accu_grads, accu_grads) overflow = ops.depend(overflow, accu_succ) overflow = self.reshape(overflow, (())) overflow = self.process_loss_scale(overflow) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, overflow, scaling_sens) return ops.depend(ret, succ)
def construct(self, img_A, img_B): weights = self.weights fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens) grads_g = self.grad(self.net, weights)(img_A, img_B, sens) if self.reducer_flag: # apply grad reducer on grads grads_g = self.grad_reducer(grads_g) return fake_A, fake_B, ops.depend( lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
def construct(self, input_ids, token_type_id, pad_mask, sens=None): """construct BertPoetryCell""" weights = self.weights loss = self.network(input_ids, token_type_id, pad_mask) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, token_type_id, pad_mask, self.cast(scaling_sens, mstype.float32)) grads = self.hyper_map(ops.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) cond = self.get_overflow_status(status, grads) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return ops.depend(ret, succ)
def _clear_grad_sum(grad_sum, zero): """Apply zero to clear grad_sum.""" success = True success = ops.depend(success, ops.assign(grad_sum, zero)) return success
def _update_accu_grads(accu_grad, grad): succ = True return ops.depend(succ, ops.assign_add(accu_grad, cast(grad, mstype.float32)))
def _reset_accu_grads(accu_grad): succ = True return ops.depend(succ, ops.assign(accu_grad, zeroslike(accu_grad)))