Beispiel #1
0
def _check_param_value(beta1, beta2, eps, prim_name):
    """Check the type of inputs."""
    validator.check_value_type("beta1", beta1, [float], prim_name)
    validator.check_value_type("beta2", beta2, [float], prim_name)
    validator.check_value_type("eps", eps, [float], prim_name)
    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1",
                                prim_name)
    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2",
                                prim_name)
    validator.check_positive_float(eps, "eps", prim_name)
Beispiel #2
0
def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
    r"""
    Calculate learning rate base on cosine decay function.

    For the i-th step, the formula of computing decayed_learning_rate[i] is:

    .. math::
        decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
        (1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))

    Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.

    Args:
        min_lr (float): The minimum value of learning rate.
        max_lr (float): The maximum value of learning rate.
        total_step (int): The total number of steps.
        step_per_epoch (int): The number of steps in per epoch.
        decay_epoch (int): A value used to calculate decayed learning rate.

    Returns:
        list[float]. The size of list is `total_step`.

    Examples:
        >>> min_lr = 0.01
        >>> max_lr = 0.1
        >>> total_step = 6
        >>> step_per_epoch = 2
        >>> decay_epoch = 2
        >>> output = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
        >>> print(output)
        [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
    """
    if not isinstance(min_lr, float):
        raise TypeError("min_lr must be float.")
    validator.check_non_negative_float(min_lr, "min_lr", None)
    validator.check_positive_float(max_lr, 'max_lr')
    validator.check_is_float(max_lr, 'max_lr')
    validator.check_positive_int(total_step, 'total_step')
    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
    validator.check_positive_int(decay_epoch, 'decay_epoch')
    if min_lr >= max_lr:
        raise ValueError('`max_lr` should be greater than `min_lr`.')

    delta = 0.5 * (max_lr - min_lr)
    lr = []
    for i in range(total_step):
        tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch)
        lr.append(min_lr + delta *
                  (1 + math.cos(math.pi * tmp_epoch / decay_epoch)))
    return lr
Beispiel #3
0
 def __init__(self,
              max_val=1.0,
              power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
              filter_size=11,
              filter_sigma=1.5,
              k1=0.01,
              k2=0.03):
     super(MSSSIM, self).__init__()
     validator.check_value_type('max_val', max_val, [int, float],
                                self.cls_name)
     validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
     self.max_val = max_val
     validator.check_value_type('power_factors', power_factors,
                                [tuple, list], self.cls_name)
     self.filter_size = validator.check_int(filter_size, 1, Rel.GE,
                                            'filter_size', self.cls_name)
     self.filter_sigma = validator.check_positive_float(
         filter_sigma, 'filter_sigma', self.cls_name)
     self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
     self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
     window = _create_window(filter_size, filter_sigma)
     self.level = len(power_factors)
     self.conv = []
     for i in range(self.level):
         self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
         self.conv[i].weight.requires_grad = False
     self.multi_convs_list = CellList(self.conv)
     self.weight_tensor = Tensor(power_factors, mstype.float32)
     self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
     self.relu = ReLU()
     self.reduce_mean = P.ReduceMean()
     self.prod = P.ReduceProd()
     self.pow = P.Pow()
     self.pack = P.Pack(axis=-1)
     self.concat = P.Concat(axis=1)
Beispiel #4
0
    def __init__(self, smooth=1e-5):
        super(Dice, self).__init__()

        self.smooth = validator.check_positive_float(smooth, "smooth")
        self._dice_coeff_sum = 0
        self._samples_num = 0
        self.clear()
Beispiel #5
0
    def predict_outlier(self, sample_x, threshold=100.0):
        """
        Predict whether the sample is an outlier.

        Args:
            sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W).
            threshold (float): the threshold of the outlier. Default: 100.0.

        Returns:
            Bool, whether the sample is an outlier.
        """
        threshold = Validator.check_positive_float(threshold)
        score = self.predict_outlier_score(sample_x)
        return score >= threshold
