def setUp(self):
     super().setUp()
     self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])])
     self.decoder_step = SimpleTransitionFunction()
     # Cost is the number of odd elements in the action history.
     self.supervision = lambda state: torch.Tensor(
         [sum([x % 2 != 0 for x in state.action_history[0]])])
     # High beam size ensures exhaustive search.
     self.trainer = ExpectedRiskMinimization(beam_size=100,
                                             normalize_by_length=False,
                                             max_decoding_steps=10)
class TestExpectedRiskMinimization(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])])
        self.decoder_step = SimpleTransitionFunction()
        # Cost is the number of odd elements in the action history.
        self.supervision = lambda state: torch.Tensor(
            [sum([x % 2 != 0 for x in state.action_history[0]])])
        # High beam size ensures exhaustive search.
        self.trainer = ExpectedRiskMinimization(beam_size=100,
                                                normalize_by_length=False,
                                                max_decoding_steps=10)

    def test_get_finished_states(self):
        finished_states = self.trainer._get_finished_states(
            self.initial_state, self.decoder_step)
        state_info = [(state.action_history[0], state.score[0].item())
                      for state in finished_states]
        # There will be exactly five finished states with the following paths. Each score is the
        # negative of one less than the number of elements in the action history.
        assert len(finished_states) == 5
        assert ([0, 2, 4], -2) in state_info
        assert ([0, 1, 2, 4], -3) in state_info
        assert ([0, 1, 3, 4], -3) in state_info
        assert ([0, 2, 3, 4], -3) in state_info
        assert ([0, 1, 2, 3, 4], -4) in state_info

    def test_decode(self):
        decoded_info = self.trainer.decode(self.initial_state,
                                           self.decoder_step, self.supervision)
        # The best state corresponds to the shortest path.
        best_state = decoded_info["best_final_states"][0][0]
        assert best_state.action_history[0] == [0, 2, 4]
        # The scores and costs corresponding to the finished states will be
        # [0, 2, 4] : -2, 0
        # [0, 1, 2, 4] : -3, 1
        # [0, 1, 3, 4] : -3, 2
        # [0, 2, 3, 4] : -3, 1
        # [0, 1, 2, 3, 4] : -4, 2

        # This is the normalization factor while re-normalizing probabilities on the beam
        partition = np.exp(-2) + np.exp(-3) + np.exp(-3) + np.exp(-3) + np.exp(
            -4)
        expected_loss = ((np.exp(-2) * 0) + (np.exp(-3) * 1) +
                         (np.exp(-3) * 2) + (np.exp(-3) * 1) +
                         (np.exp(-4) * 2)) / partition
        assert_almost_equal(decoded_info["loss"].data.numpy(), expected_loss)
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 max_num_finished_states: int = None,
                 dropout: float = 0.0,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 dynamic_cost_weight: Dict[str, Union[int, float]] = None,
                 penalize_non_agenda_actions: bool = False,
                 initial_mml_model_file: str = None) -> None:
        super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab,
                                                         sentence_embedder=sentence_embedder,
                                                         action_embedding_dim=action_embedding_dim,
                                                         encoder=encoder,
                                                         dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrWorld just to get the number of terminals.
        self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
        self._decoder_step = CoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
                                                        action_embedding_dim=action_embedding_dim,
                                                        input_attention=attention,
                                                        num_start_types=1,
                                                        activation=Activation.by_name('tanh')(),
                                                        predict_start_type_separately=False,
                                                        add_action_bias=False,
                                                        dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning("MML model file for initializing weights is passed, but does not exist."
                               " This is fine if you're just decoding.")
Exemplo n.º 4
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              attention: Attention,
              decoder_beam_size: int,
              decoder_num_finished_states: int,
              max_decoding_steps: int,
              mixture_feedforward: FeedForward = None,
              add_action_bias: bool = True,
              normalize_beam_score_by_length: bool = False,
              checklist_cost_weight: float = 0.6,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              mml_model_file: str = None) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      add_action_bias=add_action_bias,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace)
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = \
             ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                      normalize_by_length=normalize_beam_score_by_length,
                                      max_decoding_steps=self._max_decoding_steps,
                                      max_num_finished_states=decoder_num_finished_states)
     self._decoder_step = LinkingCoverageTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # We don't need a separate beam search since the trainer does that already. But we're defining one just to
     # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch``
     # class) in the demo. We'll use this only at test time.
     self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size)
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if mml_model_file is not None:
         if os.path.isfile(mml_model_file):
             archive = load_archive(mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")
Exemplo n.º 5
0
 def __init__(self,
              vocab: Vocabulary,
              question_embedder: TextFieldEmbedder,
              action_embedding_dim: int,
              encoder: Seq2SeqEncoder,
              entity_encoder: Seq2VecEncoder,
              attention: Attention,
              decoder_beam_size: int,
              decoder_num_finished_states: int,
              max_decoding_steps: int,
              mixture_feedforward: FeedForward = None,
              add_action_bias: bool = True,
              normalize_beam_score_by_length: bool = False,
              checklist_cost_weight: float = 0.6,
              use_neighbor_similarity_for_linking: bool = False,
              dropout: float = 0.0,
              num_linking_features: int = 10,
              rule_namespace: str = 'rule_labels',
              tables_directory: str = '/wikitables/',
              mml_model_file: str = None) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(vocab=vocab,
                      question_embedder=question_embedder,
                      action_embedding_dim=action_embedding_dim,
                      encoder=encoder,
                      entity_encoder=entity_encoder,
                      max_decoding_steps=max_decoding_steps,
                      add_action_bias=add_action_bias,
                      use_neighbor_similarity_for_linking=use_similarity,
                      dropout=dropout,
                      num_linking_features=num_linking_features,
                      rule_namespace=rule_namespace,
                      tables_directory=tables_directory)
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = \
             ExpectedRiskMinimization(beam_size=decoder_beam_size,
                                      normalize_by_length=normalize_beam_score_by_length,
                                      max_decoding_steps=self._max_decoding_steps,
                                      max_num_finished_states=decoder_num_finished_states)
     unlinked_terminals_global_indices = []
     global_vocab = self.vocab.get_token_to_index_vocabulary(rule_namespace)
     for production, index in global_vocab.items():
         right_side = production.split(" -> ")[1]
         if right_side in types.COMMON_NAME_MAPPING:
             # This is a terminal production.
             unlinked_terminals_global_indices.append(index)
     self._num_unlinked_terminals = len(unlinked_terminals_global_indices)
     self._decoder_step = LinkingCoverageTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         num_start_types=self._num_start_types,
         predict_start_type_separately=True,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout)
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if mml_model_file is not None:
         if os.path.isfile(mml_model_file):
             archive = load_archive(mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")