Beispiel #1
0
 def _predictDistributions(
         self, in_datas: List[TacticContext]) -> torch.FloatTensor:
     assert self._tokenizer
     assert self._embedding
     assert self.training_args
     goals_batch = [
         normalizeSentenceLength(self._tokenizer.toTokenList(goal),
                                 self.training_args.max_length)
         for _, _, _, goal in in_datas
     ]
     hyps = [
         get_closest_hyp(hyps, goal, self.training_args.max_length)
         for _, _, hyps, goal in in_datas
     ]
     hyp_types = [serapi_instance.get_hyp_type(hyp) for hyp in hyps]
     hyps_batch = [
         normalizeSentenceLength(self._tokenizer.toTokenList(hyp_type),
                                 self.training_args.max_length)
         for hyp_type in hyp_types
     ]
     word_features_batch = [
         self._get_word_features(in_data) for in_data in in_datas
     ]
     vec_features_batch = [
         self._get_vec_features(in_data) for in_data in in_datas
     ]
     stem_distribution = self._model(LongTensor(goals_batch),
                                     LongTensor(hyps_batch),
                                     FloatTensor(vec_features_batch),
                                     LongTensor(word_features_batch))
     return stem_distribution
Beispiel #2
0
    def goal_token_scores(self, stem_idxs: torch.LongTensor,
                          tokenized_goals: List[List[int]],
                          goal_masks: List[List[bool]],
                          ) -> torch.FloatTensor:
        assert self._model
        assert self.training_args
        batch_size = stem_idxs.size()[0]
        stem_width = stem_idxs.size()[1]
        goal_len = self.training_args.max_length
        # The goal probabilities include the "no argument" probability
        num_goal_probs = goal_len + 1
        unmasked_probabilities = self._model.goal_args_model(
            stem_idxs.view(batch_size * stem_width),
            LongTensor(tokenized_goals).view(
                batch_size, 1, goal_len)
            .expand(-1, stem_width, -1).contiguous()
            .view(batch_size * stem_width, goal_len))\
            .view(batch_size, stem_width, num_goal_probs)

        masked_probabilities = torch.where(
            maybe_cuda(torch.ByteTensor(goal_masks))
            .view(batch_size, 1, num_goal_probs)
            .expand(-1, stem_width, -1),
            unmasked_probabilities,
            torch.full_like(unmasked_probabilities, -float("Inf")))

        assert masked_probabilities.size() == torch.Size(
            [batch_size, stem_width, num_goal_probs])
        return masked_probabilities
Beispiel #3
0
 def predictKTacticsWithLoss(self, in_data : TacticContext, k : int, correct : str) -> \
     Tuple[List[Prediction], float]:
     assert self.training_args
     assert self._embedding
     with self._lock:
         prediction_distribution = self._predictDistributions([in_data])[0]
     if k > self._embedding.num_tokens():
         k = self._embedding.num_tokens()
     correct_stem = serapi_instance.get_stem(correct)
     if self._embedding.has_token(correct_stem):
         output_var = maybe_cuda(
             Variable(
                 LongTensor([self._embedding.encode_token(correct_stem)])))
         loss = self._criterion(prediction_distribution.view(1, -1),
                                output_var).item()
     else:
         loss = 0
     if len(in_data.hypotheses) == 0:
         certainties, idxs = topk_with_filter(
             prediction_distribution.view(-1), k,
             lambda certainty, idx: not serapi_instance.tacticTakesHypArgs(
                 cast(Embedding, self._embedding).decode_token(idx)))
     else:
         certainties, idxs = prediction_distribution.view(-1).topk(k)
     results = [
         Prediction(
             self.add_arg(self._embedding.decode_token(stem_idx.item()),
                          in_data.goal, in_data.hypotheses,
                          self.training_args.max_length),
             math.exp(certainty.item()))
         for certainty, stem_idx in zip(certainties, idxs)
     ]
     return results, loss
 def forward(self, goal_batch : torch.LongTensor, stem_batch : torch.LongTensor) \
     -> torch.FloatTensor:
     goal_var = maybe_cuda(Variable(goal_batch))
     stem_var = maybe_cuda(Variable(stem_batch))
     batch_size = goal_batch.size()[0]
     initial_hidden = self._stem_embedding(stem_var)\
                          .view(1, batch_size, self.hidden_size)
     hidden = initial_hidden
     copy_likelyhoods: List[torch.FloatTensor] = []
     for i in range(goal_batch.size()[1]):
         token_batch = self._word_embedding(goal_var[:,i])\
             .view(1, batch_size, self.hidden_size)
         token_batch = F.relu(token_batch)
         token_out, hidden = self._gru(token_batch, hidden)
         copy_likelyhood = self._likelyhood_layer(F.relu(token_out))
         copy_likelyhoods.append(copy_likelyhood[0])
     end_token_embedded = self._word_embedding(LongTensor([EOS_token])
                                                .expand(batch_size))\
                                                .view(1, batch_size, self.hidden_size)
     final_out, final_hidden = self._gru(F.relu(end_token_embedded), hidden)
     final_likelyhood = self._likelyhood_layer(F.relu(final_out))
     copy_likelyhoods.insert(0, final_likelyhood[0])
     catted = torch.cat(copy_likelyhoods, dim=1)
     result = self._softmax(catted)
     return result
