Beispiel #1
0
    def search(
            self, initial_state: State,
            transition_function: TransitionFunction) -> Dict[int, List[State]]:
        """
        Parameters
        ----------
        initial_state : ``State``
            The starting state of our search.  This is assumed to be `batched`, and our beam search
            is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch.
        transition_function : ``TransitionFunction``
            The ``TransitionFunction`` object that defines and scores transitions from one state to the
            next.

        Returns
        -------
        best_states : ``Dict[int, List[State]]``
            This is a mapping from batch index to the top states for that instance.
        """
        finished_states: Dict[int, List[State]] = defaultdict(list)
        states = [initial_state]
        step_num = 0
        while states:
            step_num += 1
            next_states: Dict[int, List[State]] = defaultdict(list)
            grouped_state = states[0].combine_states(states)
            allowed_actions = []
            for batch_index, action_history in zip(
                    grouped_state.batch_indices, grouped_state.action_history):
                allowed_actions.append(self._allowed_transitions[batch_index][
                    tuple(action_history)])
            for next_state in transition_function.take_step(
                    grouped_state,
                    max_actions=self._per_node_beam_size,
                    allowed_actions=allowed_actions):
                # NOTE: we're doing state.batch_indices[0] here (and similar things below),
                # hard-coding a group size of 1.  But, our use of `next_state.is_finished()`
                # already checks for that, as it crashes if the group size is not 1.
                batch_index = next_state.batch_indices[0]
                if next_state.is_finished():
                    finished_states[batch_index].append(next_state)
                else:
                    next_states[batch_index].append(next_state)
            states = []
            for batch_index, batch_states in next_states.items():
                # The states from the generator are already sorted, so we can just take the first
                # ones here, without an additional sort.
                if self._beam_size:
                    batch_states = batch_states[:self._beam_size]
                states.extend(batch_states)
        best_states: Dict[int, List[State]] = {}
        for batch_index, batch_states in finished_states.items():
            # The time this sort takes is pretty negligible, no particular need to optimize this
            # yet.  Maybe with a larger beam size...
            finished_to_sort = [(-state.score[0].item(), state)
                                for state in batch_states]
            finished_to_sort.sort(key=lambda x: x[0])
            best_states[batch_index] = [
                state[1] for state in finished_to_sort[:self._beam_size]
            ]
        return best_states
    def _get_finished_states(
            self, initial_state: State,
            transition_function: TransitionFunction) -> List[StateType]:
        finished_states = []
        states = [initial_state]
        num_steps = 0
        while states and num_steps < self._max_decoding_steps:
            next_states = []
            grouped_state = states[0].combine_states(states)
            # These states already come sorted.
            for next_state in transition_function.take_step(grouped_state):
                if next_state.is_finished():
                    finished_states.append(next_state)
                else:
                    next_states.append(next_state)

            states = self._prune_beam(states=next_states,
                                      beam_size=self._beam_size,
                                      sort_states=False)
            num_steps += 1
        if self._max_num_finished_states is not None:
            finished_states = self._prune_beam(
                states=finished_states,
                beam_size=self._max_num_finished_states,
                sort_states=True)
        return finished_states
