Esempio n. 1
0
    def _build_train_network(self):
        """Build train network"""
        network = self._network
        if self._optimizer:
            if self._loss_scale_manager_set:
                network = amp.build_train_network(
                    network,
                    self._optimizer,
                    self._loss_fn,
                    level=self._amp_level,
                    loss_scale_manager=self._loss_scale_manager,
                    keep_batchnorm_fp32=self._keep_bn_fp32)
            else:
                network = amp.build_train_network(
                    network,
                    self._optimizer,
                    self._loss_fn,
                    level=self._amp_level,
                    keep_batchnorm_fp32=self._keep_bn_fp32)
        elif self._loss_fn:
            network = nn.WithLossCell(network, self._loss_fn)
        # If need to check if loss_fn is not None, but optimizer is None

        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                   ParallelMode.AUTO_PARALLEL):
            network.set_auto_parallel()
        return network
Esempio n. 2
0
 def _build_train_network(self):
     """Build train network"""
     network = self._network
     if self._optimizer:
         if self._loss_scale_manager_set:
             network = amp.build_train_network(
                 network,
                 self._optimizer,
                 self._loss_fn,
                 level=self._amp_level,
                 loss_scale_manager=self._loss_scale_manager,
                 keep_batchnorm_fp32=self._keep_bn_fp32)
         else:
             network = amp.build_train_network(
                 network,
                 self._optimizer,
                 self._loss_fn,
                 level=self._amp_level,
                 keep_batchnorm_fp32=self._keep_bn_fp32)
     elif self._loss_fn:
         network = nn.WithLossCell(network, self._loss_fn)
     # If need to check if loss_fn is not None, but optimizer is None
     return network