def construct(self, gradients): params = self.params moments = self.moments if self.weight_decay > 0: gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, gradients) if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, self.one)) else: lr = self.learning_rate success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) return success
def construct(self, grads): params = self.parameters moments = self.moments linear = self.linear lr = self.learning_rate if self.weight_decay > 0.0: grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) grads = self.scale_grad(grads) success = self.map_( F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), linear, grads, params, moments) return success
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(scaling_sens, mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map( F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map( F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, succ)
def construct(self, grads): params = self.parameters accum = self.accum grads = self.decay_weight(grads) grads = self.scale_grad(grads) grads = self.gradients_centralization(grads) lr = self.get_lr() if self.is_group_lr: success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, grads) else: success = self.map_(F.partial(_ada_grad_opt, self.opt, lr), params, accum, grads) return success
def construct(self, input_ids, input_mask, token_type_id, label_ids, sens=None): weights = self.weights init = self.alloc_status() loss = self.network(input_ids, input_mask, token_type_id, label_ids) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(scaling_sens, mstype.float32)) clear_before_grad = self.clear_before_grad(init) F.control_depend(loss, init) self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) flag = self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) F.control_depend(grads, flag) F.control_depend(flag, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond) return F.depend(ret, succ)
def construct(self, gradients): params = self.params moments = self.moments gradients = self.scale_grad(gradients) new_grads = () if self.skfac: for i in range(54): g = gradients[i * 3] g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] g = self.matmul(self.matmul(matrix_G, g), matrix_A) fake_A = self.assign(self.matrix_A[i], matrix_A) fake_G = self.assign(self.matrix_G[i], matrix_G) g = F.depend(g, fake_A) g = F.depend(g, fake_G) if i == 53: new_grads = new_grads + (g, ) else: g = self.reshape(g, g_shape) new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) else: for i in range(54): g = gradients[i * 3] g_shape = self.shape(g) g = self.reshape(g, (g_shape[0], -1)) matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] matrix_A = F.depend(matrix_A, g) matrix_G = F.depend(matrix_G, g) g = self.matmul(self.matmul(matrix_G, g), matrix_A) if i == 53: new_grads = new_grads + (g, ) else: g = self.reshape(g, g_shape) new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2]) gradients = new_grads if self.weight_decay > 0: gradients = self.hyper_map( F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) lr = self.get_lr() success = self.hyper_map( F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) return success
def construct(self, gradients): params = self.parameters if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) grad_t = self.hyper_map(F.partial(lars_opt, self.lars, self.weight_decay, lr), gradients, params, self.decay_flag, self.lars_flag) success = self.opt(grad_t) return success
def construct(self, grads): params = self.parameters accum = self.accum grads = self.decay_weight(grads) grads = self.scale_grad(grads) lr = self.get_lr() if self.is_group_lr: success = self.map_( F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, grads, params, accum) else: success = self.map_( F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2, lr), grads, params, accum) return success
def construct(self, grads): params = self.parameters moments = self.moments linear = self.linear if self.weight_decay > 0.0: grads = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_tf, params, grads) if self.reciprocal_scale != 1.0: grads = self.hyper_map( F.partial(grad_scale, self.reciprocal_scale), grads) lr = self.learning_rate success = self.hyper_map( F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power), linear, grads, params, moments) return success
def construct(self, data, label): weights = self.weights loss = self.network(data, label) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(data, label, sens) norm = self.hyper_map(F.partial(compute_norm), grads) norm = self.concat(norm) norm = self.norm(norm) cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) grads = self.hyper_map(F.partial(grad_div, clip_val), grads) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) return F.depend(loss, self.optimizer(grads))
def construct(self, gradients): params = self.params moments = self.moments gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: success = self.hyper_map( F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, self.ps_parameters) else: success = self.hyper_map( F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments, self.ps_parameters) return success
def construct(self, grads): """construct of DistributedGradReducerThor""" # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, # and cast back after the operation. datatypes = self.hyper_map(F.partial(_get_datatype), grads) grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) if self.mean: new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) else: new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) return new_grad
def construct(self, gradients): params = self.parameters accum = self.accum stat = self.stat gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: success = self.hyper_map( F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) else: success = self.hyper_map( F.partial(_sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) return success
def construct(self, x): square_sum = self.hyper_map(get_square_sum, x) global_norm = F.sqrt(F.addn(square_sum)) cond = self.greater_equal(global_norm, self.clip_norm) global_norm = F.select(cond, global_norm, self.clip_norm) clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) return clip_x
def construct(self, input_ids, input_mask, token_type_id, label_ids): """Defines the computation performed.""" weights = self.weights for i in range(self.length): F.assign(self.saved_params[i], weights[i]) for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding( weights[self.quant_embedding_list[i]]) F.assign(weights[self.quant_embedding_list[i]], quant_embedding) for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight( weights[self.quant_weight_list[i]]) F.assign(weights[self.quant_weight_list[i]], quant_weight) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(F.tuple_to_array((self.sens, )), mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map( F.partial(clip_grad, self.clip_type, self.clip_value), grads) for i in range(self.length): param = F.depend(self.saved_params[i], grads) F.assign(weights[i], param) succ = self.optimizer(grads) return succ
def construct(self, gradients): # TODO: perform all_reduce # gradients = self._map(self._all_reduce, gradients) self.acc_step = self.acc_step + 1 q = self.mod_op(self.acc_step, self.apply_period) # log_gradient = True log_gradient = False if log_gradient: gradients = self.hyper_map(log_tensor, gradients) accu_grads = self.hyper_map(add_grads, self.accu_grads, gradients) accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads) if q == 0: mean_grads = self.hyper_map( F.partial(grad_scale, self.apply_period), accu_grads) apply_succ = super(CumulativeSGDOptimizer, self).construct(mean_grads) reset_succ = self.hyper_map(reset_accu_grads, self.accu_grads) succ = F.depend(reset_succ, apply_succ) else: succ = True succ = F.depend(succ, accu_succ) return F.depend(gradients, succ)
def construct(self, gradients): lr = self.get_lr() updated_velocity = self.map( F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor), self.params, self.moments1, self.moments2, gradients, self.decay_flag) return updated_velocity
def construct(self, data, label): """ construct a compute flow. """ weights = self.weights record_datas = self._split(data) record_labels = self._split(label) loss = self.network(record_datas[0], record_labels[0]) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) record_grad = self.grad(self.network, weights)(record_datas[0], record_labels[0], sens) record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) grads = record_grad total_loss = loss for i in range(1, self._micro_batches): loss = self.network(record_datas[i], record_labels[i]) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) record_grad = self.grad(self.network, weights)(record_datas[i], record_labels[i], sens) record_grad = self._clip_by_global_norm(record_grad, GRADIENT_CLIP_TYPE, self._l2_norm) grads = self._tuple_add(grads, record_grad) total_loss = P.TensorAdd()(total_loss, loss) loss = P.Div()(total_loss, self._micro_float) if self._mech is not None: grad_noise = self._hyper_map(self._mech, grads) grads = self._tuple_add(grads, grad_noise) grads = self._hyper_map(F.partial(_grad_scale, self._micro_float), grads) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) return F.depend(loss, self.optimizer(grads))
def construct(self, gradients): params = self.parameters if self.reciprocal_scale != 1.0: gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, self.one)) else: lr = self.learning_rate if self.centered: success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.mg, self.ms, self.moment, gradients) else: success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.ms, self.moment, gradients) return success
def construct(self, gradients): params = self.params accum = self.accum stat = self.stat if self.reciprocal_scale != 1.0: gradients = self.hyper_map( F.partial(grad_scale, self.reciprocal_scale), gradients) if self.dynamic_lr: lr = self.gather(self.learning_rate, self.global_step, self.axis) F.control_depend(lr, self.assignadd(self.global_step, 1)) else: lr = self.learning_rate success = self.hyper_map( F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) return success
def construct(self, gradients): params = self.parameters gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() if self.centered: success = self.hyper_map( F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.mg, self.ms, self.moment, gradients) else: success = self.hyper_map( F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, self.momentum), params, self.ms, self.moment, gradients) return success
def construct(self, img, label_indices, text, sequence_length, lab_len): """ Cell's forward Args: img: input label_indices: get from the data generator text: label got from the data generator sequence_length: get from the data generator lab_len: get from the data generator Returns: loss: loss value """ weights = self.weights loss = self.network(img, label_indices, text, sequence_length) scaling_sens = self.scale_sense grads = self.grad(self.network, weights)(img, label_indices, text, sequence_length, self.cast(scaling_sens, mstype.float32)) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.clip_gradients(grads, GRADIENT_CLIP_MIN, GRADIENT_CLIP_MAX) if self.reducer_flag: # apply grad reducer on grads grads = self.grad_reducer(grads) success = self.optimizer(grads) ret = (loss, scaling_sens) return F.depend(ret, success)
def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return F.depend( loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(F.tuple_to_array((self.sens,)), mstype.float32)) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.grad_reducer(grads) succ = self.optimizer(grads) return F.depend(loss, succ)
def construct(self, gradients): updated_velocity = self.hyper_map( F.partial(adam_opt, self.beta1, self.beta2, self.eps, self.lr, self.weight_decay_tensor), self.params, self.moments1, self.moments2, gradients) return updated_velocity
def construct(self, grads): global_norm = self.global_norm(grads) cond = P.GreaterEqual()(global_norm, self.clip_norm) global_norm = F.select(cond, global_norm, self.clip_norm) grads = self.hyper_map( F.partial(apply_global_norm, self.clip_norm, global_norm), grads) return grads
def construct(self, x, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(x) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() self.clear_before_grad(init) grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map( F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0, )) cond = self.less_equal(self.base, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, succ)
def construct(self, input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, sens=None): """Defines the computation performed.""" weights = self.weights loss = self.network(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens status, scaling_sens = self.start_overflow_check(loss, scaling_sens) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, next_sentence_labels, masked_lm_positions, masked_lm_ids, masked_lm_weights, self.cast(scaling_sens, mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) cond = self.get_overflow_status(status, grads) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) if overflow: succ = False else: succ = self.optimizer(grads) ret = (loss, cond, scaling_sens) return F.depend(ret, succ)
def construct(self, grads): success = True weights = self.weights moments = self.moments success = self.hyper_map( F.partial(run_opt, self.opt, self.iter, self.learning_rate, self.momentum), grads, weights, moments) return success
def construct(self, gradients): gradients = self.decay_weight(gradients) gradients = self.scale_grad(gradients) lr = self.get_lr() self.beta1_power = self.beta1_power * self.beta1 self.beta2_power = self.beta2_power * self.beta2 if self.is_group_lr: success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) else: success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr), gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters) return success