Beispiel #5
0
 def predictKTacticsWithLoss_batch(self,
                                   in_data : List[TacticContext],
                                   k : int, corrects : List[str]) -> \
                                   Tuple[List[List[Prediction]], float]:
     assert self._embedding
     assert self.training_args
     with self._lock:
         prediction_distributions = self._predictDistributions(in_data)
     correct_stems = [serapi_instance.get_stem(correct) for correct in corrects]
     output_var = maybe_cuda(Variable(
         LongTensor([self._embedding.encode_token(correct_stem)
                     if self._embedding.has_token(correct_stem)
                     else 0
                     for correct_stem in correct_stems])))
     loss = self._criterion(prediction_distributions, output_var).item()
     if k > self._embedding.num_tokens():
         k = self._embedding.num_tokens()
     certainties_and_idxs_list = \
         [single_distribution.view(-1).topk(k) if len(context.hypotheses) > 0 else
          topk_with_filter(single_distribution.view(-1), k,
                           lambda certainty, idx:
                           not serapi_instance.tacticTakesHypArgs(
                               cast(Embedding, self._embedding).decode_token(idx)))
          for single_distribution, context in
          zip(prediction_distributions, in_data)]
     results = [[Prediction(self.add_arg(self._embedding.decode_token(stem_idx.item()),
                                         in_datum.goal, in_datum.hypotheses,
                                         self.training_args.max_length),
                            math.exp(certainty.item()))
                 for certainty, stem_idx in zip(*certainties_and_idxs)]
                for certainties_and_idxs, in_datum in
                zip(certainties_and_idxs_list, in_data)]
     return results, loss
Beispiel #6
0
 def hyp_name_scores(self,
                     stem_idxs: torch.LongTensor,
                     tokenized_goal: List[int],
                     tokenized_premises: List[List[int]],
                     premise_features: List[List[float]]
                     ) -> torch.FloatTensor:
     assert self._model
     assert len(stem_idxs.size()) == 1
     stem_width = stem_idxs.size()[0]
     num_hyps = len(tokenized_premises)
     encoded_goals = self._model.goal_encoder(LongTensor([tokenized_goal]))
     hyp_arg_values = self.runHypModel(stem_idxs.unsqueeze(0),
                                       encoded_goals,
                                       LongTensor([tokenized_premises]),
                                       FloatTensor([premise_features]))
     assert hyp_arg_values.size() == torch.Size([1, stem_width, num_hyps])
     return hyp_arg_values
