def construct(self, x): out_conv = self.conv(x, self.weight) # BN fold1 batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold( out_conv, self.moving_mean, self.moving_variance, self.step) # fake weight weight = self.correct_mul(self.weight, self.gamma, running_std) if self.fake: weight = self.fake_quant_weight(weight) out = self.conv(x, weight) # BN fold2 if self.is_gpu: if self.training: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) F.control_depend(out, self.assignadd(self.step, self.one)) else: out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) else: if self.training: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) F.control_depend(out, self.assignadd(self.step, self.one)) else: out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std) return out
def broadcast_params(self, optim_result): """ Apply Broadcast operations in the sequential order of parameter groups. Returns: bool, the status flag. """ param_group = [] key_group = [] for _ in range(self.dev_num): param_group.append(F.make_tuple()) key_group.append(F.make_tuple()) for i in range(self.param_length): param_group[self.param_rank[i]] = param_group[ self.param_rank[i]] + (self.parameters[i], ) key = P.MakeRefKey(self.param_names[i])() key_group[ self.param_rank[i]] = key_group[self.param_rank[i]] + (key, ) new_param_group = [] for root in range(self.dev_num): ops = P.Broadcast(root) next_params = ops(param_group[root]) new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) status = F.control_depend(optim_result, new_param_group[0][0]) for i in range(self.dev_num - 1): status = F.depend( F.control_depend(new_param_group[i], new_param_group[i + 1][0]), status) return status
def construct(self, gradients): params = self.parameters if self.weight_decay > 0: gradients = self.hyper_map( F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) if self.reciprocal_scale != 1.0: gradients = self.hyper_map( F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, self.one)) else: lr = self.learning_rate if self.centered: success = self.hyper_map( F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.mg, self.ms, self.moment, gradients) else: success = self.hyper_map( F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.ms, self.moment, gradients) return success
def construct(self, gradients): step = self.min(self.global_step, self.decay_steps) p = step / self.decay_steps lr = self.diff_learning_rate * \ self.pow(self.one - p, self.power) + self.end_learning_rate if self.warmup_flag: warmup_percent = self.global_step / self.warmup_steps warmup_lr = self.start_learning_rate * warmup_percent is_warmup = self.cast(self.greater( self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr if self.enable_graph_kernel: optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor, self.global_step), self.params, self.moments1, self.moments2, gradients, self.decay_flag) else: optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor, self.global_step), self.params, self.moments1, self.moments2, gradients, self.decay_flag, self.optim_filter) if self.use_parallel: optim_result = self.broadcast_params(optim_result) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step return optim_result
def construct(self, gradients): params = self.parameters moment1 = self.moment1 moment2 = self.moment2 if self.weight_decay > 0: gradients = self.hyper_map( F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) if self.reciprocal_scale != 1.0: gradients = self.hyper_map( F.partial(grad_scale, self.reciprocal_scale), gradients) lr = self.learning_rate if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, self.one)) beta1_power = self.beta1_power * self.beta1 self.beta1_power = beta1_power beta2_power = self.beta2_power * self.beta2 self.beta2_power = beta2_power success = self.hyper_map( F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1, self.beta2, self.eps), gradients, params, moment1, moment2) return success
def construct(self, x): if self.training: beta = self.beta gamma = self.gamma gmean = self.moving_mean gvar = self.moving_variance step = self.step out_conv = self.conv(x, self.weight) batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train( out_conv, gmean, gvar, step) # BN fold1 weight = self.correct_mul(self.weight, gamma, running_std) if self.fake: weight = self.fake_quant_weight(weight) out = self.conv(x, weight) # BN fold2 out = self.batchnorm_fold2(out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step) F.control_depend(out, self.assignadd(self.step, self.one)) else: step = self.step batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( x, self.moving_mean, self.moving_variance, step) weight = self.correct_mul(self.weight, self.gamma, running_std) if self.fake: weight = self.fake_quant_weight(weight) out = self.conv(x, weight) out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, step) return out
def tensor_grad_scale(scale, grad, accu_grad): #mul = P.Mul() new_grad = accu_grad * reciprocal(scale) zeros = F.tensor_mul(accu_grad, 0.0) clear = F.assign(accu_grad, zeros) F.control_depend(new_grad, clear) F.control_depend(grad, new_grad) return new_grad
def construct(self, x, b, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(x, b) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # update accumulation parameters is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss) mean_loss = self.accu_loss / self.local_step is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) mean_loss = F.depend(mean_loss, accu_succ) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) overflow = self.less_equal(self.base, flag_sum) overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) is_accu_step = self.reshape(is_accu_step, (())) if is_accu_step: succ = False else: # apply grad reducer on grads grads = self.grad_reducer(self.accu_grads) scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), grads) grads = ClipByGlobalNorm()(grads) accu_overflow = self.overflow_reducer(accu_overflow) F.control_depend(grads, accu_overflow) overflow = self.less_equal(self.base, accu_overflow) accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) overflow = F.depend(overflow, accu_succ) overflow = self.reshape(overflow, (())) if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, overflow) if overflow: succ = False else: succ = self.optimizer(grads) ret = (mean_loss, overflow, scaling_sens) return F.depend(ret, succ)
def get_lr(self): """ Get the learning rate of current step. Returns: float, the learning rate of current step. """ lr = self.learning_rate if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, 0) F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr
def construct(self, gradients): step = self.min(self.global_step, self.decay_steps) p = step / self.decay_steps lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor), self.params, self.moments1, self.moments2, gradients, self.decay_flag) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step return updated_velocity
def construct(self, input_ids, input_position, attention_mask, past=None, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_position, attention_mask) if sens is None: scaling_sens = self.loss_scale scaling_sens = self.reshape(scaling_sens, (1,)) else: scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() status_clear = self.clear_before_grad(init) #clear_depend = self.control(status_clear, self.weights) grads = self.grad(self.network, weights)(input_ids, input_position, attention_mask, self.cast(scaling_sens / self.micro_size, mstype.float32)) get_status = self.get_status(init) get_status_depend = F.control_depend(grads, get_status) flag_sum = self.reduce_sum(init, (0,)) flag_sum_depend = F.control_depend(get_status, flag_sum) loss = F.depend(loss, status_clear) loss = F.depend(loss, get_status_depend) loss = F.depend(loss, flag_sum_depend) # apply grad reducer on grads accu_grads = self.grad_reducer(self.accu_grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) grads, global_norms = self.clip(grads) global_norm = P.Reshape()(global_norms, (())) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, overflow, scaling_sens, global_norm) return F.depend(ret, succ)
def construct(self, gradients): lr = self.get_lr() if self.enable_graph_kernel: if self.is_group: if self.is_group_lr: optim_result = self.hyper_map( F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, self.global_step), lr, self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: optim_result = self.hyper_map( F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, self.global_step, lr), self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: optim_result = self.hyper_map( F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, self.global_step, lr, self.weight_decay), self.params, self.moments1, self.moments2, gradients, self.decay_flags) else: if self.is_group: if self.is_group_lr: optim_result = self.hyper_map( F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, self.global_step), lr, self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: optim_result = self.hyper_map( F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, self.global_step, lr), self.weight_decay, self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) else: optim_result = self.hyper_map( F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, self.global_step, lr, self.weight_decay), self.params, self.moments1, self.moments2, gradients, self.decay_flags, self.optim_filter) if self.use_parallel: self.broadcast_params(optim_result) if not self.dynamic_lr: F.control_depend(lr, self.assignadd(self.global_step, 1)) return optim_result
def construct(self, gradients): params = self.parameters if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr), gradients, params, self.decay_flag, self.lars_flag) success = self.opt(grad_t) return success
def construct(self, gradients): params = self.params accum = self.accum stat = self.stat if self.reciprocal_scale != 1.0: gradients = self.hyper_map( F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate success = self.hyper_map( F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) return success
def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag): """ Update parameters. Args: beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. weight_decay (Number): Weight decay. Should be equal to or greater than 0. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. decay_flag (bool): Specifies whether param update with weight decay. Returns: Tensor, the new value of v after updating. """ op_mul = P.Mul() op_square = P.Square() op_cast = P.Cast() op_shape = P.Shape() op_pow = P.Pow() op_norm = layer.Norm() op_fill = P.Fill() op_dtype = P.DType() param_fp32 = op_cast(param, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32) i6_ex = op_cast(global_step + num_one, mstype.float32) i9 = op_cast(num_one, mstype.float32) - beta1 x1 = op_cast(num_one, mstype.float32) - beta2 i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex) i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex) i1 = op_square(gradient_fp32) add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9, beta2, x1, weight_decay, eps) if decay_flag: update = update + op_mul(weight_decay, param_fp32) w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) g_norm_hat = op_norm(add3) zeros = F.zeros_like(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0) next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, param, zeros, ones, tens) next_v = F.control_depend(add3, next_param) return next_v
def construct(self, gradients): step = self.min(self.global_step, self.decay_steps) p = step / self.decay_steps lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate if self.warmup_flag: warmup_percent = self.global_step / self.warmup_steps warmup_lr = self.start_learning_rate * warmup_percent is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor), self.params, self.moments1, self.moments2, gradients, self.decay_flag) added_global_step = self.global_step + self.one F.control_depend(lr, added_global_step) self.global_step = added_global_step return updated_velocity
def get_lr(self): """ Get the learning rate of current step. Returns: float, the learning rate of current step. """ lr = self.learning_rate if self.dynamic_lr: if self.is_group_lr: lr = () for learning_rate in self.learning_rate: current_dynamic_lr = learning_rate(self.global_step) lr += (current_dynamic_lr,) else: lr = self.learning_rate(self.global_step) F.control_depend(lr, self.assignadd(self.global_step, self.global_step_increase_tensor)) return lr
def get_lr(self): """ Get the learning rate of current step. Returns: float, the learning rate of current step. """ if self.is_group_lr: lr = self.learning_rate if self.dynamic_lr: lr = () for i in range(self.param_length): current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) lr += (current_dynamic_lr,) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, 0) F.control_depend(lr, self.assignadd(self.global_step, 1)) return lr
def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible, sens=None): """BertSquad""" weights = self.weights init = self.alloc_status() loss = self.network(input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible, self.cast(scaling_sens, mstype.float32)) clear_before_grad = self.clear_before_grad(init) F.control_depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) flag = self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) F.control_depend(grads, flag) F.control_depend(flag, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return F.depend(ret, succ)
def construct(self, input_ids, input_mask, label_ids, sens=None): """Bert Finetune""" weights = self.weights init = False loss = self.network(input_ids, input_mask, label_ids) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens if not self.gpu_target: init = self.alloc_status() clear_before_grad = self.clear_before_grad(init) F.control_depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, label_ids, self.cast(scaling_sens, mstype.float32)) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) if not self.gpu_target: flag = self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) F.control_depend(grads, flag) F.control_depend(flag, flag_sum) else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) flag_sum = self.reshape(flag_sum, (())) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return F.depend(ret, succ)
def construct(self, input_ids, input_mask, token_type_id, label_ids, sens=None): """Defines the computation performed.""" weights = self.weights saved = () for i in range(self.length): saved = saved + (F.assign(self.saved_params[i], weights[i]), ) assign_embedding = () for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding( weights[self.quant_embedding_list[i]]) assign_embedding = assign_embedding + (F.assign( weights[self.quant_embedding_list[i]], quant_embedding), ) F.control_depend(saved, assign_embedding[i]) assign_weight = () for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight( weights[self.quant_weight_list[i]]) assign_weight = assign_weight + (F.assign( weights[self.quant_weight_list[i]], quant_weight), ) F.control_depend(saved, assign_weight[i]) for i in range(self.quant_embedding_list_length): F.control_depend(assign_embedding[i], input_ids) for i in range(self.quant_weight_list_length): F.control_depend(assign_weight[i], input_ids) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # alloc status and clear should be right before grad operation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(scaling_sens, mstype.float32)) F.control_depend(input_ids, grads) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map( F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map( F.partial(clip_grad, gradient_cfg.clip_type, gradient_cfg.clip_value), grads) restore = () for i in range(self.length): restore = restore + (F.assign(weights[i], self.saved_params[i]), ) F.control_depend(grads, restore[i]) self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) for i in range(self.length): F.control_depend(restore[i], succ) return succ
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # update accumulation parameters is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) self.loss = self.select(is_accu_step, self.loss + loss, loss) mean_loss = self.loss / self.local_step is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(scaling_sens, mstype.float32)) accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) mean_loss = F.depend(mean_loss, accu_succ) self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) overflow = self.less_equal(self.base, flag_sum) overflow = self.logical_or( self.not_equal(self.accu_overflow, self.zero), overflow) accu_overflow = self.select(overflow, self.one, self.zero) self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero) if is_accu_step: succ = False else: # apply grad reducer on grads grads = self.grad_reducer(self.accu_grads) scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), grads) if self.enable_global_norm: grads = C.clip_by_global_norm(grads, 1.0, None) else: grads = self.hyper_map( F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) accu_overflow = self.overflow_reducer(accu_overflow) F.control_depend(grads, accu_overflow) overflow = self.less_equal(self.base, accu_overflow) accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) overflow = F.depend(overflow, accu_succ) overflow = self.reshape(overflow, (())) if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, overflow) if overflow: succ = False else: succ = self.optimizer(grads) ret = (mean_loss, overflow, scaling_sens) return F.depend(ret, succ)
def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter): """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices values = gradient.values if ps_parameter: op_shape = P.Shape() shapes = (op_shape(param), op_shape(m), op_shape(v), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) success = F.depend( success, pull( push((beta1_power, beta2_power, lr, beta1, beta2, eps, values, indices), shapes), param)) return success if not target: success = F.depend( success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, eps, values, indices)) else: op_mul = P.Mul() op_square = P.Square() op_sqrt = P.Sqrt() scatter_add = P.ScatterAdd(use_locking) assign_m = F.assign(m, op_mul(beta1, m)) assign_v = F.assign(v, op_mul(beta2, v)) grad_indices = gradient.indices grad_value = gradient.values next_m = scatter_add( m, grad_indices, op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value)) next_v = scatter_add( v, grad_indices, op_mul(F.tuple_to_array((1.0, )) - beta2, op_square(grad_value))) if use_nesterov: m_temp = next_m * _scaler_ten assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) div_value = scatter_add( m, op_mul(grad_indices, _scaler_one), op_mul(F.tuple_to_array((1.0, )) - beta1, grad_value)) param_update = div_value / (op_sqrt(next_v) + eps) m_recover = F.assign(m, m_temp / _scaler_ten) F.control_depend(m_temp, assign_m_nesterov) F.control_depend(assign_m_nesterov, div_value) F.control_depend(param_update, m_recover) else: param_update = next_m / (op_sqrt(next_v) + eps) lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) next_param = param - lr_t * param_update F.control_depend(assign_m, next_m) F.control_depend(assign_v, next_v) success = F.depend(success, F.assign(param, next_param)) success = F.depend(success, F.assign(m, next_m)) success = F.depend(success, F.assign(v, next_v)) return success
def construct(self, input_ids, input_mask, token_type_id, label_ids): """Defines the computation performed.""" weights = self.weights saved = () for i in range(self.length): saved = saved + (F.assign(self.saved_params[i], weights[i]), ) assign_embedding = () for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding( weights[self.quant_embedding_list[i]]) assign_embedding = assign_embedding + (F.assign( weights[self.quant_embedding_list[i]], quant_embedding), ) F.control_depend(saved, assign_embedding[i]) assign_weight = () for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight( weights[self.quant_weight_list[i]]) assign_weight = assign_weight + (F.assign( weights[self.quant_weight_list[i]], quant_weight), ) F.control_depend(saved, assign_weight[i]) for i in range(self.quant_embedding_list_length): F.control_depend(assign_embedding[i], input_ids) for i in range(self.quant_weight_list_length): F.control_depend(assign_weight[i], input_ids) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(F.tuple_to_array((self.sens, )), mstype.float32)) F.control_depend(input_ids, grads) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map( F.partial(clip_grad, gradient_cfg.clip_type, gradient_cfg.clip_value), grads) restore = () for i in range(self.length): restore = restore + (F.assign(weights[i], self.saved_params[i]), ) F.control_depend(grads, restore[i]) succ = self.optimizer(grads) for i in range(self.length): F.control_depend(restore[i], succ) return succ