Beispiel #6
0
 def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
     super(SSIM, self).__init__()
     validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
     validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
     self.max_val = max_val
     self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
     self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
     self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
     self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
     window = _create_window(filter_size, filter_sigma)
     self.conv = _conv2d(1, 1, filter_size, Tensor(window))
     self.conv.weight.requires_grad = False
     self.reduce_mean = P.ReduceMean()
     self.concat = P.Concat(axis=1)
Beispiel #7
0
    def __init__(self,
                 params,
                 learning_rate=0.1,
                 decay=0.9,
                 momentum=0.0,
                 epsilon=1e-10,
                 use_locking=False,
                 centered=False,
                 loss_scale=1.0,
                 weight_decay=0.0):
        super(RMSProp, self).__init__(learning_rate, params, weight_decay,
                                      loss_scale)
        validator.check_value_type("decay", decay, [float], self.cls_name)
        validator.check_non_negative_float(decay, "decay", self.cls_name)
        validator.check_value_type("momentum", momentum, [float],
                                   self.cls_name)
        validator.check_non_negative_float(momentum, "momentum", self.cls_name)
        validator.check_value_type("epsilon", epsilon, [float], self.cls_name)
        validator.check_positive_float(epsilon, "epsilon", self.cls_name)
        validator.check_value_type("use_locking", use_locking, [bool],
                                   self.cls_name)
        validator.check_value_type("centered", centered, [bool], self.cls_name)

        self.centered = centered
        if centered:
            self.opt = P.ApplyCenteredRMSProp(use_locking)
            self.mg = self.parameters.clone(prefix="mean_grad", init='zeros')
        else:
            self.opt = P.ApplyRMSProp(use_locking)

        self.momentum = momentum
        self.ms = self.parameters.clone(prefix="mean_square", init='ones')
        self.moment = self.parameters.clone(prefix="moment", init='zeros')
        self.hyper_map = C.HyperMap()
        self.epsilon = epsilon
        self.decay = decay
