def __init__( self, beam_size: int, per_node_beam_size: int = None, initial_sequence: torch.Tensor = None, keep_beam_details: bool = False, ) -> None: self._beam_size = beam_size self._per_node_beam_size = per_node_beam_size or beam_size if initial_sequence is not None: # construct_prefix_tree wants a tensor of shape (batch_size, num_sequences, sequence_length) # so we need to add the first two dimensions in. This returns a list, but we're assuming # batch size 1, so we extract the first element. self._allowed_transitions = util.construct_prefix_tree( initial_sequence.view(1, 1, -1))[0] else: self._allowed_transitions = None if keep_beam_details: # mapping from batch_index to a list (timesteps) of lists (beam elements) # of pairs (score, action_history) self.beam_snapshots: Dict[int, List[List[Tuple[float, List[int]]]]] = {} else: self.beam_snapshots = None
def test_create_allowed_transitions(self): targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]) target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]) prefix_tree = util.construct_prefix_tree(targets, target_mask) # There were two instances in this batch. assert len(prefix_tree) == 2 # The first instance had six valid action sequence prefixes. assert len(prefix_tree[0]) == 6 assert prefix_tree[0][()] == {1, 2} assert prefix_tree[0][(1, )] == {2, 3} assert prefix_tree[0][(1, 2)] == {4} assert prefix_tree[0][(1, 3)] == {4} assert prefix_tree[0][(2, )] == {3} assert prefix_tree[0][(2, 3)] == {4} # The second instance had four valid action sequence prefixes. assert len(prefix_tree[1]) == 4 assert prefix_tree[1][()] == {2, 3} assert prefix_tree[1][(2, )] == {3} assert prefix_tree[1][(2, 3)] == {4} assert prefix_tree[1][(3, )] == {4}
def __init__(self, beam_size: Optional[int], allowed_sequences: torch.Tensor, allowed_sequence_mask: torch.Tensor, per_node_beam_size: int = None) -> None: self._beam_size = beam_size self._per_node_beam_size = per_node_beam_size or beam_size self._allowed_transitions = util.construct_prefix_tree( allowed_sequences, allowed_sequence_mask)
def __init__( self, beam_size: Optional[int], all_action_indices: List[List[int]], allowed_sequences: Union[torch.Tensor, List[List[List[int]]]], allowed_sequence_mask: Optional[Union[torch.Tensor, List[List[List[int]]]]] = None, per_node_beam_size: int = None, ) -> None: self._beam_size = beam_size self._per_node_beam_size = per_node_beam_size or beam_size # This is a list of defaultdict (one for each batch instance) mapping action-prefix to allowed actions in the # next step self._allowed_transitions = util.construct_prefix_tree( allowed_sequences, allowed_sequence_mask) self._all_action_indices = all_action_indices