コード例 #1
0
 def setUp(self):
     super().setUp()
     self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])])
     self.decoder_step = SimpleTransitionFunction()
     # Cost is the number of odd elements in the action history.
     self.supervision = lambda state: torch.Tensor(
         [sum([x % 2 != 0 for x in state.action_history[0]])])
     # High beam size ensures exhaustive search.
     self.trainer = ExpectedRiskMinimization(beam_size=100,
                                             normalize_by_length=False,
                                             max_decoding_steps=10)
コード例 #2
0
    def __init__(self,
                 vocab: Vocabulary,
                 sentence_embedder: TextFieldEmbedder,
                 action_embedding_dim: int,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 max_num_finished_states: int = None,
                 dropout: float = 0.0,
                 normalize_beam_score_by_length: bool = False,
                 checklist_cost_weight: float = 0.6,
                 dynamic_cost_weight: Dict[str, Union[int, float]] = None,
                 penalize_non_agenda_actions: bool = False,
                 initial_mml_model_file: str = None) -> None:
        super(NlvrCoverageSemanticParser,
              self).__init__(vocab=vocab,
                             sentence_embedder=sentence_embedder,
                             action_embedding_dim=action_embedding_dim,
                             encoder=encoder,
                             dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrLanguage just to get the number of terminals.
        self._terminal_productions = set(
            NlvrLanguage(set()).terminal_productions.values())
        self._decoder_step = CoverageTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            activation=Activation.by_name('tanh')(),
            add_action_bias=False,
            dropout=dropout)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight[
                "wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(
                    "MML model file for initializing weights is passed, but does not exist."
                    " This is fine if you're just decoding.")
コード例 #3
0
 def __init__(
     self,
     vocab: Vocabulary,
     question_embedder: TextFieldEmbedder,
     action_embedding_dim: int,
     encoder: Seq2SeqEncoder,
     entity_encoder: Seq2VecEncoder,
     attention: Attention,
     decoder_beam_size: int,
     decoder_num_finished_states: int,
     max_decoding_steps: int,
     mixture_feedforward: FeedForward = None,
     add_action_bias: bool = True,
     normalize_beam_score_by_length: bool = False,
     checklist_cost_weight: float = 0.6,
     use_neighbor_similarity_for_linking: bool = False,
     dropout: float = 0.0,
     num_linking_features: int = 10,
     rule_namespace: str = "rule_labels",
     mml_model_file: str = None,
 ) -> None:
     use_similarity = use_neighbor_similarity_for_linking
     super().__init__(
         vocab=vocab,
         question_embedder=question_embedder,
         action_embedding_dim=action_embedding_dim,
         encoder=encoder,
         entity_encoder=entity_encoder,
         max_decoding_steps=max_decoding_steps,
         add_action_bias=add_action_bias,
         use_neighbor_similarity_for_linking=use_similarity,
         dropout=dropout,
         num_linking_features=num_linking_features,
         rule_namespace=rule_namespace,
     )
     # Not sure why mypy needs a type annotation for this!
     self._decoder_trainer: ExpectedRiskMinimization = ExpectedRiskMinimization(
         beam_size=decoder_beam_size,
         normalize_by_length=normalize_beam_score_by_length,
         max_decoding_steps=self._max_decoding_steps,
         max_num_finished_states=decoder_num_finished_states,
     )
     self._decoder_step = LinkingCoverageTransitionFunction(
         encoder_output_dim=self._encoder.get_output_dim(),
         action_embedding_dim=action_embedding_dim,
         input_attention=attention,
         add_action_bias=self._add_action_bias,
         mixture_feedforward=mixture_feedforward,
         dropout=dropout,
     )
     self._checklist_cost_weight = checklist_cost_weight
     self._agenda_coverage = Average()
     # We don't need a separate beam search since the trainer does that already. But we're defining one just to
     # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch``
     # class) in the demo. We'll use this only at test time.
     self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size)
     # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
     # copied a trained ERM model from a different machine and the original MML model that was
     # used to initialize it does not exist on the current machine. This may not be the best
     # solution for the problem.
     if mml_model_file is not None:
         if os.path.isfile(mml_model_file):
             archive = load_archive(mml_model_file)
             self._initialize_weights_from_archive(archive)
         else:
             # A model file is passed, but it does not exist. This is expected to happen when
             # you're using a trained ERM model to decode. But it may also happen if the path to
             # the file is really just incorrect. So throwing a warning.
             logger.warning(
                 "MML model file for initializing weights is passed, but does not exist."
                 " This is fine if you're just decoding.")