예제 #1
0
 def predictKTactics(self, in_data : TacticContext, k : int) -> List[Prediction]:
     if len(in_data.hypotheses) == 0:
         return [Prediction("eauto", 0)]
     k = min(k, len(in_data.hypotheses))
     best_hyps = sorted(in_data.hypotheses, key=len, reverse=True)[:k]
     return [Prediction("apply " + serapi_instance.get_first_var_in_hyp(hyp) + ".",
                        .5 ** idx) for idx, hyp in enumerate(best_hyps)]
예제 #2
0
 def predictKTacticsWithLoss(
         self, in_data: TacticContext, k: int,
         correct: str) -> Tuple[List[Prediction], float]:
     with self._lock:
         distribution, hyp_var = self._predictDistribution(in_data)
         correct_stem = serapi_instance.get_stem(correct)
         if self._embedding.has_token(correct_stem):
             loss = self._criterion(
                 distribution.view(1, -1),
                 Variable(
                     LongTensor([
                         self._embedding.encode_token(correct_stem)
                     ]))).item()
         else:
             loss = float("+inf")
     indices, probabilities = list_topk(list(distribution), k)
     predictions: List[Prediction] = []
     for certainty, idx in zip(probabilities, indices):
         stem = self._embedding.decode_token(idx)
         if serapi_instance.tacticTakesHypArgs(stem):
             predictions.append(
                 Prediction(stem + " " + hyp_var + ".",
                            math.exp(certainty)))
         else:
             predictions.append(Prediction(stem + ".", math.exp(certainty)))
     return predictions, loss
예제 #3
0
 def predictKTactics(self, in_data : TacticContext, k : int) -> List[Prediction]:
     if len(in_data.hypotheses) == 0:
         return [Prediction("eauto", 0)]
     k = min(k, len(in_data.hypotheses))
     best_hyps = \
         sorted(in_data.hypotheses,
                reverse=True,
                key=lambda hyp:
                SequenceMatcher(None, serapi_instance.get_hyp_type(hyp),
                                in_data.goal).ratio()
         )[:k]
     return [Prediction("apply " + serapi_instance.get_first_var_in_hyp(hyp) + ".",
                        .5 ** idx) for idx, hyp in enumerate(best_hyps)]
예제 #4
0
 def predictKTactics(self, in_data : TacticContext, k : int) \
     -> List[Prediction]:
     with self._lock:
         distribution, hyp_var = self._predictDistribution(in_data)
     indices, probabilities = list_topk(list(distribution), k)
     predictions : List[Prediction] = []
     for certainty, idx in zip(probabilities, indices):
         stem = self._embedding.decode_token(idx)
         if stem == "apply" or stem == "exploit" or stem == "rewrite":
             predictions.append(Prediction(stem + " " + hyp_var + ".",
                                           math.exp(certainty)))
         else:
             predictions.append(Prediction(stem + ".", math.exp(certainty)))
     return predictions
 def predictKTactics(self, in_data : TacticContext, k : int) \
     -> List[Prediction]:
     distribution = self.predictDistribution(in_data)
     indices, probabilities = list_topk(list(distribution), k)
     return [Prediction(self.embedding.decode_token(idx) + ".",
                        math.exp(certainty))
             for certainty, idx in zip(probabilities, indices)]
예제 #6
0
    def decodeNonDuplicatePredictions(
            self, context: TacticContext,
            all_idxs: List[Tuple[float, int, int]],
            k: int) -> List[Prediction]:
        assert self.training_args
        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        if self.training_args.lemma_args:
            all_hyps = context.hypotheses + context.relevant_lemmas
        else:
            all_hyps = context.hypotheses

        prediction_strs: List[str] = []
        prediction_probs: List[float] = []
        next_i = 0
        num_valid_probs = (1 + len(all_hyps) +
                           len(get_fpa_words(context.goal))) * stem_width
        while len(prediction_strs) < k and next_i < num_valid_probs:
            next_pred_str = decode_fpa_result(
                extract_dataloader_args(self.training_args),
                self.metadata,
                all_hyps, context.goal,
                all_idxs[next_i][1],
                all_idxs[next_i][2])
            # next_pred_str = ""
            if next_pred_str not in prediction_strs:
                prediction_strs.append(next_pred_str)
                prediction_probs.append(math.exp(all_idxs[next_i][0]))
            next_i += 1

        predictions = [Prediction(s, prob) for s, prob in
                       zip(prediction_strs, prediction_probs)]

        return predictions
