def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")]) self.opt = optimizer self.parameters = optimizer.parameters self.learning_rate = optimizer.learning_rate self.lars = P.LARSUpdate(epsilon, hyperpara, use_clip) self.reciprocal_scale = 1.0 / loss_scale self.weight_decay = weight_decay * loss_scale self.cast = P.Cast() self.decay_flag = tuple(decay_filter(x) for x in self.parameters) self.lars_flag = tuple(lars_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap() self.dynamic_lr = False self.gather = None self.global_step = None self.axis = None if isinstance(self.learning_rate.default_input, Iterable) or \ (isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1): self.dynamic_lr = True self.assignadd = P.AssignAdd() self.gather = P.GatherV2() self.global_step = Parameter(initializer(0, [1], mstype.int32), name="lars_global_step") self.axis = 0
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) self.opt = optimizer self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() self.parameters = optimizer.parameters if use_clip is True: self.learning_rate = optimizer.learning_rate self.dynamic_lr = optimizer.dynamic_lr self.gather = optimizer.gather self.assignadd = optimizer.assignadd self.global_step = optimizer.global_step else: self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") self.reciprocal_scale = optimizer.reciprocal_scale optimizer.reciprocal_scale = 1.0 self.is_group = optimizer.is_group if self.is_group: self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) else: self.weight_decay = optimizer.weight_decay / optimizer.loss_scale optimizer.exec_weight_decay = False optimizer.weight_decay = 0.0 self.decay_flags = optimizer.decay_flags self.lars_flag = tuple(lars_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap()
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")]) _check_param_value(optimizer, epsilon, coefficient, use_clip, self.cls_name) self.opt = optimizer self.parameters = optimizer.parameters self.use_clip = use_clip self.lars_flag = tuple(lars_filter(x) for x in self.parameters) self.is_group = optimizer.is_group self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") self.decay_flags = optimizer.decay_flags self.reciprocal_scale = optimizer.reciprocal_scale self.hyper_map = C.HyperMap() self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() if use_clip: self.is_group_lr = optimizer.is_group_lr self.dynamic_lr = optimizer.dynamic_lr self.origin_learning_rate = optimizer.learning_rate self.global_step = optimizer.global_step if self.is_group_lr and self.dynamic_lr: raise ValueError('Grouped dynamic learning rate is currently not supported for the inputs optimizer ' \ 'of lars.') if self.is_group: self.weight_decay = tuple(map(lambda x: x / optimizer.loss_scale, optimizer.weight_decay)) optimizer.weight_decay = tuple(map(lambda x: 0.0, optimizer.weight_decay)) else: self.weight_decay = optimizer.weight_decay / optimizer.loss_scale optimizer.weight_decay = 0.0 optimizer.decay_flags = tuple(map(lambda x: False, self.decay_flags)) optimizer.reciprocal_scale = 1.0 optimizer.exec_weight_decay = False