Exemplo n.º 1
0
 def _build_predict_network(self):
     """Build the network for prediction."""
     self._predict_network = self._network
     if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                ParallelMode.AUTO_PARALLEL):
         self._predict_network = _VirtualDatasetCell(self._network)
         self._predict_network.set_auto_parallel()
Exemplo n.º 2
0
def auto_parallel_compile_net(mode,
                              dev_num,
                              net,
                              strategy1=None,
                              strategy2=None):
    context.set_context(mode=context.GRAPH_MODE)
    context.set_auto_parallel_context(parallel_mode=mode,
                                      device_num=dev_num,
                                      enable_parallel_optimizer=True)
    inputs = Tensor(np.ones([32, 48]).astype(np.float32))
    label = Tensor(np.zeros([32, 16]).astype(np.float32))
    net = net(strategy1, strategy2)
    net = _VirtualDatasetCell(net)
    optimizer = Momentum(net.trainable_params(),
                         learning_rate=0.1,
                         momentum=0.9)
    train_network = TrainOneStepCell(net, optimizer).set_comm_fusion(4)
    train_network.set_auto_parallel()
    train_network.set_train()
    _executor.compile(train_network,
                      inputs,
                      label,
                      phase="train",
                      auto_parallel_mode=True)
    context.reset_auto_parallel_context()
    return train_network
Exemplo n.º 3
0
    def _build_eval_network(self, metrics, eval_network, eval_indexes):
        """Build the network for evaluation."""
        self._metric_fns = get_metrics(metrics)
        if not self._metric_fns:
            return

        if eval_network is not None:
            if eval_indexes is not None and not (isinstance(
                    eval_indexes, list) and len(eval_indexes) == 3):
                raise ValueError(
                    "Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
                                 must be three. But got {}".format(
                        eval_indexes))

            self._eval_network = eval_network
            self._eval_indexes = eval_indexes
        else:
            if self._loss_fn is None:
                raise ValueError("loss_fn can not be None.")
            self._eval_network = nn.WithEvalCell(self._network, self._loss_fn,
                                                 self._amp_level == "O2")
            self._eval_indexes = [0, 1, 2]

        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                   ParallelMode.AUTO_PARALLEL):
            if self._optimizer:
                self._eval_network = _VirtualDatasetCell(self._eval_network)
            self._eval_network.set_auto_parallel()
Exemplo n.º 4
0
def test_reshape_net4_2():
    try:
        reshape_net2(_VirtualDatasetCell(ReshapeNet4(((1, 8), (8, 2)))))
    except ValueError:
        pass
    except TypeError:
        pass
    except RuntimeError:
        pass
Exemplo n.º 5
0
def test_batchnorm_reshape_train():
    batch_size = 16
    device_num = 16
    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
    input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)

    net = GradWrap(NetWithLoss(_VirtualDatasetCell(BatchNormReshapeNet())))

    compile_net(net, input_)
Exemplo n.º 6
0
    def _amp_build_train_network(self,
                                 network,
                                 optimizer,
                                 loss_fn=None,
                                 level='O0',
                                 **kwargs):
        """
        Build the mixed precision training cell automatically.

        Args:
            network (Cell): Definition of the network.
            loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
                the `network` should have the loss inside. Default: None.
            optimizer (Optimizer): Optimizer to update the Parameter.
            level (str): Supports [O0, O2]. Default: "O0".
                - O0: Do not change.
                - O2: Cast network to float16, keep batchnorm and `loss_fn`
                  (if set) run in float32, using dynamic loss scale.
            cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
                or `mstype.float32`. If set to `mstype.float16`, use `float16`
                mode to train. If set, overwrite the level setting.
            keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
                overwrite the level setting.
            loss_scale_manager (Union[None, LossScaleManager]): If None, not
                scale the loss, or else scale the loss by LossScaleManager.
                If set, overwrite the level setting.
        """
        validator.check_value_type('network', network, nn.Cell, None)
        validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
        validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
        self._check_kwargs(kwargs)
        config = dict(_config_level[level], **kwargs)
        config = edict(config)

        if config.cast_model_type == mstype.float16:
            network.to_float(mstype.float16)

            if config.keep_batchnorm_fp32:
                _do_keep_batchnorm_fp32(network)

        if loss_fn:
            network = _add_loss_network(network, loss_fn,
                                        config.cast_model_type)

        if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,
                                    ParallelMode.AUTO_PARALLEL):
            network = _VirtualDatasetCell(network)

        loss_scale = 1.0
        if config.loss_scale_manager is not None:
            print("----model config have loss scale manager !")
        network = TrainOneStepCell(network, optimizer,
                                   sens=loss_scale).set_train()
        return network