예제 #7
0
 def add_args(self, stem_predictions : List[Prediction],
              goal : str, hyps : List[str], max_length : int):
     possibilities : List[Prediction] = []
     for stem, stem_score in stem_predictions:
         if serapi_instance.tacticTakesHypArgs(stem):
             for hyp, hyp_score in get_closest_hyps(hyps, goal,
                                                    len(stem_predictions),
                                                    max_length):
                 possibilities.append(
                     Prediction(stem + " " +
                                serapi_instance.get_first_var_in_hyp(hyp) + ".",
                                hyp_score * stem_score))
         else:
             possibilities.append(Prediction(stem + ".", stem_score * 0.5))
     return list(sorted(possibilities, key=lambda pred: pred.certainty,
                        reverse=True)[:len(stem_predictions)])
예제 #8
0
    def predictKTacticsWithLoss_batch(self,
                                      in_data : List[TacticContext],
                                      k : int, corrects : List[str]) -> \
                                      Tuple[List[List[Prediction]], float]:
        assert self.training_args
        if len(in_data) == 0:
            return [], 0
        with self._lock:
            tokenized_goals = [self._tokenizer.toTokenList(goal)
                               for relevant_lemmas, prev_tactics, hypotheses, goal
                               in in_data]
            input_tensor = LongTensor([inputFromSentence(tokenized_goal,
                                                         self.training_args.max_length)
                                      for tokenized_goal in tokenized_goals])
            prediction_distributions = self._model.run(input_tensor,
                                                       batch_size=len(in_data))
            correct_stems = [get_stem(correct) for correct in corrects]
            output_var = maybe_cuda(Variable(
                torch.LongTensor([self._embedding.encode_token(correct_stem)
                                  if self._embedding.has_token(correct_stem)
                                  else 0
                                  for correct_stem in correct_stems])))
            loss = self._criterion(prediction_distributions, output_var).item()

            if k > self._embedding.num_tokens():
                k = self._embedding.num_tokens()

            certainties_and_idxs_list = [single_distribution.view(-1).topk(k)
                                         for single_distribution in
                                         list(prediction_distributions)]
            results = [[Prediction(self._embedding.decode_token(stem_idx.item()) + ".",
                                   math.exp(certainty.item()))
                        for certainty, stem_idx in zip(*certainties_and_idxs)]
                       for certainties_and_idxs in certainties_and_idxs_list]
        return results, loss
예제 #9
0
    def predictKTacticsWithLoss(
            self, in_data: TacticContext, k: int,
            correct: str) -> Tuple[List[Prediction], float]:
        self.lock.acquire()
        prediction_distribution = self.predictDistribution(in_data)
        correct_stem = get_stem(correct)
        if self.embedding.has_token(correct_stem):
            output_var = maybe_cuda(
                Variable(
                    torch.LongTensor(
                        [self.embedding.encode_token(correct_stem)])))
            loss = self.criterion(prediction_distribution, output_var).item()
        else:
            loss = 0

        certainties_and_idxs = prediction_distribution.view(-1).topk(k)
        results = [
            Prediction(
                self.embedding.decode_token(stem_idx.item()) + ".",
                math.exp(certainty.item()))
            for certainty, stem_idx in zip(*certainties_and_idxs)
        ]

        self.lock.release()
        return results, loss
예제 #10
0
    def predictKTacticsWithLoss_batch(self, in_data: List[TacticContext],
                                      k: int, corrects: List[str]):
        assert self.training_args
        with self._lock:
            input_tensor = Variable(
                FloatTensor([
                    encode_ngram_classify_input(in_data_point.goal,
                                                self.training_args.num_grams,
                                                self._tokenizer)
                    for in_data_point in in_data
                ]))
            prediction_distributions = self._lsoftmax(
                self._model(input_tensor))
            correct_stems = [get_stem(correct) for correct in corrects]
            output_var = maybe_cuda(
                Variable(
                    torch.LongTensor([
                        self._embedding.encode_token(correct_stem)
                        if self._embedding.has_token(correct_stem) else 0
                        for correct_stem in correct_stems
                    ])))
            loss = self._criterion(prediction_distributions, output_var).item()
            if k > self._embedding.num_tokens():
                k = self._embedding.num_tokens()

            certainties_and_idxs_list = \
                [single_distribution.view(-1).topk(k)
                 for single_distribution in list(prediction_distributions)]
            results = [[
                Prediction(
                    self._embedding.decode_token(stem_idx.item()) + ".",
                    math.exp(certainty.item()))
                for certainty, stem_idx in zip(*certainties_and_idxs)
            ] for certainties_and_idxs in certainties_and_idxs_list]
        return results, loss
 def predictKTactics(self, in_data : TacticContext, k : int) \
     -> List[Prediction]:
     distribution = self.predictDistribution(in_data)
     probs_and_indices = distribution.squeeze().topk(k)
     return [Prediction(self.embedding.decode_token(idx.data[0]) + ".",
                        math.exp(certainty.data[0]))
             for certainty, idx in probs_and_indices]
