def beamsearch_hamcycle(pred, W, beam_size=2): N = W.size(-1) batch_size = W.size(0) BS = BeamSearch(beam_size, batch_size, N) trans_probs = pred.gather(1, BS.get_current_state()) for step in range(N-1): BS.advance(trans_probs, step + 1) trans_probs = pred.gather(1, BS.get_current_state()) ends = torch.zeros(batch_size, 1).type(dtype_l) # extract paths Paths = BS.get_hyp(ends) # Compute cost of path Costs = compute_cost_path(Paths, W) return Costs, Paths
def _beam_search_decoding(self, imgs, beam_size): B = imgs.size(0) # use batch_size*beam_size as new Batch imgs = tile(imgs, beam_size, dim=0) enc_outs, hiddens = self.model.encode(imgs) dec_states, O_t = self.model.init_decoder(enc_outs, hiddens) new_B = imgs.size(0) # first decoding step's input tgt = torch.ones(new_B, 1).long() * START_TOKEN beam = BeamSearch(beam_size, B) for t in range(self.max_len): tgt = beam.current_predictions.unsqueeze(1) dec_states, O_t, probs = self.step_decoding( dec_states, O_t, enc_outs, tgt) log_probs = torch.log(probs) beam.advance(log_probs) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states h, c = dec_states h = h.index_select(0, select_indices) c = c.index_select(0, select_indices) dec_states = (h, c) O_t = O_t.index_select(0, select_indices) # get results formulas_idx = torch.stack([hyps[1] for hyps in beam.hypotheses], dim=0) results = self._idx2formulas(formulas_idx) return results