예제 #1
0
def predictKTacticsWithLoss_batch(prediction_distributions : torch.FloatTensor,
                                  embedding : Embedding,
                                  k : int,
                                  correct_stems : List[str],
                                  criterion : nn.Module) -> \
                                  Tuple[List[List[Prediction]], float]:
    output_var = maybe_cuda(
        Variable(
            torch.LongTensor([
                embedding.encode_token(correct_stem)
                if embedding.has_token(correct_stem) else 0
                for correct_stem in correct_stems
            ])))
    loss = criterion(prediction_distributions, output_var).item()
    if k > embedding.num_tokens():
        k = embedding.num_tokens()
    certainties_and_idxs_list = [
        single_distribution.view(-1).topk(k)
        for single_distribution in list(prediction_distributions)
    ]
    results = [[
        Prediction(
            embedding.decode_token(stem_idx.item()) + ".",
            math.exp(certainty.item()))
        for certainty, stem_idx in zip(*certainties_and_idxs)
    ] for certainties_and_idxs in certainties_and_idxs_list]
    return results, loss
예제 #2
0
 def _encode_tokenized_data(self, data : TokenizedDataset, arg_values : Namespace,
                            tokenizer : Tokenizer, embedding : Embedding) \
     -> PECDataset:
     return PECDataset([
         PECSample(
             embedding.encode_token(
                 get_stem(prev_tactics[-1]
                          ) if len(prev_tactics) > 1 else "Proof"), goal,
             tactic) for prev_tactics, goal, tactic in data
     ])
예제 #3
0
def get_stem_and_arg_idx(max_length: int, embedding: Embedding,
                         inter: ScrapedTactic) -> Tuple[int, int]:
    tactic_stem, tactic_rest = serapi_instance.split_tactic(inter.tactic)
    stem_idx = embedding.encode_token(tactic_stem)
    symbols = tokenizer.get_symbols(inter.context.focused_goal)
    arg = tactic_rest.split()[0].strip(".")
    assert arg in symbols, "tactic: {}, arg: {}, goal: {}, symbols: {}"\
        .format(inter.tactic, arg, inter.context.focused_goal, symbols)
    idx = symbols.index(arg)
    if idx >= max_length:
        return stem_idx, 0
    else:
        return stem_idx, idx + 1
예제 #4
0
def predictKTacticsWithLoss(prediction_distribution : torch.FloatTensor,
                            embedding : Embedding,
                            k : int,
                            correct : str,
                            criterion : nn.Module) -> Tuple[List[Prediction], float]:
    if k > embedding.num_tokens():
        k = embedding.num_tokens()
    correct_stem = get_stem(correct)
    if embedding.has_token(correct_stem):
        output_var = maybe_cuda(Variable(
            torch.LongTensor([embedding.encode_token(correct_stem)])))
        loss = criterion(prediction_distribution.view(1, -1), output_var).item()
    else:
        loss = 0

    certainties_and_idxs = prediction_distribution.view(-1).topk(k)
    results = [Prediction(embedding.decode_token(stem_idx.item()) + ".",
                          math.exp(certainty.item()))
               for certainty, stem_idx in zip(*certainties_and_idxs)]

    return results, loss