コード例 #1
0
ファイル: main.py プロジェクト: stevew00ds/contracode
def calculate_nll(
    model,
    test_loader,
    sp: spm.SentencePieceProcessor,
    use_cuda=True,
    logger_fn=None
):
    with Timer() as t:
        pad_id = sp.PieceToId("[PAD]")
        n_examples = 0
        test_nll = 0.
        pbar = tqdm.tqdm(test_loader, desc="test")
        for X, Y, X_lengths, Y_lengths in pbar:
            B, L = X.shape
            if use_cuda:
                X, Y = X.cuda(), Y.cuda()  # B, L
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            pred_y = model(X, Y[:, :-1].to(X.device), X_lengths, Y_lengths)
            B, X, D = pred_y.shape
            loss = F.cross_entropy(pred_y.reshape(B * X, D), Y[:, 1:].reshape(B * X), ignore_index=pad_id, reduction='sum')
            
            n_examples += B
            test_nll += loss.item()
            if logger_fn is not None:
                logger_fn({'test_nll': loss.item() / B, 'test_nll_avg': test_nll / n_examples})
        return test_nll / n_examples
コード例 #2
0
ファイル: main.py プロジェクト: patilanup246/contracode
def calculate_nll(model,
                  test_loader,
                  sp: spm.SentencePieceProcessor,
                  use_cuda=True,
                  logger_fn=None):
    pad_id = sp.PieceToId("[PAD]")
    n_examples = 0
    test_nll = 0.0
    with tqdm.tqdm(test_loader, desc="Test (NLL)") as pbar:
        for X, Y, X_lengths, Y_lengths in pbar:
            B, L = X.shape
            if use_cuda:
                X, Y = X.cuda(), Y.cuda()  # B, L
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            pred_y = model(X, Y[:, :-1].to(X.device), X_lengths, Y_lengths)
            B, X, D = pred_y.shape
            loss = F.cross_entropy(pred_y.reshape(B * X, D),
                                   Y[:, 1:].reshape(B * X),
                                   ignore_index=pad_id,
                                   reduction="sum")

            n_examples += B
            test_nll += loss.item()
            metric_dict = {
                "test_nll": loss.item() / B,
                "test_nll_avg": test_nll / n_examples
            }
            if logger_fn is not None:
                logger_fn(metric_dict)
            pbar.set_postfix(metric_dict)
    return test_nll / n_examples
コード例 #3
0
def encode_comment(sp_model: sentencepiece.SentencePieceProcessor,
                   comment: str,
                   max_len=None) -> List[int]:
    """ Encode one comment with sentencepiece model.
    """
    # TODO we can do sub-word augmentation here
    start = sp_model.PieceToId('<s>')
    end = sp_model.PieceToId('</s>')
    eol = sp_model.PieceToId(EOL)
    encoded = [start]
    for i, line in enumerate(comment.split('\n')):
        if i:
            encoded.append(eol)
        encoded.extend(sp_model.EncodeAsIds(line))
    encoded.append(end)
    if max_len is not None:
        encoded = encoded[:max_len]
    return encoded
コード例 #4
0
def greedy_decode(model,
                  X,
                  sp: spm.SentencePieceProcessor,
                  max_decode_len=20,
                  sample=True):
    start_token = sp.PieceToId("<s>")
    pad_token = sp.PieceToId("<pad>")
    B = X.size(0)
    model.eval()

    with torch.no_grad():
        decoded_batch = torch.zeros((B, 1), device=X.device).long()
        decoded_batch[:, 0] = start_token
        for t in range(max_decode_len):
            logits = model(X, decoded_batch)
            _, topi = logits[:, -1, :].topk(1)
            decoded_batch = torch.cat((decoded_batch, topi.view(-1, 1)), -1)
    Y_hat = decoded_batch.cpu().numpy()
    Y_hat_str = ids_to_strs(Y_hat, sp)
    model.train()
    return Y_hat_str
