def _get_state_cost(self, worlds: List[WikiTablesVariableFreeWorld], state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0)
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [state.possible_actions[batch_index][i][0] for i in action_history]
        logical_form = world.get_logical_form(action_strings)
        target_values = state.extras[batch_index]
        evaluation = False
        try:
            executor_logger = \
                    logging.getLogger('weak_supervision.semparse.executors.wikitables_variable_free_executor')
            executor_logger.setLevel(logging.ERROR)
            evaluation = world.evaluate_logical_form(logical_form, target_values)
        except IndexError:
            # TODO(pradeep): This happens due to a bug in "filter_in" and "filter_no_in" functions.
            # The value evaluation, if it is a list, could be an empty one. Fix it there!
            pass
        if evaluation:
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
Exemple #2
0
    def _get_state_cost(self, worlds: List[WikiTablesWorld],
                        state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(),
                                        min=0.0)
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [
            state.possible_actions[batch_index][i][0] for i in action_history
        ]
        logical_form = world.get_logical_form(action_strings)
        lisp_string = state.extras[batch_index]
        if self._executor.evaluate_logical_form(logical_form, lisp_string):
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost
    def _get_state_cost(
        self, batch_worlds: List[List[NlvrLanguage]], state: CoverageState
    ) -> torch.Tensor:
        """
        Return the cost of a finished state. Since it is a finished state, the group size will be
        1, and hence we'll return just one cost.

        The ``batch_worlds`` parameter here is because we need the world to check the denotation
        accuracy of the action sequence in the finished state.  Instead of adding a field to the
        ``State`` object just for this method, we take the ``World`` as a parameter here.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        instance_worlds = batch_worlds[state.batch_indices[0]]
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # extras being None happens when we are testing. We do not care about the cost
        # then.  TODO (pradeep): Make this cleaner.
        if state.extras is None or all(self._check_state_denotations(state, instance_worlds)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
 def _get_state_info(self,
                     state: CoverageState,
                     batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]:
     """
     This method is here for debugging purposes, in case you want to look at the what the model
     is learning. It may be inefficient to call it while training the model on real data.
     """
     if len(state.batch_indices) == 1 and state.is_finished():
         costs = [float(self._get_state_cost(batch_worlds, state).detach().cpu().numpy())]
     else:
         costs = []
     model_scores = [float(score.detach().cpu().numpy()) for score in state.score]
     all_actions = state.possible_actions[0]
     action_sequences = [[self._get_action_string(all_actions[action]) for action in history]
                         for history in state.action_history]
     agenda_sequences = []
     all_agenda_indices = []
     for checklist_state in state.checklist_state:
         agenda_indices = []
         for action, is_wanted in zip(checklist_state.terminal_actions, checklist_state.checklist_target):
             action_int = int(action.detach().cpu().numpy())
             is_wanted_int = int(is_wanted.detach().cpu().numpy())
             if is_wanted_int != 0:
                 agenda_indices.append(action_int)
         agenda_sequences.append([self._get_action_string(all_actions[action])
                                  for action in agenda_indices])
         all_agenda_indices.append(agenda_indices)
     return {"agenda": agenda_sequences,
             "agenda_indices": all_agenda_indices,
             "history": action_sequences,
             "history_indices": state.action_history,
             "costs": costs,
             "scores": model_scores}
    def _get_state_cost(self, worlds: List[WikiTablesWorld], state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0)
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [state.possible_actions[batch_index][i][0] for i in action_history]
        logical_form = world.get_logical_form(action_strings)
        lisp_string = state.extras[batch_index]
        if self._executor.evaluate_logical_form(logical_form, lisp_string):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
 def _get_state_info(self,
                     state: CoverageState,
                     batch_worlds: List[List[NlvrWorld]]) -> Dict[str, List]:
     """
     This method is here for debugging purposes, in case you want to look at the what the model
     is learning. It may be inefficient to call it while training the model on real data.
     """
     if len(state.batch_indices) == 1 and state.is_finished():
         costs = [float(self._get_state_cost(batch_worlds, state).detach().cpu().numpy())]
     else:
         costs = []
     model_scores = [float(score.detach().cpu().numpy()) for score in state.score]
     all_actions = state.possible_actions[0]
     action_sequences = [[self._get_action_string(all_actions[action]) for action in history]
                         for history in state.action_history]
     agenda_sequences = []
     all_agenda_indices = []
     for checklist_state in state.checklist_state:
         agenda_indices = []
         for action, is_wanted in zip(checklist_state.terminal_actions, checklist_state.checklist_target):
             action_int = int(action.detach().cpu().numpy())
             is_wanted_int = int(is_wanted.detach().cpu().numpy())
             if is_wanted_int != 0:
                 agenda_indices.append(action_int)
         agenda_sequences.append([self._get_action_string(all_actions[action])
                                  for action in agenda_indices])
         all_agenda_indices.append(agenda_indices)
     return {"agenda": agenda_sequences,
             "agenda_indices": all_agenda_indices,
             "history": action_sequences,
             "history_indices": state.action_history,
             "costs": costs,
             "scores": model_scores}
    def _get_state_cost(self, batch_worlds: List[List[NlvrWorld]], state: CoverageState) -> torch.Tensor:
        """
        Return the cost of a finished state. Since it is a finished state, the group size will be
        1, and hence we'll return just one cost.

        The ``batch_worlds`` parameter here is because we need the world to check the denotation
        accuracy of the action sequence in the finished state.  Instead of adding a field to the
        ``State`` object just for this method, we take the ``World`` as a parameter here.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        instance_worlds = batch_worlds[state.batch_indices[0]]
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # extras being None happens when we are testing. We do not care about the cost
        # then.  TODO (pradeep): Make this cleaner.
        if state.extras is None or all(self._check_state_denotations(state, instance_worlds)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
Exemple #8
0
    def _get_state_cost(self, worlds: List[WikiTablesLanguage],
                        state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError(
                "_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(),
                                        min=0.0)
        checklist_cost = torch.sum((checklist_balance)**2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(
            state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [
            state.possible_actions[batch_index][i][0] for i in action_history
        ]
        target_values = state.extras[batch_index]
        evaluation = False
        executor_logger = \
                logging.getLogger('allennlp.semparse.domain_languages.wikitables_language')
        executor_logger.setLevel(logging.ERROR)
        evaluation = world.evaluate_action_sequence(action_strings,
                                                    target_values)
        if evaluation:
            cost = checklist_cost
        else:
            cost = checklist_cost + (
                1 - self._checklist_cost_weight) * denotation_cost
        return cost