예제 #12
0
 def predictKTacticsWithLoss_batch(self,
                                   in_data : List[TacticContext],
                                   k : int, corrects : List[str]) -> \
                                   Tuple[List[List[Prediction]], float]:
     assert self._embedding
     assert self.training_args
     with self._lock:
         prediction_distributions = self._predictDistributions(in_data)
     correct_stems = [serapi_instance.get_stem(correct) for correct in corrects]
     output_var = maybe_cuda(Variable(
         LongTensor([self._embedding.encode_token(correct_stem)
                     if self._embedding.has_token(correct_stem)
                     else 0
                     for correct_stem in correct_stems])))
     loss = self._criterion(prediction_distributions, output_var).item()
     if k > self._embedding.num_tokens():
         k = self._embedding.num_tokens()
     certainties_and_idxs_list = \
         [single_distribution.view(-1).topk(k) if len(context.hypotheses) > 0 else
          topk_with_filter(single_distribution.view(-1), k,
                           lambda certainty, idx:
                           not serapi_instance.tacticTakesHypArgs(
                               cast(Embedding, self._embedding).decode_token(idx)))
          for single_distribution, context in
          zip(prediction_distributions, in_data)]
     results = [[Prediction(self.add_arg(self._embedding.decode_token(stem_idx.item()),
                                         in_datum.goal, in_datum.hypotheses,
                                         self.training_args.max_length),
                            math.exp(certainty.item()))
                 for certainty, stem_idx in zip(*certainties_and_idxs)]
                for certainties_and_idxs, in_datum in
                zip(certainties_and_idxs_list, in_data)]
     return results, loss
예제 #13
0
    def _predictFromStemDistribution(self, beam_width : int,
                                     stem_distribution : torch.FloatTensor,
                                     in_datas : List[TacticContext],
                                     k : int) -> \
                                     List[List[Prediction]]:
        assert self.training_args
        assert self._embedding
        all_prob_batches, stem_mapping = \
            self._predictCompositeDistributionFromStemDistribution(
                beam_width, stem_distribution, in_datas)

        final_probs, final_idxs = all_prob_batches.topk(beam_width)
        row_length = self.training_args.max_length
        indices = final_idxs / row_length
        stem_idxs = [
            index_map.index_select(0, indices1)
            for index_map, indices1 in zip(stem_mapping, indices)
        ]
        arg_idxs = final_idxs % row_length
        return [[
            Prediction(
                self._embedding.decode_token(stem_idx.item()) + " " +
                get_arg_from_token_idx(in_data.goal, arg_idx.item()),
                math.exp(final_prob))
            for stem_idx, arg_idx, final_prob in islice(
                zip(stem_list, arg_list, final_list), k)
        ] for in_data, stem_list, arg_list, final_list in zip(
            in_datas, stem_idxs, arg_idxs, final_probs)]
예제 #14
0
 def predictKTacticsWithLoss(self, in_data : TacticContext, k : int, correct : str) -> \
     Tuple[List[Prediction], float]:
     assert self.training_args
     assert self._embedding
     with self._lock:
         prediction_distribution = self._predictDistributions([in_data])[0]
     if k > self._embedding.num_tokens():
         k = self._embedding.num_tokens()
     correct_stem = serapi_instance.get_stem(correct)
     if self._embedding.has_token(correct_stem):
         output_var = maybe_cuda(
             Variable(
                 LongTensor([self._embedding.encode_token(correct_stem)])))
         loss = self._criterion(prediction_distribution.view(1, -1),
                                output_var).item()
     else:
         loss = 0
     if len(in_data.hypotheses) == 0:
         certainties, idxs = topk_with_filter(
             prediction_distribution.view(-1), k,
             lambda certainty, idx: not serapi_instance.tacticTakesHypArgs(
                 cast(Embedding, self._embedding).decode_token(idx)))
     else:
         certainties, idxs = prediction_distribution.view(-1).topk(k)
     results = [
         Prediction(
             self.add_arg(self._embedding.decode_token(stem_idx.item()),
                          in_data.goal, in_data.hypotheses,
                          self.training_args.max_length),
             math.exp(certainty.item()))
         for certainty, stem_idx in zip(certainties, idxs)
     ]
     return results, loss
