def _build_train_network(self): """Build train network""" network = self._network if self._optimizer: if self._loss_scale_manager_set: network = amp.build_train_network( network, self._optimizer, self._loss_fn, level=self._amp_level, loss_scale_manager=self._loss_scale_manager, keep_batchnorm_fp32=self._keep_bn_fp32) else: network = amp.build_train_network( network, self._optimizer, self._loss_fn, level=self._amp_level, keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() return network
def _build_train_network(self): """Build train network""" network = self._network if self._optimizer: if self._loss_scale_manager_set: network = amp.build_train_network( network, self._optimizer, self._loss_fn, level=self._amp_level, loss_scale_manager=self._loss_scale_manager, keep_batchnorm_fp32=self._keep_bn_fp32) else: network = amp.build_train_network( network, self._optimizer, self._loss_fn, level=self._amp_level, keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None return network