Exemplo n.º 7
0
def compile_graph(strategy1,
                  strategy2,
                  strategy3,
                  strategy4,
                  auto=False,
                  onthot_axis=-1):
    net = GradWrap(
        _VirtualDatasetCell(
            NetWithLoss(Net(strategy1, strategy2),
                        strategy3,
                        strategy4,
                        axis=onthot_axis)))
    net.set_auto_parallel()
    if auto:
        context.set_auto_parallel_context(parallel_mode="auto_parallel")
    else:
        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")

    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64]), dtype=ms.int32)
    net.set_train()
    _executor.compile(net, x, y, b)
Exemplo n.º 8
0
    def _amp_build_train_network(self,
                                 network,
                                 optimizer,
                                 loss_fn=None,
                                 level='O0',
                                 **kwargs):
        """
        Build the mixed precision training cell automatically.

        Args:
            network (Cell): Definition of the network.
            loss_fn (Union[None, Cell]): Definition of the loss_fn. If None,
                the `network` should have the loss inside. Default: None.
            optimizer (Optimizer): Optimizer to update the Parameter.
            level (str): Supports [O0, O2]. Default: "O0".

                - O0: Do not change.
                - O2: Cast network to float16, keep batchnorm and `loss_fn`
                  (if set) run in float32, using dynamic loss scale.

            cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16`
                or `mstype.float32`. If set to `mstype.float16`, use `float16`
                mode to train. If set, overwrite the level setting.
            keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set,
                overwrite the level setting.
            loss_scale_manager (Union[None, LossScaleManager]): If None, not
                scale the loss, or else scale the loss by LossScaleManager.
                If set, overwrite the level setting.
        """
        validator.check_value_type('network', network, nn.Cell, None)
        validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
        validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
        self._check_kwargs(kwargs)
        config = dict(_config_level[level], **kwargs)
        config = edict(config)

        if config.cast_model_type == mstype.float16:
            network.to_float(mstype.float16)

            if config.keep_batchnorm_fp32:
                _do_keep_batchnorm_fp32(network)

        if loss_fn:
            network = _add_loss_network(network, loss_fn,
                                        config.cast_model_type)

        if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL,
                                    ParallelMode.AUTO_PARALLEL):
            network = _VirtualDatasetCell(network)

        loss_scale = 1.0
        if config.loss_scale_manager is not None:
            loss_scale_manager = config.loss_scale_manager
            loss_scale = loss_scale_manager.get_loss_scale()
            update_cell = loss_scale_manager.get_update_cell()
            if update_cell is not None:
                # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
                if not context.get_context(
                        "enable_ge") and context.get_context(
                            "device_target") == "CPU":
                    msg = "Only `loss_scale_manager=None` and " \
                          "`loss_scale_manager=FixedLossScaleManager(drop_overflow" \
                          "_update=False)` are supported in current version. " \
                          "If you use `O2` option, please use " \
                          "`loss_scale_manager=None` or `FixedLossScaleManager`"
                    LOGGER.error(TAG, msg)
                    raise ValueError(msg)
                network = _TrainOneStepWithLossScaleCell(
                    network,
                    optimizer,
                    scale_update_cell=update_cell,
                    micro_batches=self._micro_batches,
                    norm_bound=self._norm_bound,
                    clip_mech=self._clip_mech,
                    noise_mech=self._noise_mech).set_train()
                return network

        network = _TrainOneStepCell(network,
                                    optimizer,
                                    self._norm_bound,
                                    loss_scale,
                                    micro_batches=self._micro_batches,
                                    clip_mech=self._clip_mech,
                                    noise_mech=self._noise_mech).set_train()
        return network
Exemplo n.º 9
0
def test_reshape_common2_5():
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL,
                    _VirtualDatasetCell(ReshapeNet3(((1, 8), (8, 2)))))
Exemplo n.º 10
0
def test_reshape_net6_2():
    reshape_net2(_VirtualDatasetCell(ReshapeNet6(((1, 8), (8, 2)))))
Exemplo n.º 11
0
def test_reshape_net5_1():
    reshape_net2(_VirtualDatasetCell(ReshapeNet5(((1, 8), (8, 1)))))