Esempio n. 1
0
def max_args(num_str : str,
             in_data: TacticContext, tactic : str,
             new_in_data : TacticContext,
             arg_values : argparse.Namespace) -> bool:
    stem, args_string  = serapi_instance.split_tactic(tactic)
    args = args_string.strip()[:-1].split()
    return len(args) <= int(num_str)
def encode_seq_structural_data(data : RawDataset,
                               context_tokenizer_type : \
                               Callable[[List[str], int], Tokenizer],
                               num_keywords : int,
                               num_reserved_tokens: int) -> \
                               Tuple[StructDataset, Tokenizer, SimpleEmbedding]:
    embedding = SimpleEmbedding()

    hyps_and_goals = [
        hyp_or_goal for hyp_and_goal in [
            zip(hyps +
                [goal], itertools.repeat(embedding.encode_token(tactic)))
            for prev_tactics, hyps, goal, tactic in data
        ] for hyp_or_goal in hyp_and_goal
    ]
    context_tokenizer = make_keyword_tokenizer_relevance(
        hyps_and_goals, context_tokenizer_type, num_keywords,
        num_reserved_tokens)
    encodedData = []
    for prev_tactics, hyps, goal, tactic in data:
        stem, rest = serapi_instance.split_tactic(tactic)
        encodedData.append(
            ([context_tokenizer.toTokenList(hyp)
              for hyp in hyps], context_tokenizer.toTokenList(goal),
             (embedding.encode_token(stem),
              [hyp_index(hyps, arg) for arg in get_symbols(rest)])))

    return encodedData, context_tokenizer, embedding
Esempio n. 3
0
 def _determine_relevance(self, inter: ScrapedTactic) -> List[bool]:
     stem, args_string = serapi_instance.split_tactic(inter.tactic)
     args = args_string[:-1].split()
     return [
         any([
             var.strip() in args
             for var in serapi_instance.get_var_term_in_hyp(hyp).split(",")
         ]) for hyp in inter.context.focused_hyps
     ]
Esempio n. 4
0
def numeric_args(in_data : TacticContext, tactic : str,
                 next_in_data : TacticContext,
                 arg_values : argparse.Namespace) -> bool:
    goal_words = get_symbols(in_data.goal)
    stem, rest = serapi_instance.split_tactic(tactic)
    args = get_subexprs(rest.strip("."))
    for arg in args:
        if not re.fullmatch("\d+", arg):
            return False
    return True
Esempio n. 5
0
def args_vars_in_list(tactic : str,
                      context_list : List[str]) -> bool:
    stem, args_string  = serapi_instance.split_tactic(tactic)
    args = args_string[:-1].split()
    if not serapi_instance.tacticTakesHypArgs(stem) and len(args) > 0:
        return False
    var_names = serapi_instance.get_vars_in_hyps(context_list)
    for arg in args:
        if not arg in var_names:
            return False
    return True
Esempio n. 6
0
def get_arg_idx(max_length: int, inter: ScrapedTactic) -> int:
    tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic)
    symbols = tokenizer.get_symbols(inter.context.focused_goal)
    arg = tactic_rest.split()[0].strip(".")
    assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\
        .format(inter.tactic, arg, inter.context.focused_goal, symbols)
    idx = symbols.index(arg)
    if idx >= max_length:
        return 0
    else:
        return idx + 1
Esempio n. 7
0
def get_stem_and_arg_idx(max_length: int, embedding: Embedding,
                         inter: ScrapedTactic) -> Tuple[int, int]:
    tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic)
    stem_idx = embedding.encode_token(tactic_stem)
    symbols = tokenizer.get_symbols(inter.context.focused_goal)
    arg = tactic_rest.split()[0].strip(".")
    assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\
        .format(inter.tactic, arg, inter.context.focused_goal, symbols)
    idx = symbols.index(arg)
    if idx >= max_length:
        return stem_idx, 0
    else:
        return stem_idx, idx + 1
    def _encode_action(self, context: TacticContext, action: str) \
            -> Tuple[List[int], torch.FloatTensor]:
        stem, argument = serapi_instance.split_tactic(action)
        stem_idx = encode_fpa_stem(self.dataloader_args, self.fpa_metadata,
                                   stem)
        all_prems = context.hypotheses + context.relevant_lemmas
        arg_idx = encode_fpa_arg(self.dataloader_args, self.fpa_metadata,
                                 all_prems, context.goal, argument.strip())

        tokenized_goal = tokenize(self.dataloader_args, self.fpa_metadata,
                                  context.goal)
        premise_features_size = get_premise_features_size(
            self.dataloader_args, self.fpa_metadata)
        if arg_idx == 0:
            # No arg
            arg_type_idx = 0
            encoded_arg = torch.zeros(128 + premise_features_size)
        elif arg_idx <= self.dataloader_args.max_length:
            # Goal token arg
            arg_type_idx = 1
            encoded_arg = torch.cat((self.predictor.goal_token_encoder(
                torch.LongTensor([stem_idx]), torch.LongTensor([
                    tokenized_goal
                ])).squeeze(0)[arg_idx].to(device=torch.device("cpu")),
                                     torch.zeros(premise_features_size)),
                                    dim=0)
        else:
            # Hyp arg
            arg_type_idx = 2
            arg_hyp = all_prems[arg_idx -
                                (self.dataloader_args.max_length + 1)]
            entire_encoded_goal = self.predictor.entire_goal_encoder(
                torch.LongTensor([tokenized_goal]))
            tokenized_arg_hyp = tokenize(self.dataloader_args,
                                         self.fpa_metadata,
                                         serapi_instance.get_hyp_type(arg_hyp))
            encoded_arg = torch.cat(
                (self.predictor.hyp_encoder(
                    torch.LongTensor([stem_idx]), entire_encoded_goal,
                    torch.LongTensor([tokenized_arg_hyp
                                      ])).to(device=torch.device("cpu")),
                 torch.FloatTensor(
                     get_premise_features(self.dataloader_args,
                                          self.fpa_metadata, context.goal,
                                          arg_hyp))),
                dim=0)

        return [stem_idx, arg_type_idx], encoded_arg