Beispiel #3
0
    def search(
        self,
        num_steps: int,
        initial_state: StateType,
        transition_function: TransitionFunction,
        keep_final_unfinished_states: bool = True
    ) -> Dict[int, List[StateType]]:
        """
        Parameters
        ----------
        num_steps : ``int``
            How many steps should we take in our search?  This is an upper bound, as it's possible
            for the search to run out of valid actions before hitting this number, or for all
            states on the beam to finish.
        initial_state : ``StateType``
            The starting state of our search.  This is assumed to be `batched`, and our beam search
            is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch.
        transition_function : ``TransitionFunction``
            The ``TransitionFunction`` object that defines and scores transitions from one state to the
            next.
        keep_final_unfinished_states : ``bool``, optional (default=True)
            If we run out of steps before a state is "finished", should we return that state in our
            search results?

        Returns
        -------
        best_states : ``Dict[int, List[StateType]]``
            This is a mapping from batch index to the top states for that instance.
        """
        finished_states: Dict[int, List[StateType]] = defaultdict(list)
        states = [initial_state]
        step_num = 1

        # Erase stored beams, if we're tracking them.
        if self.beam_snapshots is not None:
            self.beam_snapshots = defaultdict(list)

        while states and step_num <= num_steps:
            next_states: Dict[int, List[StateType]] = defaultdict(list)
            grouped_state = states[0].combine_states(states)

            if self._allowed_transitions:
                # We were provided an initial sequence, so we need to check
                # if the current sequence is still constrained.
                key = tuple(grouped_state.action_history[0])
                if key in self._allowed_transitions:
                    # We're still in the initial_sequence, so our hand is forced.
                    allowed_actions = [self._allowed_transitions[key]]
                else:
                    # We've gone past the end of the initial sequence, so no constraint.
                    allowed_actions = None
            else:
                # No initial sequence was provided, so all actions are allowed.
                allowed_actions = None

            for next_state in transition_function.take_step(
                    grouped_state,
                    max_actions=self._per_node_beam_size,
                    allowed_actions=allowed_actions):
                # NOTE: we're doing state.batch_indices[0] here (and similar things below),
                # hard-coding a group size of 1.  But, our use of `next_state.is_finished()`
                # already checks for that, as it crashes if the group size is not 1.
                batch_index = next_state.batch_indices[0]
                if next_state.is_finished():
                    finished_states[batch_index].append(next_state)
                else:
                    if step_num == num_steps and keep_final_unfinished_states:
                        finished_states[batch_index].append(next_state)
                    next_states[batch_index].append(next_state)
            states = []
            for batch_index, batch_states in next_states.items():
                # The states from the generator are already sorted, so we can just take the first
                # ones here, without an additional sort.
                states.extend(batch_states[:self._beam_size])

                if self.beam_snapshots is not None:
                    # Add to beams
                    self.beam_snapshots[batch_index].append([
                        (state.score[0].item(), state.action_history[0])
                        for state in batch_states
                    ])
            step_num += 1

        # Add finished states to the stored beams as well.
        if self.beam_snapshots is not None:
            for batch_index, states in finished_states.items():
                for state in states:
                    score = state.score[0].item()
                    action_history = state.action_history[0]

                    while len(self.beam_snapshots[batch_index]) < len(
                            action_history):
                        self.beam_snapshots[batch_index].append([])

                    self.beam_snapshots[batch_index][len(action_history) -
                                                     1].append(
                                                         (score,
                                                          action_history))

        best_states: Dict[int, List[StateType]] = {}
        for batch_index, batch_states in finished_states.items():
            # The time this sort takes is pretty negligible, no particular need to optimize this
            # yet.  Maybe with a larger beam size...
            finished_to_sort = [(-state.score[0].item(), state)
                                for state in batch_states]
            finished_to_sort.sort(key=lambda x: x[0])
            best_states[batch_index] = [
                state[1] for state in finished_to_sort[:self._beam_size]
            ]
        return best_states
