Esempio n. 1
0
    def __init__(
        self,
        beam_size: int,
        per_node_beam_size: int = None,
        initial_sequence: torch.Tensor = None,
        keep_beam_details: bool = False,
    ) -> None:
        self._beam_size = beam_size
        self._per_node_beam_size = per_node_beam_size or beam_size

        if initial_sequence is not None:
            # construct_prefix_tree wants a tensor of shape (batch_size, num_sequences, sequence_length)
            # so we need to add the first two dimensions in. This returns a list, but we're assuming
            # batch size 1, so we extract the first element.
            self._allowed_transitions = util.construct_prefix_tree(
                initial_sequence.view(1, 1, -1))[0]
        else:
            self._allowed_transitions = None

        if keep_beam_details:
            # mapping from batch_index to a list (timesteps) of lists (beam elements)
            # of pairs (score, action_history)
            self.beam_snapshots: Dict[int, List[List[Tuple[float,
                                                           List[int]]]]] = {}
        else:
            self.beam_snapshots = None
Esempio n. 2
0
    def test_create_allowed_transitions(self):
        targets = torch.Tensor([[[2, 3, 4], [1, 3, 4], [1, 2, 4]],
                                [[3, 4, 0], [2, 3, 4], [0, 0, 0]]])
        target_mask = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                                    [[1, 1, 0], [1, 1, 1], [0, 0, 0]]])
        prefix_tree = util.construct_prefix_tree(targets, target_mask)

        # There were two instances in this batch.
        assert len(prefix_tree) == 2

        # The first instance had six valid action sequence prefixes.
        assert len(prefix_tree[0]) == 6
        assert prefix_tree[0][()] == {1, 2}
        assert prefix_tree[0][(1, )] == {2, 3}
        assert prefix_tree[0][(1, 2)] == {4}
        assert prefix_tree[0][(1, 3)] == {4}
        assert prefix_tree[0][(2, )] == {3}
        assert prefix_tree[0][(2, 3)] == {4}

        # The second instance had four valid action sequence prefixes.
        assert len(prefix_tree[1]) == 4
        assert prefix_tree[1][()] == {2, 3}
        assert prefix_tree[1][(2, )] == {3}
        assert prefix_tree[1][(2, 3)] == {4}
        assert prefix_tree[1][(3, )] == {4}
Esempio n. 3
0
 def __init__(self,
              beam_size: Optional[int],
              allowed_sequences: torch.Tensor,
              allowed_sequence_mask: torch.Tensor,
              per_node_beam_size: int = None) -> None:
     self._beam_size = beam_size
     self._per_node_beam_size = per_node_beam_size or beam_size
     self._allowed_transitions = util.construct_prefix_tree(
         allowed_sequences, allowed_sequence_mask)
Esempio n. 4
0
 def __init__(
     self,
     beam_size: Optional[int],
     all_action_indices: List[List[int]],
     allowed_sequences: Union[torch.Tensor, List[List[List[int]]]],
     allowed_sequence_mask: Optional[Union[torch.Tensor,
                                           List[List[List[int]]]]] = None,
     per_node_beam_size: int = None,
 ) -> None:
     self._beam_size = beam_size
     self._per_node_beam_size = per_node_beam_size or beam_size
     # This is a list of defaultdict (one for each batch instance) mapping action-prefix to allowed actions in the
     # next step
     self._allowed_transitions = util.construct_prefix_tree(
         allowed_sequences, allowed_sequence_mask)
     self._all_action_indices = all_action_indices