def test_CallbackManager_exit_called(): with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: cb1, cb2 = Callback(), Callback() with _CallbackManager([cb1, cb2]): pass for call_args in mock_exit.call_args_list: assert call_args == mock.call(mock.ANY, None, None, None) assert mock_exit.call_count == 2
def test_CallbackManager_begin_called(): context = dict() with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: cb1, cb2 = Callback(), Callback() with _CallbackManager([cb1, cb2]) as cm: cm.begin(context) for call_args in mock_begin.call_args_list: assert call_args == mock.call(context) assert mock_begin.call_count == 2
def test_CallbackManager_exit_called_when_raises(): with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: cb1, cb2 = Callback(), Callback() with pytest.raises(ValueError): with _CallbackManager([cb1, cb2]): raise ValueError() for call_args in mock_exit.call_args_list: assert call_args == mock.call(*[mock.ANY] * 4) assert mock_exit.call_count == 2
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): """ Training. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be returned and passed to the network. Otherwise, a tuple (data, label) will be returned, and the data and label are passed to the network and loss function respectively. callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Configure pynative mode, the training process will be performed with dataset not sink. """ epoch = check_int_positive(epoch) self._train_network.set_train() if self._parameter_broadcast: self._train_network.set_broadcast_flag() # build callback list cb_params = _InternalCallbackParam() cb_params.train_network = self._train_network cb_params.epoch_num = epoch cb_params.batch_num = train_dataset.get_dataset_size() cb_params.mode = "train" cb_params.loss_fn = self._loss_fn cb_params.optimizer = self._optimizer cb_params.parallel_mode = self._parallel_mode cb_params.device_number = self._device_number cb_params.train_dataset = train_dataset cb_params.list_callback = callbacks with _CallbackManager(callbacks) as list_callback: if not dataset_sink_mode: self._train_process(epoch, train_dataset, list_callback, cb_params) elif context.get_context("mode") == context.PYNATIVE_MODE: logger.warning( "The pynative mode cannot support dataset sink mode currently." "So the training process will be performed with dataset not sink." ) self._train_process(epoch, train_dataset, list_callback, cb_params) else: self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True): """ Evaluation API where the iteration is controlled by python front-end. Configure to pynative mode, the evaluation will be performed with dataset non-sink mode. Note: CPU is not supported when dataset_sink_mode is true. If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features of data will be transferred one by one. The limitation of data transmission per time is 256M. Args: valid_dataset (Dataset): Dataset to evaluate the model. callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None. dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. Returns: Dict, returns the loss value & metrics values for the model in test mode. Examples: >>> dataset = get_dataset() >>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) >>> model.eval(dataset) """ check_bool(dataset_sink_mode) _device_number_check(self._parallel_mode, self._device_number) if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network cb_params.valid_dataset = valid_dataset cb_params.batch_num = valid_dataset.get_dataset_size() cb_params.mode = "eval" cb_params.cur_step_num = 0 self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' self._clear_metrics() with _CallbackManager(callbacks) as list_callback: if dataset_sink_mode: return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) return self._eval_process(valid_dataset, list_callback, cb_params)
def test_loss_monitor_sink_mode(): """Test loss monitor sink mode.""" cb_params = _InternalCallbackParam() cb_params.cur_epoch_num = 4 cb_params.cur_step_num = 2 cb_params.batch_num = 2 cb_params.net_outputs = Tensor(2.0) run_context = RunContext(cb_params) loss_cb = LossMonitor(1) callbacks = [loss_cb] with _CallbackManager(callbacks) as callbacklist: callbacklist.begin(run_context) callbacklist.epoch_begin(run_context) callbacklist.step_begin(run_context) callbacklist.step_end(run_context) callbacklist.epoch_end(run_context) callbacklist.end(run_context)
def test_CallbackManager(): """TestCallbackManager.""" ck_obj = ModelCheckpoint() loss_cb_1 = LossMonitor(1) callbacks = [None] with pytest.raises(TypeError): _CallbackManager(callbacks) callbacks = ['Error'] with pytest.raises(TypeError): _CallbackManager(callbacks) callbacks = [ck_obj, loss_cb_1, 'Error', None] with pytest.raises(TypeError): _CallbackManager(callbacks)