コード例 #5
0
def _evaluate(
    model, loader, sp: spm.SentencePieceProcessor, use_cuda=True, num_to_print=8, beam_search_k=5, max_decode_len=20, loss_type="nll_token"
):
    model.eval()
    pad_id = sp.PieceToId("[PAD]")

    with torch.no_grad():
        # with Timer() as t:
        #     # Decode a single batch by beam search for visualization
        #     X, Y, X_lengths, _ = next(iter(loader))
        #     X, Y = X[:num_to_print], Y[:num_to_print]
        #     if use_cuda:
        #         X = X.cuda()
        #         X_lengths = X.cuda()
        #     pred, scores = beam_search_decode(model, X, X_lengths, sp, k=beam_search_k, max_decode_len=max_decode_len)
        #     for i in range(X.size(0)):
        #         logger.info(f"Eval X:   \t\t\t{ids_to_strs(X[i], sp)}")
        #         logger.info(f"Eval GT Y:\t\t\t{ids_to_strs(Y[i], sp)}")
        #         for b in range(scores.size(1)):
        #             logger.info(f"Eval beam (score={scores[i, b]:.3f}):\t{pred[i][b]}")
        # logger.debug(f"Decode time for {num_to_print} samples took {t.interval:.3f}")

        with Timer() as t:
            # Compute average loss
            total_loss = 0
            num_examples = 0
            pbar = tqdm.tqdm(loader, desc="evalaute")
            for X, Y, X_lengths, Y_lengths in pbar:
                if use_cuda:
                    X, Y = X.cuda(), Y.cuda()
                    X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
                # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
                logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
                if loss_type == "nll_sequence":
                    loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction="sum")
                    loss = loss / X.size(0)  # Average over num sequences, not target sequence lengths
                    # Thus, minimize bits per sequence.
                elif loss_type == "nll_token":
                    loss = F.cross_entropy(
                        logits.transpose(1, 2),
                        Y[:, 1:],
                        ignore_index=pad_id,
                    )

                # TODO: Compute Precision/Recall/F1 and BLEU

                total_loss += loss.item() * X.size(0)
                num_examples += X.size(0)
                avg_loss = total_loss / num_examples
                pbar.set_description(f"evaluate average loss {avg_loss:.4f}")
        logger.debug(f"Loss calculation took {t.interval:.3f}s")
        return avg_loss
コード例 #6
0
ファイル: sentence_piece.py プロジェクト: geyingli/unif
class SentencePieceTokenizer:
    def __init__(self, spm_file, do_lower_case=True):
        self.processor = SentencePieceProcessor()
        self.processor.Load(spm_file)
        self.do_lower_case = do_lower_case

    def tokenize(self, text):
        text = preprocess_text(text, lower=self.do_lower_case)
        pieces = encode_pieces(self.processor, text, sample=False)
        return pieces

    def convert_tokens_to_ids(self, tokens):
        return [self.processor.PieceToId(piece) for piece in tokens]

    def convert_ids_to_tokens(self, ids):
        pieces = [self.processor.IdToPiece(_id) for _id in ids]
        return pieces
コード例 #7
0
class SentencePieceTokenizer:
    def __init__(self, spm_file, do_lower_case=True):
        if not os.path.exists(spm_file):
            raise ValueError(
                "Can't find spm_file \"%s\". "
                "Please pass the correct path of sentence-piece model file, "
                "e.g.`spiece.model`." % spm_file
            )
        self.processor = SentencePieceProcessor()
        self.processor.Load(spm_file)
        self.do_lower_case = do_lower_case

    def tokenize(self, text):
        text = preprocess_text(text, lower=self.do_lower_case)
        pieces = encode_pieces(self.processor, text, sample=False)
        return pieces

    def convert_tokens_to_ids(self, tokens):
        return [self.processor.PieceToId(piece) for piece in tokens]

    def convert_ids_to_tokens(self, ids):
        pieces = [self.processor.IdToPiece(_id) for _id in ids]
        return pieces
コード例 #8
0
ファイル: decode.py プロジェクト: ncoop57/contracode
def greedy_decode(model,
                  X,
                  sp: spm.SentencePieceProcessor,
                  max_decode_len=20,
                  sample=True):
    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    model.eval()
    with torch.no_grad():
        Y_hat = torch.zeros(B, max_decode_len, device=X.device).long()
        Y_hat.fill_(sp.PieceToId("<s>"))
        for t in range(max_decode_len - 1):
            logits = model(X, Y_hat)
            if sample:
                idx_t = torch.distributions.categorical.Categorical(
                    logits=logits[:, t, :]).sample()
            else:
                idx_t = logits[:, t, :].argmax(dim=-1)
            Y_hat[:, t + 1] = idx_t
    Y_hat = Y_hat.cpu().numpy()
    Y_hat_str = ids_to_strs(Y_hat, sp)
    model.train()
    return Y_hat_str
