예제 #1
0
def spm_initializer(spm_path: str):
    global spm
    spm = SentencePieceProcessor()
    spm.load(spm_path)
    global vocab
    vocab = {index:spm.id_to_piece(index) for index in range(spm.GetPieceSize())}
예제 #2
0
def beam_search_decode_eos(model,
                           X,
                           X_lengths,
                           sp: spm.SentencePieceProcessor,
                           eos_id,
                           max_decode_len=20,
                           k=3):
    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    bos_id = sp.PieceToId("<s>")
    V = sp.GetPieceSize()  # Size of vocab
    model.eval()

    with torch.no_grad():
        # initial Y_hat and batchwise score tensors
        sequences = [(
            torch.zeros(B, max_decode_len, dtype=torch.long,
                        device=X.device).fill_(bos_id),  # Y_hat
            torch.ones(B, dtype=torch.long),  # Y_hat_lengths
            torch.zeros(B, device=X.device),  # scores
            torch.zeros(B, dtype=torch.long, device=X.device),  # ended
        )]
        # walk over each item in output sequence
        for t in range(max_decode_len - 1):
            all_candidates = []
            # expand each current candidate
            for Y_hat, Y_hat_lengths, scores, ended in sequences:
                Y_hat = Y_hat.to(X.device)
                scores = scores.to(X.device)
                logits = model(X, Y_hat[:, :-1].to(X.device), X_lengths,
                               Y_hat_lengths)
                logits_t = logits[:, t, :]
                logprobs_t = F.log_softmax(logits_t, dim=-1).to(
                    scores.device)  # [B, V] tensor
                for j in range(V):
                    # TODO: Only add probability if the sequence has not ended (generated </s>)
                    log_p_j = logprobs_t[:, j]  # log p(Y_t=j | Y_{<t-1}, X)
                    candidate_Y_hat = Y_hat.clone()
                    candidate_Y_hat[:, t + 1] = j
                    candidate_Y_hat_lengths = Y_hat_lengths.clone()
                    candidate_Y_hat_lengths = j
                    # candidate_ended = ended or j == eos_id
                    if j == eos_id:
                        candidate_ended = torch.ones_like(ended)
                    else:
                        candidate_ended = ended.clone()
                    candidate = (candidate_Y_hat, candidate_Y_hat_lengths,
                                 scores + log_p_j, candidate_ended)
                    all_candidates.append(candidate)
            # stack candidates
            beam_Y, beam_Y_lengths, beam_scores = zip(*all_candidates)
            beam_Y = torch.stack(beam_Y, dim=1)  # [B, V, T]
            beam_Y_lengths = (torch.stack(beam_Y_lengths, dim=1), )  # [B, V]
            beam_scores = torch.stack(beam_scores, dim=1)  # [B, V]
            # seleck k best per batch item
            topk_scores, topk_idx = torch.topk(beam_scores,
                                               k,
                                               dim=1,
                                               sorted=True)
            topk_Y = torch.gather(
                beam_Y, 1,
                topk_idx.unsqueeze(-1).expand(B, k, max_decode_len))
            topk_Y_lengths = torch.gather(beam_Y_lengths, 1,
                                          topk_idx.unsqueeze(-1).expand(B, k))
            # set beam
            sequences = [(topk_Y[:,
                                 j, :], topk_Y_lengths[:,
                                                       j, :], topk_scores[:,
                                                                          j])
                         for j in range(k)]
            # TODO: exit early if all sentences in all beam sequences contain </s>

    # stack sequences
    beam_Y, beam_scores = zip(*sequences)
    beam_Y = torch.stack(beam_Y, dim=1)  # [B, k, T]
    beam_scores = torch.stack(beam_scores, dim=1)  # [B, k]
    model.train()
    return ids_to_strs(beam_Y, sp), beam_scores