예제 #15
0
    def predictKTacticsWithLoss(
            self, in_data: TacticContext, k: int,
            correct: str) -> Tuple[List[Prediction], float]:
        with self._lock:
            distribution = self.predictDistribution(in_data)
            stem = get_stem(correct)
            if self._embedding.has_token(stem):
                output_var = maybe_cuda(
                    Variable(
                        torch.LongTensor([self._embedding.encode_token(stem)
                                          ])))
                loss = self._criterion(distribution, output_var).item()
            else:
                loss = 0

            if k > self._embedding.num_tokens():
                k = self._embedding.num_tokens()
            probs_and_indices = distribution.squeeze().topk(k)
            predictions = [
                Prediction(
                    self._embedding.decode_token(idx.item()) + ".",
                    math.exp(certainty.item()))
                for certainty, idx in zip(*probs_and_indices)
            ]
        return predictions, loss
예제 #16
0
 def predictKTactics(self, in_data: TacticContext,
                     k: int) -> List[Prediction]:
     if len(in_data.hypotheses) == 0:
         return [Prediction("eauto", 0)]
     with self._lock:
         distribution = self._predictDistribution(in_data)
         if k > len(in_data.hypotheses):
             k = len(in_data.hypotheses)
         probs, indices = distribution.squeeze().topk(k)
         if k == 1:
             probs = FloatTensor([probs])
             indices = LongTensor([indices])
     return [
         Prediction(
             "apply " + serapi_instance.get_first_var_in_hyp(
                 in_data.hypotheses[idx.item()]) + ".",
             math.exp(certainty.item()))
         for certainty, idx in zip(probs, indices)
     ]
예제 #17
0
def get_closest_hyps(hyps : List[str], goal : str, num_hyps : int, max_length : int)\
                        -> List[Tuple[str, float]]:
    if len(hyps) == 0:
        return [Prediction(":", 0)] * num_hyps
    else:
        return list(sorted([(hyp, score_hyp_type(limitNumTokens(goal, max_length),
                                                 limitNumTokens(serapi_instance.get_hyp_type(hyp), max_length),
                                                 max_length))
                            for hyp in hyps],
                           reverse=True,
                           key=lambda hyp_and_score: hyp_and_score[0]))
    def predictKTactics(self, in_data : TacticContext, k : int) -> \
        List[Prediction]:
        input_vector = encode_bag_classify_input(in_data.goal, self.tokenizer)

        nearest = self.bst.findKNearest(input_vector, k)
        assert not nearest is None
        for pair in nearest:
            assert not pair is None
        predictions = [Prediction(self.embedding.decode_token(output) + ".", .5**i)
                       for i, (neighbor, output) in enumerate(nearest)]
        return predictions
예제 #19
0
 def predictKTactics(self, in_data : TacticContext, k : int) -> \
     List[Prediction]:
     in_sentence = LongTensor(inputFromSentence(
         self.context_tokenizer.toTokenList(in_data.goal),
         self.max_length)).view(1, -1)
     feature_vector = self.encoder.run(in_sentence)
     prediction_sentences = decodeKTactics(self.decoder,
                                           feature_vector,
                                           self.beam_width,
                                           self.max_length)[:k]
     return [Prediction(self.tactic_tokenizer.toString(sentence), .5 **i)
             for sentence, i in zip(prediction_sentences, itertools.count())]
 def predictKTacticsWithLoss(self, in_data : TacticContext, k : int,
                             correct : str) -> Tuple[List[Prediction], float]:
     distribution = self.predictDistribution(in_data)
     correct_stem = get_stem(correct)
     if self.embedding.has_token(correct_stem):
         loss = self.criterion(torch.FloatTensor(distribution).view(1, -1), Variable(torch.LongTensor([self.embedding.encode_token(correct_stem)]))).item()
     else:
         loss = float("+inf")
     indices, probabilities = list_topk(list(distribution), k)
     predictions = [Prediction(self.embedding.decode_token(idx) + ".",
                               math.exp(certainty))
                    for certainty, idx in zip(probabilities, indices)]
     return predictions, loss