コード例 #9
0
ファイル: utils.py プロジェクト: whs1111/texar-pytorch
def encode_ids(sp_model: spm.SentencePieceProcessor,
               text: str,
               sample: bool = False) -> List[int]:
    pieces = encode_pieces(sp_model, text, sample=sample)
    ids = [sp_model.PieceToId(piece) for piece in pieces]
    return ids
コード例 #10
0
def beam_search_decode(
    model,
    X,
    sp: spm.SentencePieceProcessor,
    max_decode_len,
    k,
    per_node_k=None,
    constrain_decoding=False,
    sampler="deterministic",
    top_p_threshold=0.9,
    top_p_temperature=1.0,
):
    if sampler == "top_p":
        sampler = allennlp.nn.beam_search.TopPSampler(
            p=top_p_threshold, temperature=top_p_temperature)
    elif sampler == "deterministic":
        sampler = None
    else:
        raise ValueError("Unsupported sampler")

    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    pad_id = sp.PieceToId("[PAD]")
    bos_id = sp.PieceToId("<s>")
    eos_id = sp.PieceToId("</s>")
    V_full = sp.GetPieceSize()  # Size of vocab
    invalid_vocab_mask = torch.zeros(V_full, dtype=torch.bool, device=X.device)
    if constrain_decoding:
        alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_ "
        for id in range(V_full):
            piece = sp.IdToPiece(id)
            if not (id in [pad_id, bos_id, eos_id] or all(c in alphabet
                                                          for c in piece)):
                invalid_vocab_mask[id] = True
    V = V_full
    model.eval()

    # Encode X
    allen_bs = allennlp.nn.beam_search.BeamSearch(
        end_index=eos_id,
        max_steps=max_decode_len,
        beam_size=k,
        per_node_beam_size=per_node_k,
        sampler=sampler,
    )

    start_predictions = torch.tensor([bos_id] * B,
                                     dtype=torch.long,
                                     device=X.device)
    start_state = {
        "prev_tokens": torch.zeros(B, 0, dtype=torch.long, device=X.device),
        "memory": model.encode(X).transpose(0, 1),  # [B, T, d_model]
    }

    def step(last_tokens, current_state, t):
        """
        Args:
            last_tokens: (group_size,)
            current_state: {}
            t: int
        """
        group_size = last_tokens.size(0)
        prev_tokens = torch.cat(
            [current_state["prev_tokens"],
             last_tokens.unsqueeze(1)], dim=-1)  # [B*k, t+1]

        all_log_probs = model.decode(current_state["memory"].transpose(0, 1),
                                     prev_tokens)
        next_log_probs = all_log_probs[:, -1, :]
        if constrain_decoding:
            next_log_probs = next_log_probs.masked_fill(
                invalid_vocab_mask, float("-inf"))
        next_log_probs = torch.nn.functional.log_softmax(next_log_probs,
                                                         dim=-1)
        assert next_log_probs.shape == (group_size, V)
        return (next_log_probs, {
            "prev_tokens": prev_tokens,
            "memory": current_state["memory"]
        })

    predictions, log_probs = allen_bs.search(
        start_predictions=start_predictions,
        start_state=start_state,
        step=step)

    model.train()
    prediction = ids_to_strs(predictions, sp)
    return prediction, log_probs
