def calculate_nll( model, test_loader, sp: spm.SentencePieceProcessor, use_cuda=True, logger_fn=None ): with Timer() as t: pad_id = sp.PieceToId("[PAD]") n_examples = 0 test_nll = 0. pbar = tqdm.tqdm(test_loader, desc="test") for X, Y, X_lengths, Y_lengths in pbar: B, L = X.shape if use_cuda: X, Y = X.cuda(), Y.cuda() # B, L X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() pred_y = model(X, Y[:, :-1].to(X.device), X_lengths, Y_lengths) B, X, D = pred_y.shape loss = F.cross_entropy(pred_y.reshape(B * X, D), Y[:, 1:].reshape(B * X), ignore_index=pad_id, reduction='sum') n_examples += B test_nll += loss.item() if logger_fn is not None: logger_fn({'test_nll': loss.item() / B, 'test_nll_avg': test_nll / n_examples}) return test_nll / n_examples
def calculate_nll(model, test_loader, sp: spm.SentencePieceProcessor, use_cuda=True, logger_fn=None): pad_id = sp.PieceToId("[PAD]") n_examples = 0 test_nll = 0.0 with tqdm.tqdm(test_loader, desc="Test (NLL)") as pbar: for X, Y, X_lengths, Y_lengths in pbar: B, L = X.shape if use_cuda: X, Y = X.cuda(), Y.cuda() # B, L X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() pred_y = model(X, Y[:, :-1].to(X.device), X_lengths, Y_lengths) B, X, D = pred_y.shape loss = F.cross_entropy(pred_y.reshape(B * X, D), Y[:, 1:].reshape(B * X), ignore_index=pad_id, reduction="sum") n_examples += B test_nll += loss.item() metric_dict = { "test_nll": loss.item() / B, "test_nll_avg": test_nll / n_examples } if logger_fn is not None: logger_fn(metric_dict) pbar.set_postfix(metric_dict) return test_nll / n_examples
def encode_comment(sp_model: sentencepiece.SentencePieceProcessor, comment: str, max_len=None) -> List[int]: """ Encode one comment with sentencepiece model. """ # TODO we can do sub-word augmentation here start = sp_model.PieceToId('<s>') end = sp_model.PieceToId('</s>') eol = sp_model.PieceToId(EOL) encoded = [start] for i, line in enumerate(comment.split('\n')): if i: encoded.append(eol) encoded.extend(sp_model.EncodeAsIds(line)) encoded.append(end) if max_len is not None: encoded = encoded[:max_len] return encoded
def greedy_decode(model, X, sp: spm.SentencePieceProcessor, max_decode_len=20, sample=True): start_token = sp.PieceToId("<s>") pad_token = sp.PieceToId("<pad>") B = X.size(0) model.eval() with torch.no_grad(): decoded_batch = torch.zeros((B, 1), device=X.device).long() decoded_batch[:, 0] = start_token for t in range(max_decode_len): logits = model(X, decoded_batch) _, topi = logits[:, -1, :].topk(1) decoded_batch = torch.cat((decoded_batch, topi.view(-1, 1)), -1) Y_hat = decoded_batch.cpu().numpy() Y_hat_str = ids_to_strs(Y_hat, sp) model.train() return Y_hat_str
def _evaluate( model, loader, sp: spm.SentencePieceProcessor, use_cuda=True, num_to_print=8, beam_search_k=5, max_decode_len=20, loss_type="nll_token" ): model.eval() pad_id = sp.PieceToId("[PAD]") with torch.no_grad(): # with Timer() as t: # # Decode a single batch by beam search for visualization # X, Y, X_lengths, _ = next(iter(loader)) # X, Y = X[:num_to_print], Y[:num_to_print] # if use_cuda: # X = X.cuda() # X_lengths = X.cuda() # pred, scores = beam_search_decode(model, X, X_lengths, sp, k=beam_search_k, max_decode_len=max_decode_len) # for i in range(X.size(0)): # logger.info(f"Eval X: \t\t\t{ids_to_strs(X[i], sp)}") # logger.info(f"Eval GT Y:\t\t\t{ids_to_strs(Y[i], sp)}") # for b in range(scores.size(1)): # logger.info(f"Eval beam (score={scores[i, b]:.3f}):\t{pred[i][b]}") # logger.debug(f"Decode time for {num_to_print} samples took {t.interval:.3f}") with Timer() as t: # Compute average loss total_loss = 0 num_examples = 0 pbar = tqdm.tqdm(loader, desc="evalaute") for X, Y, X_lengths, Y_lengths in pbar: if use_cuda: X, Y = X.cuda(), Y.cuda() X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() # NOTE: X and Y are [B, max_seq_len] tensors (batch first) logits = model(X, Y[:, :-1], X_lengths, Y_lengths) if loss_type == "nll_sequence": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction="sum") loss = loss / X.size(0) # Average over num sequences, not target sequence lengths # Thus, minimize bits per sequence. elif loss_type == "nll_token": loss = F.cross_entropy( logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, ) # TODO: Compute Precision/Recall/F1 and BLEU total_loss += loss.item() * X.size(0) num_examples += X.size(0) avg_loss = total_loss / num_examples pbar.set_description(f"evaluate average loss {avg_loss:.4f}") logger.debug(f"Loss calculation took {t.interval:.3f}s") return avg_loss
class SentencePieceTokenizer: def __init__(self, spm_file, do_lower_case=True): self.processor = SentencePieceProcessor() self.processor.Load(spm_file) self.do_lower_case = do_lower_case def tokenize(self, text): text = preprocess_text(text, lower=self.do_lower_case) pieces = encode_pieces(self.processor, text, sample=False) return pieces def convert_tokens_to_ids(self, tokens): return [self.processor.PieceToId(piece) for piece in tokens] def convert_ids_to_tokens(self, ids): pieces = [self.processor.IdToPiece(_id) for _id in ids] return pieces
class SentencePieceTokenizer: def __init__(self, spm_file, do_lower_case=True): if not os.path.exists(spm_file): raise ValueError( "Can't find spm_file \"%s\". " "Please pass the correct path of sentence-piece model file, " "e.g.`spiece.model`." % spm_file ) self.processor = SentencePieceProcessor() self.processor.Load(spm_file) self.do_lower_case = do_lower_case def tokenize(self, text): text = preprocess_text(text, lower=self.do_lower_case) pieces = encode_pieces(self.processor, text, sample=False) return pieces def convert_tokens_to_ids(self, tokens): return [self.processor.PieceToId(piece) for piece in tokens] def convert_ids_to_tokens(self, ids): pieces = [self.processor.IdToPiece(_id) for _id in ids] return pieces
def greedy_decode(model, X, sp: spm.SentencePieceProcessor, max_decode_len=20, sample=True): # TODO: Implement constrained decoding (e.g. only alphanumeric) B = X.size(0) model.eval() with torch.no_grad(): Y_hat = torch.zeros(B, max_decode_len, device=X.device).long() Y_hat.fill_(sp.PieceToId("<s>")) for t in range(max_decode_len - 1): logits = model(X, Y_hat) if sample: idx_t = torch.distributions.categorical.Categorical( logits=logits[:, t, :]).sample() else: idx_t = logits[:, t, :].argmax(dim=-1) Y_hat[:, t + 1] = idx_t Y_hat = Y_hat.cpu().numpy() Y_hat_str = ids_to_strs(Y_hat, sp) model.train() return Y_hat_str
def encode_ids(sp_model: spm.SentencePieceProcessor, text: str, sample: bool = False) -> List[int]: pieces = encode_pieces(sp_model, text, sample=sample) ids = [sp_model.PieceToId(piece) for piece in pieces] return ids
def beam_search_decode( model, X, sp: spm.SentencePieceProcessor, max_decode_len, k, per_node_k=None, constrain_decoding=False, sampler="deterministic", top_p_threshold=0.9, top_p_temperature=1.0, ): if sampler == "top_p": sampler = allennlp.nn.beam_search.TopPSampler( p=top_p_threshold, temperature=top_p_temperature) elif sampler == "deterministic": sampler = None else: raise ValueError("Unsupported sampler") # TODO: Implement constrained decoding (e.g. only alphanumeric) B = X.size(0) pad_id = sp.PieceToId("[PAD]") bos_id = sp.PieceToId("<s>") eos_id = sp.PieceToId("</s>") V_full = sp.GetPieceSize() # Size of vocab invalid_vocab_mask = torch.zeros(V_full, dtype=torch.bool, device=X.device) if constrain_decoding: alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_ " for id in range(V_full): piece = sp.IdToPiece(id) if not (id in [pad_id, bos_id, eos_id] or all(c in alphabet for c in piece)): invalid_vocab_mask[id] = True V = V_full model.eval() # Encode X allen_bs = allennlp.nn.beam_search.BeamSearch( end_index=eos_id, max_steps=max_decode_len, beam_size=k, per_node_beam_size=per_node_k, sampler=sampler, ) start_predictions = torch.tensor([bos_id] * B, dtype=torch.long, device=X.device) start_state = { "prev_tokens": torch.zeros(B, 0, dtype=torch.long, device=X.device), "memory": model.encode(X).transpose(0, 1), # [B, T, d_model] } def step(last_tokens, current_state, t): """ Args: last_tokens: (group_size,) current_state: {} t: int """ group_size = last_tokens.size(0) prev_tokens = torch.cat( [current_state["prev_tokens"], last_tokens.unsqueeze(1)], dim=-1) # [B*k, t+1] all_log_probs = model.decode(current_state["memory"].transpose(0, 1), prev_tokens) next_log_probs = all_log_probs[:, -1, :] if constrain_decoding: next_log_probs = next_log_probs.masked_fill( invalid_vocab_mask, float("-inf")) next_log_probs = torch.nn.functional.log_softmax(next_log_probs, dim=-1) assert next_log_probs.shape == (group_size, V) return (next_log_probs, { "prev_tokens": prev_tokens, "memory": current_state["memory"] }) predictions, log_probs = allen_bs.search( start_predictions=start_predictions, start_state=start_state, step=step) model.train() prediction = ids_to_strs(predictions, sp) return prediction, log_probs
def beam_search_decode_eos(model, X, X_lengths, sp: spm.SentencePieceProcessor, eos_id, max_decode_len=20, k=3): # TODO: Implement constrained decoding (e.g. only alphanumeric) B = X.size(0) bos_id = sp.PieceToId("<s>") V = sp.GetPieceSize() # Size of vocab model.eval() with torch.no_grad(): # initial Y_hat and batchwise score tensors sequences = [( torch.zeros(B, max_decode_len, dtype=torch.long, device=X.device).fill_(bos_id), # Y_hat torch.ones(B, dtype=torch.long), # Y_hat_lengths torch.zeros(B, device=X.device), # scores torch.zeros(B, dtype=torch.long, device=X.device), # ended )] # walk over each item in output sequence for t in range(max_decode_len - 1): all_candidates = [] # expand each current candidate for Y_hat, Y_hat_lengths, scores, ended in sequences: Y_hat = Y_hat.to(X.device) scores = scores.to(X.device) logits = model(X, Y_hat[:, :-1].to(X.device), X_lengths, Y_hat_lengths) logits_t = logits[:, t, :] logprobs_t = F.log_softmax(logits_t, dim=-1).to( scores.device) # [B, V] tensor for j in range(V): # TODO: Only add probability if the sequence has not ended (generated </s>) log_p_j = logprobs_t[:, j] # log p(Y_t=j | Y_{<t-1}, X) candidate_Y_hat = Y_hat.clone() candidate_Y_hat[:, t + 1] = j candidate_Y_hat_lengths = Y_hat_lengths.clone() candidate_Y_hat_lengths = j # candidate_ended = ended or j == eos_id if j == eos_id: candidate_ended = torch.ones_like(ended) else: candidate_ended = ended.clone() candidate = (candidate_Y_hat, candidate_Y_hat_lengths, scores + log_p_j, candidate_ended) all_candidates.append(candidate) # stack candidates beam_Y, beam_Y_lengths, beam_scores = zip(*all_candidates) beam_Y = torch.stack(beam_Y, dim=1) # [B, V, T] beam_Y_lengths = (torch.stack(beam_Y_lengths, dim=1), ) # [B, V] beam_scores = torch.stack(beam_scores, dim=1) # [B, V] # seleck k best per batch item topk_scores, topk_idx = torch.topk(beam_scores, k, dim=1, sorted=True) topk_Y = torch.gather( beam_Y, 1, topk_idx.unsqueeze(-1).expand(B, k, max_decode_len)) topk_Y_lengths = torch.gather(beam_Y_lengths, 1, topk_idx.unsqueeze(-1).expand(B, k)) # set beam sequences = [(topk_Y[:, j, :], topk_Y_lengths[:, j, :], topk_scores[:, j]) for j in range(k)] # TODO: exit early if all sentences in all beam sequences contain </s> # stack sequences beam_Y, beam_scores = zip(*sequences) beam_Y = torch.stack(beam_Y, dim=1) # [B, k, T] beam_scores = torch.stack(beam_scores, dim=1) # [B, k] model.train() return ids_to_strs(beam_Y, sp), beam_scores
def get_javascript_collate( augmentations: List[dict], sp: spm.SentencePieceProcessor, program_mode: str, subword_regularization_alpha: float, max_length: int, max_target_length: int = 256, ): assert program_mode in ["contrastive", "augmentation", "identity"] bos_id = sp.PieceToId("<s>") eos_id = sp.PieceToId("</s>") pad_id = sp.PieceToId("[PAD]") def javascript_collate(examples: List[dict]): """Augments and batches a list of function dicts. Arguments: examples (List[dict[str, Any]]). The dicts must have key "function". augmentations (List[dict]). Augmentations to apply to the functions. example: [{"fn": "extract_methods"}] sp (SentencePieceProcessor): For tokenizing batch elements after augmentations """ B = len(examples) if program_mode in ["contrastive", "augmentation"]: # Set up transformation input transform_payload = [] for example in examples: transform_payload.append(dict(src=example["function"], augmentations=augmentations)) if program_mode == "contrastive": # Augment each input function twice transform_payload = transform_payload + transform_payload X = _augment_server(transform_payload) else: X = [prog["function"] for prog in examples] # Normalize programs X = [normalize_program(prog) for prog in X] # Encode as ids with sentencepiece if subword_regularization_alpha: # using subword regularization: https://arxiv.org/pdf/1804.10959.pdf # NOTE: what is the second argument here (-1)? X = [sp.SampleEncodeAsIds(prog, -1, subword_regularization_alpha) for prog in X] else: # using the best decoding X = [sp.EncodeAsIds(prog) for prog in X] # Create padded tensor for batch, [B, T] or [2B, T] X = [torch.tensor([bos_id] + ids[: (max_length - 2)] + [eos_id]) for ids in X] X_lengths = torch.tensor([len(x) for x in X], dtype=torch.long) X = pad_sequence(X, batch_first=True, padding_value=pad_id) # Create padded tensor for labels (good for seq2seq tasks) if "label" in examples[0]: label = [sp.EncodeAsIds(ex["label"]) for ex in examples] label = [torch.tensor([bos_id] + ids[: (max_target_length - 2)] + [eos_id]) for ids in label] label_lengths = torch.tensor([len(l) for l in label], dtype=torch.long) label = pad_sequence(label, batch_first=True, padding_value=pad_id) else: label = None label_lengths = None if program_mode == "contrastive": # Reshape X to [B, 2, T] T = X.size(-1) X = torch.reshape(X, (2, B, -1)) X = torch.transpose(X, 0, 1) assert X.shape == (B, 2, T) X_lengths = torch.reshape(X_lengths, (2, B)) assert label is None, "label should be None when using contrastive program dataloader" return (X, label, X_lengths, label_lengths) return javascript_collate
def beam_search_decode(model, X, X_lengths, sp: spm.SentencePieceProcessor, max_decode_len=20, k=3): # TODO: Implement constrained decoding (e.g. only alphanumeric) B = X.size(0) bos_id = sp.PieceToId("<s>") V = sp.GetPieceSize() # Size of vocab model.eval() with torch.no_grad(): Y_hat_lengths = torch.ones(B, dtype=torch.long) # Y_hat_lengths # initial Y_hat and batchwise score tensors sequences = [( torch.zeros(B, max_decode_len).long().to(X.device) + bos_id, # torch.ones(B, dtype=torch.long), # Y_hat_lengths torch.zeros(B).to(X.device))] # walk over each item in output sequence for t in range(max_decode_len - 1): all_candidates = [] # expand each current candidate for Y_hat, scores in sequences: Y_hat = Y_hat.to(X.device) scores = scores.to(X.device) logits = model(X, Y_hat[:, :t + 1].to(X.device), src_lengths=X_lengths, tgt_lengths=Y_hat_lengths + 1) logits_t = logits[:, t, :] logprobs_t = F.log_softmax(logits_t, dim=-1).to( scores.device) # [B, V] tensor for j in range(V): log_p_j = logprobs_t[:, j] # log p(Y_t=j | Y_{<t-1}, X) candidate_Y_hat = Y_hat.clone() candidate_Y_hat[:, t + 1] = j candidate = (candidate_Y_hat, scores + log_p_j) all_candidates.append(candidate) # stack candidates beam_Y, beam_scores = zip(*all_candidates) beam_Y = torch.stack(beam_Y, dim=1) # [B, V, T] beam_scores = torch.stack(beam_scores, dim=1) # [B, V] # seleck k best per batch item topk_scores, topk_idx = torch.topk(beam_scores, k, dim=1, sorted=True) topk_Y = torch.gather( beam_Y, 1, topk_idx.unsqueeze(-1).expand(B, k, max_decode_len)) # set beam sequences = [(topk_Y[:, j, :], topk_scores[:, j]) for j in range(k)] # TODO: exit early if all sentences in all beam sequences contain </s> Y_hat_lengths = Y_hat_lengths + 1 # stack sequences beam_Y, beam_scores = zip(*sequences) beam_Y = torch.stack(beam_Y, dim=1) # [B, k, T] beam_scores = torch.stack(beam_scores, dim=1) # [B, k] model.train() return ids_to_strs(beam_Y, sp), beam_scores
def convert_subword(transcript: str, sp: spm.SentencePieceProcessor): text = " ".join(sp.EncodeAsPieces(transcript)) label = " ".join([str(sp.PieceToId(token)) for token in text]) return text, label