def beam_search(decoder: Decoder, size: int, bos_index: int, eos_index: int, pad_index: int, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, max_output_length: int, alpha: float, embed: Embeddings, n_best: int = 1) -> (np.array, np.array): """ Beam search with size k. Follows OpenNMT-py implementation. In each decoding step, find the k most likely partial hypotheses. :param decoder: :param size: size of the beam :param bos_index: :param eos_index: :param pad_index: :param encoder_output: :param encoder_hidden: :param src_mask: :param max_output_length: :param alpha: `alpha` factor for length penalty :param embed: :param n_best: return this many hypotheses, <= beam :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ # init batch_size = src_mask.size(0) # pylint: disable=protected-access hidden = decoder._init_hidden(encoder_hidden) # tile hidden decoder states and encoder output beam_size times hidden = tile(hidden, size, dim=1) # layers x batch*k x dec_hidden_size att_vectors = None encoder_output = tile(encoder_output.contiguous(), size, dim=0) # batch*k x src_len x enc_hidden_size src_mask = tile(src_mask, size, dim=0) # batch*k x 1 x src_len batch_offset = torch.arange(batch_size, dtype=torch.long, device=encoder_output.device) beam_offset = torch.arange(0, batch_size * size, step=size, dtype=torch.long, device=encoder_output.device) alive_seq = torch.full([batch_size * size, 1], bos_index, dtype=torch.long, device=encoder_output.device) # Give full probability to the first beam on the first step. # pylint: disable=not-callable topk_log_probs = (torch.tensor( [0.0] + [float("-inf")] * (size - 1), device=encoder_output.device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] results = {} results["predictions"] = [[] for _ in range(batch_size)] results["scores"] = [[] for _ in range(batch_size)] results["gold_score"] = [0] * batch_size for step in range(max_output_length): decoder_input = alive_seq[:, -1].view(-1, 1) # expand current hypotheses # decode one single step # out: logits for final softmax # pylint: disable=unused-variable out, hidden, att_scores, att_vectors = decoder( encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, trg_embed=embed(decoder_input), hidden=hidden, prev_att_vector=att_vectors, unrol_steps=1) log_probs = F.log_softmax(out, dim=-1).squeeze(1) # batch*k x trg_vocab # multiply probs by the beam probability (=add logprobs) log_probs += topk_log_probs.view(-1).unsqueeze(1) curr_scores = log_probs # compute length penalty if alpha > -1: length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha curr_scores /= length_penalty # flatten log_probs into a list of possibilities curr_scores = curr_scores.reshape(-1, size * decoder.output_size) # pick currently best top k hypotheses (flattened order) topk_scores, topk_ids = curr_scores.topk(size, dim=-1) if alpha > -1: # recover original log probs topk_log_probs = topk_scores * length_penalty # reconstruct beam origin and true word ids from flattened order topk_beam_index = topk_ids.div(decoder.output_size) topk_ids = topk_ids.fmod(decoder.output_size) # map beam_index to batch_index in the flat representation batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # append latest prediction alive_seq = torch.cat( [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) # batch_size*k x hyp_len is_finished = topk_ids.eq(eos_index) if step + 1 == max_output_length: is_finished.fill_(1) # end condition is whether the top beam is finished end_condition = is_finished[:, 0].eq(1) # save finished hypotheses if is_finished.any(): predictions = alive_seq.view(-1, size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero().view(-1) # store finished hypotheses for this batch for j in finished_hyp: hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:]) # ignore start_token ) # if the batch reached the end, save the n_best hypotheses if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) for n, (score, pred) in enumerate(best_hyp): if n >= n_best: break results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(0).nonzero().view(-1) # if all sentences are translated, no need to go further # pylint: disable=len-as-condition if len(non_finished) == 0: break # remove finished batches for the next step topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # reorder indices, outputs and masks select_indices = batch_index.view(-1) encoder_output = encoder_output.index_select(0, select_indices) src_mask = src_mask.index_select(0, select_indices) if isinstance(hidden, tuple): # for LSTMs, states are tuples of tensors h, c = hidden h = h.index_select(1, select_indices) c = c.index_select(1, select_indices) hidden = (h, c) else: # for GRUs, states are single tensors hidden = hidden.index_select(1, select_indices) att_vectors = att_vectors.index_select(0, select_indices) def pad_and_stack_hyps(hyps, pad_value): filled = np.ones( (len(hyps), max([h.shape[0] for h in hyps])), dtype=int) * pad_value for j, h in enumerate(hyps): for k, i in enumerate(h): filled[j, k] = i return filled # from results to stacked outputs assert n_best == 1 # only works for n_best=1 for now final_outputs = pad_and_stack_hyps( [r[0].cpu().numpy() for r in results["predictions"]], pad_value=pad_index) # TODO also return attention scores and probabilities return final_outputs, None
def beam_search(model: Model, size: int, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, max_output_length: int, alpha: float, n_best: int = 1) -> (np.array, np.array): """ Beam search with size k. Inspired by OpenNMT-py, adapted for Transformer. In each decoding step, find the k most likely partial hypotheses. :param model: :param size: size of the beam :param encoder_output: :param encoder_hidden: :param src_mask: :param max_output_length: :param alpha: `alpha` factor for length penalty :param n_best: return this many hypotheses, <= beam (currently only 1) :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ assert size > 0, 'Beam size must be >0.' assert n_best <= size, 'Can only return {} best hypotheses.'.format(size) # init bos_index = model.bos_index eos_index = model.eos_index pad_index = model.pad_index trg_vocab_size = model.decoder.output_size device = encoder_output.device transformer = isinstance(model.decoder, TransformerDecoder) batch_size = src_mask.size(0) att_vectors = None # not used for Transformer hidden = None # not used for Transformer trg_mask = None # not used for RNN # Recurrent models only: initialize RNN hidden state # pylint: disable=protected-access if not transformer: # tile encoder states and decoder initial states beam_size times hidden = model.decoder._init_hidden(encoder_hidden) hidden = tile(hidden, size, dim=1) # layers x batch*k x dec_hidden_size # DataParallel splits batch along the 0th dim. # Place back the batch_size to the 1st dim here. if isinstance(hidden, tuple): h, c = hidden hidden = (h.permute(1, 0, 2), c.permute(1, 0, 2)) else: hidden = hidden.permute(1, 0, 2) # batch*k x layers x dec_hidden_size encoder_output = tile(encoder_output.contiguous(), size, dim=0) # batch*k x src_len x enc_hidden_size src_mask = tile(src_mask, size, dim=0) # batch*k x 1 x src_len # Transformer only: create target mask if transformer: trg_mask = src_mask.new_ones([1, 1, 1]) # transformer only if isinstance(model, torch.nn.DataParallel): trg_mask = torch.stack( [src_mask.new_ones([1, 1]) for _ in model.device_ids]) # numbering elements in the batch batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) # numbering elements in the extended batch, i.e. beam size copies of each # batch element beam_offset = torch.arange(0, batch_size * size, step=size, dtype=torch.long, device=device) # keeps track of the top beam size hypotheses to expand for each element # in the batch to be further decoded (that are still "alive") alive_seq = torch.full([batch_size * size, 1], bos_index, dtype=torch.long, device=device) # Give full probability to the first beam on the first step. topk_log_probs = torch.zeros(batch_size, size, device=device) topk_log_probs[:, 1:] = float("-inf") # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] results = { "predictions": [[] for _ in range(batch_size)], "scores": [[] for _ in range(batch_size)], "gold_score": [0] * batch_size, } for step in range(max_output_length): # This decides which part of the predicted sentence we feed to the # decoder to make the next prediction. # For Transformer, we feed the complete predicted sentence so far. # For Recurrent models, only feed the previous target word prediction if transformer: # Transformer decoder_input = alive_seq # complete prediction so far else: # Recurrent decoder_input = alive_seq[:, -1].view(-1, 1) # only the last word # expand current hypotheses # decode one single step # logits: logits for final softmax # pylint: disable=unused-variable with torch.no_grad(): logits, hidden, att_scores, att_vectors = model( return_type="decode", encoder_output=encoder_output, encoder_hidden=None, # used to initialize decoder_hidden only src_mask=src_mask, trg_input=decoder_input, #trg_embed = embed(decoder_input) decoder_hidden=hidden, att_vector=att_vectors, unroll_steps=1, trg_mask=trg_mask # subsequent mask for Transformer only ) # For the Transformer we made predictions for all time steps up to # this point, so we only want to know about the last time step. if transformer: logits = logits[:, -1] # keep only the last time step hidden = None # we don't need to keep it for transformer # batch*k x trg_vocab log_probs = F.log_softmax(logits, dim=-1).squeeze(1) # multiply probs by the beam probability (=add logprobs) log_probs += topk_log_probs.view(-1).unsqueeze(1) curr_scores = log_probs.clone() # compute length penalty if alpha > -1: length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha curr_scores /= length_penalty # flatten log_probs into a list of possibilities curr_scores = curr_scores.reshape(-1, size * trg_vocab_size) # pick currently best top k hypotheses (flattened order) topk_scores, topk_ids = curr_scores.topk(size, dim=-1) if alpha > -1: # recover original log probs topk_log_probs = topk_scores * length_penalty else: topk_log_probs = topk_scores.clone() # reconstruct beam origin and true word ids from flattened order topk_beam_index = topk_ids.floor_divide(trg_vocab_size) topk_ids = topk_ids.fmod(trg_vocab_size) # map beam_index to batch_index in the flat representation batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # append latest prediction alive_seq = torch.cat( [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) # batch_size*k x hyp_len is_finished = topk_ids.eq(eos_index) if step + 1 == max_output_length: is_finished.fill_(True) # end condition is whether the top beam is finished end_condition = is_finished[:, 0].eq(True) # save finished hypotheses if is_finished.any(): predictions = alive_seq.view(-1, size, alive_seq.size(-1)) for i in range(is_finished.size(0)): b = batch_offset[i] if end_condition[i]: is_finished[i].fill_(1) finished_hyp = is_finished[i].nonzero(as_tuple=False).view(-1) # store finished hypotheses for this batch for j in finished_hyp: # Check if the prediction has more than one EOS. # If it has more than one EOS, it means that the # prediction should have already been added to # the hypotheses, so you don't have to add them again. if (predictions[i, j, 1:] == eos_index).nonzero(as_tuple=False).numel() < 2: # ignore start_token hypotheses[b].append( (topk_scores[i, j], predictions[i, j, 1:])) # if the batch reached the end, save the n_best hypotheses if end_condition[i]: best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) for n, (score, pred) in enumerate(best_hyp): if n >= n_best: break results["scores"][b].append(score) results["predictions"][b].append(pred) non_finished = end_condition.eq(False).nonzero( as_tuple=False).view(-1) # if all sentences are translated, no need to go further # pylint: disable=len-as-condition if len(non_finished) == 0: break # remove finished batches for the next step topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) # reorder indices, outputs and masks select_indices = batch_index.view(-1) encoder_output = encoder_output.index_select(0, select_indices) src_mask = src_mask.index_select(0, select_indices) if hidden is not None and not transformer: if isinstance(hidden, tuple): # for LSTMs, states are tuples of tensors h, c = hidden h = h.index_select(0, select_indices) c = c.index_select(0, select_indices) hidden = (h, c) else: # for GRUs, states are single tensors hidden = hidden.index_select(0, select_indices) if att_vectors is not None: att_vectors = att_vectors.index_select(0, select_indices) def pad_and_stack_hyps(hyps, pad_value): filled = np.ones( (len(hyps), max([h.shape[0] for h in hyps])), dtype=int) * pad_value for j, h in enumerate(hyps): for k, i in enumerate(h): filled[j, k] = i return filled # from results to stacked outputs assert n_best == 1 # only works for n_best=1 for now final_outputs = pad_and_stack_hyps( [r[0].cpu().numpy() for r in results["predictions"]], pad_value=pad_index) return final_outputs, None
def beam_search(decoder: Decoder, generator: Gen, size: int, bos_index: int, eos_index: int, pad_index: int, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, max_output_length: int, alpha: float, embed: Embeddings, n_best: int = 1, knowledgebase: Tuple = None) -> (np.array, np.array, np.array): """ Beam search with size k. Inspired by OpenNMT-py, adapted for Transformer. In each decoding step, find the k most likely partial hypotheses. :param decoder: :param generator: :param size: size of the beam :param bos_index: :param eos_index: :param pad_index: :param encoder_output: :param encoder_hidden: :param src_mask: :param max_output_length: :param alpha: `alpha` factor for length penalty :param embed: :param n_best: return this many hypotheses, <= beam :param knowledgebase: knowledgebase tuple containing keys, values and true values for decoding :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) - stacked_kb_att_scores: kb attention scores (3d array) """ with torch.no_grad(): # initializations and so on, this should keep weird cuda errors from happening # init transformer = isinstance(decoder, TransformerDecoder) batch_size = src_mask.size(0) att_vectors = None # not used for Transformer # Recurrent models only: initialize RNN hidden state # pylint: disable=protected-access if not transformer: hidden = decoder._init_hidden(encoder_hidden) else: hidden = None # tile encoder states and decoder initial states beam_size times if hidden is not None: hidden = tile(hidden, size, dim=1) # layers x batch*k x dec_hidden_size encoder_output = tile(encoder_output.contiguous(), size, dim=0) # batch*k x src_len x enc_hidden_size src_mask = tile(src_mask, size, dim=0) # batch*k x 1 x src_len # Transformer only: create target mask if transformer: trg_mask = src_mask.new_ones([1, 1, 1]) # transformer only else: trg_mask = None # numbering elements in the batch batch_offset = torch.arange(batch_size, dtype=torch.long, device=encoder_output.device) # numbering elements in the extended batch, i.e. beam size copies of each # batch element beam_offset = torch.arange(0, batch_size * size, step=size, dtype=torch.long, device=encoder_output.device) # keeps track of the top beam size hypotheses to expand for each element # in the batch to be further decoded (that are still "alive") alive_seq = torch.full([batch_size * size, 1], bos_index, dtype=torch.long, device=encoder_output.device) # Give full probability to the first beam on the first step. # pylint: disable=not-callable topk_log_probs = torch.zeros(batch_size, size, device=encoder_output.device) topk_log_probs[:, 1:] = float("-inf") # Structure that holds finished hypotheses in order of completion. hypotheses = [[] for _ in range(batch_size)] results = {} results["predictions"] = [[] for _ in range(batch_size)] results["scores"] = [[] for _ in range(batch_size)] results["att_scores"] = [[] for _ in range(batch_size)] results["kb_att_scores"] = [[] for _ in range(batch_size)] # kb task: also tile kb tensors along batch dimension as done with other inputs above if knowledgebase != None: kb_values = tile(knowledgebase[1], size, dim=0) kb_mask = tile(knowledgebase[-1], size, dim=0) kb_values_embed = tile(knowledgebase[2], size, dim=0) kb_size = kb_values.size(1) kb_keys = knowledgebase[0] if isinstance(kb_keys, tuple): kb_keys = tuple( [tile(key_dim, size, dim=0) for key_dim in kb_keys]) else: kb_keys = tile(kb_keys, size, dim=0) att_alive = torch.Tensor( # batch * k x src x time [[[] for _ in range(encoder_output.size(1))] for _ in range(batch_size * size) ]).to(dtype=torch.float32, device=encoder_output.device) kb_att_alive = torch.Tensor( # batch*k x KB x time [[[] for _ in range(kb_size)] for _ in range(batch_size * size) ]).to(dtype=torch.float32, device=encoder_output.device) debug_tnsrs = (kb_values, kb_mask, kb_values_embed, (kb_keys if isinstance(kb_keys, torch.Tensor) else kb_keys[0]), alive_seq) assert set([t.size(0) for t in debug_tnsrs ]) == set([batch_size * size ]), [t.shape for t in debug_tnsrs] stacked_attention_scores = [[] for _ in range(batch_size)] stacked_kb_att_scores = [[] for _ in range(batch_size)] util_dims_cache = None kb_feed_hidden_cache = None else: kb_keys, kb_values, kb_mask = None, None, None kb_size = None att_alive = None kb_att_alive = None stacked_attention_scores, stacked_kb_att_scores = None, None for step in range(max_output_length): # This decides which part of the predicted sentence we feed to the # decoder to make the next prediction. # For Transformer, we feed the complete predicted sentence so far. # For Recurrent models, only feed the previous target word prediction if transformer: # Transformer decoder_input = alive_seq # complete prediction so far else: # Recurrent decoder_input = alive_seq[:, -1].view(-1, 1) # only the last word # expand current hypotheses # decode one single step # pylint: disable=unused-variable trg_embed = embed(decoder_input) hidden, att_scores, att_vectors, kb_scores, util_dims_cache, kb_feed_hidden_cache = decoder( encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, trg_embed=trg_embed, hidden=hidden, prev_att_vector=att_vectors, unroll_steps=1, trg_mask=trg_mask, # subsequent mask for Transformer only kb_keys=kb_keys, # None by default kb_mask=kb_mask, kb_values_embed=kb_values_embed, util_dims_cache=util_dims_cache, kb_feed_hidden_cache=kb_feed_hidden_cache) try: # generator applies output layer, biases towards KB values, then applies log_softmax log_probs = generator(att_vectors, kb_values=kb_values, kb_probs=kb_scores) except Exception as e: print(kb_scores.shape) print(kb_mask_before_index) print(kb_mask_after_index) raise e # hidden = ?? x batch*k x dec hidden #FIXME why 3 ?????? # att_scores = batch*k x 1 x src_len #TODO Find correct beam in dim 0 at every timestep. # att_vectors = batch*k x 1 x dec hidden # kb_scores = batch*k x 1 x KB #TODO find correct beam in dim 0 at every timestep # log_probs = batch*k x 1 x trg_voc # For the Transformer we made predictions for all time steps up to # this point, so we only want to know about the last time step. if transformer: log_probs = log_probs[:, -1] # keep only the last time step hidden = None # we don't need to keep it for transformer # batch * k x trg_vocab log_probs = log_probs.squeeze(1) # multiply probs by the probability of each beam thus far ( = add logprobs) try: log_probs += topk_log_probs.view(-1).unsqueeze(1) except Exception as e: dbg_tnsrs = [ hidden, att_scores, att_vectors, kb_scores, util_dims_cache, kb_feed_hidden_cache ] print([t.shape for t in dbg_tnsrs if isinstance(t, torch.Tensor)]) print( [t.size(0) for t in dbg_tnsrs if isinstance(t, torch.Tensor)]) print(step) print(encoder_output.shape) print(select_indices) print(batch_index) print(non_finished) print(non_finished.shape) print(batch_size * size) raise e curr_scores = log_probs # compute length penalty if alpha > -1: length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha curr_scores /= length_penalty # flatten log_probs into a list of possibilities curr_scores = curr_scores.reshape( -1, size * generator.output_size) # batch x k * voc FIXME # pick currently best top k hypotheses (flattened order) topk_scores, topk_ids = curr_scores.topk(size, dim=-1) # each: batch x k if alpha > -1: # recover original log probs topk_log_probs = topk_scores * length_penalty # b x k # reconstruct beam origin and true word ids from flattened order topk_beam_index = (topk_ids // generator.output_size).to( dtype=torch.int64 ) # NOTE why divide by voc size?? this should always be 0 topk_ids = topk_ids.fmod( generator.output_size ) # NOTE why mod voc size? isnt every entry < voc size? # map beam_index to batch_index in the flat representation batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # batch * k # append latest prediction alive_seq = torch.cat( [ alive_seq.index_select( 0, select_indices ), # index first dim (batch * k) with the beams we want to continue this step topk_ids.view(-1, 1) ], -1) # batch_size*k x hyp_len if knowledgebase is not None: # print(f"kb_att_alive.shape: {kb_att_alive.shape}") # print(f"kb_size: {kb_size}") # print(kb_att_alive.index_select(0,select_indices).shape) # print(kb_scores.transpose(1,2).index_select(0,select_indices).shape) if att_scores is not None: # FIXME sometimes this way sometimes the other idk try: att_alive = torch.cat( # batch * k x src len x time [ att_alive.index_select(0, select_indices), att_scores.transpose(1, 2).index_select( 0, select_indices).contiguous() ], -1) except Exception as e: print(f"step: {step}") print(select_indices) print(f"att_alive.shape: {att_alive.shape}") print(f"encoder steps: {encoder_output.size(1)}") print( att_scores.transpose(1, 2).index_select( 0, select_indices).shape) raise e kb_att_alive = torch.cat( # batch * k x KB x time [ kb_att_alive.index_select(0, select_indices), kb_scores.transpose(1, 2).index_select( 0, select_indices).contiguous() ], -1) # which batches are finished? is_finished = topk_ids.eq(eos_index) # batch x k if step + 1 == max_output_length: # force finish is_finished.fill_(True) # end condition is whether the top beam of given batch is finished end_condition = is_finished[:, 0].eq(True) # save finished hypotheses if any of the batches finished if is_finished.any(): predictions = alive_seq.view( -1, size, alive_seq.size(-1)) # batch x k x time for i in range(is_finished.size(0)): # iter over batches b = batch_offset[i] if end_condition[i]: # this batch finished is_finished[i].fill_(True) finished_hyp = is_finished[i].nonzero(as_tuple=False).view( -1) # k # store finished hypotheses for this batch # (that doesnt mean the batch is completely finished, # hence the list 'hypotheses' is maintained outside the unroll loop) for j in finished_hyp: # iter over finished beams # first time EOS appears in this beam, save it as hypothesis # (also save attentions here) if (predictions[i, j, 1:] == eos_index).nonzero(as_tuple=False).numel() < 2: hypotheses[b].append(( topk_scores[ i, j], # for sorting beams by prob (below) predictions[i, j, 1:]) # ignore BOS token ) if knowledgebase is not None: # batch x k x src len x time if 0 not in att_alive.shape: # at least one attention matrix has been inserted attentions = att_alive.view( -1, size, att_alive.size(-2), att_alive.size(-1)) stacked_attention_scores[b].append( attentions[i, j].cpu().numpy()) else: attentions = None # batch x k x KB x time kb_attentions = kb_att_alive.view( -1, size, kb_att_alive.size(-2), kb_att_alive.size(-1)) stacked_kb_att_scores[b].append( kb_attentions[i, j].cpu().numpy()) # if the batch reached the end, save the n best hypotheses (and their attentions and kb attentions) if end_condition[i]: # (hypotheses[b] is list of the completed hypotheses of this batch in order of completion => find out which is best) # (stacked_attention_scores[b] and stacked_kb_att_scores[b] are also in order of completion) # which beam is best? best_hyps_descending = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) dbg = np.array( [hyp[1].cpu().numpy() for hyp in best_hyps_descending]) print(dbg.shape, dbg[0]) if knowledgebase is not None: print(hypotheses[b][0], type(hypotheses[b][0])) scores, hyps = zip(*hypotheses[b]) sort_key = np.array(scores) hyps = np.array([hyp.cpu().numpy() for hyp in hyps]) # indices that would sort hyp[b] in descending order of beam score best_hyps_idx = np.argsort(sort_key)[::-1].copy() best_hyps_d_ = hyps[best_hyps_idx] # sanity check implementation try: assert set([(t1 == t2).all() for t1, t2 in zip(best_hyps_d_, dbg) ]) == {True} except Exception as e: print(best_hyps_d_.dtype) print(dbg.dtype) print([[t.dtype for t in tup] for tup in (best_hyps_d_, dbg)]) raise e assert n_best == 1, f"This is a massive clutch: Currently indexing only top 1 beam while saving attentions" # FIXME TODO NOTE XXX if 0 not in att_alive.shape: best_atts_d_ = [ stacked_attention_scores[b][best_hyps_idx[0]] ] else: best_atts_d_ = None best_kb_atts_d_ = [ stacked_kb_att_scores[b][best_hyps_idx[0]] ] # TODO replace best_hyps_descending with best_hyps_d_ FIXME XXX (after cluster beam test) for n, (score, pred) in enumerate(best_hyps_descending): if n >= n_best: break results["scores"][b].append(score) results["predictions"][b].append(pred) if knowledgebase is not None: if best_atts_d_ is not None: results["att_scores"][b].append( best_atts_d_[n]) results["kb_att_scores"][b].append( best_kb_atts_d_[n]) non_finished = end_condition.eq(False).nonzero( as_tuple=False).view(-1) # batch # if all sentences are translated, no need to go further # pylint: disable=len-as-condition if len(non_finished) == 0: break # remove finished batches for the next step batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) topk_log_probs = topk_log_probs.index_select(0, non_finished) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) if knowledgebase is not None: # briefly go to # batch x k x time x att # to easily index_select finished batches in batch dimension 0 # afterwards reshape to # batch * k x time x att # where att = src_len for alive attentions, and att = kb_size for kb_attentions alive if 0 not in att_alive.shape: att_alive = att_alive.view(-1, size, att_alive.size(-2), att_alive.size(-1)) \ .index_select(0, non_finished) att_alive = att_alive.view(-1, att_alive.size(-2), att_alive.size(-1)) kb_att_alive = kb_att_alive.view(-1, size, kb_att_alive.size(-2), kb_att_alive.size(-1)) \ .index_select(0, non_finished) kb_att_alive = kb_att_alive.view(-1, kb_att_alive.size(-2), kb_att_alive.size(-1)) # reorder indices, outputs and masks using this select_indices = batch_index.view(-1) encoder_output = encoder_output.index_select(0, select_indices) src_mask = src_mask.index_select(0, select_indices) # for transformer if hidden is not None and not transformer: # reshape hidden to correct shape for next step if isinstance(hidden, tuple): # for LSTMs, states are tuples of tensors h, c = hidden h = h.index_select(1, select_indices) c = c.index_select(1, select_indices) hidden = (h, c) else: # for GRUs, states are single tensors hidden = hidden.index_select(1, select_indices) if att_vectors is not None: if isinstance(att_vectors, tuple): att_vectors = tuple([ att_v.index_select(0, select_indices) for att_v in att_vectors ]) else: att_vectors = att_vectors.index_select(0, select_indices) if knowledgebase is not None: kb_values = kb_values.index_select(0, select_indices) if isinstance(kb_keys, tuple): kb_keys = tuple([ key_dim.index_select(0, select_indices) for key_dim in kb_keys ]) else: kb_keys = kb_keys.index_select(0, select_indices) if util_dims_cache is not None: util_dims_cache = [ utils.index_select(0, select_indices) for utils in util_dims_cache if utils is not None ] if kb_feed_hidden_cache is not None: try: kb_feed_hidden_cache = [ kbf_hidden.index_select(0, select_indices) for kbf_hidden in kb_feed_hidden_cache if kbf_hidden is not None ] except IndexError as IE: print(hidden[0].shape) print([t.shape for t in kb_feed_hidden_cache]) print(select_indices) print(select_indices.shape) print(size) print(generator.output_size) raise IE kb_mask_before_index = kb_mask.shape kb_mask = kb_mask.index_select(0, select_indices) kb_mask_after_index = kb_mask.shape def pad_and_stack_hyps(hyps, pad_value): # hyps is arrays of hypotheses filled = np.ones( (len(hyps), max([h.shape[0] for h in hyps])), dtype=int) * pad_value for j, h in enumerate(hyps): for k, i in enumerate(h): filled[j, k] = i return filled def pad_and_stack_attention_matrices(atts, pad_value=float("-inf")): assert len(list(set([att.shape[1] for att in atts]))) == 1, \ f"attention matrices have differing attention key bag dimension: {[att.shape[1] for att in atts]}" # atts is array of attention matrices, each of dims time x att_dim, where time dims may vary from matrix to matrix # NOTE pad_value is used in model.postprocess to recover original part of matrix try: filled = np.ones( (len(atts), max([att.shape[-2] for att in atts]), atts[0].shape[-1]), dtype=atts[0].dtype) filled = filled * pad_value except Exception as e: print(atts[0].shape) raise e for batch_element_index, attention_matrix in enumerate(atts): for t, attentions_at_decoding_step in enumerate(attention_matrix): for attention_key, score in enumerate( attentions_at_decoding_step): filled[batch_element_index, t, attention_key] = score return filled # b x time x attention keys # from results to stacked outputs assert n_best == 1 # only works for n_best=1 for now # final_outputs = batch x time final_outputs = pad_and_stack_hyps( [r[0].cpu().numpy() for r in results["predictions"]], pad_value=pad_index) if knowledgebase is not None: # TODO FIXME confirm this implementation # stacked_attention_scores: batch x max output len x src len if len(results["att_scores"][0]): stacked_attention_scores = pad_and_stack_attention_matrices( [atts[0].T for atts in results["att_scores"]]) else: stacked_attention_scores = None # stacked_kb_att_scores: batch x max output len x kb stacked_kb_att_scores = pad_and_stack_attention_matrices( [kb_atts[0].T for kb_atts in results["kb_att_scores"]]) return final_outputs, stacked_attention_scores, stacked_kb_att_scores
def tile(self, size, dim): self.states = { k: tile(v.contiguous(), size, dim=dim) for k, v in self.states.items() }
def tile(self, size, dim): self.states = tile(self.states.contiguous(), size, dim=dim)
def beam_search(model, size: int, encoder_output, masks: Dict[str, Tensor], max_output_length: int, scorer, labels: dict = None, return_scores: bool = False): """ Beam search with size k. In each decoding step, find the k most likely partial hypotheses. :param decoder: :param size: size of the beam :param encoder_output: :param masks: :param max_output_length: :param scorer: function for rescoring hypotheses :param embed: :return: - stacked_output: output hypotheses (2d array of indices), - stacked_attention_scores: attention scores (3d array) """ transformer = model.is_transformer any_mask = next(iter(masks.values())) batch_size = any_mask.size(0) att_vectors = None # not used for Transformer device = encoder_output.device if model.is_ensemble: # run model.ensemble_bridge, I guess hidden = model.ensemble_bridge(encoder_output) else: if not transformer and model.decoder.bridge_layer is not None: hidden = model.decoder.bridge_layer(encoder_output.hidden) else: hidden = None # tile encoder states and decoder initial states beam_size times if hidden is not None: # layers x batch*k x dec_hidden_size if isinstance(hidden, list): hidden = [ tile(h, size, dim=1) if h is not None else None for h in hidden ] else: hidden = tile(hidden, size, dim=1) # encoder_output: batch*k x src_len x enc_hidden_size encoder_output.tile(size, dim=0) masks = {k: tile(v, size, dim=0) for k, v in masks.items() if k != "trg"} masks["trg"] = any_mask.new_ones([1, 1, 1]) if transformer else None # numbering elements in the batch batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) # beam_size copies of each batch element beam_offset = torch.arange(0, batch_size * size, step=size, dtype=torch.long, device=device) # keeps track of the top beam size hypotheses to expand for each # element in the batch to be further decoded (that are still "alive") alive_seq = beam_offset.new_full((batch_size * size, 1), model.bos_index) prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1) # Give full probability to the first beam on the first step. # pylint: disable=not-callable current_beam = torch.tensor([0.0] + [float("-inf")] * (size - 1), device=device).repeat(batch_size, 1) results = { "predictions": [[] for _ in range(batch_size)], "scores": [[] for _ in range(batch_size)], "gold_score": [0] * batch_size } for step in range(1, max_output_length + 1): # decode a single step log_probs, hidden, _, att_vectors = model.decode( trg_input=prev_y, encoder_output=encoder_output, masks=masks, decoder_hidden=hidden, prev_att_vector=att_vectors, unroll_steps=1, generate="log", labels=labels) log_probs = log_probs.squeeze(1) # log_probs: batch*k x trg_vocab # multiply probs by the beam probability (=add logprobs) raw_scores = log_probs + current_beam.view(-1).unsqueeze(1) # flatten log_probs into a list of possibilities vocab_size = log_probs.size(-1) # vocab size raw_scores = raw_scores.reshape(-1, size * vocab_size) # apply an additional scorer, such as a length penalty scores = scorer(raw_scores, step) if scorer is not None else raw_scores # pick currently best top k hypotheses (flattened order) topk_scores, topk_ids = scores.topk(size, dim=-1) # If using a length penalty, scores are distinct from log probs. # The beam keeps track of log probabilities regardless current_beam = topk_scores if scorer is None \ else raw_scores.gather(1, topk_ids) # reconstruct beam origin and true word ids from flattened order topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # map beam_index to batch_index in the flat representation b_off = beam_offset[:topk_beam_index.size(0)].unsqueeze(1) batch_index = topk_beam_index + b_off select_ix = batch_index.view(-1) # append latest prediction (result: batch_size*k x hyp_len) selected_alive_seq = alive_seq.index_select(0, select_ix) alive_seq = torch.cat([selected_alive_seq, topk_ids.view(-1, 1)], -1) is_finished = topk_ids.eq(model.eos_index) # batch x beam if step == max_output_length: is_finished.fill_(1) top_finished = is_finished[:, 0].eq(1) # batch # save finished hypotheses seq_len = alive_seq.size(-1) predictions = alive_seq.view(-1, size, seq_len) ix = top_finished.nonzero().view(-1) for i in ix: finished_scores = topk_scores[i] finished_preds = predictions[i, :, 1:] b = batch_offset[i] # if you desire more hypotheses, you can use topk/sort top_score, top_pred_ix = finished_scores.max(dim=0) top_pred = finished_preds[top_pred_ix] results["scores"][b].append(top_score) results["predictions"][b].append(top_pred) if top_finished.all(): break # remove finished batches for the next step unfinished = top_finished.eq(0).nonzero().view(-1) current_beam = current_beam.index_select(0, unfinished) batch_index = batch_index.index_select(0, unfinished) batch_offset = batch_offset.index_select(0, unfinished) alive_seq = predictions.index_select(0, unfinished).view(-1, seq_len) # reorder indices, outputs and masks select_ix = batch_index.view(-1) encoder_output.index_select(select_ix) masks = { k: v.index_select(0, select_ix) if k != "trg" else v for k, v in masks.items() } if model.is_ensemble: if not transformer: new_hidden = [] for h_i in hidden: if isinstance(h_i, tuple): # for LSTMs, states are tuples of tensors h, c = h_i h = h.index_select(1, select_ix) c = c.index_select(1, select_ix) new_h_i = h, c else: # for GRUs, states are single tensors new_h_i = h_i.index_select(1, select_ix) new_hidden.append(new_h_i) hidden = new_hidden else: if hidden is not None and not transformer: if isinstance(hidden, tuple): # for LSTMs, states are tuples of tensors h, c = hidden h = h.index_select(1, select_ix) c = c.index_select(1, select_ix) hidden = h, c else: # for GRUs, states are single tensors hidden = hidden.index_select(1, select_ix) if att_vectors is not None: if model.is_ensemble: att_vectors = [ av.index_select(0, select_ix) if av is not None else None for av in att_vectors ] else: att_vectors = att_vectors.index_select(0, select_ix) prev_y = alive_seq if transformer else alive_seq[:, -1].view(-1, 1) # is moving to cpu necessary/good? final_outputs = pad_and_stack_hyps( [r[0].cpu() for r in results["predictions"]], model.pad_index) if return_scores: final_scores = torch.stack([s[0] for s in results["scores"]]) return final_outputs, None, final_scores else: return final_outputs, None, None