Beispiel #8
0
    def __init__(self,
                 learning_rate,
                 parameters,
                 weight_decay=0.0,
                 loss_scale=1.0):
        super(Optimizer, self).__init__(auto_prefix=False)
        if parameters is not None and not isinstance(parameters, list):
            parameters = list(parameters)

        if not parameters:
            raise ValueError("Optimizer got an empty parameter list.")

        if not isinstance(parameters[0], (dict, Parameter)):
            raise TypeError(
                "Only a list of Parameter or dict can be supported.")

        if isinstance(loss_scale, int):
            loss_scale = float(loss_scale)
        validator.check_value_type("loss_scale", loss_scale, [float],
                                   self.cls_name)
        validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
        self.loss_scale = loss_scale

        weight_decay = self._preprocess_weight_decay(weight_decay)

        self._unique = True
        self._target = context.get_context("device_target")
        self.dynamic_lr = False
        self.assignadd = None
        self.global_step = None
        self.is_group = False
        self.is_group_lr = False
        self.is_group_params_ordered = False
        learning_rate = self._preprocess_single_lr(learning_rate)
        if isinstance(parameters[0], dict):
            self.is_group = True
            self.group_params = []
            self.group_lr = []
            self.group_weight_decay = []
            self._init_group_params(parameters, learning_rate, weight_decay)

        # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
        if self.dynamic_lr:
            self.assignadd = P.AssignAdd()
            self.global_step = Parameter(initializer(0, [1], mindspore.int32),
                                         name='global_step')

        if self.is_group_lr:
            if self.dynamic_lr:
                self.learning_rate = CellList(self.group_lr)
            else:
                self.learning_rate = ParameterTuple(self.group_lr)
        else:
            self.learning_rate = self._build_single_lr(learning_rate,
                                                       'learning_rate')
        if self.is_group:
            self.parameters = ParameterTuple(self.group_params)
            self.weight_decay = tuple(self.group_weight_decay)
            self.weight_decay_tensor_tuple = tuple(
                Tensor(x, mstype.float32) for x in self.group_weight_decay)
            decay_filter = lambda x: x > 0
            self.decay_flags = tuple(
                decay_filter(x) for x in self.weight_decay)
            self.exec_weight_decay = any(self.decay_flags)
        else:
            self.parameters = ParameterTuple(parameters)
            self.weight_decay = weight_decay * loss_scale
            self.weight_decay_tensor = Tensor(self.weight_decay,
                                              mstype.float32)
            decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
            self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
            self.exec_weight_decay = self.weight_decay > 0
        # when a parameter has been unique, there is no need do another unique in optimizer.
        for param in self.parameters:
            if param.unique:
                self._unique = False
                break
        ps_filter = lambda x: x.is_param_ps
        self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
        ps_cache_filter = lambda x: x.cache_enable
        self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
        self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
        self.need_scale = loss_scale != 1.0
        self.global_step_increase_tensor = Tensor(1, mstype.int32)
        self.param_length = len(self.parameters)
        self.map_ = C.Map()
        if context.get_auto_parallel_context("enable_parallel_optimizer"):
            if _get_parallel_mode(
            ) == ParallelMode.DATA_PARALLEL and context.get_context(
                    "device_target") == "Ascend":
                self.use_parallel = True
            elif _get_parallel_mode() == ParallelMode.DATA_PARALLEL \
                    and context.get_context("device_target") != "Ascend":
                raise RuntimeError(
                    "Parallel optimizer only supports Ascend in data parallel mode."
                )
            elif _get_parallel_mode() in (ParallelMode.STAND_ALONE,
                                          ParallelMode.HYBRID_PARALLEL):
                raise RuntimeError(
                    "Parallel optimizer is not supported in {}.".format(
                        _get_parallel_mode()))
            else:
                self.use_parallel = False
        else:
            self.use_parallel = False
        if self.use_parallel:
            if self.cls_name not in ["Lamb", "AdamWeightDecay"]:
                raise RuntimeError(
                    "Parallel optimizer does not support optimizer {}".format(
                        self.cls_name))
            self.dev_num = _get_device_num()
            if self.dev_num > self.param_length:
                raise RuntimeError(
                    "Parallel optimizer can not be applied when the number of parameters {} is"
                    " less than the number of devices {}".format(
                        self.param_length, self.dev_num))
            self.param_rank = self._get_parameter_group_id()
            self.optim_filter = tuple(
                map(lambda x: x == _get_global_rank(), self.param_rank))
            self.param_names = []
            for param in self.parameters:
                self.param_names.append(param.name)

        else:
            self.optim_filter = (True, ) * self.param_length