Beispiel #7
0
def train(dataset : SequenceSequenceDataset, hidden_size : int,
          learning_rate : float, num_encoder_layers : int,
          num_decoder_layers : int, max_length : int, num_epochs : int, batch_size : int,
          print_every : int, context_vocab_size : int, tactic_vocab_size : int) -> Iterable[Checkpoint]:
    print("Initializing PyTorch...")
    in_stream = [inputFromSentence(datum[0], max_length) for datum in dataset]
    out_stream = [inputFromSentence(datum[1], max_length) for datum in dataset]
    data_loader = data.DataLoader(data.TensorDataset(torch.LongTensor(out_stream),
                                                     torch.LongTensor(in_stream)),
                                  batch_size=batch_size, num_workers=0,
                                  shuffle=True, pin_memory=True,
                                  drop_last=True)

    encoder = EncoderRNN(context_vocab_size, hidden_size, num_encoder_layers,
                         batch_size=batch_size)
    decoder = DecoderRNN(hidden_size, tactic_vocab_size, num_decoder_layers,
                         batch_size=batch_size)
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    optimizers = [encoder_optimizer, decoder_optimizer]
    criterion = maybe_cuda(nn.NLLLoss())

    start = time.time()
    num_items = len(dataset) * num_epochs
    total_loss = 0

    print("Training...")
    for epoch in range(num_epochs):
        print("Epoch {}".format(epoch))
        adjustLearningRates(learning_rate, optimizers, epoch)
        for batch_num, (output_batch, input_batch) in enumerate(data_loader):
            target_length = output_batch.size()[1]

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            predictor_output = decoder.run_teach(encoder
                                                 .run(cast(SomeLongTensor, input_batch)),
                                                 cast(SomeLongTensor, output_batch))
            loss = maybe_cuda(Variable(LongTensor(0)))
            output_var = maybe_cuda(Variable(output_batch))
            for i in range(target_length):
                loss += criterion(predictor_output[i], output_var[:,i])
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

            total_loss += (loss.data[0] / target_length) * batch_size

            if (batch_num + 1) % print_every == 0:
                items_processed = (batch_num + 1) * batch_size + epoch * len(dataset)
                progress = items_processed / num_items
                print("{} ({} {:.2f}%) {:.4f}".
                      format(timeSince(start, progress),
                             items_processed, progress * 100,
                             total_loss / items_processed))

        yield encoder.state_dict(), decoder.state_dict()
 def _predictStemDistributions(self, in_datas : List[TacticContext]) \
     -> torch.FloatTensor:
     word_features_batch = LongTensor(
         [self._get_word_features(in_data) for in_data in in_datas])
     vec_features_batch = FloatTensor(
         [self._get_vec_features(in_data) for in_data in in_datas])
     encoded_word_features = self._model.word_features_encoder(
         word_features_batch)
     stem_distribution = \
         self._softmax(self._model.features_classifier(torch.cat((
             encoded_word_features, vec_features_batch), dim=1)))
     return stem_distribution
Beispiel #9
0
 def predictKTactics(self, in_data : TacticContext, k : int) -> \
     List[Prediction]:
     in_sentence = LongTensor(inputFromSentence(
         self.context_tokenizer.toTokenList(in_data.goal),
         self.max_length)).view(1, -1)
     feature_vector = self.encoder.run(in_sentence)
     prediction_sentences = decodeKTactics(self.decoder,
                                           feature_vector,
                                           self.beam_width,
                                           self.max_length)[:k]
     return [Prediction(self.tactic_tokenizer.toString(sentence), .5 **i)
             for sentence, i in zip(prediction_sentences, itertools.count())]