Esempio n. 9
0
def encode_tactic_structure(stem_embedding : SimpleEmbedding,
                            max_args : int,
                            hyps_and_tactic : Tuple[List[str], str]) \
    -> TacticStructure:
    hyps, tactic = hyps_and_tactic
    tactic_stem, args_str = serapi_instance.split_tactic(tactic)
    arg_strs = args_str.split()[:max_args]
    stem_idx = stem_embedding.encode_token(tactic_stem)
    arg_idxs = [get_arg_idx(hyps, arg.strip()) for arg in args_str.split()]
    if len(arg_idxs) < max_args:
        arg_idxs += [EOS_token] * (max_args - len(arg_idxs))
    # If any arguments aren't hypotheses, ignore the arguments
    if not all(arg_idxs):
        arg_idxs = [EOS_token] * max_args

    return TacticStructure(stem_idx=stem_idx, hyp_idxs=arg_idxs)
Esempio n. 10
0
def args_token_in_goal(in_data : TacticContext, tactic : str,
                       next_in_data : TacticContext,
                       arg_values : argparse.Namespace) -> bool:
    goal_words = get_symbols(in_data.goal)[:arg_values.max_length]
    stem, rest = serapi_instance.split_tactic(tactic)
    if len(rest) > 0 and rest[-1] == '.':
        rest = rest[:-1]
    args = get_subexprs(rest)
    # While the arguments to an intro(s) might *look* like
    # goal arguments, they are actually fresh variables
    if (stem == "intros" or stem == "intro") and len(args) > 0:
        return False
    for arg in args:
        if not any([serapi_instance.symbol_matches(goal_word, arg)
                    for goal_word in goal_words]):
            return False
    return True
 def _encode_action(self, context: TacticContext, action: str) \
         -> Tuple[int, int]:
     stem, argument = serapi_instance.split_tactic(action)
     stem_idx = emap_lookup(self.tactic_map, 32, stem)
     all_premises = context.hypotheses + context.relevant_lemmas
     stripped_arg = argument.strip(".").strip()
     if stripped_arg == "":
         arg_idx = 0
     else:
         index_hyp_vars = dict(
             serapi_instance.get_indexed_vars_in_hyps(all_premises))
         if stripped_arg in index_hyp_vars:
             hyp_varw, _, rest = all_premises[index_hyp_vars[stripped_arg]]\
                 .partition(":")
             arg_idx = emap_lookup(self.token_map, 128,
                                   tokenizer.get_words(rest)[0]) + 2
         else:
             goal_symbols = tokenizer.get_symbols(context.goal)
             if stripped_arg in goal_symbols:
                 arg_idx = emap_lookup(self.token_map, 128,
                                       stripped_arg) + 128 + 2
             else:
                 arg_idx = 1
     return stem_idx, arg_idx
Esempio n. 12
0
    def predictionCertainty(self, context: TacticContext, prediction: str) -> float:

        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
            sample_fpa(extract_dataloader_args(self.training_args),
                       self.metadata,
                       context.relevant_lemmas,
                       context.prev_tactics,
                       context.hypotheses,
                       context.goal)

        prediction_stem, prediction_args = \
            serapi_instance.split_tactic(prediction)
        prediction_stem_idx = encode_fpa_stem(extract_dataloader_args(self.training_args),
                                              self.metadata, prediction_stem)
        stem_distributions = self._model.stem_classifier(
            maybe_cuda(torch.LongTensor(word_features)),
            maybe_cuda(torch.FloatTensor(vec_features)))
        stem_certainties, stem_idxs = stem_distributions.topk(stem_width)
        if prediction_stem_idx in stem_idxs[0]:
            merged_stem_idxs = stem_idxs
            merged_stem_certainties = stem_certainties
        else:
            merged_stem_idxs = torch.cat(
                (maybe_cuda(torch.LongTensor([[prediction_stem_idx]])),
                 stem_idxs[:, :stem_width-1]),
                dim=1)
            cother = stem_certainties[:, :stem_width-1]
            val = stem_distributions[0][prediction_stem_idx]
            merged_stem_certainties = \
                torch.cat((val.view(1, 1), cother),dim=1)

        prediction_stem_idx_idx = list(merged_stem_idxs[0]).index(
            prediction_stem_idx)
        prediction_arg_idx = encode_fpa_arg(
            extract_dataloader_args(self.training_args),
            self.metadata,
            context.hypotheses + context.relevant_lemmas,
            context.goal,
            prediction_args)

        goal_arg_values = self.goal_token_scores(
            merged_stem_idxs, tokenized_goal, goal_mask)

        if len(tokenized_premises[0]) > 0:
            hyp_arg_values = self.hyp_name_scores(
                merged_stem_idxs[0], tokenized_goal[0],
                tokenized_premises[0], hyp_features[0])

            total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_scores = goal_arg_values

        final_probs, predicted_stem_idxs, predicted_arg_idxs = \
            self.predict_args(total_scores, merged_stem_certainties,
                              merged_stem_idxs)

        for prob, stem_idx_idx, arg_idx in zip(final_probs,
                                               predicted_stem_idxs,
                                               predicted_arg_idxs):
            if stem_idx_idx == prediction_stem_idx and \
               arg_idx == prediction_arg_idx:
                return math.exp(prob.item())

        assert False, "Shouldn't be able to get here"