Exemple #1
0
    def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor:
        """
        Return the costs a finished state. Since it is a finished state, the group size will be 1,
        and hence we'll return just one cost.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # label_strings being None happens when we are testing. We do not care about the cost then.
        # TODO (pradeep): Make this cleaner.
        if state.label_strings is None or all(self._check_state_denotations(state)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
    def _get_state_cost(self, state: NlvrDecoderState) -> torch.Tensor:
        """
        Return the costs a finished state. Since it is a finished state, the group size will be 1,
        and hence we'll return just one cost.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # label_strings being None happens when we are testing. We do not care about the cost then.
        # TODO (pradeep): Make this cleaner.
        if state.label_strings is None or all(self._check_state_denotations(state)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            util.new_variable_with_data(
                list(sentence.values())[0], torch.Tensor([0.0]))
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds_list,
            label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(
            self._max_decoding_steps,
            initial_state,
            self._decoder_step,
            keep_final_unfinished_states=False)
        best_action_sequences: Dict[int, List[List[int]]] = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [
                    best_final_states[i][0].action_history[0]
                ]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
Exemple #4
0
    def take_step(
            self,  # type: ignore
            state: NlvrDecoderState,
            max_actions: int = None,
            allowed_actions: List[Set[int]] = None) -> List[NlvrDecoderState]:
        """
        Given a ``NlvrDecoderState``, returns a list of next states that are sorted by their scores.
        This method is very similar to ``WikiTablesDecoderStep._take_step``. The differences are
        that depending on the type of supervision being used, we may not have a notion of
        "allowed actions" here, and we do not perform entity linking here.
        """
        # Outline here: first we'll construct the input to the decoder, which is a concatenation of
        # an embedding of the decoder input (the last action taken) and an attention over the
        # sentence.  Then we'll update our decoder's hidden state given this input, and recompute
        # an attention over the sentence given our new hidden state.  We'll use a concatenation of
        # the new hidden state, the new attention, and optionally the checklist balance to predict an
        # output, then yield new states. We will compute and use a checklist balance when
        # ``allowed_actions`` is None, with the assumption that the ``DecoderTrainer`` that is
        # calling this method is trying to train a parser without logical form supervision.
        # TODO (pradeep): Make the distinction between the two kinds of trainers in the way they
        # call this method more explicit.

        # Each new state corresponds to one valid action that can be taken from the current state,
        # and they are ordered by model scores.
        attended_sentence = torch.stack(
            [rnn_state.attended_input for rnn_state in state.rnn_state])
        hidden_state = torch.stack(
            [rnn_state.hidden_state for rnn_state in state.rnn_state])
        memory_cell = torch.stack(
            [rnn_state.memory_cell for rnn_state in state.rnn_state])
        previous_action_embedding = torch.stack([
            rnn_state.previous_action_embedding
            for rnn_state in state.rnn_state
        ])

        # (group_size, decoder_input_dim)
        decoder_input = self._input_projection_layer(
            torch.cat([attended_sentence, previous_action_embedding], -1))
        decoder_input = torch.nn.functional.tanh(decoder_input)
        hidden_state, memory_cell = self._decoder_cell(
            decoder_input, (hidden_state, memory_cell))

        hidden_state = self._dropout(hidden_state)
        # (group_size, encoder_output_dim)
        encoder_outputs = torch.stack([
            state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices
        ])
        encoder_output_mask = torch.stack([
            state.rnn_state[0].encoder_output_mask[i]
            for i in state.batch_indices
        ])
        attended_sentence = self.attend_on_sentence(hidden_state,
                                                    encoder_outputs,
                                                    encoder_output_mask)

        # We get global indices of actions to embed here. The following logic is similar to
        # ``WikiTablesDecoderStep._get_actions_to_consider``, except that we do not have any actions
        # to link.
        valid_actions = state.get_valid_actions()
        global_valid_actions: List[List[Tuple[int, int]]] = []
        for batch_index, valid_action_list in zip(state.batch_indices,
                                                  valid_actions):
            global_valid_actions.append([])
            for action_index in valid_action_list:
                # state.action_indices is a dictionary that maps (batch_index, batch_action_index)
                # to global_action_index
                global_action_index = state.action_indices[(batch_index,
                                                            action_index)]
                global_valid_actions[-1].append(
                    (global_action_index, action_index))
        global_actions_to_embed: List[List[int]] = []
        local_actions: List[List[int]] = []
        for global_action_list in global_valid_actions:
            global_action_list.sort()
            global_actions_to_embed.append([])
            local_actions.append([])
            for global_action_index, action_index in global_action_list:
                global_actions_to_embed[-1].append(global_action_index)
                local_actions[-1].append(action_index)
        max_num_actions = max(
            [len(action_list) for action_list in global_actions_to_embed])
        # We pad local actions with -1 as padding to get considered actions.
        considered_actions = [
            common_util.pad_sequence_to_length(action_list,
                                               max_num_actions,
                                               default_value=lambda: -1)
            for action_list in local_actions
        ]

        # action_embeddings: (group_size, num_embedded_actions, action_embedding_dim)
        # action_mask: (group_size, num_embedded_actions)
        action_embeddings, embedded_action_mask = self._get_action_embeddings(
            state, global_actions_to_embed)
        action_query = torch.cat([hidden_state, attended_sentence], dim=-1)
        # (group_size, action_embedding_dim)
        predicted_action_embedding = self._output_projection_layer(
            action_query)
        predicted_action_embedding = self._dropout(
            torch.nn.functional.tanh(predicted_action_embedding))
        if state.checklist_state[0] is not None:
            embedding_addition = self._get_predicted_embedding_addition(state)
            addition = embedding_addition * self._checklist_embedding_multiplier
            predicted_action_embedding = predicted_action_embedding + addition
        # We'll do a batch dot product here with `bmm`.  We want `dot(predicted_action_embedding,
        # action_embedding)` for each `action_embedding`, and we can get that efficiently with
        # `bmm` and some squeezing.
        # Shape: (group_size, num_embedded_actions)
        action_logits = action_embeddings.bmm(
            predicted_action_embedding.unsqueeze(-1)).squeeze(-1)

        action_mask = embedded_action_mask.float()
        if state.checklist_state[0] is not None:
            # We will compute the logprobs and the checklists of potential next states together for
            # efficiency.
            logprobs, new_checklist_states = self._get_next_state_info_with_agenda(
                state, considered_actions, action_logits, action_mask)
        else:
            logprobs = self._get_next_state_info_without_agenda(
                state, considered_actions, action_logits, action_mask)
            new_checklist_states = None
        return self._compute_new_states(state, logprobs, hidden_state,
                                        memory_cell, action_embeddings,
                                        attended_sentence, considered_actions,
                                        allowed_actions, new_checklist_states,
                                        max_actions)
Exemple #5
0
    def _compute_new_states(cls,
                            state: NlvrDecoderState,
                            action_logprobs: List[List[Tuple[int,
                                                             torch.Tensor]]],
                            hidden_state: torch.Tensor,
                            memory_cell: torch.Tensor,
                            action_embeddings: torch.Tensor,
                            attended_sentence: torch.Tensor,
                            considered_actions: List[List[int]],
                            allowed_actions: List[Set[int]] = None,
                            new_checklist_states: List[
                                List[ChecklistState]] = None,
                            max_actions: int = None) -> List[NlvrDecoderState]:
        """
        This method is very similar to ``WikiTabledDecoderStep._compute_new_states``.
        The difference here is that we also keep track of checklists if they are passed to this
        method.
        """
        # batch_index -> group_index, action_index, checklist, score
        states_info: Dict[int, List[Tuple[int, int, torch.Tensor,
                                          torch.Tensor]]] = defaultdict(list)
        if new_checklist_states is None:
            # We do not have checklist states. Making a list of lists of Nones of the appropriate size for
            # the zips below.
            new_checklist_states = [[None for logprob in instance_logprobs]
                                    for instance_logprobs in action_logprobs]
        for group_index, instance_info in enumerate(
                zip(state.batch_indices, action_logprobs,
                    new_checklist_states)):
            batch_index, instance_logprobs, instance_new_checklist_states = instance_info
            for (action_index,
                 score), checklist_state in zip(instance_logprobs,
                                                instance_new_checklist_states):
                states_info[batch_index].append(
                    (group_index, action_index, checklist_state, score))

        new_states = []
        for batch_index, instance_states_info in states_info.items():
            batch_scores = torch.cat(
                [info[-1] for info in instance_states_info])
            _, sorted_indices = batch_scores.sort(-1, descending=True)
            sorted_states_info = [
                instance_states_info[i]
                for i in sorted_indices.data.cpu().numpy()
            ]
            allowed_states_info = []
            for i, (group_index, action_index, _,
                    _) in enumerate(sorted_states_info):
                action = considered_actions[group_index][action_index]
                if allowed_actions is not None and action not in allowed_actions[
                        group_index]:
                    continue
                allowed_states_info.append(sorted_states_info[i])
            sorted_states_info = allowed_states_info
            if max_actions is not None:
                sorted_states_info = sorted_states_info[:max_actions]
            for group_index, action_index, new_checklist_state, new_score in sorted_states_info:
                # This is the actual index of the action from the original list of actions.
                # We do not have to check whether it is the padding index because ``take_step``
                # already took care of that.
                action = considered_actions[group_index][action_index]
                action_embedding = action_embeddings[group_index,
                                                     action_index, :]
                new_action_history = state.action_history[group_index] + [
                    action
                ]
                production_rule = state.possible_actions[batch_index][action][
                    0]
                new_grammar_state = state.grammar_state[
                    group_index].take_action(production_rule)
                new_rnn_state = RnnState(
                    hidden_state[group_index], memory_cell[group_index],
                    action_embedding, attended_sentence[group_index],
                    state.rnn_state[group_index].encoder_outputs,
                    state.rnn_state[group_index].encoder_output_mask)
                new_state = NlvrDecoderState(
                    batch_indices=[batch_index],
                    action_history=[new_action_history],
                    score=[new_score],
                    rnn_state=[new_rnn_state],
                    grammar_state=[new_grammar_state],
                    action_embeddings=state.action_embeddings,
                    action_indices=state.action_indices,
                    possible_actions=state.possible_actions,
                    worlds=state.worlds,
                    label_strings=state.label_strings,
                    checklist_state=[new_checklist_state])
                new_states.append(new_state)
        return new_states
Exemple #6
0
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRuleArray]],
                agenda: torch.LongTensor,
                identifier: List[str] = None,
                labels: torch.LongTensor = None,
                epoch_num: List[int] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError("If you want a dynamic cost weight, use the "
                                   "EpochTrackingBucketIterator!")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f", self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(ChecklistState(terminal_actions=terminal_actions,
                                                           checklist_target=checklist_target,
                                                           checklist_mask=checklist_mask,
                                                           checklist=initial_checklist))
        initial_state = NlvrDecoderState(batch_indices=list(range(batch_size)),
                                         action_history=[[] for _ in range(batch_size)],
                                         score=initial_score_list,
                                         rnn_state=initial_rnn_state,
                                         grammar_state=initial_grammar_state,
                                         action_embeddings=action_embeddings,
                                         action_indices=action_indices,
                                         possible_actions=actions,
                                         worlds=worlds,
                                         label_strings=label_strings,
                                         checklist_state=initial_checklist_states)

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(initial_state,
                                               self._decoder_step,
                                               self._get_state_cost)
        if identifier is not None:
            outputs['identifier'] = identifier
        best_action_sequences = outputs['best_action_sequences']
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
    def take_step(self,  # type: ignore
                  state: NlvrDecoderState,
                  max_actions: int = None,
                  allowed_actions: List[Set[int]] = None) -> List[NlvrDecoderState]:
        """
        Given a ``NlvrDecoderState``, returns a list of next states that are sorted by their scores.
        This method is very similar to ``WikiTablesDecoderStep._take_step``. The differences are
        that depending on the type of supervision being used, we may not have a notion of
        "allowed actions" here, and we do not perform entity linking here.
        """
        # Outline here: first we'll construct the input to the decoder, which is a concatenation of
        # an embedding of the decoder input (the last action taken) and an attention over the
        # sentence.  Then we'll update our decoder's hidden state given this input, and recompute
        # an attention over the sentence given our new hidden state.  We'll use a concatenation of
        # the new hidden state, the new attention, and optionally the checklist balance to predict an
        # output, then yield new states. We will compute and use a checklist balance when
        # ``allowed_actions`` is None, with the assumption that the ``DecoderTrainer`` that is
        # calling this method is trying to train a parser without logical form supervision.
        # TODO (pradeep): Make the distinction between the two kinds of trainers in the way they
        # call this method more explicit.

        # Each new state corresponds to one valid action that can be taken from the current state,
        # and they are ordered by model scores.
        attended_sentence = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])
        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding
                                                 for rnn_state in state.rnn_state])

        # (group_size, decoder_input_dim)
        decoder_input = self._input_projection_layer(torch.cat([attended_sentence,
                                                                previous_action_embedding], -1))
        decoder_input = torch.nn.functional.tanh(decoder_input)
        hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))

        hidden_state = self._dropout(hidden_state)
        # (group_size, encoder_output_dim)
        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])
        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])
        attended_sentence = self.attend_on_sentence(hidden_state, encoder_outputs, encoder_output_mask)

        # We get global indices of actions to embed here. The following logic is similar to
        # ``WikiTablesDecoderStep._get_actions_to_consider``, except that we do not have any actions
        # to link.
        valid_actions = state.get_valid_actions()
        global_valid_actions: List[List[Tuple[int, int]]] = []
        for batch_index, valid_action_list in zip(state.batch_indices, valid_actions):
            global_valid_actions.append([])
            for action_index in valid_action_list:
                # state.action_indices is a dictionary that maps (batch_index, batch_action_index)
                # to global_action_index
                global_action_index = state.action_indices[(batch_index, action_index)]
                global_valid_actions[-1].append((global_action_index, action_index))
        global_actions_to_embed: List[List[int]] = []
        local_actions: List[List[int]] = []
        for global_action_list in global_valid_actions:
            global_action_list.sort()
            global_actions_to_embed.append([])
            local_actions.append([])
            for global_action_index, action_index in global_action_list:
                global_actions_to_embed[-1].append(global_action_index)
                local_actions[-1].append(action_index)
        max_num_actions = max([len(action_list) for action_list in global_actions_to_embed])
        # We pad local actions with -1 as padding to get considered actions.
        considered_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions,
                                                                 default_value=lambda: -1)
                              for action_list in local_actions]

        # action_embeddings: (group_size, num_embedded_actions, action_embedding_dim)
        # action_mask: (group_size, num_embedded_actions)
        action_embeddings, embedded_action_mask = self._get_action_embeddings(state,
                                                                              global_actions_to_embed)
        action_query = torch.cat([hidden_state, attended_sentence], dim=-1)
        # (group_size, action_embedding_dim)
        predicted_action_embedding = self._output_projection_layer(action_query)
        predicted_action_embedding = self._dropout(torch.nn.functional.tanh(predicted_action_embedding))
        if state.checklist_state[0] is not None:
            embedding_addition = self._get_predicted_embedding_addition(state)
            addition = embedding_addition * self._checklist_embedding_multiplier
            predicted_action_embedding = predicted_action_embedding + addition
        # We'll do a batch dot product here with `bmm`.  We want `dot(predicted_action_embedding,
        # action_embedding)` for each `action_embedding`, and we can get that efficiently with
        # `bmm` and some squeezing.
        # Shape: (group_size, num_embedded_actions)
        action_logits = action_embeddings.bmm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)

        action_mask = embedded_action_mask.float()
        if state.checklist_state[0] is not None:
            # We will compute the logprobs and the checklists of potential next states together for
            # efficiency.
            logprobs, new_checklist_states = self._get_next_state_info_with_agenda(state,
                                                                                   considered_actions,
                                                                                   action_logits,
                                                                                   action_mask)
        else:
            logprobs = self._get_next_state_info_without_agenda(state,
                                                                considered_actions,
                                                                action_logits,
                                                                action_mask)
            new_checklist_states = None
        return self._compute_new_states(state,
                                        logprobs,
                                        hidden_state,
                                        memory_cell,
                                        action_embeddings,
                                        attended_sentence,
                                        considered_actions,
                                        allowed_actions,
                                        new_checklist_states,
                                        max_actions)