def encode_seq_structural_data(data : RawDataset,
                               context_tokenizer_type : \
                               Callable[[List[str], int], Tokenizer],
                               num_keywords : int,
                               num_reserved_tokens: int) -> \
                               Tuple[StructDataset, Tokenizer, SimpleEmbedding]:
    embedding = SimpleEmbedding()

    hyps_and_goals = [
        hyp_or_goal for hyp_and_goal in [
            zip(hyps +
                [goal], itertools.repeat(embedding.encode_token(tactic)))
            for prev_tactics, hyps, goal, tactic in data
        ] for hyp_or_goal in hyp_and_goal
    ]
    context_tokenizer = make_keyword_tokenizer_relevance(
        hyps_and_goals, context_tokenizer_type, num_keywords,
        num_reserved_tokens)
    encodedData = []
    for prev_tactics, hyps, goal, tactic in data:
        stem, rest = serapi_instance.split_tactic(tactic)
        encodedData.append(
            ([context_tokenizer.toTokenList(hyp)
              for hyp in hyps], context_tokenizer.toTokenList(goal),
             (embedding.encode_token(stem),
              [hyp_index(hyps, arg) for arg in get_symbols(rest)])))

    return encodedData, context_tokenizer, embedding
Ejemplo n.º 2
0
def max_args(num_str : str,
             in_data: TacticContext, tactic : str,
             new_in_data : TacticContext,
             arg_values : argparse.Namespace) -> bool:
    stem, args_string  = serapi_instance.split_tactic(tactic)
    args = args_string.strip()[:-1].split()
    return len(args) <= int(num_str)
Ejemplo n.º 3
0
 def _determine_relevance(self, inter: ScrapedTactic) -> List[bool]:
     stem, args_string = serapi_instance.split_tactic(inter.tactic)
     args = args_string[:-1].split()
     return [
         any([
             var.strip() in args
             for var in serapi_instance.get_var_term_in_hyp(hyp).split(",")
         ]) for hyp in inter.hypotheses
     ]
Ejemplo n.º 4
0
def numeric_args(in_data : TacticContext, tactic : str,
                 next_in_data : TacticContext,
                 arg_values : argparse.Namespace) -> bool:
    goal_words = get_symbols(in_data.goal)
    stem, rest = serapi_instance.split_tactic(tactic)
    args = get_subexprs(rest.strip("."))
    for arg in args:
        if not re.fullmatch("\d+", arg):
            return False
    return True
Ejemplo n.º 5
0
def args_vars_in_list(tactic : str,
                      context_list : List[str]) -> bool:
    stem, args_string  = serapi_instance.split_tactic(tactic)
    args = args_string[:-1].split()
    if not serapi_instance.tacticTakesHypArgs(stem) and len(args) > 0:
        return False
    var_names = serapi_instance.get_vars_in_hyps(context_list)
    for arg in args:
        if not arg in var_names:
            return False
    return True
Ejemplo n.º 6
0
def args_token_in_goal(in_data: ContextData, tactic: str,
                       next_in_data: ContextData,
                       arg_values: argparse.Namespace) -> bool:
    goal = in_data["goal"]
    goal_words = get_symbols(cast(str, goal))[:arg_values.max_length]
    stem, rest = serapi_instance.split_tactic(tactic)
    args = get_subexprs(rest.strip("."))
    for arg in args:
        if not arg in goal_words:
            return False
    return True
Ejemplo n.º 7
0
def get_arg_idx(max_length: int, inter: ScrapedTactic) -> int:
    tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic)
    symbols = tokenizer.get_symbols(inter.context.focused_goal)
    arg = tactic_rest.split()[0].strip(".")
    assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\
        .format(inter.tactic, arg, inter.context.focused_goal, symbols)
    idx = symbols.index(arg)
    if idx >= max_length:
        return 0
    else:
        return idx + 1
