コード例 #1
0
def get_meta_data(api: API=api):
    output, msg = None, None
    try:
        err_overlaps =api.get_err_overlap(show_filtered_err_overlap=False)
        e_serialized = [ e.serialize() for e in Rewrite.values() ]
        groups = [ g.serialize() for g in Group.values() ]
        attrs = [ a.serialize() for a in Attribute.values() ]
        predictors = [ p.serialize() for p in api.predictors.values() ]
        output = {
            'total_size': len(Instance.qid_hash),
            'anchor_predictor': api.get_anchor_predictor(),
            'compare_predictor': api.get_compare_predictor(),
            'selected_rewrite': api.get_selected_rewrite(),
            'predictors': predictors,
            'attributes': attrs,
            'groups': groups,
            'err_overlaps': err_overlaps,
            'rewrites': e_serialized
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #2
0
def set_anchor_predictor(model: str, api: API=api):
    output, msg = None, None
    try:
        api.set_anchor_predictor(model)
        return get_meta_data(api)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
        return wrap_output(output, msg)
コード例 #3
0
def formalize_rewritten_examples(rid: str, api: API=api):
    output, msg = None, None
    try:
        if Rewrite.exists(rid):
            api.formalize_prev_tried_rewrites(rid)
            e = Rewrite.get(rid)
            output = e.serialize()
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #4
0
def set_compare_predictor(model: str, api: API=api):
    output, msg = None, None
    try:
        api.set_compare_predictor(model)
        output = {
            'compare_predictor': api.get_compare_predictor(),
            'err_overlaps': api.get_err_overlap(),
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #5
0
ファイル: __main__.py プロジェクト: ysenarath/errudite
def predict_formalize(qid: str,
                      rid: str,
                      q_rewrite: str,
                      groundtruths: List[str],
                      c_rewrite: str = None,
                      api: API = api) -> List['Identifier']:
    output, msg = None, None
    try:
        data = api.predict_formalize(qid, rid, q_rewrite, groundtruths,
                                     c_rewrite)
        output = {
            'key':
            data['key'],
            'question':
            data['question'].serialize() if data['question'] else None,
            'context':
            data['context'].serialize() if data['context'] else None,
            'groundtruths':
            [g.serialize()
             for g in data['groundtruths']] if data['groundtruths'] else None,
            'predictions':
            [g.serialize()
             for g in data['predictions']] if data['predictions'] else None
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #6
0
def detect_build_blocks(target: str, qid: str, vid: int, start_idx: int, end_idx: int, api: API=api):
    output, msg = None, None
    try:
        output = api.detect_build_blocks(target, qid, vid, start_idx, end_idx)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #7
0
def get_err_overlap(show_filtered_err_overlap: bool, api: API=api):
    output, msg = None, None
    try:
        output = api.get_err_overlap(show_filtered_err_overlap)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #8
0
def get_rewrites_of_instances(instance_keys: List[InstanceKey], api: API=api):
    output, msg = None, None
    try:
        output = api.get_rewrites_of_instances(instance_keys)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #9
0
def predict_on_manual_rewrite(qtext: str, groundtruths: List[str], ctext: str, api: API=api):
    output, msg = None, None
    try:
        output = api.predict_on_manual_rewrite(qtext, groundtruths, ctext)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #10
0
def export_built(file_name: str, built_type: str, api: API=api):
    output, msg = None, None
    try:
        output = api.export_built(file_name, built_type) 
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #11
0
def detect_rule_from_rewrite(atext: str, btext: str, target_cmd: str, api: API=api) -> List['Identifier']:
    output, msg = None, None
    try:
        data = api.detect_rule_from_rewrite(atext, btext, target_cmd)
        output = [ r.serialize() for r in data.values() ]
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #12
0
def create_rewrite(from_cmd: str, to_cmd: str, target_cmd: str, api: API=api):
    output, msg = None, None
    try:
        rewrite = api.create_rewrite(from_cmd, to_cmd, target_cmd)
        output = rewrite.serialize()
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #13
0
def rewrite_instances_by_rid(
    rid: str, qids: List[str]=None, sample_size: int=10, save: bool=False, api: API=api):
    output, msg = None, None
    try:
        output = api.rewrite_instances_by_rid(rid, qids, sample_size, save)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #14
0
def get_more_samples(direction: int, sample_size: int=10, api: API=api):
    output, msg = None, None
    try: 
        data = api.get_more_samples(direction, sample_size)        
        output = {
            'sample_cache_idx': data['sample_cache_idx'],
            'sampled_keys': data['sampled_keys'], 
            'questions': [q.serialize() for q in data['questions']],
            'contexts': [p.serialize() for p in data['contexts']],
            'answers': [a.serialize() for a in data['answers']],
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #15
0
def delete_selected_rules(rids: List[str], api: API=api):
    output, msg = None, None
    try:
        data = api.delete_selected_rules(rids)
        output = {
            'key': data['key'],
            'question': data['question'].serialize() if data['question'] else None,
            'context': data['context'].serialize() if data['context'] else None,
            'groundtruths': [g.serialize() for g in data['groundtruths']] if data['groundtruths'] else None,
            'predictions': [g.serialize() for g in data['predictions']] if data['predictions'] else None
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #16
0
def sample_instances(
        selected_predictor: str=None, 
        cmd: str='', 
        sample_method: str="rand", 
        sample_rewrite: str=None, 
        sample_size: int=10, 
        test_size: int=None,
        show_filtered_arr: bool=False, 
        show_filtered_err_overlap: bool=False,
        show_filtered_group: bool=False,
        show_filtered_rewrite: bool=False,
        qids: List[str]=None,
        api: API=api):
    output, msg = None, None
    try: 
        data = api.sample_instances(
            selected_predictor, cmd, sample_method,
            sample_rewrite, sample_size, test_size, 
            show_filtered_arr, 
            show_filtered_err_overlap,
            show_filtered_group,
            show_filtered_rewrite,
            qids)
        output = {
            'attrs': data['attrs'],
            'rewrites': data['rewrites'],
            'groups': data['groups'],
            'err_overlaps': data['err_overlaps'],
            'sample_cache_idx': data['sample_cache_idx'],
            'sampled_keys': data['sampled_keys'], 
            'info': data['info'],
            'questions': [q.serialize() for q in data['questions']],
            'contexts': [p.serialize() for p in data['contexts']],
            'answers': [a.serialize() for a in data['answers']],
        }
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #17
0
ファイル: __main__.py プロジェクト: ysenarath/errudite
def get_attr_distribution(attr_names: List[str],
                          filter_cmd: str = '',
                          use_sampled_data: bool = False,
                          include_rewrite: str = None,
                          include_model: str = None,
                          test_size: int = None,
                          api: API = api):
    output, msg = None, None
    try:
        output = api.get_attr_distribution(attr_names=attr_names,
                                           filter_cmd=filter_cmd,
                                           test_size=test_size,
                                           use_sampled_data=use_sampled_data,
                                           include_model=include_model,
                                           include_rewrite=include_rewrite)
    except Exception as e:
        msg = e
        logger.error(e)
        traceback.print_exc()
    finally:
        return wrap_output(output, msg)
コード例 #18
0
    """Get the user arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_file', 
        required=True, 
        help='the configuration file required.')
    args = parser.parse_args()
    return args

args = get_args()
try:
    with open(args.config_file) as f:
        configs = yaml.safe_load(f)
    logger.info(configs)
    # construct the API
    API_CONSTRUCTOR =  API.by_name(configs["task"])
    api = API_CONSTRUCTOR(
        cache_path=configs["cache_path"],
        model_metas=configs["model_metas"],
        attr_file_name=configs["attr_file_name"],
        group_file_name=configs["group_file_name"],
        rewrite_file_name=configs["rewrite_file_name"]
    )
except Exception as e:
    api = None
    traceback.print_exc()
    logger.error(e)

def wrap_output(o, msg=None):
    if msg:
        msg = f'ERR! {msg}'