def test_train_multi_eval(self, *args): """Callback for train once and eval twice.""" args[0].return_value = 10 summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_multi_eval') make_directory(summary_dir) args[1].return_value = os.path.join( summary_dir, 'train_out.events.summary.1590107366.ubuntu_lineage') train_callback = TrainLineage(summary_dir, True) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( summary_dir, 'eval_out.events.summary.1590107367.ubuntu_lineage') eval_callback = EvalLineage(summary_dir, True) eval_run_context = self.run_context eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['metrics'] = {'accuracy': 0.79} eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('metric', {}).get('accuracy') == 0.79 args[1].return_value = os.path.join( summary_dir, 'eval_out.events.summary.1590107368.ubuntu_lineage') eval_callback = EvalLineage(summary_dir, True) eval_run_context = self.run_context eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['metrics'] = {'accuracy': 0.80} eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('metric', {}).get('accuracy') == 0.80 if os.path.exists(summary_dir): shutil.rmtree(summary_dir)
def test_multiple_trains(self, *args): """ Callback TrainLineage and EvalLineage for multiple times. Write TrainLineage and EvalLineage in different files under same directory. EvalLineage log file end with '_lineage'. """ args[0].return_value = 10 for i in range(2): summary_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time()) + i) eval_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time() + 10) + i) args[1].return_value = os.path.join( SUMMARY_DIR_2, f'train_out.events.summary.{str(int(time.time()) + 2*i)}.ubuntu_lineage' ) train_callback = TrainLineage(summary_record, True) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( SUMMARY_DIR_2, f'eval_out.events.summary.{str(int(time.time())+ 2*i + 1)}.ubuntu_lineage' ) eval_callback = EvalLineage(eval_record, True) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78 + i + 1} eval_run_context['valid_dataset'] = self.run_context[ 'train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context)) file_num = os.listdir(SUMMARY_DIR_2) assert len(file_num) == 8
def test_train_eval(self, *args): """Callback for train once and eval once.""" args[0].return_value = 10 summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_eval') make_directory(summary_dir) args[1].return_value = os.path.join( summary_dir, f'train_out.events.summary.{str(int(time.time()))}.ubuntu_lineage') train_callback = TrainLineage(summary_dir) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( summary_dir, f'eval_out.events.summary.{str(int(time.time())+1)}.ubuntu_lineage' ) eval_callback = EvalLineage(summary_dir) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('hyper_parameters', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('algorithm', {}).get('network') == 'ResNet' if os.path.exists(summary_dir): shutil.rmtree(summary_dir)
def test_train_with_customized_network(self, *args): """Test train with customized network.""" args[0].return_value = 64 train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) run_context_customized = self.run_context del run_context_customized['optimizer'] del run_context_customized['net_outputs'] del run_context_customized['loss_fn'] net = WithLossCell(self.net, self.loss_fn) net_cap = net net_cap._cells = {'_backbone': self.net, '_loss_fn': self.loss_fn} net = TrainOneStep(net, self.optimizer) net._cells = {'optimizer': self.optimizer, 'network': net_cap, 'backbone': self.net} run_context_customized['train_network'] = net train_callback.begin(RunContext(run_context_customized)) train_callback.end(RunContext(run_context_customized)) LINEAGE_DATA_MANAGER.start_load_data().join() res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet' assert res.get('object')[0].get('model_lineage', {}).get('optimizer') == 'Momentum'
def test_train_with_customized_network(self, *args): """Test train with customized network.""" args[0].return_value = 64 train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) run_context_customized = self.run_context del run_context_customized['optimizer'] del run_context_customized['net_outputs'] del run_context_customized['loss_fn'] net = WithLossCell(self.net, self.loss_fn) net_cap = net net_cap._cells = {'_backbone': self.net, '_loss_fn': self.loss_fn} net = TrainOneStep(net, self.optimizer) net._cells = { 'optimizer': self.optimizer, 'network': net_cap, 'backbone': self.net } run_context_customized['train_network'] = net train_callback.begin(RunContext(run_context_customized)) train_callback.end(RunContext(run_context_customized)) res = get_summary_lineage(summary_dir=SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('algorithm', {}).get('network') == 'ResNet' assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
def test_training_end(self, *args): """Test the end function in TrainLineage.""" args[0].return_value = 64 train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) train_callback.initial_learning_rate = 0.12 train_callback.end(RunContext(self.run_context)) res = get_summary_lineage(SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('epoch') == 10 run_context = self.run_context run_context['epoch_num'] = 14 train_callback.end(RunContext(run_context)) res = get_summary_lineage(SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('epoch') == 14
def test_raise_exception_init(self): """Test exception when error happened during the initialization process.""" if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) summary_record = SummaryRecord(SUMMARY_DIR_3) train_callback = TrainLineage('fake_summary_record', False) eval_callback = EvalLineage('fake_summary_record', False) train_callback.begin(RunContext(self.run_context)) eval_callback.end(RunContext(self.run_context)) file_num = os.listdir(SUMMARY_DIR_3) full_file_name = summary_record.full_file_name assert len(file_num) == 1 assert os.path.isfile(full_file_name + "_lineage") is False
def test_train_begin(self): """Test the begin function in TrainLineage.""" train_callback = TrainLineage(self.summary_record, True) train_callback.begin(RunContext(self.run_context)) assert train_callback.initial_learning_rate == 0.12 lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True
def test_checkpoint_save_ckpt_seconds(): """Test checkpoint save ckpt seconds.""" train_config = CheckpointConfig( save_checkpoint_steps=16, save_checkpoint_seconds=100, keep_checkpoint_max=0, keep_checkpoint_per_n_minutes=1) ckpt_cb = ModelCheckpoint(config=train_config) cb_params = _InternalCallbackParam() net = Net() loss = nn.SoftmaxCrossEntropyWithLogits() optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) network_ = WithLossCell(net, loss) _train_network = TrainOneStepCell(network_, optim) cb_params.train_network = _train_network cb_params.epoch_num = 10 cb_params.cur_epoch_num = 4 cb_params.cur_step_num = 128 cb_params.batch_num = 32 run_context = RunContext(cb_params) ckpt_cb.begin(run_context) ckpt_cb.step_end(run_context) ckpt_cb2 = ModelCheckpoint(config=train_config) cb_params.cur_epoch_num = 1 cb_params.cur_step_num = 16 ckpt_cb2.begin(run_context) ckpt_cb2.step_end(run_context)
def test_checkpoint_save_ckpt_with_encryption(): """Test checkpoint save ckpt with encryption.""" train_config = CheckpointConfig( save_checkpoint_steps=16, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, enc_key=os.urandom(16), enc_mode="AES-GCM") ckpt_cb = ModelCheckpoint(config=train_config) cb_params = _InternalCallbackParam() net = Net() loss = nn.SoftmaxCrossEntropyWithLogits() optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) network_ = WithLossCell(net, loss) _train_network = TrainOneStepCell(network_, optim) cb_params.train_network = _train_network cb_params.epoch_num = 10 cb_params.cur_epoch_num = 5 cb_params.cur_step_num = 160 cb_params.batch_num = 32 run_context = RunContext(cb_params) ckpt_cb.begin(run_context) ckpt_cb.step_end(run_context) ckpt_cb2 = ModelCheckpoint(config=train_config) cb_params.cur_epoch_num = 1 cb_params.cur_step_num = 15 if platform.system().lower() == "windows": with pytest.raises(NotImplementedError): ckpt_cb2.begin(run_context) ckpt_cb2.step_end(run_context) else: ckpt_cb2.begin(run_context) ckpt_cb2.step_end(run_context)
def test_step_end_save_graph(): """Test save checkpoint.""" train_config = CheckpointConfig( save_checkpoint_steps=16, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0) cb_params = _InternalCallbackParam() net = LossNet() input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32)) input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32)) net(input_data, input_label) cb_params.train_network = net cb_params.epoch_num = 10 cb_params.cur_epoch_num = 5 cb_params.cur_step_num = 0 cb_params.batch_num = 32 ckpoint_cb = ModelCheckpoint(prefix="test", directory='./test_files', config=train_config) run_context = RunContext(cb_params) ckpoint_cb.begin(run_context) ckpoint_cb.step_end(run_context) assert os.path.exists('./test_files/test-graph.meta') if os.path.exists('./test_files/test-graph.meta'): os.chmod('./test_files/test-graph.meta', stat.S_IWRITE) os.remove('./test_files/test-graph.meta') ckpoint_cb.step_end(run_context) assert not os.path.exists('./test_files/test-graph.meta')
def test_save_checkpoint(): """Test save checkpoint.""" train_config = CheckpointConfig( save_checkpoint_steps=16, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0) cb_params = _InternalCallbackParam() net = Net() loss = nn.SoftmaxCrossEntropyWithLogits() optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) network_ = WithLossCell(net, loss) _train_network = TrainOneStepCell(network_, optim) cb_params.train_network = _train_network cb_params.epoch_num = 10 cb_params.cur_epoch_num = 5 cb_params.cur_step_num = 0 cb_params.batch_num = 32 ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) run_context = RunContext(cb_params) ckpoint_cb.begin(run_context) ckpoint_cb.step_end(run_context) if os.path.exists('./test_files/test_ckpt-model.pkl'): os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) os.remove('./test_files/test_ckpt-model.pkl')
def test_raise_exception_record_trainlineage(self, mock_analyze): """Test exception when error happened after recording training infos.""" mock_analyze.return_value = 64 if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) train_callback = TrainLineage(SUMMARY_DIR_3, True) train_callback.begin(RunContext(self.run_context)) full_file_name = train_callback.lineage_summary.lineage_log_path file_size1 = os.path.getsize(full_file_name) train_callback.end(RunContext(self.run_context)) file_size2 = os.path.getsize(full_file_name) assert file_size2 > file_size1 eval_callback = EvalLineage(SUMMARY_DIR_3, False) eval_callback.end(RunContext(self.run_context)) file_size3 = os.path.getsize(full_file_name) assert file_size3 == file_size2
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network directly. Args: valid_dataset (Dataset): Dataset to evaluate the model. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: Dict, returns the loss value & metrics values for the model in test mode. """ run_context = RunContext(cb_params) list_callback.begin(run_context) dataset_helper, _ = self._exec_preprocess(self._eval_network, is_train=False, phase='eval', dataset=valid_dataset, dataset_sink_mode=False) for next_element in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs list_callback.step_end(run_context) self._update_metrics(outputs) valid_dataset.reset() metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) return metrics
def test_raise_exception_record_trainlineage(self, *args): """Test exception when error happened after recording training infos.""" if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) args[1].side_effect = MindInsightException(error=LineageErrors.PARAM_RUN_CONTEXT_ERROR, message="RunContext error.") train_callback = TrainLineage(SUMMARY_DIR_3, True) train_callback.begin(RunContext(self.run_context)) full_file_name = train_callback.lineage_summary.lineage_log_path file_size1 = os.path.getsize(full_file_name) train_callback.end(RunContext(self.run_context)) file_size2 = os.path.getsize(full_file_name) assert file_size2 > file_size1 eval_callback = EvalLineage(SUMMARY_DIR_3, False) eval_callback.end(RunContext(self.run_context)) file_size3 = os.path.getsize(full_file_name) assert file_size3 == file_size2
def test_eval_end(self): """Test the end function in EvalLineage.""" eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'}) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context))
def test_train_begin_with_user_defined_info(self): """Test TrainLineage with nested user defined info.""" user_defined_info = {"info": {"version": "v1"}} train_callback = TrainLineage(self.summary_record, False, user_defined_info) train_callback.begin(RunContext(self.run_context)) assert train_callback.initial_learning_rate == 0.12 lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True
def test_train_lineage_with_log_dir(self): """Test TrainLineage with log_dir.""" summary_dir = os.path.join(BASE_SUMMARY_DIR, 'log_dir') train_callback = TrainLineage(summary_record=summary_dir) train_callback.begin(RunContext(self.run_context)) assert summary_dir == train_callback.lineage_log_dir lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True if os.path.exists(summary_dir): shutil.rmtree(summary_dir)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None): """ Evaluation. The data would be passed to network through dataset channel. Args: valid_dataset (Dataset): Dataset to evaluate the model. list_callback (ListCallback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. Returns: Dict, returns the loss value & metrics values for the model in test mode. """ _device_number_check(self._parallel_mode, self._device_number) run_context = RunContext(cb_params) # remove later to deal with loop sink need_wrap = False if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ and not context.get_context("enable_ge"): need_wrap = True valid_dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(valid_dataset) # remove later to deal with loop sink if need_wrap: self._eval_network = nn.DataWrapper( self._eval_network, *(dataset_helper.types_shapes()), valid_dataset.__ME_INITED__) self._eval_network.set_train(mode=False) self._eval_network.phase = 'eval' list_callback.begin(run_context) for inputs in dataset_helper: cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) self._update_metrics(outputs) metrics = self._get_metrics() cb_params.metrics = metrics list_callback.end(run_context) return metrics
def test_loss_monitor_normal_mode(): """Test loss monitor normal(non-sink) mode.""" cb_params = _InternalCallbackParam() run_context = RunContext(cb_params) loss_cb = LossMonitor(1) cb_params.cur_epoch_num = 4 cb_params.cur_step_num = 1 cb_params.batch_num = 1 cb_params.net_outputs = Tensor(2.0) loss_cb.begin(run_context) loss_cb.epoch_begin(run_context) loss_cb.step_begin(run_context) loss_cb.step_end(run_context) loss_cb.epoch_end(run_context) loss_cb.end(run_context)
def test_training_end(self, *args): """Test the end function in TrainLineage.""" args[0].return_value = 64 train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info) train_callback.initial_learning_rate = 0.12 train_callback.end(RunContext(self.run_context)) LINEAGE_DATA_MANAGER.start_load_data().join() res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 10 run_context = self.run_context run_context['epoch_num'] = 14 train_callback.end(RunContext(run_context)) LINEAGE_DATA_MANAGER.start_load_data().join() res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition) assert res.get('object')[0].get('model_lineage', {}).get('epoch') == 14 eval_callback = EvalLineage(self.summary_record, True, self.user_defined_info) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context))
def test_Loss_Monitor_feed_feed_model(): """Test Loss Monitor feed feed mode.""" cb_params = _InternalCallbackParam() run_context = RunContext(cb_params) loss_cb = LossMonitor(1) cb_params.cur_epoch_num = 4 cb_params.cur_step_num = 1 cb_params.batch_num = 1 cb_params.net_outputs = Tensor(2.0) loss_cb.begin(run_context) loss_cb.epoch_begin(run_context) loss_cb.step_begin(run_context) loss_cb.step_end(run_context) loss_cb.epoch_end(run_context) loss_cb.end(run_context)
def test_RunContext(): """Test RunContext.""" context_err = 666 with pytest.raises(TypeError): RunContext(context_err) cb_params = _InternalCallbackParam() cb_params.member1 = 1 cb_params.member2 = "abc" run_context = RunContext(cb_params) run_context.original_args() assert cb_params.member1 == 1 assert cb_params.member2 == "abc" run_context.request_stop() should_stop = run_context.get_stop_requested() assert should_stop
def test_loss_monitor_graph_model(): """Test lossmonitor Graph model.""" cb_params = _InternalCallbackParam() cb_params.cur_epoch_num = 4 cb_params.cur_step_num = 2 cb_params.batch_num = 2 cb_params.net_outputs = Tensor(2.0) run_context = RunContext(cb_params) loss_cb = LossMonitor(1) callbacks = [loss_cb] callbacklist = _build_callbacks(callbacks) callbacklist.begin(run_context) callbacklist.epoch_begin(run_context) callbacklist.step_begin(run_context) callbacklist.step_end(run_context) callbacklist.epoch_end(run_context) callbacklist.end(run_context)
def test_train_begin_with_user_defined_key_in_lineage(self): """Test TrainLineage with nested user defined info.""" expected_res = {"info": "info1", "version": "v1"} user_defined_info = { "info": "info1", "version": "v1", "network": "LeNet" } train_callback = TrainLineage(self.summary_record, False, user_defined_info) train_callback.begin(RunContext(self.run_context)) assert train_callback.initial_learning_rate == 0.12 lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True res = filter_summary_lineage(os.path.dirname(lineage_log_path)) assert expected_res == res['object'][0]['model_lineage'][ 'user_defined']
def test_raise_exception_create_file(self): """Test exception when error happened after creating file.""" if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) summary_record = SummaryRecord(SUMMARY_DIR_3) eval_callback = EvalLineage(summary_record, False) full_file_name = summary_record.full_file_name + "_lineage" eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['step_num'] = 32 eval_run_context['valid_dataset'] = self.run_context['train_dataset'] with open(full_file_name, 'ab'): with mock.patch('builtins.open') as mock_handler: mock_handler.return_value.__enter__.return_value.write.side_effect = IOError eval_callback.end(RunContext(eval_run_context)) assert os.path.isfile(full_file_name) is True assert os.path.getsize(full_file_name) == 0
def test_loss_monitor_sink_mode(): """Test loss monitor sink mode.""" cb_params = _InternalCallbackParam() cb_params.cur_epoch_num = 4 cb_params.epoch_num = 4 cb_params.cur_step_num = 2 cb_params.batch_num = 2 cb_params.net_outputs = Tensor(2.0) run_context = RunContext(cb_params) loss_cb = LossMonitor(1) callbacks = [loss_cb] with _CallbackManager(callbacks) as callbacklist: callbacklist.begin(run_context) callbacklist.epoch_begin(run_context) callbacklist.step_begin(run_context) callbacklist.step_end(run_context) callbacklist.epoch_end(run_context) callbacklist.end(run_context)
def test_eval_only(self): """Test record evaluation event only.""" summary_dir = os.path.join(BASE_SUMMARY_DIR, 'eval_only_dir') summary_record = SummaryRecord(summary_dir) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.58} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['step_num'] = 32 eval_only_callback = EvalLineage(summary_record) eval_only_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir, ['metric', 'dataset_graph']) expect_res = { 'summary_dir': summary_dir, 'dataset_graph': {}, 'metric': { 'accuracy': 0.58 } } assert res == expect_res shutil.rmtree(summary_dir)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): """ Training process. The data would be passed to network directly. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. """ dataset_helper, _ = self._exec_preprocess(self._train_network, is_train=True, phase='train', dataset=train_dataset, dataset_sink_mode=False) cb_params.cur_step_num = 0 run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) for next_element in dataset_helper: len_element = len(next_element) if self._loss_fn and len_element != 2: raise ValueError( "when loss_fn is not None, train_dataset should" "return two elements, but got {}".format(len_element)) cb_params.cur_step_num += 1 list_callback.step_begin(run_context) overflow = False if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update( ): scaling_sens = self._get_scaling_sens() next_element = tuple(next_element) + (Tensor( scaling_sens, mstype.float32), ) cb_params.train_dataset_element = next_element outputs = self._train_network(*next_element) cb_params.net_outputs = outputs if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update( ): _, overflow, _ = outputs overflow = np.all(overflow.asnumpy()) self._loss_scale_manager.update_loss_scale(overflow) list_callback.step_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break train_dataset.reset() list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break list_callback.end(run_context)
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1): """ Training process. The data would be passed to network through dataset channel. Args: epoch (int): Total number of iterations on the data. train_dataset (Dataset): A training dataset iterator. If there is no loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be returned and passed to the network. Otherwise, a tuple (data, label) should be returned, and the data and label are passed to the network and loss function respectively. list_callback (Callback): Executor of callback list. Default: None. cb_params (_InternalCallbackParam): Callback parameters. Default: None. sink_size (int): Control the amount of data each sink. Default: -1. """ if sink_size == -1: epoch_num = epoch else: epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size()) iter_first_order = self._frequency - 1 iter_second_order = 1 train_dataset.__loop_size__ = iter_second_order dataset_helper, train_network = self._exec_preprocess( self._train_network, is_train=True, phase='train', dataset=train_dataset, dataset_sink_mode=True, sink_size=sink_size, epoch_num=epoch_num, iter_first_order=iter_first_order) self._train_network = train_network cb_params.train_network = self._train_network cb_params.cur_step_num = 0 run_context = RunContext(cb_params) list_callback.begin(run_context) # used to stop training for early stop, such as stopAtTIme or stopATStep should_stop = False has_do_dataset_init = False switch_branch_one = True train_network_init_flag = True for i in range(epoch): cb_params.cur_epoch_num = i + 1 list_callback.epoch_begin(run_context) # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: if _need_to_full(): inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) if switch_branch_one: cb_params.cur_step_num += dataset_helper.sink_size() if train_network_init_flag: self._train_network.add_flags_recursive(thor=True) self._train_network.phase = 'train0' else: cb_params.cur_step_num += iter_first_order if train_network_init_flag: self._train_network.add_flags_recursive(thor=False) train_network_init_flag = False self._train_network.phase = 'train1' if not has_do_dataset_init: _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') has_do_dataset_init = True switch_branch_one = not switch_branch_one outputs = self._train_network(*inputs) cb_params.net_outputs = outputs list_callback.step_end(run_context) list_callback.epoch_end(run_context) should_stop = should_stop or run_context.get_stop_requested() if should_stop: break dataset_helper.stop_send() list_callback.end(run_context)