コード例 #11
0
ファイル: decode.py プロジェクト: ncoop57/contracode
def beam_search_decode_eos(model,
                           X,
                           X_lengths,
                           sp: spm.SentencePieceProcessor,
                           eos_id,
                           max_decode_len=20,
                           k=3):
    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    bos_id = sp.PieceToId("<s>")
    V = sp.GetPieceSize()  # Size of vocab
    model.eval()

    with torch.no_grad():
        # initial Y_hat and batchwise score tensors
        sequences = [(
            torch.zeros(B, max_decode_len, dtype=torch.long,
                        device=X.device).fill_(bos_id),  # Y_hat
            torch.ones(B, dtype=torch.long),  # Y_hat_lengths
            torch.zeros(B, device=X.device),  # scores
            torch.zeros(B, dtype=torch.long, device=X.device),  # ended
        )]
        # walk over each item in output sequence
        for t in range(max_decode_len - 1):
            all_candidates = []
            # expand each current candidate
            for Y_hat, Y_hat_lengths, scores, ended in sequences:
                Y_hat = Y_hat.to(X.device)
                scores = scores.to(X.device)
                logits = model(X, Y_hat[:, :-1].to(X.device), X_lengths,
                               Y_hat_lengths)
                logits_t = logits[:, t, :]
                logprobs_t = F.log_softmax(logits_t, dim=-1).to(
                    scores.device)  # [B, V] tensor
                for j in range(V):
                    # TODO: Only add probability if the sequence has not ended (generated </s>)
                    log_p_j = logprobs_t[:, j]  # log p(Y_t=j | Y_{<t-1}, X)
                    candidate_Y_hat = Y_hat.clone()
                    candidate_Y_hat[:, t + 1] = j
                    candidate_Y_hat_lengths = Y_hat_lengths.clone()
                    candidate_Y_hat_lengths = j
                    # candidate_ended = ended or j == eos_id
                    if j == eos_id:
                        candidate_ended = torch.ones_like(ended)
                    else:
                        candidate_ended = ended.clone()
                    candidate = (candidate_Y_hat, candidate_Y_hat_lengths,
                                 scores + log_p_j, candidate_ended)
                    all_candidates.append(candidate)
            # stack candidates
            beam_Y, beam_Y_lengths, beam_scores = zip(*all_candidates)
            beam_Y = torch.stack(beam_Y, dim=1)  # [B, V, T]
            beam_Y_lengths = (torch.stack(beam_Y_lengths, dim=1), )  # [B, V]
            beam_scores = torch.stack(beam_scores, dim=1)  # [B, V]
            # seleck k best per batch item
            topk_scores, topk_idx = torch.topk(beam_scores,
                                               k,
                                               dim=1,
                                               sorted=True)
            topk_Y = torch.gather(
                beam_Y, 1,
                topk_idx.unsqueeze(-1).expand(B, k, max_decode_len))
            topk_Y_lengths = torch.gather(beam_Y_lengths, 1,
                                          topk_idx.unsqueeze(-1).expand(B, k))
            # set beam
            sequences = [(topk_Y[:,
                                 j, :], topk_Y_lengths[:,
                                                       j, :], topk_scores[:,
                                                                          j])
                         for j in range(k)]
            # TODO: exit early if all sentences in all beam sequences contain </s>

    # stack sequences
    beam_Y, beam_scores = zip(*sequences)
    beam_Y = torch.stack(beam_Y, dim=1)  # [B, k, T]
    beam_scores = torch.stack(beam_scores, dim=1)  # [B, k]
    model.train()
    return ids_to_strs(beam_Y, sp), beam_scores
コード例 #12
0
def get_javascript_collate(
    augmentations: List[dict],
    sp: spm.SentencePieceProcessor,
    program_mode: str,
    subword_regularization_alpha: float,
    max_length: int,
    max_target_length: int = 256,
):
    assert program_mode in ["contrastive", "augmentation", "identity"]
    bos_id = sp.PieceToId("<s>")
    eos_id = sp.PieceToId("</s>")
    pad_id = sp.PieceToId("[PAD]")

    def javascript_collate(examples: List[dict]):
        """Augments and batches a list of function dicts.

        Arguments:
            examples (List[dict[str, Any]]). The dicts must have key "function".
            augmentations (List[dict]). Augmentations to apply to the functions.
                example: [{"fn": "extract_methods"}]
            sp (SentencePieceProcessor): For tokenizing batch elements after augmentations
        """
        B = len(examples)
        if program_mode in ["contrastive", "augmentation"]:
            # Set up transformation input
            transform_payload = []
            for example in examples:
                transform_payload.append(dict(src=example["function"], augmentations=augmentations))
            if program_mode == "contrastive":
                # Augment each input function twice
                transform_payload = transform_payload + transform_payload
            X = _augment_server(transform_payload)
        else:
            X = [prog["function"] for prog in examples]

        # Normalize programs
        X = [normalize_program(prog) for prog in X]

        # Encode as ids with sentencepiece
        if subword_regularization_alpha:
            # using subword regularization: https://arxiv.org/pdf/1804.10959.pdf
            # NOTE: what is the second argument here (-1)?
            X = [sp.SampleEncodeAsIds(prog, -1, subword_regularization_alpha) for prog in X]
        else:
            # using the best decoding
            X = [sp.EncodeAsIds(prog) for prog in X]

        # Create padded tensor for batch, [B, T] or [2B, T]
        X = [torch.tensor([bos_id] + ids[: (max_length - 2)] + [eos_id]) for ids in X]
        X_lengths = torch.tensor([len(x) for x in X], dtype=torch.long)
        X = pad_sequence(X, batch_first=True, padding_value=pad_id)

        # Create padded tensor for labels (good for seq2seq tasks)
        if "label" in examples[0]:
            label = [sp.EncodeAsIds(ex["label"]) for ex in examples]
            label = [torch.tensor([bos_id] + ids[: (max_target_length - 2)] + [eos_id]) for ids in label]
            label_lengths = torch.tensor([len(l) for l in label], dtype=torch.long)
            label = pad_sequence(label, batch_first=True, padding_value=pad_id)
        else:
            label = None
            label_lengths = None

        if program_mode == "contrastive":
            # Reshape X to [B, 2, T]
            T = X.size(-1)
            X = torch.reshape(X, (2, B, -1))
            X = torch.transpose(X, 0, 1)
            assert X.shape == (B, 2, T)
            X_lengths = torch.reshape(X_lengths, (2, B))
            assert label is None, "label should be None when using contrastive program dataloader"
        return (X, label, X_lengths, label_lengths)

    return javascript_collate