Beispiel #9
0
 def __init__(self,
              vocab_size,
              embedding_size,
              param_init='normal',
              target='CPU',
              slice_mode='batch_slice',
              manual_shapes=None,
              max_norm=None,
              sparse=True,
              vocab_cache_size=0):
     super(EmbeddingLookup, self).__init__()
     validator.check_value_type('sparse', sparse, [bool], self.cls_name)
     self.vocab_size = validator.check_positive_int(vocab_size,
                                                    'vocab_size')
     self.vocab_cache_size = validator.check_non_negative_int(
         vocab_cache_size, 'vocab_cache_size')
     self.target = target
     self.sparse = sparse
     self.cache_enable = self.vocab_cache_size > 0
     self.forward_unique = False
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     if not sparse and target == 'CPU':
         raise ValueError(
             'When target is CPU, embedding_lookup must be sparse.')
     if sparse:
         self.gatherv2 = P.SparseGatherV2()
     else:
         self.gatherv2 = P.Gather()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
     enable_ps = _get_ps_context("enable_ps")
     if enable_ps:
         self._process_vocab_cache(slice_mode)
     self.embedding_size = validator.check_positive_int(
         embedding_size, 'embedding_size')
     self.embedding_table = Parameter(initializer(
         param_init, [self.vocab_size, self.embedding_size]),
                                      name='embedding_table')
     parallel_mode = _get_parallel_mode()
     is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                          ParallelMode.AUTO_PARALLEL)
     self.gather_revert = P.Gather()
     self.reshape_first = P.Reshape()
     self.reshape = P.Reshape()
     self.unique = P.Unique()
     self.shape = P.Shape()
     if is_auto_parallel:
         self.unique = P.Unique().shard(((1, ), ))
     if self.cache_enable and enable_ps:
         self._set_voacb_cache_enable_for_ps(vocab_cache_size,
                                             embedding_size, vocab_size)
         if is_auto_parallel:
             self.unique.add_prim_attr('cache_enable', True)
     indices_shape_size = 2
     if slice_mode == "field_slice" and is_auto_parallel:
         if not manual_shapes:
             raise ValueError(
                 "in slice field mode, the manual_shapes should not be none"
             )
         if not isinstance(manual_shapes, tuple):
             raise TypeError(
                 "manual_shapes type must be tuple(int) cannot be {}!".
                 format(type(manual_shapes)))
         for dim in manual_shapes:
             validator.check_positive_int(dim, 'manual shape dim',
                                          self.cls_name)
         self.gatherv2.add_prim_attr("manual_split", manual_shapes)
         self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
         self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), (1, get_group_size())))
     elif slice_mode == "table_row_slice" and is_auto_parallel:
         full_batch = _get_full_batch()
         if (target == 'DEVICE'
                 and not full_batch) or (self.cache_enable and enable_ps
                                         and sparse):
             indices_shape_size = 1
             self.gather_revert.shard(((1, 1), (get_group_size(), )))
             self.forward_unique = True
         indices_strategy = (1, ) * indices_shape_size
         self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), indices_strategy))
     elif slice_mode == "table_column_slice" and is_auto_parallel:
         if target == 'DEVICE':
             indices_shape_size = 1
             self.gather_revert.shard(((1, get_group_size()), (1, )))
             self.forward_unique = True
         indices_strategy = (1, ) * indices_shape_size
         self.gatherv2.shard(((1, get_group_size()), indices_strategy))
         self.embeddinglookup.shard(
             ((1, get_group_size()), indices_strategy))
     elif slice_mode == "batch_slice" and is_auto_parallel:
         indices_strategy = [get_group_size()]
         indices_strategy.extend([1] * (indices_shape_size - 1))
         indices_strategy = tuple(indices_strategy)
         self.gatherv2.shard(((1, 1), indices_strategy))
         self.embeddinglookup.shard(((1, 1), indices_strategy))
     else:
         if is_auto_parallel:
             raise ValueError(
                 "slice_mode should support mode in nn.EmbeddingLookup, but get "
                 + str(slice_mode))
     if self.cache_enable and not enable_ps:
         if parallel_mode != ParallelMode.STAND_ALONE:
             raise ValueError(
                 "parallel mode haven't supported cache enable yet.")
         self._set_cache_enable()
     self.embedding_table.unique = self.forward_unique
     self.max_norm = max_norm
     if self.max_norm is not None:
         self.max_norm = validator.check_positive_float(
             self.max_norm, 'max_norm', self.cls_name)
         self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
