Ejemplo n.º 1
0
    def _amp_build_train_network(self,
                                 network,
                                 optimizer,
                                 loss_fn=None,
                                 level='O0',
                                 **kwargs):
        """
        Build the mixed precision training cell automatically.

        Args:
            network (Cell): Definition of the network.
            loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
                the `network` should have the loss inside. Default: None.
            optimizer (Optimizer): Optimizer to update the Parameter.
            level (str): Supports [O0, O2]. Default: "O0".
                - O0: Do not change.
                - O2: Cast network to float16, keep batchnorm and `loss_fn`
                  (if set) run in float32, using dynamic loss scale.
            cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
                or `mstype.float32`. If set to `mstype.float16`, use `float16`
                mode to train. If set, overwrite the level setting.
            keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
                overwrite the level setting.
            loss_scale_manager (Union[None, LossScaleManager]): If None, not
                scale the loss, or else scale the loss by LossScaleManager.
                If set, overwrite the level setting.
        """
        validator.check_value_type('network', network, nn.Cell, None)
        validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
        validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
        self._check_kwargs(kwargs)
        config = dict(_config_level[level], **kwargs)
        config = edict(config)

        if config.cast_model_type == mstype.float16:
            network.to_float(mstype.float16)

            if config.keep_batchnorm_fp32:
                _do_keep_batchnorm_fp32(network)

        if loss_fn:
            network = _add_loss_network(network, loss_fn,
                                        config.cast_model_type)

        if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,
                                    ParallelMode.AUTO_PARALLEL):
            network = _VirtualDatasetCell(network)

        loss_scale = 1.0
        if config.loss_scale_manager is not None:
            print("----model config have loss scale manager !")
        network = TrainOneStepCell(network, optimizer,
                                   sens=loss_scale).set_train()
        return network
Ejemplo n.º 2
0
def _get_confusion_matrix(y_pred, y, skip_channel=True):
    """
    The confusion matrix is calculated. An array of shape [BC4] is returned. The third dimension represents each channel
    of each sample in the input batch.Where B is the batch size and C is the number of classes to be calculated.

    Args:
        y_pred (ndarray): input data to compute. It must be one-hot format and first dim is batch.
                             The values should be binarized.
        y (ndarray): ground truth to compute the metric. It must be one-hot format and first dim is batch.
                    The values should be binarized.
        skip_channel (bool): whether to skip metric computation on the first channel of the predicted output.
                            Default: True.

    Raises:
        ValueError: when `y_pred` and `y` have different shapes.
    """

    if not skip_channel:
        y = y[:, 1:] if y.shape[1] > 1 else y
        y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred

    y = y.astype(float)
    y_pred = y_pred.astype(float)
    validator.check('y_shape', y.shape, 'y_pred_shape', y_pred.shape)
    batch_size, n_class = y_pred.shape[:2]
    y_pred = y_pred.reshape(batch_size, n_class, -1)
    y = y.reshape(batch_size, n_class, -1)
    tp = ((y_pred + y) == 2).astype(float)
    tn = ((y_pred + y) == 0).astype(float)
    tp = tp.sum(axis=2)
    tn = tn.sum(axis=2)
    p = y.sum(axis=2)
    n = y.shape[-1] - p
    fn = p - tp
    fp = n - tn

    return np.stack([tp, fp, tn, fn], axis=-1)
Ejemplo n.º 3
0
def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
    _check_input_4d(input_shape, param_name, func_name)
    validator.check(param_name + " shape[2]", input_shape[2], "filter_size",
                    filter_size, Rel.GE, func_name)
    validator.check(param_name + " shape[3]", input_shape[3], "filter_size",
                    filter_size, Rel.GE, func_name)
Ejemplo n.º 4
0
    def _amp_build_train_network(self,
                                 network,
                                 optimizer,
                                 loss_fn=None,
                                 level='O0',
                                 **kwargs):
        """
        Build the mixed precision training cell automatically.

        Args:
            network (Cell): Definition of the network.
            loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
                the `network` should have the loss inside. Default: None.
            optimizer (Optimizer): Optimizer to update the Parameter.
            level (str): Supports [O0, O2]. Default: "O0".

                - O0: Do not change.
                - O2: Cast network to float16, keep batchnorm and `loss_fn`
                  (if set) run in float32, using dynamic loss scale.

            cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
                or `mstype.float32`. If set to `mstype.float16`, use `float16`
                mode to train. If set, overwrite the level setting.
            keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
                overwrite the level setting.
            loss_scale_manager (Union[None, LossScaleManager]): If None, not
                scale the loss, or else scale the loss by LossScaleManager.
                If set, overwrite the level setting.
        """
        validator.check_value_type('network', network, nn.Cell, None)
        validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
        validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
        self._check_kwargs(kwargs)
        config = dict(_config_level[level], **kwargs)
        config = edict(config)

        if config.cast_model_type == mstype.float16:
            network.to_float(mstype.float16)

            if config.keep_batchnorm_fp32:
                _do_keep_batchnorm_fp32(network)

        if loss_fn:
            network = _add_loss_network(network, loss_fn,
                                        config.cast_model_type)

        if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,
                                    ParallelMode.AUTO_PARALLEL):
            network = _VirtualDatasetCell(network)

        loss_scale = 1.0
        if config.loss_scale_manager is not None:
            loss_scale_manager = config.loss_scale_manager
            loss_scale = loss_scale_manager.get_loss_scale()
            update_cell = loss_scale_manager.get_update_cell()
            if update_cell is not None:
                # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
                if not context.get_context(
                        "enable_ge") and context.get_context(
                            "device_target") == "CPU":
                    msg = "Only `loss_scale_manager=None` and " \
                          "`loss_scale_manager=FixedLossScaleManager(drop_overflow" \
                          "_update=False)` are supported in current version. " \
                          "If you use `O2` option, please use " \
                          "`loss_scale_manager=None` or `FixedLossScaleManager`"
                    LOGGER.error(TAG, msg)
                    raise ValueError(msg)
                network = _TrainOneStepWithLossScaleCell(
                    network,
                    optimizer,
                    scale_update_cell=update_cell,
                    micro_batches=self._micro_batches,
                    norm_bound=self._norm_bound,
                    clip_mech=self._clip_mech,
                    noise_mech=self._noise_mech).set_train()
                return network

        network = _TrainOneStepCell(network,
                                    optimizer,
                                    self._norm_bound,
                                    loss_scale,
                                    micro_batches=self._micro_batches,
                                    clip_mech=self._clip_mech,
                                    noise_mech=self._noise_mech).set_train()
        return network
Ejemplo n.º 5
0
def _check_shape(logits_shape, label_shape):
    validator.check('logits_shape', logits_shape, 'label_shape', label_shape)