def test_search(self):
        beam_search = BeamSearch.from_params(Params({'beam_size': 4}))
        initial_state = SimpleState([0, 1, 2, 3], [[], [], [], []], [
            torch.Tensor([0.0]),
            torch.Tensor([0.0]),
            torch.Tensor([0.0]),
            torch.Tensor([0.0])
        ], [-3, 1, -20, 5])
        decoder_step = SimpleTransitionFunction(include_value_in_score=True)
        best_states = beam_search.search(5,
                                         initial_state,
                                         decoder_step,
                                         keep_final_unfinished_states=False)

        # Instance with batch index 2 needed too many steps to finish, and batch index 3 had no
        # path to get to a finished state.  (See the simple transition system definition; goal is
        # to end up at 4, actions are either add one or two to starting value.)
        assert len(best_states) == 2
        assert best_states[0][0].action_history[0] == [-1, 1, 3, 4]
        assert best_states[1][0].action_history[0] == [3, 4]

        best_states = beam_search.search(5,
                                         initial_state,
                                         decoder_step,
                                         keep_final_unfinished_states=True)

        # Now we're keeping final unfinished states, which allows a "best state" for the instances
        # that didn't have one before.  Our previous best states for the instances that finish
        # doesn't change, because the score for taking another step is always negative at these
        # values.
        assert len(best_states) == 4
        assert best_states[0][0].action_history[0] == [-1, 1, 3, 4]
        assert best_states[1][0].action_history[0] == [3, 4]
        assert best_states[2][0].action_history[0] == [-18, -16, -14, -12, -10]
        assert best_states[3][0].action_history[0] == [7, 9, 11, 13, 15]
    def test_constraints(self):
        # The simple transition system starts at some number, adds one or two at each state, and
        # tries to get to 4.  The highest scoring path has the shortest length and the highest
        # numbers (so always add two, unless you're at 3).  From -3, there are lots of possible
        # sequences: [-2, -1, 0, 1, 2, 3, 4], [-1, 1, 3, 4], ...  We'll specify a few of those up
        # front as "allowed", and use that to test the constrained beam search implementation.
        initial_state = SimpleState([0], [[]], [torch.Tensor([0.0])], [-3])
        beam_size = 3
        initial_sequence = torch.Tensor([-2, -1, 0, 1])
        beam_search = BeamSearch(beam_size, initial_sequence=initial_sequence)

        decoder_step = SimpleTransitionFunction(include_value_in_score=True)
        best_states = beam_search.search(7, initial_state, decoder_step)

        assert len(best_states) == 1

        # After the constraint runs out, we generate [3], [2],
        # then we generate [3, 5], [3, 4], [2, 4], the latter two of which are finished,
        # then we generate [3, 5, 7], [3, 5, 6], and we're out of steps, so we keep the former
        assert best_states[0][0].action_history[0] == [-2, -1, 0, 1, 3, 4]
        assert best_states[0][1].action_history[0] == [-2, -1, 0, 1, 2, 4]
        assert best_states[0][2].action_history[0] == [-2, -1, 0, 1, 3, 5, 7]

        # Now set the beam size to 6, we generate [3], [2]
        # then [3, 5], [2, 3], [3, 4], [2, 4] (the latter two of which are finished)
        # then [3, 5, 6], [3, 5, 7], [2, 3, 5], [2, 3, 4] (the last is finished)
        beam_size = 6
        beam_search = BeamSearch(beam_size,
                                 initial_sequence=initial_sequence,
                                 keep_beam_details=True)
        decoder_step = SimpleTransitionFunction(include_value_in_score=True)
        best_states = beam_search.search(7,
                                         initial_state,
                                         decoder_step,
                                         keep_final_unfinished_states=False)

        assert len(best_states) == 1
        assert len(best_states[0]) == 3
        assert best_states[0][0].action_history[0] == [-2, -1, 0, 1, 3, 4]
        assert best_states[0][1].action_history[0] == [-2, -1, 0, 1, 2, 4]
        assert best_states[0][2].action_history[0] == [-2, -1, 0, 1, 2, 3, 4]

        # Check that beams are correct
        best_action_sequence = best_states[0][0].action_history[0]

        beam_snapshots = beam_search.beam_snapshots
        assert len(beam_snapshots) == 1

        beam_snapshots0 = beam_snapshots.get(0)
        assert beam_snapshots0 is not None

        for i, beam in enumerate(beam_snapshots0):
            assert all(len(sequence) == i + 1 for _, sequence in beam)
            if i < len(best_action_sequence):
                assert any(sequence[-1] == best_action_sequence[i]
                           for _, sequence in beam)
 def __init__(
     self,
     vocab: Vocabulary,
     question_embedder: TextFieldEmbedder,
     action_embedding_dim: int,
     encoder: Seq2SeqEncoder,
     entity_encoder: Seq2VecEncoder,
     attention: Attention,
     decoder_beam_size: int,
     decoder_num_finished_states: int,
     max_decoding_steps: int,
     mixture_feedforward: FeedForward = None,
     add_action_bias: bool = True,
     normalize_beam_score_by_length: bool = False,
     checklist_cost_weight: float = 0.6,
     use_neighbor_similarity_for_linking: bool = False,
     dropout: float = 0.0,
     num_linking_features: int = 10,
     rule_namespace: str = "rule_labels",
     mml_model_file: str = None,
 ) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(
         vocab=vocab,
         question_embedder=question_embedder,
         action_embedding_dim=action_embedding_dim,
         encoder=encoder,
         entity_encoder=entity_encoder,
         max_decoding_steps=max_decoding_steps,
         add_action_bias=add_action_bias,
         use_neighbor_similarity_for_linking=use_similarity,
         dropout=dropout,
         num_linking_features=num_linking_features,
         rule_namespace=rule_namespace,
     )
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = ExpectedRiskMinimization(
         beam_size=decoder_beam_size,
         normalize_by_length=normalize_beam_score_by_length,
         max_decoding_steps=self._max_decoding_steps,
         max_num_finished_states=decoder_num_finished_states,
     )
     self._decoder_step = LinkingCoverageTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout,
     )
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # We don't need a separate beam search since the trainer does that already. But we're defining one just to
     # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch``
     # class) in the demo. We'll use this only at test time.
     self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size)
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if mml_model_file is not None:
         if os.path.isfile(mml_model_file):
             archive = load_archive(mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")