예제 #21
0
 def predictKTactics(self, in_data : TacticContext, k : int) \
     -> List[Prediction]:
     self.lock.acquire()
     distribution = self.predictDistribution(in_data)
     certainties_and_idxs = distribution.squeeze().topk(k)
     results = [
         Prediction(
             self.embedding.decode_token(idx.data[0]) + ".",
             math.exp(certainty.data[0]))
         for certainty, idx in zip(*certainties_and_idxs)
     ]
     self.lock.release()
     return results
예제 #22
0
    def _predictFromStemDistributionWithLoss(
            self, beam_width: int, stem_distribution: torch.FloatTensor,
            in_datas: List[TacticContext], corrects: List[str],
            k: int) -> Tuple[List[List[Prediction]], float]:
        assert self.training_args
        assert self._embedding
        stem_probs, likely_correct_stems = stem_distribution.topk(
            min(beam_width,
                stem_distribution.size()[1]))

        all_prob_batches, index_mapping = \
            self._predictCompositeDistributionFromStemDistribution(
                beam_width, stem_distribution, in_datas)
        correct_idxs = LongTensor([[
            max(0, arg_idx - 1) + (stem_idx * self.training_args.max_length)
            for stem_idx, arg_idx in [
                get_stem_and_arg_idx(
                    self.training_args.max_length, self._embedding,
                    serapi_instance.normalizeNumericArgs(
                        ScrapedTactic(
                            [], in_data.prev_tactics,
                            ProofContext(
                                [Obligation(in_data.hypotheses, in_data.goal)],
                                [], [], []), correct)))
            ]
        ][0] for in_data, correct in zip(in_datas, corrects)])

        loss = self._criterion(all_prob_batches, correct_idxs)

        final_probs, final_idxs = all_prob_batches.topk(beam_width)
        row_length = self.training_args.max_length
        indices = final_idxs / row_length
        stem_idxs = [
            index_map.index_select(0, indices1)
            for index_map, indices1 in zip(index_mapping, indices)
        ]
        arg_idxs = final_idxs % row_length
        return [[
            Prediction(
                self._embedding.decode_token(stem_idx.item()) + " " +
                get_arg_from_token_idx(in_data.goal, arg_idx.item()) + ".",
                math.exp(final_prob))
            for stem_idx, arg_idx, final_prob in islice(
                zip(stem_list, arg_list, final_list), k)
        ] for in_data, stem_list, arg_list, final_list in zip(
            in_datas, stem_idxs, arg_idxs, final_probs)], loss
 def predictKTactics(self, in_data : Dict[str, Union[List[str], str]], k : int) \
     -> List[Prediction]:
     self.lock.acquire()
     in_sentence = LongTensor(inputFromSentence(
         self.tokenizer.toTokenList(in_data["goal"]),
         self.max_length))\
         .view(1, -1)
     encoded_vector = self.encoder.run(in_sentence)
     prediction_structures, certainties = \
         self.decodeKTactics(encoded_vector, k, cast(List[str], in_data["hyps"]),
                             k * k, 3)
     self.lock.release()
     return [
         Prediction(
             decode_tactic_structure(self.tokenizer, self.embedding,
                                     structure,
                                     cast(List[str], in_data["hyps"])),
             certainty)
         for structure, certainty in zip(prediction_structures, certainties)
     ]
 def predictKTactics(self, in_data: TacticContext, k: int) \
         -> List[Prediction]:
     assert self._fpa
     assert self._estimator
     # eprint(f"In goal: {in_data.goal}")
     inner_predictions = self._fpa.predictKTactics(in_data, 16)
     q_choices = list(
         zip(
             self._estimator([(in_data, prediction.prediction)
                              for prediction in inner_predictions]),
             inner_predictions))
     # eprint("Scored predictions:")
     # for score, prediction in q_choices:
     #     eprint(f"{prediction.prediction}: {score}")
     # assert False
     ordered_actions = [
         Prediction(p[1].prediction, p[0])
         for p in sorted(q_choices, key=lambda q: q[0], reverse=True)
     ]
     return ordered_actions[:k]
