def setUp(self): super().setUp() self.initial_state = SimpleState([0], [[0]], [torch.Tensor([0.0])]) self.decoder_step = SimpleTransitionFunction() # Cost is the number of odd elements in the action history. self.supervision = lambda state: torch.Tensor( [sum([x % 2 != 0 for x in state.action_history[0]])]) # High beam size ensures exhaustive search. self.trainer = ExpectedRiskMinimization(beam_size=100, normalize_by_length=False, max_decoding_steps=10)
def __init__(self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, max_num_finished_states: int = None, dropout: float = 0.0, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, dynamic_cost_weight: Dict[str, Union[int, float]] = None, penalize_non_agenda_actions: bool = False, initial_mml_model_file: str = None) -> None: super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout) self._agenda_coverage = Average() self._decoder_trainer: DecoderTrainer[Callable[[CoverageState], torch.Tensor]] = \ ExpectedRiskMinimization(beam_size=beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=max_decoding_steps, max_num_finished_states=max_num_finished_states) # Instantiating an empty NlvrLanguage just to get the number of terminals. self._terminal_productions = set( NlvrLanguage(set()).terminal_productions.values()) self._decoder_step = CoverageTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, activation=Activation.by_name('tanh')(), add_action_bias=False, dropout=dropout) self._checklist_cost_weight = checklist_cost_weight self._dynamic_cost_wait_epochs = None self._dynamic_cost_rate = None if dynamic_cost_weight: self._dynamic_cost_wait_epochs = dynamic_cost_weight[ "wait_num_epochs"] self._dynamic_cost_rate = dynamic_cost_weight["rate"] self._penalize_non_agenda_actions = penalize_non_agenda_actions self._last_epoch_in_forward: int = None # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.")
def __init__( self, vocab: Vocabulary, question_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, entity_encoder: Seq2VecEncoder, attention: Attention, decoder_beam_size: int, decoder_num_finished_states: int, max_decoding_steps: int, mixture_feedforward: FeedForward = None, add_action_bias: bool = True, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, use_neighbor_similarity_for_linking: bool = False, dropout: float = 0.0, num_linking_features: int = 10, rule_namespace: str = "rule_labels", mml_model_file: str = None, ) -> None: use_similarity = use_neighbor_similarity_for_linking super().__init__( vocab=vocab, question_embedder=question_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, entity_encoder=entity_encoder, max_decoding_steps=max_decoding_steps, add_action_bias=add_action_bias, use_neighbor_similarity_for_linking=use_similarity, dropout=dropout, num_linking_features=num_linking_features, rule_namespace=rule_namespace, ) # Not sure why mypy needs a type annotation for this! self._decoder_trainer: ExpectedRiskMinimization = ExpectedRiskMinimization( beam_size=decoder_beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=self._max_decoding_steps, max_num_finished_states=decoder_num_finished_states, ) self._decoder_step = LinkingCoverageTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, add_action_bias=self._add_action_bias, mixture_feedforward=mixture_feedforward, dropout=dropout, ) self._checklist_cost_weight = checklist_cost_weight self._agenda_coverage = Average() # We don't need a separate beam search since the trainer does that already. But we're defining one just to # be able to use interactive beam search (a functionality that's only implemented in the ``BeamSearch`` # class) in the demo. We'll use this only at test time. self._beam_search: BeamSearch = BeamSearch(beam_size=decoder_beam_size) # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if mml_model_file is not None: if os.path.isfile(mml_model_file): archive = load_archive(mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding.")