Beispiel #4
0
    def search(
        self,
        num_steps: int,
        initial_state: StateType,
        transition_function: TransitionFunction,
        firststep_allowed_actions: List[Set[int]] = None,
        keep_final_unfinished_states: bool = True,
    ) -> Mapping[int, Sequence[StateType]]:
        """
        Parameters
        ----------
        num_steps : ``int``
            How many steps should we take in our search?  This is an upper bound, as it's possible
            for the search to run out of valid actions before hitting this number, or for all
            states on the beam to finish.
        initial_state : ``StateType``
            The starting state of our search.  This is assumed to be `batched`, and our beam search
            is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch.
        transition_function : ``TransitionFunction``
            The ``TransitionFunction`` object that defines and scores transitions from one state to the
            next.
        firststep_allowed_actions : ``List[Set[int]]``
            For each instance in the initial_state, a set of allowed_actions_idxs for the first decoding step.
            This is useful in our case since we have supervision for the `type` of the action sequence
            and only want the first action so that the action sequence results to that type.
        keep_final_unfinished_states : ``bool``, optional (default=True)
            If we run out of steps before a state is "finished", should we return that state in our
            search results?

        Returns
        -------
        best_states : ``Dict[int, List[StateType]]``
            This is a mapping from batch index to the top states for that instance.
        """
        finished_states: Dict[int, List[StateType]] = defaultdict(list)
        states = [initial_state]
        step_num = 1

        while states and step_num <= num_steps:
            next_states: Dict[int, List[StateType]] = defaultdict(list)
            grouped_state = states[0].combine_states(states)
            if step_num == 1:
                possible_next_states = transition_function.take_step(
                    grouped_state, max_actions=self._per_node_beam_size, allowed_actions=firststep_allowed_actions
                )
            else:
                possible_next_states = transition_function.take_step(
                    grouped_state, max_actions=self._per_node_beam_size
                )
            for next_state in possible_next_states:
                # NOTE: we're doing state.batch_indices[0] here (and similar things below),
                # hard-coding a group size of 1.  But, our use of `next_state.is_finished()`
                # already checks for that, as it crashes if the group size is not 1.
                batch_index = next_state.batch_indices[0]
                if next_state.is_finished():
                    finished_states[batch_index].append(next_state)
                else:
                    if step_num == num_steps and keep_final_unfinished_states:
                        finished_states[batch_index].append(next_state)
                    next_states[batch_index].append(next_state)
            states = []
            for batch_index, batch_states in next_states.items():
                # The states from the generator are already sorted, so we can just take the first
                # ones here, without an additional sort.
                states.extend(batch_states[: self._beam_size])
            step_num += 1
        best_states: Dict[int, Sequence[StateType]] = {}
        for batch_index, batch_states in finished_states.items():
            # The time this sort takes is pretty negligible, no particular need to optimize this
            # yet.  Maybe with a larger beam size...
            finished_to_sort = [(-state.score[0].item(), state) for state in batch_states]
            finished_to_sort.sort(key=lambda x: x[0])
            best_states[batch_index] = [state[1] for state in finished_to_sort[: self._beam_size]]
        return best_states
    def search(
        self,
        initial_state: State,
        transition_function: TransitionFunction,
        num_steps: int = None,
        keep_final_unfinished_states: bool = True,
    ) -> Dict[int, List[State]]:
        """
        Parameters
        ----------
        initial_state : ``State``
            The starting state of our search.  This is assumed to be `batched`, and our beam search
            is batch-aware - we'll keep ``beam_size`` states around for each instance in the batch.
        transition_function : ``TransitionFunction``
            The ``TransitionFunction`` object that defines and scores transitions from one state to the
            next.
        num_steps : ``int``
            How many steps should we take in our search?  This is an upper bound, as it's possible
            for the search to run out of valid actions before hitting this number, or for all
            states on the beam to finish.
        keep_final_unfinished_states : ``bool``, optional (default=True)
            If we run out of steps before a state is "finished", should we return that state in our
            search results?

        Returns
        -------
        best_states : ``Dict[int, List[State]]``
            This is a mapping from batch index to the top states for that instance.
        """
        finished_states: Dict[int, List[State]] = defaultdict(list)
        states = [initial_state]
        step_num = 0

        while states:
            step_num += 1
            next_states: Dict[int, List[State]] = defaultdict(list)
            grouped_state = states[0].combine_states(states)
            allowed_actions = []

            for batch_index, action_history in zip(
                    grouped_state.batch_indices, grouped_state.action_history):
                # This path being explored has a constrained-prefix so only allow actions from that path --
                if tuple(action_history
                         ) in self._allowed_transitions[batch_index]:
                    actions_allowed_for_this_state = self._allowed_transitions[
                        batch_index][tuple(action_history)]
                else:
                    actions_allowed_for_this_state = self._all_action_indices[
                        batch_index]

                allowed_actions.append(actions_allowed_for_this_state)

            for next_state in transition_function.take_step(
                    grouped_state,
                    max_actions=self._per_node_beam_size,
                    allowed_actions=allowed_actions):
                # NOTE: we're doing state.batch_indices[0] here (and similar things below),
                # hard-coding a group size of 1.  But, our use of `next_state.is_finished()`
                # already checks for that, as it crashes if the group size is not 1.
                batch_index = next_state.batch_indices[0]
                if next_state.is_finished():
                    finished_states[batch_index].append(next_state)
                else:
                    if num_steps and step_num == num_steps and keep_final_unfinished_states:
                        finished_states[batch_index].append(next_state)
                    next_states[batch_index].append(next_state)
            states = []
            for batch_index, batch_states in next_states.items():
                # The states from the generator are already sorted, so we can just take the first
                # ones here, without an additional sort.
                if self._beam_size:
                    batch_states = batch_states[:self._beam_size]
                states.extend(batch_states)
            if num_steps and step_num >= num_steps:
                break
        best_states: Dict[int, List[State]] = {}
        for batch_index, batch_states in finished_states.items():
            # The time this sort takes is pretty negligible, no particular need to optimize this
            # yet.  Maybe with a larger beam size...
            finished_to_sort = [(-state.score[0].item(), state)
                                for state in batch_states]
            finished_to_sort.sort(key=lambda x: x[0])
            best_states[batch_index] = [
                state[1] for state in finished_to_sort[:self._beam_size]
            ]
        return best_states