Beispiel #10
0
 def __init__(self,
              vocab_size,
              embedding_size,
              param_init='normal',
              target='CPU',
              slice_mode='batch_slice',
              manual_shapes=None,
              max_norm=None):
     super(EmbeddingLookup, self).__init__()
     self.target = target
     if target not in ('CPU', 'DEVICE'):
         raise ValueError(
             'Attr \'target\' of \'EmbeddingLookup\' Op passed ' +
             str(target) +
             ', should be one of values in \'CPU\', \'DEVICE\'.')
     self.gatherv2 = P.GatherV2()
     self.embeddinglookup = P.EmbeddingLookup().add_prim_attr(
         'primitive_target', 'CPU')
     self.vocab_size = validator.check_value_type('vocab_size', vocab_size,
                                                  [int], self.cls_name)
     self.embedding_size = validator.check_value_type(
         'embedding_size', embedding_size, [int], self.cls_name)
     self.embedding_table = Parameter(initializer(
         param_init, [self.vocab_size, self.embedding_size]),
                                      name='embedding_table')
     parallel_mode = _get_parallel_mode()
     is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                          ParallelMode.AUTO_PARALLEL)
     if slice_mode == "field_slice" and is_auto_parallel:
         if not manual_shapes:
             raise ValueError(
                 "in slice field mode, the manual_shapes should not be none"
             )
         if not isinstance(manual_shapes, tuple):
             raise TypeError(
                 "manual_shapes type must be tuple(int) cannot be {}!".
                 format(type(manual_shapes)))
         for dim in manual_shapes:
             validator.check_positive_int(dim, 'manual shape dim',
                                          self.cls_name)
         self.gatherv2.add_prim_attr("manual_split", manual_shapes)
         self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
         self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
         self.embeddinglookup.shard(
             ((get_group_size(), 1), (1, get_group_size())))
     elif slice_mode == "table_row_slice" and is_auto_parallel:
         self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
         self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
     elif slice_mode == "table_column_slice" and is_auto_parallel:
         self.gatherv2.shard(((1, get_group_size()), (1, 1)))
         self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
     elif slice_mode == "batch_slice" and is_auto_parallel:
         self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
         self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
     else:
         if is_auto_parallel:
             raise ValueError(
                 "slice_mode should support mode in nn.EmbeddingLookup, but get "
                 + str(slice_mode))
     self.max_norm = max_norm
     if self.max_norm is not None:
         self.max_norm = validator.check_positive_float(
             self.max_norm, 'max_norm', self.cls_name)
         self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
Beispiel #11
0
def polynomial_decay_lr(learning_rate,
                        end_learning_rate,
                        total_step,
                        step_per_epoch,
                        decay_epoch,
                        power,
                        update_decay_epoch=False):
    r"""
    Calculate learning rate base on polynomial decay function.

    For the i-th step, the formula of computing decayed_learning_rate[i] is:

    .. math::
        decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
        (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate

    Where:

    .. math::
        tmp\_epoch = min(current\_epoch, decay\_epoch)

    .. math::
        current\_epoch=floor(\frac{i}{step\_per\_epoch})

    .. math::
        tmp\_decay\_epoch = decay\_epoch

    If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` every epoch. The formula is:

    .. math::
        tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)

    Args:
        learning_rate (float): The initial value of learning rate.
        end_learning_rate (float): The end value of learning rate.
        total_step (int): The total number of steps.
        step_per_epoch (int): The number of steps in per epoch.
        decay_epoch (int): A value used to calculate decayed learning rate.
        power (float): A value used to calculate decayed learning rate. This parameter must be greater than 0.
        update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.

    Returns:
        list[float]. The size of list is `total_step`.

    Examples:
        >>> learning_rate = 0.1
        >>> end_learning_rate = 0.01
        >>> total_step = 6
        >>> step_per_epoch = 2
        >>> decay_epoch = 2
        >>> power = 0.5
        >>> r = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
        >>> print(r)
        [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
    """
    validator.check_positive_float(learning_rate, 'learning_rate')
    validator.check_is_float(learning_rate, 'learning_rate')
    if not isinstance(end_learning_rate, float):
        raise TypeError("end_learning_rate must be float.")
    validator.check_non_negative_float(end_learning_rate, "end_learning_rate",
                                       None)
    validator.check_positive_float(power, 'power')
    validator.check_is_float(power, 'power')
    validator.check_positive_int(total_step, 'total_step')
    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
    validator.check_positive_int(decay_epoch, 'decay_epoch')
    validator.check_value_type('update_decay_epoch', update_decay_epoch,
                               [bool])

    origin_decay_epoch = decay_epoch
    function = lambda x, y: (x, min(x, y))
    if update_decay_epoch:
        function = lambda x, y: (origin_decay_epoch * max(
            math.ceil(y / origin_decay_epoch), 1), y)

    lr = []
    delta = learning_rate - end_learning_rate
    for i in range(total_step):
        current_epoch = math.floor(i / step_per_epoch)
        decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
        lr.append(delta * (1 - tmp_epoch / decay_epoch)**power +
                  end_learning_rate)
    return lr
