def get_meta_data(api: API=api): output, msg = None, None try: err_overlaps =api.get_err_overlap(show_filtered_err_overlap=False) e_serialized = [ e.serialize() for e in Rewrite.values() ] groups = [ g.serialize() for g in Group.values() ] attrs = [ a.serialize() for a in Attribute.values() ] predictors = [ p.serialize() for p in api.predictors.values() ] output = { 'total_size': len(Instance.qid_hash), 'anchor_predictor': api.get_anchor_predictor(), 'compare_predictor': api.get_compare_predictor(), 'selected_rewrite': api.get_selected_rewrite(), 'predictors': predictors, 'attributes': attrs, 'groups': groups, 'err_overlaps': err_overlaps, 'rewrites': e_serialized } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def set_anchor_predictor(model: str, api: API=api): output, msg = None, None try: api.set_anchor_predictor(model) return get_meta_data(api) except Exception as e: msg = e logger.error(e) traceback.print_exc() return wrap_output(output, msg)
def formalize_rewritten_examples(rid: str, api: API=api): output, msg = None, None try: if Rewrite.exists(rid): api.formalize_prev_tried_rewrites(rid) e = Rewrite.get(rid) output = e.serialize() except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def set_compare_predictor(model: str, api: API=api): output, msg = None, None try: api.set_compare_predictor(model) output = { 'compare_predictor': api.get_compare_predictor(), 'err_overlaps': api.get_err_overlap(), } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def predict_formalize(qid: str, rid: str, q_rewrite: str, groundtruths: List[str], c_rewrite: str = None, api: API = api) -> List['Identifier']: output, msg = None, None try: data = api.predict_formalize(qid, rid, q_rewrite, groundtruths, c_rewrite) output = { 'key': data['key'], 'question': data['question'].serialize() if data['question'] else None, 'context': data['context'].serialize() if data['context'] else None, 'groundtruths': [g.serialize() for g in data['groundtruths']] if data['groundtruths'] else None, 'predictions': [g.serialize() for g in data['predictions']] if data['predictions'] else None } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def detect_build_blocks(target: str, qid: str, vid: int, start_idx: int, end_idx: int, api: API=api): output, msg = None, None try: output = api.detect_build_blocks(target, qid, vid, start_idx, end_idx) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def get_err_overlap(show_filtered_err_overlap: bool, api: API=api): output, msg = None, None try: output = api.get_err_overlap(show_filtered_err_overlap) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def get_rewrites_of_instances(instance_keys: List[InstanceKey], api: API=api): output, msg = None, None try: output = api.get_rewrites_of_instances(instance_keys) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def predict_on_manual_rewrite(qtext: str, groundtruths: List[str], ctext: str, api: API=api): output, msg = None, None try: output = api.predict_on_manual_rewrite(qtext, groundtruths, ctext) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def export_built(file_name: str, built_type: str, api: API=api): output, msg = None, None try: output = api.export_built(file_name, built_type) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def detect_rule_from_rewrite(atext: str, btext: str, target_cmd: str, api: API=api) -> List['Identifier']: output, msg = None, None try: data = api.detect_rule_from_rewrite(atext, btext, target_cmd) output = [ r.serialize() for r in data.values() ] except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def create_rewrite(from_cmd: str, to_cmd: str, target_cmd: str, api: API=api): output, msg = None, None try: rewrite = api.create_rewrite(from_cmd, to_cmd, target_cmd) output = rewrite.serialize() except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def rewrite_instances_by_rid( rid: str, qids: List[str]=None, sample_size: int=10, save: bool=False, api: API=api): output, msg = None, None try: output = api.rewrite_instances_by_rid(rid, qids, sample_size, save) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def get_more_samples(direction: int, sample_size: int=10, api: API=api): output, msg = None, None try: data = api.get_more_samples(direction, sample_size) output = { 'sample_cache_idx': data['sample_cache_idx'], 'sampled_keys': data['sampled_keys'], 'questions': [q.serialize() for q in data['questions']], 'contexts': [p.serialize() for p in data['contexts']], 'answers': [a.serialize() for a in data['answers']], } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def delete_selected_rules(rids: List[str], api: API=api): output, msg = None, None try: data = api.delete_selected_rules(rids) output = { 'key': data['key'], 'question': data['question'].serialize() if data['question'] else None, 'context': data['context'].serialize() if data['context'] else None, 'groundtruths': [g.serialize() for g in data['groundtruths']] if data['groundtruths'] else None, 'predictions': [g.serialize() for g in data['predictions']] if data['predictions'] else None } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def sample_instances( selected_predictor: str=None, cmd: str='', sample_method: str="rand", sample_rewrite: str=None, sample_size: int=10, test_size: int=None, show_filtered_arr: bool=False, show_filtered_err_overlap: bool=False, show_filtered_group: bool=False, show_filtered_rewrite: bool=False, qids: List[str]=None, api: API=api): output, msg = None, None try: data = api.sample_instances( selected_predictor, cmd, sample_method, sample_rewrite, sample_size, test_size, show_filtered_arr, show_filtered_err_overlap, show_filtered_group, show_filtered_rewrite, qids) output = { 'attrs': data['attrs'], 'rewrites': data['rewrites'], 'groups': data['groups'], 'err_overlaps': data['err_overlaps'], 'sample_cache_idx': data['sample_cache_idx'], 'sampled_keys': data['sampled_keys'], 'info': data['info'], 'questions': [q.serialize() for q in data['questions']], 'contexts': [p.serialize() for p in data['contexts']], 'answers': [a.serialize() for a in data['answers']], } except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
def get_attr_distribution(attr_names: List[str], filter_cmd: str = '', use_sampled_data: bool = False, include_rewrite: str = None, include_model: str = None, test_size: int = None, api: API = api): output, msg = None, None try: output = api.get_attr_distribution(attr_names=attr_names, filter_cmd=filter_cmd, test_size=test_size, use_sampled_data=use_sampled_data, include_model=include_model, include_rewrite=include_rewrite) except Exception as e: msg = e logger.error(e) traceback.print_exc() finally: return wrap_output(output, msg)
"""Get the user arguments """ parser = argparse.ArgumentParser() parser.add_argument('--config_file', required=True, help='the configuration file required.') args = parser.parse_args() return args args = get_args() try: with open(args.config_file) as f: configs = yaml.safe_load(f) logger.info(configs) # construct the API API_CONSTRUCTOR = API.by_name(configs["task"]) api = API_CONSTRUCTOR( cache_path=configs["cache_path"], model_metas=configs["model_metas"], attr_file_name=configs["attr_file_name"], group_file_name=configs["group_file_name"], rewrite_file_name=configs["rewrite_file_name"] ) except Exception as e: api = None traceback.print_exc() logger.error(e) def wrap_output(o, msg=None): if msg: msg = f'ERR! {msg}'