예제 #1
0
    def decodeNonDuplicatePredictions(
            self, context: TacticContext,
            all_idxs: List[Tuple[float, int, int]],
            k: int) -> List[Prediction]:
        assert self.training_args
        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        if self.training_args.lemma_args:
            all_hyps = context.hypotheses + context.relevant_lemmas
        else:
            all_hyps = context.hypotheses

        prediction_strs: List[str] = []
        prediction_probs: List[float] = []
        next_i = 0
        num_valid_probs = (1 + len(all_hyps) +
                           len(get_fpa_words(context.goal))) * stem_width
        while len(prediction_strs) < k and next_i < num_valid_probs:
            next_pred_str = decode_fpa_result(
                extract_dataloader_args(self.training_args),
                self.metadata,
                all_hyps, context.goal,
                all_idxs[next_i][1],
                all_idxs[next_i][2])
            # next_pred_str = ""
            if next_pred_str not in prediction_strs:
                prediction_strs.append(next_pred_str)
                prediction_probs.append(math.exp(all_idxs[next_i][0]))
            next_i += 1

        predictions = [Prediction(s, prob) for s, prob in
                       zip(prediction_strs, prediction_probs)]

        return predictions
예제 #2
0
    def getAllPredictionIdxs_batch(self, contexts: List[TacticContext],
                                   verbosity:int = 0) -> List[List[Tuple[float, int, int]]]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        tokenized_premises_batch, premise_features_batch, \
            nhyps_batch, tokenized_goal_batch, \
            goal_mask, \
            word_features, vec_features = \
            sample_fpa_batch(extract_dataloader_args(self.training_args),
                             self.metadata,
                             [context_py2r(context)
                              for context in contexts])

        stem_certainties_batch, stem_idxs_batch = self.predict_stems(
            stem_width, word_features, vec_features)

        goal_arg_values_batch = self.goal_token_scores(
            stem_idxs_batch, tokenized_goal_batch, goal_mask)

        idxs_batch = []

        for (stem_certainties, stem_idxs,
             goal_arg_values, tokenized_goal,
             tokenized_premises,
             premise_features) in \
            tqdm(zip(stem_certainties_batch, stem_idxs_batch,
                     goal_arg_values_batch, tokenized_goal_batch,
                     tokenized_premises_batch, premise_features_batch),
                 desc="Assessing hyp args and decoding indices",
                 total=len(contexts),
                 disable=verbosity <= 1):
            if len(tokenized_premises) > 0:
                premise_arg_values = self.hyp_name_scores(
                    stem_idxs,
                    tokenized_goal,
                    tokenized_premises,
                    premise_features)
                total_scores = torch.cat((goal_arg_values.unsqueeze(0),
                                          premise_arg_values),
                                         dim=2)
            else:
                total_scores = goal_arg_values.unsqueeze(0)

            probs, stems, args = self.predict_args(
                total_scores, stem_certainties, stem_idxs)
            idxs_batch.append(list(zip(list(probs), list(stems), list(args))))

        return idxs_batch
예제 #3
0
 def load_saved_state(self, args: Namespace, unparsed_args: List[str],
                      metadata: Any, state: NeuralPredictorState) -> None:
     model = maybe_cuda(
         self._get_model(args, get_word_feature_vocab_sizes(metadata),
                         get_vec_features_size(metadata),
                         get_num_indices(metadata),
                         get_num_tokens(metadata)))
     model.load_state_dict(state.weights)
     self._model = model
     self.training_loss = state.loss
     self.num_epochs = state.epoch
     self.training_args = args
     self.unparsed_args = unparsed_args
     self._metadata = metadata
예제 #4
0
    def getAllPredictionIdxs(self, context: TacticContext
                             ) -> List[Tuple[float, int, int]]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
            sample_fpa(extract_dataloader_args(self.training_args),
                       self.metadata,
                       context.relevant_lemmas,
                       context.prev_tactics,
                       context.hypotheses,
                       context.goal)

        stem_certainties, stem_idxs = self.predict_stems(
            stem_width, word_features, vec_features)

        goal_arg_values = self.goal_token_scores(
            stem_idxs, tokenized_goal, goal_mask)

        if len(tokenized_premises[0]) > 0:
            hyp_arg_values = self.hyp_name_scores(
                stem_idxs[0], tokenized_goal[0],
                tokenized_premises[0], hyp_features[0])

            total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_scores = goal_arg_values

        final_probs, predicted_stem_idxs, predicted_arg_idxs = \
            self.predict_args(total_scores, stem_certainties, stem_idxs)

        result = list(zip(list(final_probs), list(predicted_stem_idxs),
                          list(predicted_arg_idxs)))
        return result