Beispiel #12
0
 def __init__(self, smooth=1e-5):
     super(DiceLoss, self).__init__()
     self.smooth = validator.check_positive_float(smooth, "smooth")
     self.reshape = P.Reshape()
Beispiel #13
0
    def __init__(self,
                 learning_rate,
                 parameters,
                 weight_decay=0.0,
                 loss_scale=1.0):
        super(Optimizer, self).__init__(auto_prefix=False)
        if parameters is not None and not isinstance(parameters, list):
            parameters = list(parameters)

        if not parameters:
            raise ValueError("Optimizer got an empty parameter list.")

        if not isinstance(parameters[0], (dict, Parameter)):
            raise TypeError(
                "Only a list of Parameter or dict can be supported.")

        if isinstance(loss_scale, int):
            loss_scale = float(loss_scale)
        validator.check_value_type("loss_scale", loss_scale, [float],
                                   self.cls_name)
        validator.check_positive_float(loss_scale, "loss_scale", self.cls_name)
        self.loss_scale = loss_scale

        weight_decay = self._preprocess_weight_decay(weight_decay)
        self.grad_centralization = False

        self._unique = True
        self._target = context.get_context("device_target")
        self.dynamic_lr = False
        self.assignadd = None
        self.global_step = None
        self.is_group = False
        self.is_group_lr = False
        self.is_group_params_ordered = False
        learning_rate = self._preprocess_single_lr(learning_rate)
        if isinstance(parameters[0], dict):
            self.is_group = True
            self.group_params = []
            self.group_lr = []
            self.group_weight_decay = []
            self.group_grad_centralization = []
            self._init_group_params(parameters, learning_rate, weight_decay,
                                    self.grad_centralization)

        # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params
        if self.dynamic_lr:
            self.assignadd = P.AssignAdd()
            self.global_step = Parameter(initializer(0, [1], mindspore.int32),
                                         name='global_step')

        if self.is_group_lr:
            self.learning_rate = CellList(
                self.group_lr) if self.dynamic_lr else ParameterTuple(
                    self.group_lr)
        else:
            self.learning_rate = self._build_single_lr(learning_rate,
                                                       'learning_rate')

        if self.is_group:
            self.parameters = ParameterTuple(self.group_params)
            self.weight_decay = tuple(self.group_weight_decay)
            self.weight_decay_tensor_tuple = tuple(
                Tensor(x, mstype.float32) for x in self.group_weight_decay)
            decay_filter = lambda x: x > 0
            self.decay_flags = tuple(
                decay_filter(x) for x in self.weight_decay)
            self.exec_weight_decay = any(self.decay_flags)
            self.grad_centralization_flags = tuple(
                self.group_grad_centralization)
        else:
            self.parameters = ParameterTuple(parameters)
            self.weight_decay = weight_decay * loss_scale
            self.weight_decay_tensor = Tensor(self.weight_decay,
                                              mstype.float32)
            decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
            self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
            self.exec_weight_decay = self.weight_decay > 0
        # when a parameter has been unique, there is no need do another unique in optimizer.
        for param in self.parameters:
            if param.unique:
                self._unique = False
                break
        ps_filter = lambda x: x.is_param_ps
        self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
        cache_filter = lambda x: x.cache_enable
        self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
        self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
        self.need_scale = loss_scale != 1.0
        self.global_step_increase_tensor = Tensor(1, mstype.int32)
        self.param_length = len(self.parameters)
        self.map_ = C.Map()
        self._use_parallel_optimizer()
Beispiel #14
0
    def __init__(self, smooth=1e-5, threshold=0.5):
        super(Dice, self).__init__()

        self.smooth = validator.check_positive_float(smooth, "smooth")
        self.threshold = validator.check_value_type("threshold", threshold, [float])
        self.clear()