def max_args(num_str : str, in_data: TacticContext, tactic : str, new_in_data : TacticContext, arg_values : argparse.Namespace) -> bool: stem, args_string = serapi_instance.split_tactic(tactic) args = args_string.strip()[:-1].split() return len(args) <= int(num_str)
def encode_seq_structural_data(data : RawDataset, context_tokenizer_type : \ Callable[[List[str], int], Tokenizer], num_keywords : int, num_reserved_tokens: int) -> \ Tuple[StructDataset, Tokenizer, SimpleEmbedding]: embedding = SimpleEmbedding() hyps_and_goals = [ hyp_or_goal for hyp_and_goal in [ zip(hyps + [goal], itertools.repeat(embedding.encode_token(tactic))) for prev_tactics, hyps, goal, tactic in data ] for hyp_or_goal in hyp_and_goal ] context_tokenizer = make_keyword_tokenizer_relevance( hyps_and_goals, context_tokenizer_type, num_keywords, num_reserved_tokens) encodedData = [] for prev_tactics, hyps, goal, tactic in data: stem, rest = serapi_instance.split_tactic(tactic) encodedData.append( ([context_tokenizer.toTokenList(hyp) for hyp in hyps], context_tokenizer.toTokenList(goal), (embedding.encode_token(stem), [hyp_index(hyps, arg) for arg in get_symbols(rest)]))) return encodedData, context_tokenizer, embedding
def _determine_relevance(self, inter: ScrapedTactic) -> List[bool]: stem, args_string = serapi_instance.split_tactic(inter.tactic) args = args_string[:-1].split() return [ any([ var.strip() in args for var in serapi_instance.get_var_term_in_hyp(hyp).split(",") ]) for hyp in inter.context.focused_hyps ]
def numeric_args(in_data : TacticContext, tactic : str, next_in_data : TacticContext, arg_values : argparse.Namespace) -> bool: goal_words = get_symbols(in_data.goal) stem, rest = serapi_instance.split_tactic(tactic) args = get_subexprs(rest.strip(".")) for arg in args: if not re.fullmatch("\d+", arg): return False return True
def args_vars_in_list(tactic : str, context_list : List[str]) -> bool: stem, args_string = serapi_instance.split_tactic(tactic) args = args_string[:-1].split() if not serapi_instance.tacticTakesHypArgs(stem) and len(args) > 0: return False var_names = serapi_instance.get_vars_in_hyps(context_list) for arg in args: if not arg in var_names: return False return True
def get_arg_idx(max_length: int, inter: ScrapedTactic) -> int: tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic) symbols = tokenizer.get_symbols(inter.context.focused_goal) arg = tactic_rest.split()[0].strip(".") assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\ .format(inter.tactic, arg, inter.context.focused_goal, symbols) idx = symbols.index(arg) if idx >= max_length: return 0 else: return idx + 1
def get_stem_and_arg_idx(max_length: int, embedding: Embedding, inter: ScrapedTactic) -> Tuple[int, int]: tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic) stem_idx = embedding.encode_token(tactic_stem) symbols = tokenizer.get_symbols(inter.context.focused_goal) arg = tactic_rest.split()[0].strip(".") assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\ .format(inter.tactic, arg, inter.context.focused_goal, symbols) idx = symbols.index(arg) if idx >= max_length: return stem_idx, 0 else: return stem_idx, idx + 1
def _encode_action(self, context: TacticContext, action: str) \ -> Tuple[List[int], torch.FloatTensor]: stem, argument = serapi_instance.split_tactic(action) stem_idx = encode_fpa_stem(self.dataloader_args, self.fpa_metadata, stem) all_prems = context.hypotheses + context.relevant_lemmas arg_idx = encode_fpa_arg(self.dataloader_args, self.fpa_metadata, all_prems, context.goal, argument.strip()) tokenized_goal = tokenize(self.dataloader_args, self.fpa_metadata, context.goal) premise_features_size = get_premise_features_size( self.dataloader_args, self.fpa_metadata) if arg_idx == 0: # No arg arg_type_idx = 0 encoded_arg = torch.zeros(128 + premise_features_size) elif arg_idx <= self.dataloader_args.max_length: # Goal token arg arg_type_idx = 1 encoded_arg = torch.cat((self.predictor.goal_token_encoder( torch.LongTensor([stem_idx]), torch.LongTensor([ tokenized_goal ])).squeeze(0)[arg_idx].to(device=torch.device("cpu")), torch.zeros(premise_features_size)), dim=0) else: # Hyp arg arg_type_idx = 2 arg_hyp = all_prems[arg_idx - (self.dataloader_args.max_length + 1)] entire_encoded_goal = self.predictor.entire_goal_encoder( torch.LongTensor([tokenized_goal])) tokenized_arg_hyp = tokenize(self.dataloader_args, self.fpa_metadata, serapi_instance.get_hyp_type(arg_hyp)) encoded_arg = torch.cat( (self.predictor.hyp_encoder( torch.LongTensor([stem_idx]), entire_encoded_goal, torch.LongTensor([tokenized_arg_hyp ])).to(device=torch.device("cpu")), torch.FloatTensor( get_premise_features(self.dataloader_args, self.fpa_metadata, context.goal, arg_hyp))), dim=0) return [stem_idx, arg_type_idx], encoded_arg
def encode_tactic_structure(stem_embedding : SimpleEmbedding, max_args : int, hyps_and_tactic : Tuple[List[str], str]) \ -> TacticStructure: hyps, tactic = hyps_and_tactic tactic_stem, args_str = serapi_instance.split_tactic(tactic) arg_strs = args_str.split()[:max_args] stem_idx = stem_embedding.encode_token(tactic_stem) arg_idxs = [get_arg_idx(hyps, arg.strip()) for arg in args_str.split()] if len(arg_idxs) < max_args: arg_idxs += [EOS_token] * (max_args - len(arg_idxs)) # If any arguments aren't hypotheses, ignore the arguments if not all(arg_idxs): arg_idxs = [EOS_token] * max_args return TacticStructure(stem_idx=stem_idx, hyp_idxs=arg_idxs)
def args_token_in_goal(in_data : TacticContext, tactic : str, next_in_data : TacticContext, arg_values : argparse.Namespace) -> bool: goal_words = get_symbols(in_data.goal)[:arg_values.max_length] stem, rest = serapi_instance.split_tactic(tactic) if len(rest) > 0 and rest[-1] == '.': rest = rest[:-1] args = get_subexprs(rest) # While the arguments to an intro(s) might *look* like # goal arguments, they are actually fresh variables if (stem == "intros" or stem == "intro") and len(args) > 0: return False for arg in args: if not any([serapi_instance.symbol_matches(goal_word, arg) for goal_word in goal_words]): return False return True
def _encode_action(self, context: TacticContext, action: str) \ -> Tuple[int, int]: stem, argument = serapi_instance.split_tactic(action) stem_idx = emap_lookup(self.tactic_map, 32, stem) all_premises = context.hypotheses + context.relevant_lemmas stripped_arg = argument.strip(".").strip() if stripped_arg == "": arg_idx = 0 else: index_hyp_vars = dict( serapi_instance.get_indexed_vars_in_hyps(all_premises)) if stripped_arg in index_hyp_vars: hyp_varw, _, rest = all_premises[index_hyp_vars[stripped_arg]]\ .partition(":") arg_idx = emap_lookup(self.token_map, 128, tokenizer.get_words(rest)[0]) + 2 else: goal_symbols = tokenizer.get_symbols(context.goal) if stripped_arg in goal_symbols: arg_idx = emap_lookup(self.token_map, 128, stripped_arg) + 128 + 2 else: arg_idx = 1 return stem_idx, arg_idx
def predictionCertainty(self, context: TacticContext, prediction: str) -> float: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self.metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss) tokenized_premises, hyp_features, \ nhyps_batch, tokenized_goal, \ goal_mask, \ word_features, vec_features = \ sample_fpa(extract_dataloader_args(self.training_args), self.metadata, context.relevant_lemmas, context.prev_tactics, context.hypotheses, context.goal) prediction_stem, prediction_args = \ serapi_instance.split_tactic(prediction) prediction_stem_idx = encode_fpa_stem(extract_dataloader_args(self.training_args), self.metadata, prediction_stem) stem_distributions = self._model.stem_classifier( maybe_cuda(torch.LongTensor(word_features)), maybe_cuda(torch.FloatTensor(vec_features))) stem_certainties, stem_idxs = stem_distributions.topk(stem_width) if prediction_stem_idx in stem_idxs[0]: merged_stem_idxs = stem_idxs merged_stem_certainties = stem_certainties else: merged_stem_idxs = torch.cat( (maybe_cuda(torch.LongTensor([[prediction_stem_idx]])), stem_idxs[:, :stem_width-1]), dim=1) cother = stem_certainties[:, :stem_width-1] val = stem_distributions[0][prediction_stem_idx] merged_stem_certainties = \ torch.cat((val.view(1, 1), cother),dim=1) prediction_stem_idx_idx = list(merged_stem_idxs[0]).index( prediction_stem_idx) prediction_arg_idx = encode_fpa_arg( extract_dataloader_args(self.training_args), self.metadata, context.hypotheses + context.relevant_lemmas, context.goal, prediction_args) goal_arg_values = self.goal_token_scores( merged_stem_idxs, tokenized_goal, goal_mask) if len(tokenized_premises[0]) > 0: hyp_arg_values = self.hyp_name_scores( merged_stem_idxs[0], tokenized_goal[0], tokenized_premises[0], hyp_features[0]) total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2) else: total_scores = goal_arg_values final_probs, predicted_stem_idxs, predicted_arg_idxs = \ self.predict_args(total_scores, merged_stem_certainties, merged_stem_idxs) for prob, stem_idx_idx, arg_idx in zip(final_probs, predicted_stem_idxs, predicted_arg_idxs): if stem_idx_idx == prediction_stem_idx and \ arg_idx == prediction_arg_idx: return math.exp(prob.item()) assert False, "Shouldn't be able to get here"