Beispiel #10
0
 def predict_stems(self, k: int,
                   word_features: List[List[int]],
                   vec_features: List[List[float]]
                   ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
     assert self._model
     assert len(word_features) == len(vec_features)
     batch_size = len(word_features)
     stem_distribution = self._model.stem_classifier(
         LongTensor(word_features), FloatTensor(vec_features))
     stem_probs, stem_idxs = stem_distribution.topk(k)
     assert stem_probs.size() == torch.Size([batch_size, k])
     assert stem_idxs.size() == torch.Size([batch_size, k])
     return stem_probs, stem_idxs
    def _predictFromStemDistributionWithLoss(
            self, beam_width: int, stem_distribution: torch.FloatTensor,
            in_datas: List[TacticContext], corrects: List[str],
            k: int) -> Tuple[List[List[Prediction]], float]:
        assert self.training_args
        assert self._embedding
        stem_probs, likely_correct_stems = stem_distribution.topk(
            min(beam_width,
                stem_distribution.size()[1]))

        all_prob_batches, index_mapping = \
            self._predictCompositeDistributionFromStemDistribution(
                beam_width, stem_distribution, in_datas)
        correct_idxs = LongTensor([[
            max(0, arg_idx - 1) + (stem_idx * self.training_args.max_length)
            for stem_idx, arg_idx in [
                get_stem_and_arg_idx(
                    self.training_args.max_length, self._embedding,
                    serapi_instance.normalizeNumericArgs(
                        ScrapedTactic(
                            [], in_data.prev_tactics,
                            ProofContext(
                                [Obligation(in_data.hypotheses, in_data.goal)],
                                [], [], []), correct)))
            ]
        ][0] for in_data, correct in zip(in_datas, corrects)])

        loss = self._criterion(all_prob_batches, correct_idxs)

        final_probs, final_idxs = all_prob_batches.topk(beam_width)
        row_length = self.training_args.max_length
        indices = final_idxs / row_length
        stem_idxs = [
            index_map.index_select(0, indices1)
            for index_map, indices1 in zip(index_mapping, indices)
        ]
        arg_idxs = final_idxs % row_length
        return [[
            Prediction(
                self._embedding.decode_token(stem_idx.item()) + " " +
                get_arg_from_token_idx(in_data.goal, arg_idx.item()) + ".",
                math.exp(final_prob))
            for stem_idx, arg_idx, final_prob in islice(
                zip(stem_list, arg_list, final_list), k)
        ] for in_data, stem_list, arg_list, final_list in zip(
            in_datas, stem_idxs, arg_idxs, final_probs)], loss
Beispiel #12
0
 def predictKTacticsWithLoss(self, in_data : TacticContext, k : int,
                             correct : str) -> Tuple[List[Prediction], float]:
     with self._lock:
         distribution, hyp_var = self._predictDistribution(in_data)
         correct_stem = serapi_instance.get_stem(correct)
         if self._embedding.has_token(correct_stem):
             loss = self._criterion(distribution.view(1, -1),
                                    Variable(LongTensor([self._embedding.encode_token(correct_stem)]))).item()
         else:
             loss = float("+inf")
     indices, probabilities = list_topk(list(distribution), k)
     predictions : List[Prediction] = []
     for certainty, idx in zip(probabilities, indices):
         stem = self._embedding.decode_token(idx)
         if serapi_instance.tacticTakesHypArgs(stem):
             predictions.append(Prediction(stem + " " + hyp_var + ".",
                                           math.exp(certainty)))
         else:
             predictions.append(Prediction(stem + ".", math.exp(certainty)))
     return predictions, loss
Beispiel #13
0
 def forward(self, stem_batch: torch.LongTensor, goal_batch: torch.LongTensor) \
         -> torch.FloatTensor:
     goal_var = maybe_cuda(Variable(goal_batch))
     stem_var = maybe_cuda(Variable(stem_batch))
     batch_size = goal_batch.size()[0]
     assert stem_batch.size()[0] == batch_size
     initial_hidden = self._stem_embedding(stem_var)\
                          .view(1, batch_size, self.hidden_size)
     hidden = initial_hidden
     encoded_tokens: List[torch.FloatTensor] = []
     for i in range(goal_batch.size()[1]):
         token_batch = self._token_embedding(goal_var[:, i])\
                           .view(1, batch_size, self.hidden_size)
         token_batch2 = F.relu(token_batch)
         token_out, hidden = self._gru(token_batch2, hidden)
         encoded_tokens.append(token_out.squeeze(dim=0).unsqueeze(1))
     end_token_embedded = self._token_embedding(LongTensor([EOS_token])
                                                .expand(batch_size))\
         .view(1, batch_size, self.hidden_size)
     final_out, _final_hidden = self._gru(F.relu(end_token_embedded), hidden)
     encoded_tokens.insert(0, final_out.squeeze(dim=0).unsqueeze(1))
     catted = torch.cat(encoded_tokens, dim=1)
     return catted
Beispiel #14
0
 def forward(self, stem_batch : torch.LongTensor, goal_batch : torch.LongTensor) \
     -> torch.FloatTensor:
     goal_var = maybe_cuda(Variable(goal_batch))
     stem_var = maybe_cuda(Variable(stem_batch))
     batch_size = goal_batch.size()[0]
     assert stem_batch.size()[0] == batch_size
     initial_hidden = self._stem_embedding(stem_var)\
                          .view(1, batch_size, self.hidden_size)
     hidden = initial_hidden
     copy_likelyhoods: List[torch.FloatTensor] = []
     for i in range(goal_batch.size()[1]):
         try:
             token_batch = self._token_embedding(goal_var[:,i])\
                               .view(1, batch_size, self.hidden_size)
             token_batch2 = F.relu(token_batch)
             token_out, hidden = self._gru(token_batch2, hidden)
             copy_likelyhood = self._likelyhood_layer(F.relu(token_out))
             copy_likelyhoods.append(copy_likelyhood[0])
         except RuntimeError:
             eprint("Tokenized goal:")
             for j in range(goal_batch.size()[0]):
                 eprint(goal_batch[j, i].item(), end=" ")
                 assert goal_batch[j, i] < 123
             eprint()
             eprint(f"goal_var: {goal_var}")
             eprint("Token batch")
             eprint(token_batch)
             raise
     end_token_embedded = self._token_embedding(LongTensor([EOS_token])
                                                .expand(batch_size))\
                                                .view(1, batch_size, self.hidden_size)
     final_out, final_hidden = self._gru(F.relu(end_token_embedded), hidden)
     final_likelyhood = self._likelyhood_layer(F.relu(final_out))
     copy_likelyhoods.insert(0, final_likelyhood[0])
     catted = torch.cat(copy_likelyhoods, dim=1)
     return catted
Beispiel #15
0
    def predictKTactics(self, context: TacticContext,
                        k: int) -> List[Prediction]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)

        tokenized_premises, hyp_features, \
            nhyps_batch, tokenized_goal, \
            goal_mask, \
            word_features, vec_features = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)

        num_hyps = nhyps_batch[0]

        stem_distribution = self._model.stem_classifier(
            LongTensor(word_features), FloatTensor(vec_features))
        stem_certainties, stem_idxs = stem_distribution.topk(stem_width)

        goals_batch = LongTensor(tokenized_goal)
        goal_arg_values = self._model.goal_args_model(
            stem_idxs.view(1 * stem_width),
            goals_batch.view(1, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(1 * stem_width,
                  self.training_args.max_length))\
                                     .view(1, stem_width,
                                           self.training_args.max_length + 1)
        goal_arg_values = torch.where(
            torch.ByteTensor(goal_mask).view(1, 1,
                                             self.training_args.max_length +
                                             1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))
        assert goal_arg_values.size() == torch.Size([1, stem_width,
                                                     self.training_args.max_length + 1]),\
            "goal_arg_values.size(): {}; stem_width: {}".format(goal_arg_values.size(),
                                                                stem_width)

        num_probs = 1 + num_hyps + self.training_args.max_length
        goal_symbols = get_fpa_words(context.goal)
        num_valid_probs = (1 + num_hyps + len(goal_symbols)) * stem_width
        if num_hyps > 0:
            encoded_goals = self._model.goal_encoder(goals_batch)\
                                       .view(1, 1, self.training_args.hidden_size)

            hyps_batch = LongTensor(tokenized_premises)
            assert hyps_batch.size() == torch.Size([1, num_hyps,
                                                    self.training_args.max_length]), \
                                                    (hyps_batch.size(),
                                                     num_hyps,
                                                     self.training_args.max_length)
            hypfeatures_batch = FloatTensor(hyp_features)
            assert hypfeatures_batch.size() == \
                torch.Size([1, num_hyps, hypFeaturesSize()]), \
                (hypfeatures_batch.size(), num_hyps, hypFeaturesSize())
            hyp_arg_values = self.runHypModel(stem_idxs, encoded_goals,
                                              hyps_batch, hypfeatures_batch)
            assert hyp_arg_values.size() == \
                torch.Size([1, stem_width, num_hyps])
            total_values = torch.cat((goal_arg_values, hyp_arg_values), dim=2)
        else:
            total_values = goal_arg_values
        all_prob_batches = self._softmax((total_values +
                                          stem_certainties.view(1, stem_width, 1)
                                          .expand(-1, -1, num_probs))
                                         .contiguous()
                                         .view(1, stem_width * num_probs))\
                               .view(stem_width * num_probs)

        final_probs, final_idxs = all_prob_batches.topk(k)
        assert not torch.isnan(final_probs).any()
        assert final_probs.size() == torch.Size([k])
        row_length = self.training_args.max_length + num_hyps + 1
        stem_keys = final_idxs // row_length
        assert stem_keys.size() == torch.Size([k])
        assert stem_idxs.size() == torch.Size([1,
                                               stem_width]), stem_idxs.size()
        prediction_stem_idxs = stem_idxs.view(stem_width).index_select(
            0, stem_keys)
        assert prediction_stem_idxs.size() == torch.Size([k]), \
            prediction_stem_idxs.size()
        arg_idxs = final_idxs % row_length
        assert arg_idxs.size() == torch.Size([k])

        if self.training_args.lemma_args:
            all_hyps = context.hypotheses + context.relevant_lemmas
        else:
            all_hyps = context.hypotheses
        return [
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata, all_hyps, context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(prediction_stem_idxs, arg_idxs, final_probs),
                    min(k, num_valid_probs))
        ]