예제 #5
0
    def _optimize_model(
            self, arg_values: Namespace) -> Iterable[FeaturesPolyargState]:
        with print_time("Loading data", guard=arg_values.verbose):
            if arg_values.start_from:
                _, (old_arg_values, unparsed_args, metadata,
                    state) = torch.load(arg_values.start_from)
                _, data_lists, \
                    (word_features_size, vec_features_size) = \
                    features_polyarg_tensors_with_meta(
                        extract_dataloader_args(arg_values),
                        str(arg_values.scrape_file),
                        metadata)
            else:
                metadata, data_lists, \
                    (word_features_size, vec_features_size) = \
                    features_polyarg_tensors(
                        extract_dataloader_args(arg_values),
                        str(arg_values.scrape_file))
        with print_time("Converting data to tensors",
                        guard=arg_values.verbose):
            unpadded_tokenized_hyp_types, \
                unpadded_hyp_features, \
                num_hyps, \
                tokenized_goals, \
                goal_masks, \
                word_features, \
                vec_features, \
                tactic_stem_indices, \
                arg_indices = data_lists

            tensors = [
                pad_sequence([
                    torch.LongTensor(tokenized_hyps_list)
                    for tokenized_hyps_list in unpadded_tokenized_hyp_types
                ],
                             batch_first=True),
                pad_sequence([
                    torch.FloatTensor(hyp_features_vec)
                    for hyp_features_vec in unpadded_hyp_features
                ],
                             batch_first=True),
                torch.LongTensor(num_hyps),
                torch.LongTensor(tokenized_goals),
                torch.ByteTensor(goal_masks),
                torch.LongTensor(word_features),
                torch.FloatTensor(vec_features),
                torch.LongTensor(tactic_stem_indices),
                torch.LongTensor(arg_indices)
            ]
            with open("tensors.pickle", 'wb') as f:
                torch.save(tensors, f)
            eprint(tensors, guard=arg_values.print_tensors)

        with print_time("Building the model", guard=arg_values.verbose):

            if arg_values.start_from:
                self.load_saved_state(arg_values, unparsed_args, metadata,
                                      state)
                model = self._model
                epoch_start = self.num_epochs
            else:
                model = self._get_model(arg_values, word_features_size,
                                        vec_features_size,
                                        get_num_indices(metadata),
                                        get_num_tokens(metadata))
                epoch_start = 1

        assert model
        assert epoch_start
        return ((metadata, state) for state in optimize_checkpoints(
            tensors, arg_values, model,
            lambda batch_tensors, model: self._getBatchPredictionLoss(
                arg_values, batch_tensors, model), epoch_start))
예제 #6
0
    def predictKTactics(self, context: TacticContext,
                        k: int) -> List[Prediction]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)

        num_hyps = nhyps_batch[0]

        stem_distribution = self._model.stem_classifier(
            LongTensor(word_features), FloatTensor(vec_features))
        stem_certainties, stem_idxs = stem_distribution.topk(stem_width)

        goals_batch = LongTensor(tokenized_goal)
        goal_arg_values = self._model.goal_args_model(
            stem_idxs.view(1 * stem_width),
            goals_batch.view(1, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(1 * stem_width,
                  self.training_args.max_length))\
                                     .view(1, stem_width,
                                           self.training_args.max_length + 1)
        goal_arg_values = torch.where(
            torch.ByteTensor(goal_mask).view(1, 1,
                                             self.training_args.max_length +
                                             1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))
        assert goal_arg_values.size() == torch.Size([1, stem_width,
                                                     self.training_args.max_length + 1]),\
            "goal_arg_values.size(): {}; stem_width: {}".format(goal_arg_values.size(),
                                                                stem_width)

        num_probs = 1 + num_hyps + self.training_args.max_length
        goal_symbols = get_fpa_words(context.goal)
        num_valid_probs = (1 + num_hyps + len(goal_symbols)) * stem_width
        if num_hyps > 0:
            encoded_goals = self._model.goal_encoder(goals_batch)\
                                       .view(1, 1, self.training_args.hidden_size)

            hyps_batch = LongTensor(tokenized_premises)
            assert hyps_batch.size() == torch.Size([1, num_hyps,
                                                    self.training_args.max_length]), \
                                                    (hyps_batch.size(),
                                                     num_hyps,
                                                     self.training_args.max_length)
            hypfeatures_batch = FloatTensor(hyp_features)
            assert hypfeatures_batch.size() == \
                torch.Size([1, num_hyps, hypFeaturesSize()]), \
                (hypfeatures_batch.size(), num_hyps, hypFeaturesSize())
            hyp_arg_values = self.runHypModel(stem_idxs, encoded_goals,
                                              hyps_batch, hypfeatures_batch)
            assert hyp_arg_values.size() == \
                torch.Size([1, stem_width, num_hyps])
            total_values = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_values = goal_arg_values
        all_prob_batches = self._softmax((total_values +
                                          stem_certainties.view(1, stem_width, 1)
                                          .expand(-1, -1, num_probs))
                                         .contiguous()
                                         .view(1, stem_width * num_probs))\
                               .view(stem_width * num_probs)

        final_probs, final_idxs = all_prob_batches.topk(k)
        assert not torch.isnan(final_probs).any()
        assert final_probs.size() == torch.Size([k])
        row_length = self.training_args.max_length + num_hyps + 1
        stem_keys = final_idxs // row_length
        assert stem_keys.size() == torch.Size([k])
        assert stem_idxs.size() == torch.Size([1,
                                               stem_width]), stem_idxs.size()
        prediction_stem_idxs = stem_idxs.view(stem_width).index_select(
            0, stem_keys)
        assert prediction_stem_idxs.size() == torch.Size([k]), \
            prediction_stem_idxs.size()
        arg_idxs = final_idxs % row_length
        assert arg_idxs.size() == torch.Size([k])

        if self.training_args.lemma_args:
            all_hyps = context.hypotheses + context.relevant_lemmas
        else:
            all_hyps = context.hypotheses
        return [
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata, all_hyps, context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(prediction_stem_idxs, arg_idxs, final_probs),
                    min(k, num_valid_probs))
        ]