コード例 #13
0
def beam_search_decode(model,
                       X,
                       X_lengths,
                       sp: spm.SentencePieceProcessor,
                       max_decode_len=20,
                       k=3):
    # TODO: Implement constrained decoding (e.g. only alphanumeric)
    B = X.size(0)
    bos_id = sp.PieceToId("<s>")
    V = sp.GetPieceSize()  # Size of vocab
    model.eval()

    with torch.no_grad():
        Y_hat_lengths = torch.ones(B, dtype=torch.long)  # Y_hat_lengths
        # initial Y_hat and batchwise score tensors
        sequences = [(
            torch.zeros(B, max_decode_len).long().to(X.device) + bos_id,
            # torch.ones(B, dtype=torch.long),  # Y_hat_lengths
            torch.zeros(B).to(X.device))]
        # walk over each item in output sequence
        for t in range(max_decode_len - 1):
            all_candidates = []
            # expand each current candidate
            for Y_hat, scores in sequences:
                Y_hat = Y_hat.to(X.device)
                scores = scores.to(X.device)
                logits = model(X,
                               Y_hat[:, :t + 1].to(X.device),
                               src_lengths=X_lengths,
                               tgt_lengths=Y_hat_lengths + 1)
                logits_t = logits[:, t, :]
                logprobs_t = F.log_softmax(logits_t, dim=-1).to(
                    scores.device)  # [B, V] tensor
                for j in range(V):
                    log_p_j = logprobs_t[:, j]  # log p(Y_t=j | Y_{<t-1}, X)
                    candidate_Y_hat = Y_hat.clone()
                    candidate_Y_hat[:, t + 1] = j
                    candidate = (candidate_Y_hat, scores + log_p_j)
                    all_candidates.append(candidate)
            # stack candidates
            beam_Y, beam_scores = zip(*all_candidates)
            beam_Y = torch.stack(beam_Y, dim=1)  # [B, V, T]
            beam_scores = torch.stack(beam_scores, dim=1)  # [B, V]
            # seleck k best per batch item
            topk_scores, topk_idx = torch.topk(beam_scores,
                                               k,
                                               dim=1,
                                               sorted=True)
            topk_Y = torch.gather(
                beam_Y, 1,
                topk_idx.unsqueeze(-1).expand(B, k, max_decode_len))
            # set beam
            sequences = [(topk_Y[:, j, :], topk_scores[:, j])
                         for j in range(k)]
            # TODO: exit early if all sentences in all beam sequences contain </s>
            Y_hat_lengths = Y_hat_lengths + 1

    # stack sequences
    beam_Y, beam_scores = zip(*sequences)
    beam_Y = torch.stack(beam_Y, dim=1)  # [B, k, T]
    beam_scores = torch.stack(beam_scores, dim=1)  # [B, k]
    model.train()
    return ids_to_strs(beam_Y, sp), beam_scores
コード例 #14
0
def convert_subword(transcript: str, sp: spm.SentencePieceProcessor):
    text = " ".join(sp.EncodeAsPieces(transcript))
    label = " ".join([str(sp.PieceToId(token)) for token in text])
    return text, label