def test_increment_train(self):
     """test increment train."""
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content['mode'] = 'runner1'
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.regex_match_len(self.out, self.epoch_re, 2, self.err_msg)
     built_in.regex_match_equal(self.out, '\nmode\s+(\S+)\s+\n', 'runner1', self.err_msg)
 def test_mode_list_single_selected_gpus_1card_c2(self):
     """test selected gpus 1card, it will run with single mode."""
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content["runner"][0]["device"] = 'gpu'
     self.yaml_content["runner"][0]["selected_gpus"] = "0"
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.regex_match_len(self.out, self.epoch_re, 2, self.err_msg)
     built_in.regex_match_equal(self.out,
                                '\ntrain.trainer.engine\s+(\S+)\s+\n',
                                "single",
                                self.err_msg)
示例#3
0
 def test_optimizer_lr_le(self):
     """test optimizer lr"""
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content["hyper_parameters"]['optimizer']['class'] = 'SGD'
     self.yaml_content["hyper_parameters"]['optimizer']['learning_rate'] = 2e-2
     self.yaml_content["hyper_parameters"]['reg'] = 0.1
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.regex_match_len(self.out, self.epoch_re, 2, self.err_msg)
     built_in.regex_match_equal(self.out,
                                '\nhyper_parameters.optimizer.learning_rate\s+(\S+)\s+\n',
                                '0.02',
                                self.err_msg)
示例#4
0
 def test_QueueDataset_train_c2(self):
     """test QueueDataset in train."""
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content["dataset"][0]["type"] = "QueueDataset"
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.regex_match_len(self.out, self.epoch_re, 2, self.err_msg)
     # NOTE windows和mac直接会强行切换到dataloader
     if utils.get_platform() != "LINUX" or not six.PY2:
         check_type = "DataLoader"
     else:
         check_type = "QueueDataset"
     built_in.regex_match_equal(
         self.out, r'\ndataset.dataset_train.type\s+(\S+)\s+\n', check_type,
         self.err_msg)
示例#5
0
 def test_thread_num(self):
     """test thread num."""
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content['phase'].append({
         'name': 'phase2',
         'model': '{workspace}/model.py',  # user-defined model
         'dataset_name': 'dataset_infer',  # select dataset by name
         'thread_num': 2
     })
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.regex_match_len(self.out, 'epoch.+done', 4, self.err_msg)
     built_in.regex_match_equal(self.out,
                                '\nphase.phase2.thread_num\s+(\S+)\s+\n', 2,
                                self.err_msg)
示例#6
0
 def test_mode_list_ps_selected_gpus_2f_2card_c2(self):
     """test selected gpus 2card with two files and not set fleet mode,
        it will change ps to collective and run with local_cluster_train mode
     """
     self.yaml_config_name = sys._getframe().f_code.co_name + '.yaml'
     self.yaml_content["runner"][0]["device"] = 'gpu'
     self.yaml_content["runner"][0]["selected_gpus"] = "0,1"
     self.yaml_content["dataset"][0]["data_path"] = "criteo_data"
     self.run_yaml()
     built_in.equals(self.pro.returncode, 0, self.err_msg)
     built_in.not_contains(self.err, 'Traceback', self.err_msg)
     built_in.path_not_exist('logs/server.0', self.err_msg)
     built_in.regex_match_equal(self.out,
                                '\ntrain.trainer.engine\s+(\S+)\s+\n',
                                "local_cluster",
                                self.err_msg)
     built_in.regex_match_len('logs/worker.1', self.auc_re, 6, self.err_msg)