예제 #7
0
    def predictKTactics_batch(self, context_batch: List[TacticContext],
                              k: int) \
            -> List[List[Prediction]]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)
        batch_size = len(context_batch)

        tprems_batch, pfeat_batch, \
            nhyps_batch, tgoals_batch, \
            goal_masks_batch, \
            wfeats_batch, vfeats_batch = \
            sample_fpa_batch(extract_dataloader_args(self.training_args),
                             self._metadata,
                             [context_py2r(context)
                              for context in context_batch])
        for tprem, pfeat, nhyp, tgoal, masks, wfeat, vfeat, context \
            in zip(tprems_batch, pfeat_batch,
                   nhyps_batch, tgoals_batch,
                   goal_masks_batch, wfeats_batch,
                   vfeats_batch, context_batch):
            s_tprem, s_pfeat, s_nhyp, s_tgoal, s_masks, s_wfeat, s_vfeat = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)
            assert len(s_tprem) == 1
            assert len(tprem) == len(s_tprem[0])
            for p1, p2 in zip(tprem, s_tprem[0]):
                assert p1 == p2, (p1, p2)
            assert len(s_pfeat) == 1
            assert len(pfeat) == len(s_pfeat[0])
            for f1, f2 in zip(pfeat, s_pfeat[0]):
                assert f1 == f2, (f1, f2)
            assert s_nhyp[0] == nhyp, (s_nhyp[0], nhyp)
            assert s_tgoal[0] == tgoal, (s_tgoal[0], tgoal)
            assert s_masks[0] == masks, (s_masks[0], masks)
            assert s_wfeat[0] == wfeat, (s_wfeat[0], wfeat)
            assert s_vfeat[0] == vfeat, (s_vfeat[0], vfeat)

        stem_distribution = self._model.stem_classifier(
            LongTensor(wfeats_batch), FloatTensor(vfeats_batch))
        stem_certainties_batch, stem_idxs_batch = stem_distribution.topk(
            stem_width)

        goals_batch = LongTensor(tgoals_batch)

        goal_arg_values = self._model.goal_args_model(
            stem_idxs_batch.view(batch_size * stem_width),
            goals_batch.view(batch_size, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(batch_size * stem_width,
                  self.training_args.max_length))\
                                     .view(batch_size, stem_width,
                                           self.training_args.max_length + 1)

        goal_arg_values = torch.where(
            torch.ByteTensor(goal_masks_batch).view(
                batch_size, 1,
                self.training_args.max_length + 1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))

        encoded_goals_batch = self._model.goal_encoder(goals_batch)

        stems_expanded_batch = torch.cat([
            stem_idxs.view(1, stem_width).expand(
                num_hyps, stem_width).contiguous().view(num_hyps * stem_width)
            for stem_idxs, num_hyps, in zip(stem_idxs_batch, nhyps_batch)
        ])
        egoals_expanded_batch = torch.cat([
            encoded_goal.view(1, self.training_args.hidden_size).expand(
                num_hyps * stem_width, -1).contiguous()
            for encoded_goal, num_hyps in zip(encoded_goals_batch, nhyps_batch)
        ])
        tprems_expanded_batch = torch.cat([
            LongTensor(tpremises).view(
                1, -1, self.training_args.max_length).expand(
                    stem_width, -1,
                    -1).contiguous().view(-1, self.training_args.max_length)
            for tpremises in tprems_batch
        ])
        pfeat_expanded_batch = torch.cat([
            FloatTensor(premise_features).view(1, num_hyps, 2).expand(
                stem_width, -1, -1).contiguous().view(num_hyps * stem_width, 2)
            for premise_features, num_hyps in zip(pfeat_batch, nhyps_batch)
        ])

        prem_arg_values = self._model.hyp_model(stems_expanded_batch,
                                                egoals_expanded_batch,
                                                tprems_expanded_batch,
                                                pfeat_expanded_batch)

        prem_arg_values_split = prem_arg_values.split(
            [num_hyps * stem_width for num_hyps in nhyps_batch])
        total_values_list = [
            torch.cat((goal_values, prem_values.view(stem_width, -1)),
                      dim=1) for goal_values, prem_values in zip(
                          goal_arg_values, prem_arg_values_split)
        ]
        all_probs_list = [
            self._softmax((total_values + stem_certainties.view(
                stem_width, 1).expand_as(total_values)).contiguous().view(
                    1, -1)).view(-1) for total_values, stem_certainties in zip(
                        total_values_list, stem_certainties_batch)
        ]

        final_probs_list, final_idxs_list = zip(
            *[probs.topk(k) for probs in all_probs_list])
        stem_keys_list = [
            final_idxs // (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        stem_idxs_list = [
            stem_idxs.view(stem_width).index_select(0, stem_keys)
            for stem_idxs, stem_keys in zip(stem_idxs_batch, stem_keys_list)
        ]

        arg_idxs_list = [
            final_idxs % (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        predictions = [[
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata,
                                  context.hypotheses + context.relevant_lemmas,
                                  context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(stem_idxs, arg_idxs, final_probs),
                    min(k, 1 + num_hyps + len(get_fpa_words(context.goal))))
        ] for stem_idxs, arg_idxs, final_probs, context, num_hyps in zip(
            stem_idxs_list, arg_idxs_list, final_probs_list, context_batch,
            nhyps_batch)]

        for context, pred_list in zip(context_batch, predictions):
            for batch_pred, single_pred in zip(
                    pred_list, self.predictKTactics(context, k)):
                assert batch_pred.prediction == single_pred.prediction, \
                    (batch_pred, single_pred)

        return predictions
예제 #8
0
    def predictionCertainty(self, context: TacticContext, prediction: str) -> float:

        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self.metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
            sample_fpa(extract_dataloader_args(self.training_args),
                       self.metadata,
                       context.relevant_lemmas,
                       context.prev_tactics,
                       context.hypotheses,
                       context.goal)

        prediction_stem, prediction_args = \
            serapi_instance.split_tactic(prediction)
        prediction_stem_idx = encode_fpa_stem(extract_dataloader_args(self.training_args),
                                              self.metadata, prediction_stem)
        stem_distributions = self._model.stem_classifier(
            maybe_cuda(torch.LongTensor(word_features)),
            maybe_cuda(torch.FloatTensor(vec_features)))
        stem_certainties, stem_idxs = stem_distributions.topk(stem_width)
        if prediction_stem_idx in stem_idxs[0]:
            merged_stem_idxs = stem_idxs
            merged_stem_certainties = stem_certainties
        else:
            merged_stem_idxs = torch.cat(
                (maybe_cuda(torch.LongTensor([[prediction_stem_idx]])),
                 stem_idxs[:, :stem_width-1]),
                dim=1)
            cother = stem_certainties[:, :stem_width-1]
            val = stem_distributions[0][prediction_stem_idx]
            merged_stem_certainties = \
                torch.cat((val.view(1, 1), cother),dim=1)

        prediction_stem_idx_idx = list(merged_stem_idxs[0]).index(
            prediction_stem_idx)
        prediction_arg_idx = encode_fpa_arg(
            extract_dataloader_args(self.training_args),
            self.metadata,
            context.hypotheses + context.relevant_lemmas,
            context.goal,
            prediction_args)

        goal_arg_values = self.goal_token_scores(
            merged_stem_idxs, tokenized_goal, goal_mask)

        if len(tokenized_premises[0]) > 0:
            hyp_arg_values = self.hyp_name_scores(
                merged_stem_idxs[0], tokenized_goal[0],
                tokenized_premises[0], hyp_features[0])

            total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_scores = goal_arg_values

        final_probs, predicted_stem_idxs, predicted_arg_idxs = \
            self.predict_args(total_scores, merged_stem_certainties,
                              merged_stem_idxs)

        for prob, stem_idx_idx, arg_idx in zip(final_probs,
                                               predicted_stem_idxs,
                                               predicted_arg_idxs):
            if stem_idx_idx == prediction_stem_idx and \
               arg_idx == prediction_arg_idx:
                return math.exp(prob.item())

        assert False, "Shouldn't be able to get here"