def parse(self, result: Result) -> Result: for qt in result.question_types: # 查询语句翻译 if qt == 'year_status': self.trans_year_status(result['year']) elif qt == 'catalog_status': self.trans_catalog_status(result['year'], result['catalog']) elif qt == 'exist_catalog': self.trans_exist_catalog(result['year']) elif qt in ('index_value', 'indexes_m_compare', 'indexes_n_compare'): self.trans_index_value(result['year'], result['index']) elif qt == 'index_overall': self.trans_index_overall(result['year'], result['index']) elif qt in ('index_2_overall', 'indexes_overall_trend'): self.trans_indexes_overall(result['year'], result['index']) elif qt == 'index_compose': self.trans_index_compose(result['year'], result['index']) elif qt in ('indexes_2m_compare', 'indexes_2n_compare'): self.trans_indexes_mn_compare(result['year'], result['index']) elif qt == 'indexes_g_compare': self.trans_indexes_g_compare(result['year'], result['index']) elif qt in ('area_value', 'areas_m_compare', 'areas_n_compare'): self.trans_area_value(result['year'], result['area'], result['index']) elif qt == 'area_overall': self.trans_area_overall(result['year'], result['area'], result['index']) elif qt in ('area_2_overall', 'areas_overall_trend'): self.trans_areas_overall(result['year'], result['area'], result['index']) elif qt == 'area_compose': self.trans_area_compose(result['year'], result['index']) elif qt in ('areas_2m_compare', 'areas_2n_compare'): self.trans_areas_mn_compare(result['year'], result['area'], result['index']) elif qt == 'areas_g_compare': self.trans_areas_g_compare(result['year'], result['area'], result['index']) elif qt in ('indexes_trend', 'indexes_max'): self.trans_indexes_value(result['year'], result['index']) elif qt in ('areas_trend', 'areas_max'): self.trans_areas_value(result['year'], result['area'], result['index']) elif qt in ('index_change', 'indexes_change'): self.trans_index_change(result['year']) elif qt in ('catalog_change', 'catalogs_change'): self.trans_catalog_change(result['year']) elif qt == 'begin_stats': self.trans_begin_stats(result['index']) result.add_sql(qt, deepcopy(self.chain)) self.chain.reset() return result
def extract_index(self, result: Result, len_threshold: int = 4, ratio_threshold: float = 0.5): """ 提取因错别字或说法而未识别到的指标 """ new_word, old_word = index_complement(result.filtered_question, self.index_wds, len_threshold, ratio_threshold) if new_word: debug('||REPLACE FOUND||', new_word, '<=', old_word) result.add_word(new_word, self.word_type_dict.get(new_word)) result.replace_words(old_word, new_word)
def case(self): #通过导入测试类来实现生成测试集 suite = unittest.TestLoader().loadTestsFromTestCase(AppDemo) #实例化结果对象 #生成一个空的结果集 r = Result() #运行case,并更新结果集,记录正确的case 失败的case res = suite.run(r)
def question_filter(self, question: str) -> Result: question = question.replace(' ', '') # 过滤年份 filtered_question = year_complement(question) # 过滤特征词 region_wds = [] for w in self.region_tree.iter(filtered_question): region_wds.append(w[1][1]) region_dict = {w: self.word_type_dict.get(w) for w in region_wds} return Result(region_dict, question, filtered_question)
def case(self): suite = unittest.TestLoader().loadTestsFromTestCase(QQDemo) local.result = Result() res = suite.run(local.result) logger.debug('当前线程的的名字:%s' % threading.current_thread().getName()) result = {threading.current_thread().getName(): res} for deviceName, result in result.items(): html = HTMLTestAppRunner.HTMLTestRunner(stream=open( APPREPORT.format('{}.html'.format(deviceName)), "wb"), verbosity=2, title='测试') html.generateReport('', result)
def question_filter(self, question: str) -> Result: question = question.replace(' ', '') # 过滤年份 filtered_question = year_complement(question) # 过滤特征词 region_wds = [] for w in self.region_tree.iter(filtered_question): region_wds.append(w[1][1]) region_dict = {w: self.word_type_dict.get(w) for w in region_wds} # 过滤指标(在提取失败或问题中无指标值时) if not region_dict or 'index' not in region_dict.values(): new_word, old_word = index_complement(filtered_question, self.index_wds) if new_word: filtered_question = filtered_question.replace(old_word, new_word) region_dict[new_word] = self.word_type_dict.get(new_word) return Result(region_dict, question, filtered_question)
def case(self): #通过导入测试类来实现生成测试集 suite = unittest.TestLoader().loadTestsFromTestCase(DingDang_Login) #生成一个空的结果集 local.result = Result() #运行case,并更新结果,记录通过与失败的case res = suite.run(local.result) #将结果通过测试手机名称进行区分 logger.debug('当前线程的的名字:%s' % threading.current_thread().getName()) # 当前线程的名字 就是当前运行手机的名字 result = {threading.current_thread().getName(): res} for deviceName, result in result.items(): html = HTMLTestAppRunner.HTMLTestRunner(stream=open( APPREPORT_PATH.format('{}.html'.format(deviceName)), 'wb'), verbosity=2, title='测试报告') #这个方法就是生成报告的主要函数 html.generateReport('', result)
def case(self): # 通过导入测试类来实现生成测试集 local.suite = unittest.TestLoader().loadTestsFromTestCase(ThreadDemo) # 生成空的结果集,用来存执行结果 local.result = Result() # 运行case,并更新结果集,如执行状态 local.res = local.suite.run(local.result) # 将结果通过测试手机名称进行区分 logger.debug('当前线程的的名字:%s' % threading.current_thread().getName()) # 当前线程的名称就是当前手机的名字 result = {threading.current_thread().getName(): local.res} for deviceName, result in result.items(): report_name = deviceName + '-' + time.strftime('%Y%m%d%H%M%S') html = HTMLTestAppRunner.HTMLTestRunner(stream=open( APP_REPORT.format('{}.html'.format(report_name)), 'wb'), verbosity=2, title='Test') # 这个方法就是生成报告的主要函数 html.generateReport('', result)
def main(locality_file): if args.l_shape: localities = ShapeLocalities(args, locality_file) if args.localities: localities = TextLocalities(args, locality_file) from lib.result import Result # This next part is probably redundant, sinch assigning the 'polygons' variable is now done in the main() function. if args.p_shape: polygons = ShapePolygons(args) if args.polygons: polygons = TextPolygons(args) result = Result(polygons, args) done = 0 # Index the geotiff files if available. if args.tif: from lib.readGeoTiff import indexTiffs try: index = indexTiffs(args.tif) except AttributeError: sys.exit("[ Error ] No such file \'%s\'" % args.tif[0]) # Read the locality data and test if the coordinates # are located in any of the polygons. # For each locality record ... if args.localities or args.l_shape: # localities = TextLocalities(args, locality_file) numLoc = localities.getNrLocalities() result.setSpeciesNames(localities) for locality in localities.getLocalities(): done = print_progress(done, numLoc) # ... and for each polygon ... for polygon in polygons.getPolygons(): # ... test if the locality record is found in the polygon. # locality[0] = species name, locality[1] = latitude, locality[2] = longitude # if pointInPolygon(polygon[1], locality[2], locality[1]) == True: if pointInPolygon(polygon[1], locality) == True: # Test if elevation files are available. if args.tif: if elevationTest(locality[1], locality[2], polygon, index) == True: # Store the result result.setResult(locality, polygon[0]) else: # Store the result result.setResult(locality, polygon[0]) if args.gbif: gbifData = GbifLocalities(args) result.setSpeciesNames(gbifData) numLoc = gbifData.getNrLocalities() # For each GBIF locality record ... for locality in gbifData.getLocalities(): done = print_progress(done, numLoc) # ... and for each polygon ... for polygon in polygons.getPolygons(): # ... test if the locality record is found in the polygon. # if pointInPolygon(polygon[1], locality[2], locality[1]) == True: if pointInPolygon(polygon[1], locality) == True: result.setResult(locality, polygon[0]) # Test if elevation files are available. if args.tif: if elevationTest(locality[1], locality[2], polygon, index) == True: # Store the result result.setResult(locality, polygon[0]) else: # Store the result result.setResult(locality, polygon[0]) # Clean up if args.np > 1: try: os.remove(locality_file) except: pass sys.stderr.write("\n") return result
# Multiprocessing if args.np > 1: from lib.splitLocalityFile import split_file from multiprocessing import Pool from lib.result import Result from lib.joinResults import joinResults if args.localities: tmp_input_files = split_file(args.localities, args.np) if args.gbif: tmp_input_files = split_file(args.gbif, args.np) pool = Pool(processes = args.np) result_objects = pool.map(main, tmp_input_files) # Instantiate a Result object to join the results from the parallel processes. finalResult = Result(polygons, args) Result.joinResults(finalResult, result_objects) plottResult(finalResult) else: if args.test == True: if args.localities: from lib.testData import testLocality localities = TextLocalities(args) testLocality(localities, args.localities) if args.polygons: from lib.testData import testPolygons testPolygons(polygons, args.polygons) else:
def main(): from lib.result import Result # Create list to store the geotif objects in. polygons = Polygons() result = Result(polygons, args) done = 0 # Index the geotiff files if available. if args.tif: from lib.readGeoTiff import indexTiffs try: index = indexTiffs(args.tif) except AttributeError: sys.exit("[ Error ] No such file \'%s\'" % args.tif[0]) # Read the locality data and test if the coordinates # are located in any of the polygons. # For each locality record ... if args.localities: localities = MyLocalities() numLoc = localities.getQuant() result.setSpeciesNames(localities) for locality in localities.getLocalities(): done = print_progress(done, numLoc) # ... and for each polygon ... for polygon in polygons.getPolygons(): # ... test if the locality record is found in the polygon. if localities.getCoOrder() == "lat-long": # locality[0] = species name, locality[1] = latitude, locality[2] = longitude if pointInPolygon(polygon[1], locality[2], locality[1]) == True: # Test if elevation files are available. if args.tif: if elevationTest(locality[1], locality[2], polygon, index) == True: # Store the result result.setResult(locality, polygon[0]) else: # Store the result result.setResult(locality, polygon[0]) else: # locality[0] = species name, locality[1] = longitude, locality[2] = latitude if pointInPolygon(polygon[1], locality[1], locality[2]) == True: if args.tif: if elevationTest(locality[2], locality[1], polygon, index) == True: result.setResult(locality[0], polygon[0]) if args.gbif: gbifData = GbifLocalities() result.setSpeciesNames(gbifData) numLoc = gbifData.getQuant() # For each GBIF locality record ... for locality in gbifData.getLocalities(): done = print_progress(done, numLoc) # ... and for each polygon ... for polygon in polygons.getPolygons(): # ... test if the locality record is found in the polygon. if pointInPolygon(polygon[1], locality[2], locality[1]) == True: result.setResult(locality, polygon[0]) # Test if elevation files are available. if args.tif: if elevationTest(locality[1], locality[2], polygon, index) == True: # Store the result result.setResult(locality, polygon[0]) else: # Store the result result.setResult(locality, polygon[0]) sys.stderr.write("\n") result.printNexus(args.out) if args.plot == True: import os from lib.plot import prepare_plots prepare_plots(result, polygons) #__ GUI STUFF dir_output = args.dir_output # Working directory path_script = args.path_script cmd="Rscript %s/R/graphical_output.R %s %s %s %s %s %s" \ % (path_script,path_script, "coordinates.sgc.txt", "polygons.sgc.txt", "sampletable.sgc.txt", "speciestable.sgc.txt",dir_output) os.system(cmd) if args.stochastic_mapping == True: import os import lib.stochasticMapping as stochasticMapping # Run the stochastic mapping analysis stochasticMapping.main(args, result)
# Multiprocessing if args.np > 1: from lib.splitLocalityFile import split_file from multiprocessing import Pool from lib.result import Result from lib.joinResults import joinResults if args.localities: tmp_input_files = split_file(args.localities, args.np) if args.gbif: tmp_input_files = split_file(args.gbif, args.np) pool = Pool(processes=args.np) result_objects = pool.map(main, tmp_input_files) # Instantiate a Result object to join the results from the parallel processes. finalResult = Result(polygons, args) Result.joinResults(finalResult, result_objects) plottResult(finalResult) else: if args.test == True: if args.localities: from lib.testData import testLocality localities = TextLocalities(args) testLocality(localities, args.localities) if args.polygons: from lib.testData import testPolygons testPolygons(polygons, args.polygons)
def run_experiment(): args = get_args() args.experiment = "results_physionet" N_exp = 10 args.ckpt_dir = "paper" args.data_dir = "../data" args.dataset = "physionet_burst" args.n_epochs = 100 args.z_dim = 35 args.K = 10 args.plot_every = 20 args.kl_annealing_epochs = 20 for n_exp in range(N_exp): # Shi-VAE args.model_name = '{}_{}_{}z_{}h_{}s_{}'.format(args.model, args.dataset, args.z_dim, args.h_dim, args.K, n_exp) args.result_dir = os.path.join(args.ckpt_dir, args.experiment, args.model_name) args.ckpt_file = os.path.join(args.result_dir, args.model_name + ".pth") args.best_ckpt_file = os.path.join(args.result_dir, args.model_name + "_best.pth") # Restore training if (args.restore == 1): if (not os.path.isfile(args.ckpt_file)): print('Model not found at {}'.format(args.ckpt_file)) sys.exit() model_dict = torch.load(args.ckpt_file) n = args.n_epochs # Restore args from training args. args = model_dict['params'] args.n_epochs = n args.restore = 1 # Print Arguments print('ARGUMENTS') for arg in vars(args): print('{} = {}'.format(arg, getattr(args, arg))) # Create checkpoint directory if (not os.path.exists(args.ckpt_dir)): os.makedirs(args.ckpt_dir) # Create results directory if (not os.path.exists(args.result_dir)): os.makedirs(args.result_dir) # ============= LOAD DATA ============= # # Load data data = np.load(os.path.join(args.data_dir, args.dataset, args.dataset + ".npz")) types_csv = os.path.join(args.data_dir, args.dataset, "data_types_real.csv") types_list = utils.read_csv_types(types_csv) # Train x_train = data["x_train_miss"].astype(np.float32) x_train_full = data["x_train_full"].astype(np.float32) m_train = data["m_train_miss"].astype(bool) m_train_artificial = data["m_train_artificial"].astype(bool) y_train = data["y_train"] # Val x_val = data["x_val_miss"].astype(np.float32) x_val_full = data["x_val_full"].astype(np.float32) m_val = data["m_val_miss"].astype(bool) m_val_artificial = data["m_val_artificial"].astype(bool) y_val = data["y_val"] # Test x_test = data["x_test_miss"].astype(np.float32) x_test_full = data["x_test_full"].astype(np.float32) m_test = data["m_test_miss"].astype(bool) m_test_artificial = data["m_test_artificial"].astype(bool) y_test = data["y_test"] # ===== Scaler ===== # scaler = HeterogeneousScaler(types_list) scaler.fit(x_train, m_train) data_train = dset.HeterDataset(x_train, m_train, x_train_full, m_train_artificial, types_list=types_list) data_valid = dset.HeterDataset(x_val, m_val, x_val_full, m_val_artificial, types_list=types_list) data_test = dset.HeterDataset(x_test, m_test, x_test_full, m_test_artificial, types_list=types_list) train_loader = torch.utils.data.DataLoader(data_train, batch_size=64, shuffle=True, collate_fn=dset.standard_collate) valid_loader = torch.utils.data.DataLoader(data_valid, batch_size=64, shuffle=False, collate_fn=dset.standard_collate) test_loader = torch.utils.data.DataLoader(data_test, batch_size=64, shuffle=False, collate_fn=dset.standard_collate) # ============= MODEL ============= # # Shi-VAE from models.shivae import ShiVAE model = ShiVAE(h_dim=args.h_dim, z_dim=args.z_dim, s_dim=args.K, types_list=types_list, n_layers=1, learn_std=False) optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('Trainable params: {}'.format(total_params)) # ============= TRAIN ============= # # Train model from models.trainers import Trainer if args.train == 1 or args.train == -1: trainer = Trainer(model, optimizer, args, scaler=scaler) # Train from pretrained model if (args.restore == 1 and os.path.isfile(args.ckpt_file)): print('Model loaded at {}'.format(args.ckpt_file)) trainer.load_checkpoint(model_dict) print('Training points: {}'.format(len(train_loader.dataset))) trainer.train(train_loader, test_loader) # ============= RESULTS ============= # if args.train == 0 or args.train == -1: from lib.result import Result result_dir = os.path.dirname(args.ckpt_file) print('Save images in: {}'.format(result_dir)) # Load pretrained model model_dict = torch.load(args.best_ckpt_file) model.load_state_dict(model_dict['state_dict']) # Create test loader test_loader = torch.utils.data.DataLoader(data_test, batch_size=64, shuffle=False, collate_fn=dset.standard_collate) # Reconstruction and generation result = Result(test_loader, scaler, model, result_dir, args) model_name = "ShiVAE" result.avg_error(model_name=model_name) result.reconstruction(types_list=types_list) result.generation(args.result_imgs, types_list=types_list) # ===== Save args ===== # args_path = os.path.join(args.result_dir, args.model_name) + args.model_name + '.json' save_args(args, args_path)
print('Training points: {}'.format(len(train_loader.dataset))) trainer.train(train_loader, valid_loader) # ============= RESULTS ============= # if args.train == 0 or args.train == -1: from lib.result import Result result_dir = os.path.dirname(args.ckpt_file) print('Save images in: {}'.format(result_dir)) # Load pretrained model model_dict = torch.load(args.ckpt_file) model.load_state_dict(model_dict['state_dict']) # Create test loader test_loader = torch.utils.data.DataLoader(data_test, batch_size=64, shuffle=False, collate_fn=dset.standard_collate) # Reconstruction and generation result = Result(test_loader, scaler, model, result_dir, args) model_name = "ShiVAE" result.avg_error(model_name=model_name) result.reconstruction(types_list=types_list) result.generation(args.result_imgs, types_list=types_list) # ===== Save args ===== # args_path = os.path.join(args.result_dir, args.model_name) + args.model_name + '.json' save_args(args, args_path)
def _classify_tree(self, result: Result): # 收集实体类型 question = result.filtered_question year_count = result.count('year') # 问题与单个年份相关 if year_count == 1: # 全年总体情况 if check_contain(self.status_rwds, question) and 'year' in result and len(result) == 1: result.add_qtype('year_status') # 全年含有目录 if check_contain(self.exist_qwds, question) and check_contain(self.catalog_rwds, question): result.add_qtype('exist_catalog') # 目录 if 'catalog' in result: # 总体情况 if check_contain(self.status_rwds, question): result.add_qtype('catalog_status') # 指标 if 'index' in result: # 值 if check_contain(self.value_qwds, question) or check_endswith(self.is_twds, question): if not check_contain(self.child_index_rwds, question): # 涉及地区 if 'area' in result: result.add_qtype('area_value') else: result.add_qtype('index_value') # 值比较(上级) if check_regexp(question, MultipleCmp1, functions=[lambda x: check_contain(self.parent_index_rwds, x[0][-1])], callback=lambda x: QuestionOrderError.check(x, self.parent_index_rwds) ): # 涉及地区 if 'area' in result: result.add_qtype('area_overall') else: result.add_qtype('index_overall') # 值比较(同类同单位) if result.count('index') == 2 and 'area' not in result: if check_regexp(question, MultipleCmp1, functions=[ lambda x: check_list_contain(result['index'], x[0], 0, -1) ]): result.add_qtype('indexes_m_compare') # 比较倍数关系 if check_regexp(question, NumberCmp1, functions=[ lambda x: (check_list_contain(result['index'], x[0], 0, -1) or check_all_contain(result['index'], x[0][0])) ]): result.add_qtype('indexes_n_compare') # 比较数量关系 # 地区值比较(相同指标不同地区) if result.count('index') == 1 and result.count('area') == 2: if check_regexp(question, MultipleCmp1, functions=[ lambda x: (check_list_contain(result['area'], x[0], 0, -1) and check_list_contain(result['index'], x[0], 0)) ]): result.add_qtype('areas_m_compare') # 比较倍数关系 if check_regexp(question, NumberCmp1, functions=[ lambda x: ((check_list_contain(result['area'], x[0], 0, -1) and check_contain(result['index'], x[0][0])) or (check_all_contain(result['area'], x[0][0]) and check_list_any_contain(result['index'], x[0], 0, -1))) ]): result.add_qtype('areas_n_compare') # 比较数量关系 # 同比值比较 if check_regexp(question, GrowthCmp, functions=[ lambda x: check_all_contain(result['index'], x[0]) ]): if 'area' in result: # 单地区多指标 if result.count('area') == 1: result.add_qtype('areas_g_compare') else: result.add_qtype('indexes_g_compare') # 指标下不同地区组成情况 if check_contain(self.location_rwds, question): if check_contain(self.status_rwds, question): result.add_qtype('area_compose') # 指标的子组成 else: if check_contain(self.child_index_rwds, question): result.add_qtype('index_compose') # 问题与两个年份相关 elif year_count == 2: # 目录与指标的变化情况 if result.count('year') == len(result): if check_contain(self.catalog_rwds, question): result.add_qtype('catalog_change') elif check_contain(self.index_rwds, question): result.add_qtype('index_change') # 指标 if 'index' in result: if check_contain(self.parent_index_rwds, question): # 上级占比变化 if check_regexp(question, NumberCmp2[0], NumberCmp2[1], functions=[ lambda x: check_contain(self.parent_index_rwds, x[0]) ]*2): if 'area' not in result: result.add_qtype('index_2_overall') elif check_regexp(question, NumberCmp2[0], NumberCmp2[1], functions=[ lambda x: check_contain(result['area'], x[0]) ]*2): result.add_qtype('area_2_overall') else: # 比较数值 if check_regexp(question, *NumberCmp2, functions=[ lambda x: check_contain(result['index'], x[0]), lambda x: check_contain(result['index'], x[0]), lambda x: check_contain(result['index'], x[0][-1]) ]): if 'area' not in result: # 不涉及地区 result.add_qtype('indexes_2n_compare') else: # 涉及地区 if result.count('index') == 1: # 单指标下不同地区比较 if check_regexp(question, *NumberCmp2, functions=[ lambda x: check_contain(result['area'], x[0]), lambda x: check_contain(result['area'], x[0]), lambda x: check_contain(result['area'], x[0][0]), ]): result.add_qtype('areas_2n_compare') # 比较倍数 if check_regexp(question, MultipleCmp2, functions=[ lambda x: check_list_any_contain(result['index'], x[0], 0, -1) ]): if 'area' not in result: # 不涉及地区 result.add_qtype('indexes_2m_compare') else: # 涉及地区 if check_regexp(question, MultipleCmp2, functions=[ lambda x: check_list_any_contain(result['area'], x[0], 0, -1) ]): result.add_qtype('areas_2m_compare') # 问题与多个年份相关 elif year_count > 2: # 指标/目录变化趋势 if result.count('year') == len(result) and check_contain(self.status_rwds, question): if check_contain(self.catalog_rwds, question): result.add_qtype('catalogs_change') elif check_contain(self.index_rwds, question): result.add_qtype('indexes_change') # 关于指标的变化趋势 if 'index' in result: # 占上级的 if check_regexp(question, MultipleCmp1, functions=[ lambda x: (check_contain(result['index'], x[0][0]) and check_contain(self.status_rwds, x[0][-1]) and check_contain(self.parent_index_rwds, x[0][-1])) ]): if 'area' in result: result.add_qtype('areas_overall_trend') else: result.add_qtype('indexes_overall_trend') # 值的 if check_contain(self.status_rwds, question) and not check_contain(self.parent_index_rwds, question): if 'area' in result: result.add_qtype('areas_trend') else: result.add_qtype('indexes_trend') # 最值 if check_contain(self.max_rwds, question): if 'area' in result: result.add_qtype('areas_max') else: result.add_qtype('indexes_max') # 问题与年份无关 else: if 'index' in result and check_contain(self.when_qwds, question): result.add_qtype('begin_stats')