def single_queue_decode(model, encoder_output, masks, max_output_length, max_hyps=1, labels: dict = None, buffer_size: int = 1, gamma=float("-inf")): enc_outs = encoder_output.split() ys = [] gammas = [] for i, enc_out in enumerate(enc_outs): seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()} if model.is_transformer: seq_masks["trg"] = torch.ones(1, 1, 1, dtype=torch.bool, device=masks["trg"].device) #print(seq_masks["trg"].size()) if labels is not None: seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()} # as in full_traversal, we have a generator that lazily produces # hypotheses hypotheses = _sqd(model, enc_out, seq_masks, max_output_length, labels=seq_labels, buffer_size=buffer_size, gamma=gamma) total_p = 0.0 current_best = float("-inf") generated = [] for j in range(max_hyps): try: p, y = next(hypotheses) except StopIteration: break # print(p, [model.trg_vocab.itos[y_i] for y_i in y]) generated.append((p, y)) total_p += math.exp(p) if p > current_best: current_best = p # print() best_p, best_y = max(generated) if generated else (float("-inf"), (model.bos_index, )) ys.append(best_y) gammas.append(best_p) outputs = pad_and_stack_hyps(ys, model.pad_index) # any_finished = outputs.eq(3).any(dim=-1) scores = torch.FloatTensor(gammas, device=outputs.device) # assert scores[~any_finished].eq(float("-inf")).all() return outputs, scores
def iterative_deepening(model, encoder_output, masks, max_output_length, labels: dict = None, buffer_size: int = 1, max_hyps=1, verbose=False): enc_outs = encoder_output.split() ys = [] gammas = [] for i, enc_out in enumerate(enc_outs): best_y = None gamma = float("-inf") for max_len in range(max_output_length): if verbose: print("new max length", max_len) seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()} if model.is_transformer: seq_masks["trg"] = torch.ones(1, 1, 1, dtype=torch.bool, device=masks["trg"].device) if labels is not None: seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()} hypotheses = _sqd(model, enc_out, masks=seq_masks, max_output_length=max_len, gamma=gamma, labels=seq_labels, buffer_size=buffer_size) for j in range(max_hyps): try: p, y = next(hypotheses) except StopIteration: break if p > gamma: gamma = p best_y = y if verbose: y_seq = [model.trg_vocab.itos[y_i] for y_i in best_y] print(gamma, y_seq) if verbose: print() ys.append(best_y) gammas.append(gamma) outputs = pad_and_stack_hyps(ys, model.pad_index) # any_finished = outputs.eq(3).any(dim=-1) scores = torch.FloatTensor(gammas, device=outputs.device) # assert scores[~any_finished].eq(float("-inf")).all() return outputs, scores
def depth_first_search(model, encoder_output, masks, max_output_length, labels: dict = None): # non-recursive wrapper around the recursive function that does DFS enc_outs = encoder_output.split() ys = [] gammas = [] for i, enc_out in enumerate(enc_outs): seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()} if model.is_transformer: seq_masks["trg"] = torch.ones(1, 1, 1, dtype=torch.bool, device=masks["trg"].device) #print(seq_masks["trg"].size()) if labels is not None: seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()} else: seq_labels = None y, gamma = _depth_first_search( model, enc_out, seq_masks, max_output_length + 1, # fixing fencepost? labels=seq_labels) y = y[1:] if y is not None else None # necessary? ys.append(y) gammas.append(gamma) outputs = pad_and_stack_hyps(ys, model.pad_index) # any_finished = outputs.eq(3).any(dim=-1) scores = torch.FloatTensor(gammas, device=outputs.device) # assert scores[~any_finished].eq(float("-inf")).all() return outputs, scores
def full_traversal(model, encoder_output, masks, max_output_length, max_hyps=1, mode="dfs", labels: dict = None, break_at_argmax=False, break_at_p=1.0): assert mode in ["dfs", "bfs"] enc_outs = encoder_output.split() ys = [] gammas = [] n_hyps = [] for i, enc_out in enumerate(enc_outs): seq_masks = {k: v[i].unsqueeze(0) for k, v in masks.items()} if model.is_transformer: seq_masks["trg"] = torch.ones(1, 1, 1, dtype=torch.bool, device=masks["trg"].device) #print(seq_masks["trg"].size()) if labels is not None: seq_labels = {k: v[i].unsqueeze(0) for k, v in labels.items()} hypotheses = _traverse(model, enc_out, seq_masks, max_output_length, mode=mode, labels=seq_labels, prune=break_at_argmax) total_p = 0.0 current_best = float("-inf") generated = [] for j in range(max_hyps): try: p, y = next(hypotheses) # print(p, [model.trg_vocab.itos[y_i] for y_i in y]) except StopIteration: break generated.append((p, y)) total_p += math.exp(p) if p > current_best: current_best = p if total_p > break_at_p or (break_at_argmax and current_best > 1 - total_p): # either you've found the argmax or you've found a hypothesis set # that covers a very large part of the mass break best_p, best_y = max(generated) if generated else (float("-inf"), (model.bos_index, )) ys.append(best_y) gammas.append(best_p) n_hyps.append(len(generated)) print(total_p, len(generated)) outputs = pad_and_stack_hyps(ys, model.pad_index) # any_finished = outputs.eq(3).any(dim=-1) scores = torch.FloatTensor(gammas, device=outputs.device) # assert scores[~any_finished].eq(float("-inf")).all() # idea: return the size of the set of hypotheses return outputs, scores
def beam_search(model, size: int, encoder_output, masks: Dict[str, Tensor], max_output_length: int, scorer, labels: dict = None, return_scores: bool = False): """ Beam search with size k. In each decoding step, find the k most likely partial hypotheses. :param decoder: :param size: size of the beam :param encoder_output: :param masks: :param max_output_length: :param scorer: function for rescoring hypotheses :param embed: :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ transformer = model.is_transformer any_mask = next(iter(masks.values())) batch_size = any_mask.size(0) att_vectors = None # not used for Transformer device = encoder_output.device if model.is_ensemble: # run model.ensemble_bridge, I guess hidden = model.ensemble_bridge(encoder_output) else: if not transformer and model.decoder.bridge_layer is not None: hidden = model.decoder.bridge_layer(encoder_output.hidden) else: hidden = None # tile encoder states and decoder initial states beam_size times if hidden is not None: # layers x batch*k x dec_hidden_size if isinstance(hidden, list): hidden = [ tile(h, size, dim=1) if h is not None else None for h in hidden ] else: hidden = tile(hidden, size, dim=1) # encoder_output: batch*k x src_len x enc_hidden_size encoder_output.tile(size, dim=0) masks = {k: tile(v, size, dim=0) for k, v in masks.items() if k != "trg"} masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None # numbering elements in the batch batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) # beam_size copies of each batch element beam_offset = torch.arange(0, batch_size * size, step=size, dtype=torch.long, device=device) # keeps track of the top beam size hypotheses to expand for each # element in the batch to be further decoded (that are still "alive") alive_seq = beam_offset.new_full((batch_size * size, 1), model.bos_index) prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1) # Give full probability to the first beam on the first step. # pylint: disable=not-callable current_beam = torch.tensor([0.0] + [float("-inf")] * (size - 1), device=device).repeat(batch_size, 1) results = { "predictions": [[] for _ in range(batch_size)], "scores": [[] for _ in range(batch_size)], "gold_score": [0] * batch_size } for step in range(1, max_output_length + 1): # decode a single step log_probs, hidden, _, att_vectors = model.decode( trg_input=prev_y, encoder_output=encoder_output, masks=masks, decoder_hidden=hidden, prev_att_vector=att_vectors, unroll_steps=1, generate="log", labels=labels) log_probs = log_probs.squeeze(1) # log_probs: batch*k x trg_vocab # multiply probs by the beam probability (=add logprobs) raw_scores = log_probs + current_beam.view(-1).unsqueeze(1) # flatten log_probs into a list of possibilities vocab_size = log_probs.size(-1) # vocab size raw_scores = raw_scores.reshape(-1, size * vocab_size) # apply an additional scorer, such as a length penalty scores = scorer(raw_scores, step) if scorer is not None else raw_scores # pick currently best top k hypotheses (flattened order) topk_scores, topk_ids = scores.topk(size, dim=-1) # If using a length penalty, scores are distinct from log probs. # The beam keeps track of log probabilities regardless current_beam = topk_scores if scorer is None \ else raw_scores.gather(1, topk_ids) # reconstruct beam origin and true word ids from flattened order topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # map beam_index to batch_index in the flat representation b_off = beam_offset[:topk_beam_index.size(0)].unsqueeze(1) batch_index = topk_beam_index + b_off select_ix = batch_index.view(-1) # append latest prediction (result: batch_size*k x hyp_len) selected_alive_seq = alive_seq.index_select(0, select_ix) alive_seq = torch.cat([selected_alive_seq, topk_ids.view(-1, 1)], -1) is_finished = topk_ids.eq(model.eos_index) # batch x beam if step == max_output_length: is_finished.fill_(1) top_finished = is_finished[:, 0].eq(1) # batch # save finished hypotheses seq_len = alive_seq.size(-1) predictions = alive_seq.view(-1, size, seq_len) ix = top_finished.nonzero().view(-1) for i in ix: finished_scores = topk_scores[i] finished_preds = predictions[i, :, 1:] b = batch_offset[i] # if you desire more hypotheses, you can use topk/sort top_score, top_pred_ix = finished_scores.max(dim=0) top_pred = finished_preds[top_pred_ix] results["scores"][b].append(top_score) results["predictions"][b].append(top_pred) if top_finished.all(): break # remove finished batches for the next step unfinished = top_finished.eq(0).nonzero().view(-1) current_beam = current_beam.index_select(0, unfinished) batch_index = batch_index.index_select(0, unfinished) batch_offset = batch_offset.index_select(0, unfinished) alive_seq = predictions.index_select(0, unfinished).view(-1, seq_len) # reorder indices, outputs and masks select_ix = batch_index.view(-1) encoder_output.index_select(select_ix) masks = { k: v.index_select(0, select_ix) if k != "trg" else v for k, v in masks.items() } if model.is_ensemble: if not transformer: new_hidden = [] for h_i in hidden: if isinstance(h_i, tuple): # for LSTMs, states are tuples of tensors h, c = h_i h = h.index_select(1, select_ix) c = c.index_select(1, select_ix) new_h_i = h, c else: # for GRUs, states are single tensors new_h_i = h_i.index_select(1, select_ix) new_hidden.append(new_h_i) hidden = new_hidden else: if hidden is not None and not transformer: if isinstance(hidden, tuple): # for LSTMs, states are tuples of tensors h, c = hidden h = h.index_select(1, select_ix) c = c.index_select(1, select_ix) hidden = h, c else: # for GRUs, states are single tensors hidden = hidden.index_select(1, select_ix) if att_vectors is not None: if model.is_ensemble: att_vectors = [ av.index_select(0, select_ix) if av is not None else None for av in att_vectors ] else: att_vectors = att_vectors.index_select(0, select_ix) prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1) # is moving to cpu necessary/good? final_outputs = pad_and_stack_hyps( [r[0].cpu() for r in results["predictions"]], model.pad_index) if return_scores: final_scores = torch.stack([s[0] for s in results["scores"]]) return final_outputs, None, final_scores else: return final_outputs, None, None