예제 #25
0
    def predictOneTactic(self, in_data : TacticContext) \
        -> Prediction:

        # Size: (1, self.features_size)
        general_features: torch.FloatTensor = self.encode_general_context(
            in_data)
        # Size: (1, num_hypotheses, self.features_size)
        hypothesis_features : torch.FloatTensor = \
            self.encode_hypotheses(in_data.hypotheses)
        # Size: (1, num_hypotheses)
        stem_distribution: torch.FloatTensor = self.predict_stem(
            general_features)
        # Size(stem): (1, 1)
        # Size(probability): (1, 1)
        stem, probability = stem_distribution.topk(1)[0]  # type: int, float
        # Size: (1, self.hidden_size)
        hidden_state: torch.FloatTensor = self.initial_hidden(
            general_features, stem)
        # Size: (1, 1)
        decoder_input: torch.LongTensor = self.initInput()
        arg_idxs: List[int] = []
        for idx in range(self.max_args):
            # Size(arg_distribution): (1, num_hypotheses)
            # Size(hidden_state): (1, self.hidden_size)
            arg_distribution, hidden_state = self.decode_arg(
                decoder_input, hidden_state, hypothesis_features)
            # Size(decoder_input): (1, 1)
            # Size(next_probability): (1, 1)
            decoder_input, next_probability = [
                lst[0] for lst in arg_distribution.topk(1)
            ]
            if decoder_input.item() == 0:
                break
            probability *= next_probability
            arg_idxs.append(decoder_input.item())

        result_struct = TacticStructure(stem_idx=stem, hyp_idxs=arg_idxs)

        return Prediction(
            decode_tactic_structure(self.embedding, result_struct,
                                    in_data.hypotheses), probability)
예제 #26
0
 def predictKTactics(self, in_data : TacticContext, k : int) \
     -> List[Prediction]:
     assert self.training_args
     assert self._embedding
     with self._lock:
         prediction_distribution = self._predictDistributions([in_data])[0]
     if k > self._embedding.num_tokens():
         k = self._embedding.num_tokens()
     if len(in_data.hypotheses) == 0:
         certainties, idxs = topk_with_filter(
             prediction_distribution.view(-1), k,
             lambda certainty, idx: not serapi_instance.tacticTakesHypArgs(
                 cast(Embedding, self._embedding).decode_token(idx)))
     else:
         certainties, idxs = prediction_distribution.view(-1).topk(k)
     results = self.add_args([
         Prediction(self._embedding.decode_token(stem_idx.item()),
                    math.exp(certainty.item()))
         for certainty, stem_idx in zip(certainties, idxs)
     ], in_data.goal, in_data.hypotheses, self.training_args.max_length)
     return results
예제 #27
0
    def predictKTacticsWithLoss(
            self, in_data: TacticContext, k: int,
            correct: str) -> Tuple[List[Prediction], float]:
        self.lock.acquire()
        distribution = self.predictDistribution(in_data)
        stem = get_stem(correct)
        if self.embedding.has_token(stem):
            output_var = maybe_cuda(
                Variable(torch.LongTensor([self.embedding.encode_token(stem)
                                           ])))
            loss = self.criterion(distribution.view(1, -1), output_var).item()
        else:
            loss = 0

        certainties, idxs = distribution.squeeze().topk(k)
        predictions_and_certainties = \
            [Prediction(self.embedding.decode_token(idx.item()) + ".",
                        math.exp(certainty.item()))
             for certainty, idx in zip(list(certainties), list(idxs))]
        self.lock.release()

        return predictions_and_certainties, loss