Beispiel #16
0
    def predictKTactics_batch(self, context_batch: List[TacticContext],
                              k: int) \
            -> List[List[Prediction]]:
        assert self.training_args
        assert self._model

        num_stem_poss = get_num_tokens(self._metadata)
        stem_width = min(self.training_args.max_beam_width, num_stem_poss,
                         k**2)
        batch_size = len(context_batch)

        tprems_batch, pfeat_batch, \
            nhyps_batch, tgoals_batch, \
            goal_masks_batch, \
            wfeats_batch, vfeats_batch = \
            sample_fpa_batch(extract_dataloader_args(self.training_args),
                             self._metadata,
                             [context_py2r(context)
                              for context in context_batch])
        for tprem, pfeat, nhyp, tgoal, masks, wfeat, vfeat, context \
            in zip(tprems_batch, pfeat_batch,
                   nhyps_batch, tgoals_batch,
                   goal_masks_batch, wfeats_batch,
                   vfeats_batch, context_batch):
            s_tprem, s_pfeat, s_nhyp, s_tgoal, s_masks, s_wfeat, s_vfeat = \
                sample_fpa(extract_dataloader_args(self.training_args),
                           self._metadata,
                           context.relevant_lemmas,
                           context.prev_tactics,
                           context.hypotheses,
                           context.goal)
            assert len(s_tprem) == 1
            assert len(tprem) == len(s_tprem[0])
            for p1, p2 in zip(tprem, s_tprem[0]):
                assert p1 == p2, (p1, p2)
            assert len(s_pfeat) == 1
            assert len(pfeat) == len(s_pfeat[0])
            for f1, f2 in zip(pfeat, s_pfeat[0]):
                assert f1 == f2, (f1, f2)
            assert s_nhyp[0] == nhyp, (s_nhyp[0], nhyp)
            assert s_tgoal[0] == tgoal, (s_tgoal[0], tgoal)
            assert s_masks[0] == masks, (s_masks[0], masks)
            assert s_wfeat[0] == wfeat, (s_wfeat[0], wfeat)
            assert s_vfeat[0] == vfeat, (s_vfeat[0], vfeat)

        stem_distribution = self._model.stem_classifier(
            LongTensor(wfeats_batch), FloatTensor(vfeats_batch))
        stem_certainties_batch, stem_idxs_batch = stem_distribution.topk(
            stem_width)

        goals_batch = LongTensor(tgoals_batch)

        goal_arg_values = self._model.goal_args_model(
            stem_idxs_batch.view(batch_size * stem_width),
            goals_batch.view(batch_size, 1, self.training_args.max_length)
            .expand(-1, stem_width, -1).contiguous()
            .view(batch_size * stem_width,
                  self.training_args.max_length))\
                                     .view(batch_size, stem_width,
                                           self.training_args.max_length + 1)

        goal_arg_values = torch.where(
            torch.ByteTensor(goal_masks_batch).view(
                batch_size, 1,
                self.training_args.max_length + 1).expand(-1, stem_width, -1),
            goal_arg_values, torch.full_like(goal_arg_values, -float("Inf")))

        encoded_goals_batch = self._model.goal_encoder(goals_batch)

        stems_expanded_batch = torch.cat([
            stem_idxs.view(1, stem_width).expand(
                num_hyps, stem_width).contiguous().view(num_hyps * stem_width)
            for stem_idxs, num_hyps, in zip(stem_idxs_batch, nhyps_batch)
        ])
        egoals_expanded_batch = torch.cat([
            encoded_goal.view(1, self.training_args.hidden_size).expand(
                num_hyps * stem_width, -1).contiguous()
            for encoded_goal, num_hyps in zip(encoded_goals_batch, nhyps_batch)
        ])
        tprems_expanded_batch = torch.cat([
            LongTensor(tpremises).view(
                1, -1, self.training_args.max_length).expand(
                    stem_width, -1,
                    -1).contiguous().view(-1, self.training_args.max_length)
            for tpremises in tprems_batch
        ])
        pfeat_expanded_batch = torch.cat([
            FloatTensor(premise_features).view(1, num_hyps, 2).expand(
                stem_width, -1, -1).contiguous().view(num_hyps * stem_width, 2)
            for premise_features, num_hyps in zip(pfeat_batch, nhyps_batch)
        ])

        prem_arg_values = self._model.hyp_model(stems_expanded_batch,
                                                egoals_expanded_batch,
                                                tprems_expanded_batch,
                                                pfeat_expanded_batch)

        prem_arg_values_split = prem_arg_values.split(
            [num_hyps * stem_width for num_hyps in nhyps_batch])
        total_values_list = [
            torch.cat((goal_values, prem_values.view(stem_width, -1)),
                      dim=1) for goal_values, prem_values in zip(
                          goal_arg_values, prem_arg_values_split)
        ]
        all_probs_list = [
            self._softmax((total_values + stem_certainties.view(
                stem_width, 1).expand_as(total_values)).contiguous().view(
                    1, -1)).view(-1) for total_values, stem_certainties in zip(
                        total_values_list, stem_certainties_batch)
        ]

        final_probs_list, final_idxs_list = zip(
            *[probs.topk(k) for probs in all_probs_list])
        stem_keys_list = [
            final_idxs // (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        stem_idxs_list = [
            stem_idxs.view(stem_width).index_select(0, stem_keys)
            for stem_idxs, stem_keys in zip(stem_idxs_batch, stem_keys_list)
        ]

        arg_idxs_list = [
            final_idxs % (self.training_args.max_length + num_hyps + 1)
            for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch)
        ]

        predictions = [[
            Prediction(
                decode_fpa_result(extract_dataloader_args(self.training_args),
                                  self._metadata,
                                  context.hypotheses + context.relevant_lemmas,
                                  context.goal,
                                  stem_idx.item(), arg_idx.item()),
                math.exp(prob)) for stem_idx, arg_idx, prob in islice(
                    zip(stem_idxs, arg_idxs, final_probs),
                    min(k, 1 + num_hyps + len(get_fpa_words(context.goal))))
        ] for stem_idxs, arg_idxs, final_probs, context, num_hyps in zip(
            stem_idxs_list, arg_idxs_list, final_probs_list, context_batch,
            nhyps_batch)]

        for context, pred_list in zip(context_batch, predictions):
            for batch_pred, single_pred in zip(
                    pred_list, self.predictKTactics(context, k)):
                assert batch_pred.prediction == single_pred.prediction, \
                    (batch_pred, single_pred)

        return predictions
