def _get_checklist_info(agenda: torch.LongTensor,
                            all_actions: List[ProductionRule],
                            terminal_productions: Set[str],
                            max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
 def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(1):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     targets_trimmed = targets[:, :len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
示例#4
0
 def _action_history_match(predicted: List[int],
                           targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(0):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     targets_trimmed = targets[:len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return predicted_tensor.equal(targets_trimmed)
示例#5
0
    def _get_checklist_info(
        self, agenda: torch.LongTensor, all_actions: List[ProductionRule]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = {int(x) for x in agenda.squeeze(0).detach().cpu().numpy()}
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
示例#6
0
    def __init__(self, start_tokens: torch.LongTensor,
                 end_token: Union[int, torch.LongTensor]):
        if start_tokens.dim() != 1:
            raise ValueError("start_tokens must be a vector")
        if not isinstance(end_token, int) and end_token.dim() != 0:
            raise ValueError("end_token must be a scalar")

        self._start_tokens = start_tokens
        self._batch_size = start_tokens.size(0)
        if isinstance(end_token, int):
            self._end_token = start_tokens.new_tensor(end_token)
        else:
            self._end_token = end_token
示例#7
0
 def is_equal(self, predicted: List[int], targets: torch.LongTensor,
              target_mask: torch.LongTensor) -> int:
     """
     Judge whether given predict sql is equal to ground truth under the db_id
     :return: if equal, return 1; otherwise, return 0
     """
     if len(predicted) > targets.size(0):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     # remove padding ones
     actual_len = target_mask.sum()
     targets_trimmed = targets[:actual_len]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     is_correct = torch.equal(predicted_tensor, targets_trimmed)
     if is_correct:
         return 1
     else:
         return 0
    def _get_candidates(self,
                        entity_ids: torch.LongTensor) -> torch.LongTensor:
        """
        Combines the unique ids from the current batch with the previous set of ids to form the
        collection of **all** relevant ids.

        Parameters
        ----------
        entity_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, sequence_length)`` whose elements are the ids
            of the corresponding token in the ``target`` sequence.

        Returns
        -------
        unique_entity_ids : ``torch.LongTensor``
            A tensor of shape ``(batch_size, max_num_parents)`` containing all of the unique
            candidate ids.
        """
        # Get the tensors of unique ids for each batch element and store them in a list
        all_unique: List[torch.LongTensor] = []
        for i, ids in enumerate(entity_ids):
            if self._remaining[i] is not None:
                previous_ids = list(self._remaining[i].keys())
                previous_ids = entity_ids.new_tensor(previous_ids)
                ids = torch.cat((ids.view(-1), previous_ids), dim=0)
            unique = torch.unique(ids, sorted=True)
            all_unique.append(unique)

        # Convert the list to a tensor by adding adequete padding.
        batch_size = entity_ids.shape[0]
        max_num_parents = max(unique.shape[0] for unique in all_unique)
        unique_entity_ids = entity_ids.new_zeros(
            size=(batch_size, max_num_parents))
        for i, unique in enumerate(all_unique):
            unique_entity_ids[i, :unique.shape[0]] = unique

        return unique_entity_ids
示例#9
0
def segment_lengths_to_ids(
        segment_lengths: torch.LongTensor) -> torch.LongTensor:
    """
    Args:
        segment_lengths: Non-negative lengths of the tensor segments

    Returns:
        A tensor containing ids for every element in the tensor to be segmented

    Examples:
        >>> segments = torch.tensor([2, 4, 3, 1])
        >>> segment_lengths_to_slices(segments)
        tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3])
    """
    if segment_lengths.dim() != 1:
        raise ValueError(
            f'`segment_lengths` should have a single dimension, got shape {segment_lengths.shape}'
        )
    if (segment_lengths < 0).any():
        raise ValueError(
            f'All entries in `segment_lengths` should be non-negative')

    return segment_lengths.new_tensor(
        np.arange(len(segment_lengths)).repeat(segment_lengths.cpu().numpy()))
示例#10
0
    def forward(self,
                word_ids: torch.LongTensor,
                word_segment_ids: torch.LongTensor,
                word_attention_mask: torch.LongTensor,
                entity_ids: torch.LongTensor,
                entity_position_ids: torch.LongTensor,
                entity_segment_ids: torch.LongTensor,
                entity_attention_mask: torch.LongTensor,
                masked_entity_labels: Optional[torch.LongTensor] = None,
                masked_lm_labels: Optional[torch.LongTensor] = None,
                **kwargs):
        model_dtype = next(self.parameters()).dtype  # for fp16 compatibility
        output = super().forward(
            word_ids,
            word_segment_ids,
            word_attention_mask,
            entity_ids,
            entity_position_ids,
            entity_segment_ids,
            entity_attention_mask,
        )
        word_sequence_output, entity_sequence_output = output[:2]

        loss_fn = CrossEntropyLoss(ignore_index=-1)
        ret = dict(loss=word_ids.new_tensor(0.0, dtype=model_dtype))

        if masked_entity_labels is not None:
            entity_mask = masked_entity_labels != -1
            if entity_mask.sum() > 0:
                target_entity_sequence_output = torch.masked_select(
                    entity_sequence_output, entity_mask.unsqueeze(-1))
                target_entity_sequence_output = target_entity_sequence_output.view(
                    -1, self.config.hidden_size)
                target_entity_labels = torch.masked_select(
                    masked_entity_labels, entity_mask)

                entity_scores = self.entity_predictions(
                    target_entity_sequence_output)
                entity_scores = entity_scores.view(
                    -1, self.config.entity_vocab_size)

                ret["masked_entity_loss"] = loss_fn(entity_scores,
                                                    target_entity_labels)
                ret["masked_entity_correct"] = (torch.argmax(
                    entity_scores, 1).data == target_entity_labels.data).sum()
                ret["masked_entity_total"] = target_entity_labels.ne(-1).sum()
                ret["loss"] += ret["masked_entity_loss"]
            else:
                ret["masked_entity_loss"] = word_ids.new_tensor(
                    0.0, dtype=model_dtype)
                ret["masked_entity_correct"] = word_ids.new_tensor(
                    0, dtype=torch.long)
                ret["masked_entity_total"] = word_ids.new_tensor(
                    0, dtype=torch.long)

        if masked_lm_labels is not None:
            masked_lm_mask = masked_lm_labels != -1
            if masked_lm_mask.sum() > 0:
                masked_word_sequence_output = torch.masked_select(
                    word_sequence_output, masked_lm_mask.unsqueeze(-1))
                masked_word_sequence_output = masked_word_sequence_output.view(
                    -1, self.config.hidden_size)

                if self.config.bert_model_name and "roberta" in self.config.bert_model_name:
                    masked_lm_scores = self.lm_head(
                        masked_word_sequence_output)
                else:
                    masked_lm_scores = self.cls.predictions(
                        masked_word_sequence_output)
                masked_lm_scores = masked_lm_scores.view(
                    -1, self.config.vocab_size)
                masked_lm_labels = torch.masked_select(masked_lm_labels,
                                                       masked_lm_mask)

                ret["masked_lm_loss"] = loss_fn(masked_lm_scores,
                                                masked_lm_labels)
                ret["masked_lm_correct"] = (torch.argmax(
                    masked_lm_scores, 1).data == masked_lm_labels.data).sum()
                ret["masked_lm_total"] = masked_lm_labels.ne(-1).sum()
                ret["loss"] += ret["masked_lm_loss"]
            else:
                ret["masked_lm_loss"] = word_ids.new_tensor(0.0,
                                                            dtype=model_dtype)
                ret["masked_lm_correct"] = word_ids.new_tensor(
                    0, dtype=torch.long)
                ret["masked_lm_total"] = word_ids.new_tensor(0,
                                                             dtype=torch.long)

        return ret