예제 #28
0
    def predictKTactics(self, context: TacticContext,
                        k: int) -> List[Prediction]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)

        num_hyps = nhyps_batch[0]

        stem_distribution = self._model.stem_classifier(
            LongTensor(word_features), FloatTensor(vec_features))
        stem_certainties, stem_idxs = stem_distribution.topk(stem_width)

        goals_batch = LongTensor(tokenized_goal)
        goal_arg_values = self._model.goal_args_model(
            stem_idxs.view(1 * stem_width),
            goals_batch.view(1, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(1 * stem_width,
                  self.training_args.max_length))\
                                     .view(1, stem_width,
                                           self.training_args.max_length + 1)
        goal_arg_values = torch.where(
            torch.ByteTensor(goal_mask).view(1, 1,
                                             self.training_args.max_length +
                                             1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))
        assert goal_arg_values.size() == torch.Size([1, stem_width,
                                                     self.training_args.max_length + 1]),\
            "goal_arg_values.size(): {}; stem_width: {}".format(goal_arg_values.size(),
                                                                stem_width)

        num_probs = 1 + num_hyps + self.training_args.max_length
        goal_symbols = get_fpa_words(context.goal)
        num_valid_probs = (1 + num_hyps + len(goal_symbols)) * stem_width
        if num_hyps > 0:
            encoded_goals = self._model.goal_encoder(goals_batch)\
                                       .view(1, 1, self.training_args.hidden_size)

            hyps_batch = LongTensor(tokenized_premises)
            assert hyps_batch.size() == torch.Size([1, num_hyps,
                                                    self.training_args.max_length]), \
                                                    (hyps_batch.size(),
                                                     num_hyps,
                                                     self.training_args.max_length)
            hypfeatures_batch = FloatTensor(hyp_features)
            assert hypfeatures_batch.size() == \
                torch.Size([1, num_hyps, hypFeaturesSize()]), \
                (hypfeatures_batch.size(), num_hyps, hypFeaturesSize())
            hyp_arg_values = self.runHypModel(stem_idxs, encoded_goals,
                                              hyps_batch, hypfeatures_batch)
            assert hyp_arg_values.size() == \
                torch.Size([1, stem_width, num_hyps])
            total_values = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_values = goal_arg_values
        all_prob_batches = self._softmax((total_values +
                                          stem_certainties.view(1, stem_width, 1)
                                          .expand(-1, -1, num_probs))
                                         .contiguous()
                                         .view(1, stem_width * num_probs))\
                               .view(stem_width * num_probs)

        final_probs, final_idxs = all_prob_batches.topk(k)
        assert not torch.isnan(final_probs).any()
        assert final_probs.size() == torch.Size([k])
        row_length = self.training_args.max_length + num_hyps + 1
        stem_keys = final_idxs // row_length
        assert stem_keys.size() == torch.Size([k])
        assert stem_idxs.size() == torch.Size([1,
                                               stem_width]), stem_idxs.size()
        prediction_stem_idxs = stem_idxs.view(stem_width).index_select(
            0, stem_keys)
        assert prediction_stem_idxs.size() == torch.Size([k]), \
            prediction_stem_idxs.size()
        arg_idxs = final_idxs % row_length
        assert arg_idxs.size() == torch.Size([k])

        if self.training_args.lemma_args:
            all_hyps = context.hypotheses + context.relevant_lemmas
        else:
            all_hyps = context.hypotheses
        return [
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata, all_hyps, context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(prediction_stem_idxs, arg_idxs, final_probs),
                    min(k, num_valid_probs))
        ]
