def _encode_data(self, data : RawDataset, arg_values : Namespace) \ -> Tuple[EncFeaturesDataset, Tuple[Tokenizer, Embedding, List[VecFeature], List[WordFeature]]]: stripped_data = [strip_scraped_output(dat) for dat in data] self._vec_feature_functions = [ feature_constructor(stripped_data, arg_values) for # type: ignore feature_constructor in vec_feature_constructors ] self._word_feature_functions = [ feature_constructor(stripped_data, arg_values) for # type: ignore feature_constructor in word_feature_constructors ] embedding, embedded_data = embed_data(data) tokenizer, tokenized_goals = tokenize_goals(embedded_data, arg_values) result_data = EncFeaturesDataset([ EncFeaturesSample( self._get_vec_features( TacticContext([], prev_tactics, hypotheses, goal)), self._get_word_features( TacticContext([], prev_tactics, hypotheses, goal)), normalizeSentenceLength(tokenized_goal, arg_values.max_length), tactic) for (relevant_lemmas, prev_tactics, hypotheses, goal, tactic), tokenized_goal in zip(embedded_data, tokenized_goals) ]) return result_data, (tokenizer, embedding, self._vec_feature_functions, self._word_feature_functions)
def generate() -> Iterator[Tuple[TacticContext, str, float]]: prediction_lists = cast(features_polyarg_predictor .FeaturesPolyargPredictor, predictor) \ .predictKTactics_batch( [transition.after_context for transition in transitions], num_predictions) for transition, predictions in zip(transitions, prediction_lists): tactic_ctxt = transition.after_context if len(transition.after.all_goals) == 0: new_q = transition.reward else: estimates = q_estimator( [(tactic_ctxt, prediction.prediction) for prediction in predictions]) estimated_future_q = \ discount * max(estimates) estimated_current_q = q_estimator([(transition.before_context, transition.action)])[0] new_q = transition.reward + estimated_future_q \ - estimated_current_q assert transition.reward == transition.reward assert discount == discount assert new_q == new_q if transition.graph_node: graph.setNodeApproxQScore(transition.graph_node, new_q) yield TacticContext( transition.relevant_lemmas, transition.prev_tactics, transition.before.focused_hyps, transition.before.focused_goal), transition.action, new_q
def commandLinePredict(predictor : EncDecRNNPredictor, k : int) -> None: sentence = "" next_line = sys.stdin.readline() while next_line != "+++++\n": sentence += next_line next_line = sys.stdin.readline() for result in predictor.predictKTactics(TacticContext([], [], [], sentence), k): print(result)
def pre_train(args: argparse.Namespace, estimator: QEstimator, transitions: List[dataloader.ScrapedTransition]) -> None: samples = [(TacticContext(transition.relevant_lemmas, transition.prev_tactics, transition.before.fg_goals[0].hypotheses, transition.before.fg_goals[0].goal), transition.tactic, 0.0) for transition in transitions if len(transition.before.fg_goals) > 0] estimator.train(samples, args.batch_size, args.pretrain_epochs)
def q_report(args: argparse.Namespace) -> None: num_originally_correct = 0 num_correct = 0 num_top3 = 0 num_total = 0 num_possible = 0 predictor = predict_tactic.loadPredictorByFile(args.predictor_weights) q_estimator_name, *saved = \ torch.load(args.estimator_weights) q_estimator = FeaturesQEstimator(0, 0, 0) q_estimator.load_saved_state(*saved) for filename in args.test_files: points = dataloader.scraped_tactics_from_file( str(filename) + ".scrape", None) for point in points: context = TacticContext(point.relevant_lemmas, point.prev_tactics, point.prev_hyps, point.prev_goal) predictions = [ p.prediction for p in predictor.predictKTactics( context, args.num_predictions) ] q_choices = zip( q_estimator([(context, prediction) for prediction in predictions]), predictions) ordered_actions = [ p[1] for p in sorted(q_choices, key=lambda q: q[0], reverse=True) ] num_total += 1 if point.tactic.strip() in predictions: num_possible += 1 if ordered_actions[0] == point.tactic.strip(): num_correct += 1 if point.tactic.strip() in ordered_actions[:3]: num_top3 += 1 if predictions[0] == point.tactic.strip(): num_originally_correct += 1 pass print(f"num_correct: {num_correct}") print(f"num_originally_correct: {num_originally_correct}") print(f"num_top3: {num_top3}") print(f"num_total: {num_total}") print(f"num_possible: {num_possible}")
def mkHFSample(max_length : int, word_feature_functions : List[WordFeature], vec_feature_functions : List[VecFeature], zipped : Tuple[EmbeddedSample, List[int], List[int]]) \ -> HypFeaturesSample: context, goal, best_hyp = zipped (relevant_lemmas, prev_tactic_list, hypotheses, goal_str, tactic) = context tac_context = TacticContext(relevant_lemmas, prev_tactic_list, hypotheses, goal_str) return HypFeaturesSample( [feature(tac_context) for feature in word_feature_functions], [ feature_val for feature in vec_feature_functions for feature_val in feature(tac_context) ], normalizeSentenceLength(goal, max_length), normalizeSentenceLength(best_hyp, max_length), tactic)
def get_should_filter(data: MixedDataset) \ -> Iterable[Tuple[ScrapedCommand, bool]]: list_data: List[ScrapedCommand] = list(data) extended_list: List[Optional[ScrapedCommand]] = \ cast(List[Optional[ScrapedCommand]], list_data[1:]) + [None] for point, nextpoint in zip(list_data, extended_list): if isinstance(point, ScrapedTactic) \ and not re.match(r"\s*[{}]\s*", point.tactic) and \ point.context.focused_goal.strip() != "": if isinstance(nextpoint, ScrapedTactic): context_after = strip_scraped_output(nextpoint) else: context_after = TacticContext([], [], [], "") should_filter = not context_filter(strip_scraped_output(point), point.tactic, context_after, training_args) yield (point, should_filter) else: yield (point, True)
def mkCopySample(max_length : int, word_feature_functions : List[WordFeature], vec_feature_functions : List[VecFeature], zipped : Tuple[EmbeddedSample, List[int], int]) \ -> CopyArgSample: context, goal, arg_idx = zipped (relevant_lemmas, prev_tactic_list, hypotheses, goal_str, tactic_idx) = context tac_context = TacticContext(relevant_lemmas, prev_tactic_list, hypotheses, goal_str) word_features = [ feature(tac_context) for feature in word_feature_functions ] assert len(word_features) == 3 return CopyArgSample(normalizeSentenceLength(goal, max_length), word_features, [ feature_val for feature in vec_feature_functions for feature_val in feature(tac_context) ], tactic_idx, arg_idx)
def before_context(self) -> TacticContext: return TacticContext(self.relevant_lemmas, self.prev_tactics, self.before.focused_hyps, self.before.focused_goal)
def after_context(self) -> TacticContext: return TacticContext(self.relevant_lemmas, self.prev_tactics, self.after.focused_hyps, self.after.focused_goal)
def predictKTacticsWithLoss_batch(self, in_data : List[TacticContext], k : int, correct : List[str]) -> \ Tuple[List[List[Prediction]], float]: return [self.predictKTactics(TacticContext([], [], [], ""), k) ] * len(in_data), 0.