示例#1
0
def predictKTacticsWithLoss_batch(prediction_distributions : torch.FloatTensor,
                                  embedding : Embedding,
                                  k : int,
                                  correct_stems : List[str],
                                  criterion : nn.Module) -> \
                                  Tuple[List[List[Prediction]], float]:
    output_var = maybe_cuda(
        Variable(
            torch.LongTensor([
                embedding.encode_token(correct_stem)
                if embedding.has_token(correct_stem) else 0
                for correct_stem in correct_stems
            ])))
    loss = criterion(prediction_distributions, output_var).item()
    if k > embedding.num_tokens():
        k = embedding.num_tokens()
    certainties_and_idxs_list = [
        single_distribution.view(-1).topk(k)
        for single_distribution in list(prediction_distributions)
    ]
    results = [[
        Prediction(
            embedding.decode_token(stem_idx.item()) + ".",
            math.exp(certainty.item()))
        for certainty, stem_idx in zip(*certainties_and_idxs)
    ] for certainties_and_idxs in certainties_and_idxs_list]
    return results, loss
def predictKTacticsWithLoss(prediction_distribution : torch.FloatTensor,
                            embedding : Embedding,
                            k : int,
                            correct : str,
                            criterion : nn.Module) -> Tuple[List[Prediction], float]:
    if k > embedding.num_tokens():
        k = embedding.num_tokens()
    correct_stem = get_stem(correct)
    if embedding.has_token(correct_stem):
        output_var = maybe_cuda(Variable(
            torch.LongTensor([embedding.encode_token(correct_stem)])))
        loss = criterion(prediction_distribution.view(1, -1), output_var).item()
    else:
        loss = 0

    certainties_and_idxs = prediction_distribution.view(-1).topk(k)
    results = [Prediction(embedding.decode_token(stem_idx.item()) + ".",
                          math.exp(certainty.item()))
               for certainty, stem_idx in zip(*certainties_and_idxs)]

    return results, loss