Beispiel #17
0
 def initInput(self) -> SomeLongTensor:
     return Variable(LongTensor([[SOS_token] * self.batch_size]))
Beispiel #18
0
def decodeKTactics(decoder : DecoderRNN, encoder_hidden : torch.FloatTensor,
                   beam_width : int, max_length : int):
    v = decoder.output_size
    pos_index = Variable(LongTensor([0]) * beam_width).view(-1, 1)

    hidden = _inflate(encoder_hidden, beam_width)

    sequence_scores = FloatTensor(beam_width, 1)
    sequence_scores.fill_(-float('Inf'))
    sequence_scores.index_fill_(0, LongTensor([0]), 0.0)
    sequence_scores = Variable(sequence_scores)

    input_var = Variable(LongTensor([[SOS_token] * beam_width]))

    stored_predecessors = list()
    stored_emitted_symbols = list()

    for j in range(max_length):
        decoder_output, hidden = decoder(input_var, hidden)

        sequence_scores = _inflate(sequence_scores, v)
        sequence_scores += decoder_output

        scores, candidates = sequence_scores.view(1, -1).topk(beam_width)

        input_var = (candidates % v).view(1, beam_width)
        sequence_scores = scores.view(beam_width, 1)

        predecessors = (candidates / v +
                        pos_index.expand_as(candidates)).view(beam_width, 1)
        hidden = hidden.index_select(1, cast(torch.LongTensor, predecessors.squeeze()))

        eos_indices = input_var.data.eq(EOS_token)
        if eos_indices.nonzero().dim() > 0:
            sequence_scores.data.masked_fill_(torch.transpose(eos_indices, 0, 1),
                                              -float('inf'))

        stored_predecessors.append(predecessors)
        stored_emitted_symbols.append(torch.transpose(input_var, 0, 1))


    # Trace back from the final three highest scores
    _, next_idxs = sequence_scores.view(beam_width).sort(descending=True)
    seqs = [] # type: List[List[SomeLongTensor]]
    eos_found = 0
    for i in range(max_length - 1, -1, -1):
        # The next column of symbols from the end
        next_symbols = stored_emitted_symbols[i].view(beam_width) \
                                                .index_select(0, next_idxs).data
        # The predecessors of that column
        next_idxs = stored_predecessors[i].view(beam_width).index_select(0, next_idxs)

        # Handle sequences that ended early
        eos_indices = stored_emitted_symbols[i].data.squeeze(1).eq(EOS_token).nonzero()
        if eos_indices.dim() > 0:
            for j in range(eos_indices.size(0)-1, -1, -1):
                idx = eos_indices[j]

                res_k_idx = beam_width - (eos_found % beam_width) - 1
                eos_found += 1
                res_idx = res_k_idx

                next_idxs[res_idx] = stored_predecessors[i][idx[0]]
                next_symbols[res_idx] = stored_emitted_symbols[i][idx[0]].data[0]

        # Commit the result
        seqs.insert(0, next_symbols)

    # Transpose
    int_seqs = [[data[i][0] for data in seqs] for i in range(beam_width)]
    # Cut off EOS tokens
    int_seqs = [list(takewhile(lambda x: x != EOS_token, seq)) for seq in int_seqs]

    return int_seqs