Ejemplo n.º 8
0
def args_vars_in_context(in_data: ContextData, tactic: str,
                         next_in_data: ContextData,
                         arg_values: argparse.Namespace) -> bool:
    stem, args_string = serapi_instance.split_tactic(tactic)
    args = args_string[:-1].split()
    if not serapi_instance.tacticTakesHypArgs(stem) and len(args) > 0:
        return False
    var_names = serapi_instance.get_vars_in_hyps(
        cast(List[str], in_data["hyps"]))
    for arg in args:
        if not arg in var_names:
            return False
    return True
Ejemplo n.º 9
0
def get_stem_and_arg_idx(max_length: int, embedding: Embedding,
                         inter: ScrapedTactic) -> Tuple[int, int]:
    tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic)
    stem_idx = embedding.encode_token(tactic_stem)
    symbols = tokenizer.get_symbols(inter.context.focused_goal)
    arg = tactic_rest.split()[0].strip(".")
    assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\
        .format(inter.tactic, arg, inter.context.focused_goal, symbols)
    idx = symbols.index(arg)
    if idx >= max_length:
        return stem_idx, 0
    else:
        return stem_idx, idx + 1
Ejemplo n.º 10
0
def encode_tactic_structure(stem_embedding : SimpleEmbedding,
                            max_args : int,
                            hyps_and_tactic : Tuple[List[str], str]) \
    -> TacticStructure:
    hyps, tactic = hyps_and_tactic
    tactic_stem, args_str = serapi_instance.split_tactic(tactic)
    arg_strs = args_str.split()[:max_args]
    stem_idx = stem_embedding.encode_token(tactic_stem)
    arg_idxs = [get_arg_idx(hyps, arg.strip()) for arg in args_str.split()]
    if len(arg_idxs) < max_args:
        arg_idxs += [EOS_token] * (max_args - len(arg_idxs))
    # If any arguments aren't hypotheses, ignore the arguments
    if not all(arg_idxs):
        arg_idxs = [EOS_token] * max_args

    return TacticStructure(stem_idx=stem_idx, hyp_idxs=arg_idxs)
Ejemplo n.º 11
0
def args_token_in_goal(in_data : TacticContext, tactic : str,
                       next_in_data : TacticContext,
                       arg_values : argparse.Namespace) -> bool:
    goal_words = get_symbols(in_data.goal)[:arg_values.max_length]
    stem, rest = serapi_instance.split_tactic(tactic)
    if len(rest) > 0 and rest[-1] == '.':
        rest = rest[:-1]
    args = get_subexprs(rest)
    # While the arguments to an intro(s) might *look* like
    # goal arguments, they are actually fresh variables
    if (stem == "intros" or stem == "intro") and len(args) > 0:
        return False
    for arg in args:
        if not any([serapi_instance.symbol_matches(goal_word, arg)
                    for goal_word in goal_words]):
            return False
    return True
 def _encode_action(self, context: TacticContext, action: str) \
         -> Tuple[int, int]:
     stem, argument = serapi_instance.split_tactic(action)
     stem_idx = emap_lookup(self.tactic_map, 32, stem)
     all_premises = context.hypotheses + context.relevant_lemmas
     stripped_arg = argument.strip(".").strip()
     if stripped_arg == "":
         arg_idx = 0
     else:
         index_hyp_vars = dict(
             serapi_instance.get_indexed_vars_in_hyps(all_premises))
         if stripped_arg in index_hyp_vars:
             hyp_varw, _, rest = all_premises[index_hyp_vars[stripped_arg]]\
                 .partition(":")
             arg_idx = emap_lookup(self.token_map, 128,
                                   tokenizer.get_words(rest)[0]) + 2
         else:
             goal_symbols = tokenizer.get_symbols(context.goal)
             if stripped_arg in goal_symbols:
                 arg_idx = emap_lookup(self.token_map, 128,
                                       stripped_arg) + 128 + 2
             else:
                 arg_idx = 1
     return stem_idx, arg_idx