def test_train_with_customized_network(self, *args): """Test train with customized network.""" args[0].return_value = 64 train_callback = TrainLineage(self.summary_record, True) run_context_customized = self.run_context del run_context_customized['optimizer'] del run_context_customized['net_outputs'] del run_context_customized['loss_fn'] net = WithLossCell(self.net, self.loss_fn) net_cap = net net_cap._cells = {'_backbone': self.net, '_loss_fn': self.loss_fn} net = TrainOneStep(net, self.optimizer) net._cells = { 'optimizer': self.optimizer, 'network': net_cap, 'backbone': self.net } run_context_customized['train_network'] = net train_callback.begin(RunContext(run_context_customized)) train_callback.end(RunContext(run_context_customized)) res = get_summary_lineage(SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('algorithm', {}).get('network') == 'ResNet' assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
def test_train_multi_eval(self, *args): """Callback for train once and eval twice.""" args[0].return_value = 10 summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_multi_eval') make_directory(summary_dir) args[1].return_value = os.path.join( summary_dir, 'train_out.events.summary.1590107366.ubuntu_lineage') train_callback = TrainLineage(summary_dir, True) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( summary_dir, 'eval_out.events.summary.1590107367.ubuntu_lineage') eval_callback = EvalLineage(summary_dir, True) eval_run_context = self.run_context eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['metrics'] = {'accuracy': 0.79} eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('metric', {}).get('accuracy') == 0.79 args[1].return_value = os.path.join( summary_dir, 'eval_out.events.summary.1590107368.ubuntu_lineage') eval_callback = EvalLineage(summary_dir, True) eval_run_context = self.run_context eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['metrics'] = {'accuracy': 0.80} eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('metric', {}).get('accuracy') == 0.80 if os.path.exists(summary_dir): shutil.rmtree(summary_dir)
def test_train_eval(self, *args): """Callback for train once and eval once.""" args[0].return_value = 10 summary_dir = os.path.join(BASE_SUMMARY_DIR, 'train_eval') make_directory(summary_dir) args[1].return_value = os.path.join( summary_dir, f'train_out.events.summary.{str(int(time.time()))}.ubuntu_lineage') train_callback = TrainLineage(summary_dir) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( summary_dir, f'eval_out.events.summary.{str(int(time.time())+1)}.ubuntu_lineage' ) eval_callback = EvalLineage(summary_dir) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context)) res = get_summary_lineage(summary_dir) assert res.get('hyper_parameters', {}).get('loss_function') \ == 'SoftmaxCrossEntropyWithLogits' assert res.get('algorithm', {}).get('network') == 'ResNet' if os.path.exists(summary_dir): shutil.rmtree(summary_dir)
def test_multiple_trains(self, *args): """ Callback TrainLineage and EvalLineage for multiple times. Write TrainLineage and EvalLineage in different files under same directory. EvalLineage log file end with '_lineage'. """ args[0].return_value = 10 for i in range(2): summary_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time()) + i) eval_record = SummaryRecord(SUMMARY_DIR_2, create_time=int(time.time() + 10) + i) args[1].return_value = os.path.join( SUMMARY_DIR_2, f'train_out.events.summary.{str(int(time.time()) + 2*i)}.ubuntu_lineage' ) train_callback = TrainLineage(summary_record, True) train_callback.begin(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context)) args[1].return_value = os.path.join( SUMMARY_DIR_2, f'eval_out.events.summary.{str(int(time.time())+ 2*i + 1)}.ubuntu_lineage' ) eval_callback = EvalLineage(eval_record, True) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78 + i + 1} eval_run_context['valid_dataset'] = self.run_context[ 'train_dataset'] eval_run_context['step_num'] = 32 eval_callback.end(RunContext(eval_run_context)) file_num = os.listdir(SUMMARY_DIR_2) assert len(file_num) == 8
def test_training_end(self, *args): """Test the end function in TrainLineage.""" args[0].return_value = 64 train_callback = TrainLineage(self.summary_record, True) train_callback.initial_learning_rate = 0.12 train_callback.end(RunContext(self.run_context)) res = get_summary_lineage(SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('epoch') == 10 run_context = self.run_context run_context['epoch_num'] = 14 train_callback.end(RunContext(run_context)) res = get_summary_lineage(SUMMARY_DIR) assert res.get('hyper_parameters', {}).get('epoch') == 14
def test_raise_exception_record_trainlineage(self, *args): """Test exception when error happened after recording training infos.""" if os.path.exists(SUMMARY_DIR_3): shutil.rmtree(SUMMARY_DIR_3) args[1].side_effect = MindInsightException( error=LineageErrors.PARAM_RUN_CONTEXT_ERROR, message="RunContext error.") summary_record = SummaryRecord(SUMMARY_DIR_3) train_callback = TrainLineage(summary_record, True) train_callback.begin(RunContext(self.run_context)) full_file_name = train_callback.lineage_summary.lineage_log_path file_size1 = os.path.getsize(full_file_name) train_callback.end(RunContext(self.run_context)) file_size2 = os.path.getsize(full_file_name) assert file_size2 > file_size1 eval_callback = EvalLineage(summary_record, False) eval_callback.end(RunContext(self.run_context)) file_size3 = os.path.getsize(full_file_name) assert file_size3 == file_size2