예제 #29
0
    def predictKTactics_batch(self, context_batch: List[TacticContext],
                              k: int) \
            -> List[List[Prediction]]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)
        batch_size = len(context_batch)

        tprems_batch, pfeat_batch, \
            nhyps_batch, tgoals_batch, \
            goal_masks_batch, \
            wfeats_batch, vfeats_batch = \
            sample_fpa_batch(extract_dataloader_args(self.training_args),
                             self._metadata,
                             [context_py2r(context)
                              for context in context_batch])
        for tprem, pfeat, nhyp, tgoal, masks, wfeat, vfeat, context \
            in zip(tprems_batch, pfeat_batch,
                   nhyps_batch, tgoals_batch,
                   goal_masks_batch, wfeats_batch,
                   vfeats_batch, context_batch):
            s_tprem, s_pfeat, s_nhyp, s_tgoal, s_masks, s_wfeat, s_vfeat = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)
            assert len(s_tprem) == 1
            assert len(tprem) == len(s_tprem[0])
            for p1, p2 in zip(tprem, s_tprem[0]):
                assert p1 == p2, (p1, p2)
            assert len(s_pfeat) == 1
            assert len(pfeat) == len(s_pfeat[0])
            for f1, f2 in zip(pfeat, s_pfeat[0]):
                assert f1 == f2, (f1, f2)
            assert s_nhyp[0] == nhyp, (s_nhyp[0], nhyp)
            assert s_tgoal[0] == tgoal, (s_tgoal[0], tgoal)
            assert s_masks[0] == masks, (s_masks[0], masks)
            assert s_wfeat[0] == wfeat, (s_wfeat[0], wfeat)
            assert s_vfeat[0] == vfeat, (s_vfeat[0], vfeat)

        stem_distribution = self._model.stem_classifier(
            LongTensor(wfeats_batch), FloatTensor(vfeats_batch))
        stem_certainties_batch, stem_idxs_batch = stem_distribution.topk(
            stem_width)

        goals_batch = LongTensor(tgoals_batch)

        goal_arg_values = self._model.goal_args_model(
            stem_idxs_batch.view(batch_size * stem_width),
            goals_batch.view(batch_size, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(batch_size * stem_width,
                  self.training_args.max_length))\
                                     .view(batch_size, stem_width,
                                           self.training_args.max_length + 1)

        goal_arg_values = torch.where(
            torch.ByteTensor(goal_masks_batch).view(
                batch_size, 1,
                self.training_args.max_length + 1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))

        encoded_goals_batch = self._model.goal_encoder(goals_batch)

        stems_expanded_batch = torch.cat([
            stem_idxs.view(1, stem_width).expand(
                num_hyps, stem_width).contiguous().view(num_hyps * stem_width)
            for stem_idxs, num_hyps, in zip(stem_idxs_batch, nhyps_batch)
        ])
        egoals_expanded_batch = torch.cat([
            encoded_goal.view(1, self.training_args.hidden_size).expand(
                num_hyps * stem_width, -1).contiguous()
            for encoded_goal, num_hyps in zip(encoded_goals_batch, nhyps_batch)
        ])
        tprems_expanded_batch = torch.cat([
            LongTensor(tpremises).view(
                1, -1, self.training_args.max_length).expand(
                    stem_width, -1,
                    -1).contiguous().view(-1, self.training_args.max_length)
            for tpremises in tprems_batch
        ])
        pfeat_expanded_batch = torch.cat([
            FloatTensor(premise_features).view(1, num_hyps, 2).expand(
                stem_width, -1, -1).contiguous().view(num_hyps * stem_width, 2)
            for premise_features, num_hyps in zip(pfeat_batch, nhyps_batch)
        ])

        prem_arg_values = self._model.hyp_model(stems_expanded_batch,
                                                egoals_expanded_batch,
                                                tprems_expanded_batch,
                                                pfeat_expanded_batch)

        prem_arg_values_split = prem_arg_values.split(
            [num_hyps * stem_width for num_hyps in nhyps_batch])
        total_values_list = [
            torch.cat((goal_values, prem_values.view(stem_width, -1)),
                      dim=1) for goal_values, prem_values in zip(
                          goal_arg_values, prem_arg_values_split)
        ]
        all_probs_list = [
            self._softmax((total_values + stem_certainties.view(
                stem_width, 1).expand_as(total_values)).contiguous().view(
                    1, -1)).view(-1) for total_values, stem_certainties in zip(
                        total_values_list, stem_certainties_batch)
        ]

        final_probs_list, final_idxs_list = zip(
            *[probs.topk(k) for probs in all_probs_list])
        stem_keys_list = [
            final_idxs // (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        stem_idxs_list = [
            stem_idxs.view(stem_width).index_select(0, stem_keys)
            for stem_idxs, stem_keys in zip(stem_idxs_batch, stem_keys_list)
        ]

        arg_idxs_list = [
            final_idxs % (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        predictions = [[
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata,
                                  context.hypotheses + context.relevant_lemmas,
                                  context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(stem_idxs, arg_idxs, final_probs),
                    min(k, 1 + num_hyps + len(get_fpa_words(context.goal))))
        ] for stem_idxs, arg_idxs, final_probs, context, num_hyps in zip(
            stem_idxs_list, arg_idxs_list, final_probs_list, context_batch,
            nhyps_batch)]

        for context, pred_list in zip(context_batch, predictions):
            for batch_pred, single_pred in zip(
                    pred_list, self.predictKTactics(context, k)):
                assert batch_pred.prediction == single_pred.prediction, \
                    (batch_pred, single_pred)

        return predictions
예제 #30
0
 def predictKTactics(self, in_data: TacticContext,
                     k: int) -> List[Prediction]:
     return [
         Prediction("induction {}.".format(idx), .5**idx)
         for idx in range(1, k + 1)
     ]