예제 #3
0
def beam_search_decode(
    model,
    X,
    sp: spm.SentencePieceProcessor,
    max_decode_len,
    k,
    per_node_k=None,
    constrain_decoding=False,
    sampler="deterministic",
    top_p_threshold=0.9,
    top_p_temperature=1.0,
):
    if sampler == "top_p":
        sampler = allennlp.nn.beam_search.TopPSampler(
            p=top_p_threshold, temperature=top_p_temperature)
    elif sampler == "deterministic":
        sampler = None
    else:
        raise ValueError("Unsupported sampler")

    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    pad_id = sp.PieceToId("[PAD]")
    bos_id = sp.PieceToId("<s>")
    eos_id = sp.PieceToId("</s>")
    V_full = sp.GetPieceSize()  # Size of vocab
    invalid_vocab_mask = torch.zeros(V_full, dtype=torch.bool, device=X.device)
    if constrain_decoding:
        alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_ "
        for id in range(V_full):
            piece = sp.IdToPiece(id)
            if not (id in [pad_id, bos_id, eos_id] or all(c in alphabet
                                                          for c in piece)):
                invalid_vocab_mask[id] = True
    V = V_full
    model.eval()

    # Encode X
    allen_bs = allennlp.nn.beam_search.BeamSearch(
        end_index=eos_id,
        max_steps=max_decode_len,
        beam_size=k,
        per_node_beam_size=per_node_k,
        sampler=sampler,
    )

    start_predictions = torch.tensor([bos_id] * B,
                                     dtype=torch.long,
                                     device=X.device)
    start_state = {
        "prev_tokens": torch.zeros(B, 0, dtype=torch.long, device=X.device),
        "memory": model.encode(X).transpose(0, 1),  # [B, T, d_model]
    }

    def step(last_tokens, current_state, t):
        """
        Args:
            last_tokens: (group_size,)
            current_state: {}
            t: int
        """
        group_size = last_tokens.size(0)
        prev_tokens = torch.cat(
            [current_state["prev_tokens"],
             last_tokens.unsqueeze(1)], dim=-1)  # [B*k, t+1]

        all_log_probs = model.decode(current_state["memory"].transpose(0, 1),
                                     prev_tokens)
        next_log_probs = all_log_probs[:, -1, :]
        if constrain_decoding:
            next_log_probs = next_log_probs.masked_fill(
                invalid_vocab_mask, float("-inf"))
        next_log_probs = torch.nn.functional.log_softmax(next_log_probs,
                                                         dim=-1)
        assert next_log_probs.shape == (group_size, V)
        return (next_log_probs, {
            "prev_tokens": prev_tokens,
            "memory": current_state["memory"]
        })

    predictions, log_probs = allen_bs.search(
        start_predictions=start_predictions,
        start_state=start_state,
        step=step)

    model.train()
    prediction = ids_to_strs(predictions, sp)
    return prediction, log_probs
예제 #4
0
def beam_search_decode(model,
                       X,
                       X_lengths,
                       sp: spm.SentencePieceProcessor,
                       max_decode_len=20,
                       k=3):
    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    bos_id = sp.PieceToId("<s>")
    V = sp.GetPieceSize()  # Size of vocab
    model.eval()

    with torch.no_grad():
        Y_hat_lengths = torch.ones(B, dtype=torch.long)  # Y_hat_lengths
        # initial Y_hat and batchwise score tensors
        sequences = [(
            torch.zeros(B, max_decode_len).long().to(X.device) + bos_id,
            # torch.ones(B, dtype=torch.long),  # Y_hat_lengths
            torch.zeros(B).to(X.device))]
        # walk over each item in output sequence
        for t in range(max_decode_len - 1):
            all_candidates = []
            # expand each current candidate
            for Y_hat, scores in sequences:
                Y_hat = Y_hat.to(X.device)
                scores = scores.to(X.device)
                logits = model(X,
                               Y_hat[:, :t + 1].to(X.device),
                               src_lengths=X_lengths,
                               tgt_lengths=Y_hat_lengths + 1)
                logits_t = logits[:, t, :]
                logprobs_t = F.log_softmax(logits_t, dim=-1).to(
                    scores.device)  # [B, V] tensor
                for j in range(V):
                    log_p_j = logprobs_t[:, j]  # log p(Y_t=j | Y_{<t-1}, X)
                    candidate_Y_hat = Y_hat.clone()
                    candidate_Y_hat[:, t + 1] = j
                    candidate = (candidate_Y_hat, scores + log_p_j)
                    all_candidates.append(candidate)
            # stack candidates
            beam_Y, beam_scores = zip(*all_candidates)
            beam_Y = torch.stack(beam_Y, dim=1)  # [B, V, T]
            beam_scores = torch.stack(beam_scores, dim=1)  # [B, V]
            # seleck k best per batch item
            topk_scores, topk_idx = torch.topk(beam_scores,
                                               k,
                                               dim=1,
                                               sorted=True)
            topk_Y = torch.gather(
                beam_Y, 1,
                topk_idx.unsqueeze(-1).expand(B, k, max_decode_len))
            # set beam
            sequences = [(topk_Y[:, j, :], topk_scores[:, j])
                         for j in range(k)]
            # TODO: exit early if all sentences in all beam sequences contain </s>
            Y_hat_lengths = Y_hat_lengths + 1

    # stack sequences
    beam_Y, beam_scores = zip(*sequences)
    beam_Y = torch.stack(beam_Y, dim=1)  # [B, k, T]
    beam_scores = torch.stack(beam_scores, dim=1)  # [B, k]
    model.train()
    return ids_to_strs(beam_Y, sp), beam_scores