def _get_mask_for_eval(self,
                           mask: torch.LongTensor,
                           pos_tags: torch.LongTensor) -> torch.LongTensor:
        """
        Dependency evaluation excludes words are punctuation.
        Here, we create a new mask to exclude word indices which
        have a "punctuation-like" part of speech tag.

        Parameters
        ----------
        mask : ``torch.LongTensor``, required.
            The original mask.
        pos_tags : ``torch.LongTensor``, required.
            The pos tags for the sequence.

        Returns
        -------
        A new mask, where any indices equal to labels
        we should be ignoring are masked.
        """
        new_mask = mask.detach()
        for label in self._pos_to_ignore:
            label_mask = pos_tags.eq(label).long()
            new_mask = new_mask * (1 - label_mask)
        return new_mask
Exemplo n.º 2
0
def flattened_index_select(target: torch.Tensor,
                           indices: torch.LongTensor) -> torch.Tensor:
    """
    The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
    that each of the set_size rows should select. The `target` has size
    ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
    ``(batch_size, set_size, subset_size, embedding_size)``.

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, sequence_length, embedding_size).
    indices : ``torch.LongTensor``, required.
        A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
        as this tensor is an index into the sequence_length dimension of the target.

    Returns
    -------
    selected : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
    """
    if indices.dim() != 2:
        raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
                                 "only 2 dimensional inputs are supported.".format(indices.size()))
    # Shape: (batch_size, set_size * subset_size, embedding_size)
    flattened_selected = target.index_select(1, indices.view(-1))

    # Shape: (batch_size, set_size, subset_size, embedding_size)
    selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
    return selected
Exemplo n.º 3
0
    def __call__(self,  # type: ignore
                 predictions: torch.LongTensor,
                 gold_targets: torch.LongTensor) -> None:
        """
        Update precision counts.

        Parameters
        ----------
        predictions : ``torch.LongTensor``, required
            Batched predicted tokens of shape `(batch_size, max_sequence_length)`.
        references : ``torch.LongTensor``, required
            Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`.

        Returns
        -------
        None
        """
        predictions, gold_targets = self.unwrap_to_tensors(predictions, gold_targets)
        for ngram_size, _ in enumerate(self._ngram_weights, start=1):
            precision_matches, precision_totals = self._get_modified_precision_counts(
                    predictions, gold_targets, ngram_size)
            self._precision_matches[ngram_size] += precision_matches
            self._precision_totals[ngram_size] += precision_totals
        if not self._exclude_indices:
            self._prediction_lengths += predictions.size(0) * predictions.size(1)
            self._reference_lengths += gold_targets.size(0) * gold_targets.size(1)
        else:
            valid_predictions_mask = self._get_valid_tokens_mask(predictions)
            self._prediction_lengths += valid_predictions_mask.sum().item()
            valid_gold_targets_mask = self._get_valid_tokens_mask(gold_targets)
            self._reference_lengths += valid_gold_targets_mask.sum().item()
Exemplo n.º 4
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(f"Adjusted span indices must lie inside the the sequence tensor, "
                                 f"but found: exclusive_span_starts: {exclusive_span_starts}.")

            start_embeddings = util.batched_index_select(sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor, span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()
        return combined_tensors
 def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(1):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     targets_trimmed = targets[:, :len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
Exemplo n.º 6
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    # Contribution to the negative log likelihood only comes from the exact indices
    # of the targets, as the target distributions are one-hot. Here we use torch.gather
    # to extract the indices of the num_classes dimension which contribute to the loss.
    # shape : (batch * sequence_length, 1)
    negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Exemplo n.º 7
0
    def _get_modified_precision_counts(self,
                                       predicted_tokens: torch.LongTensor,
                                       reference_tokens: torch.LongTensor,
                                       ngram_size: int) -> Tuple[int, int]:
        """
        Compare the predicted tokens to the reference (gold) tokens at the desired
        ngram size and calculate the numerator and denominator for a modified
        form of precision.

        The numerator is the number of ngrams in the predicted sentences that match
        with an ngram in the corresponding reference sentence, clipped by the total
        count of that ngram in the reference sentence. The denominator is just
        the total count of predicted ngrams.
        """
        clipped_matches = 0
        total_predicted = 0
        for batch_num in range(predicted_tokens.size(0)):
            predicted_row = predicted_tokens[batch_num, :]
            reference_row = reference_tokens[batch_num, :]
            predicted_ngram_counts = self._ngrams(predicted_row, ngram_size)
            reference_ngram_counts = self._ngrams(reference_row, ngram_size)
            for ngram, count in predicted_ngram_counts.items():
                clipped_matches += min(count, reference_ngram_counts[ngram])
                total_predicted += count
        return clipped_matches, total_predicted
    def _get_checklist_info(agenda: torch.LongTensor,
                            all_actions: List[ProductionRule],
                            terminal_productions: Set[str],
                            max_num_terminals: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Takes an agenda, a list of all actions, a set of terminal productions in the corresponding
        world, and a length to pad the checklist vectors to, and returns a target checklist against
        which the checklist at each state will be compared to compute a loss, indices of
        ``terminal_actions``, and a ``checklist_mask`` that indicates which of the terminal actions
        are relevant for checklist loss computation.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRule]``
            All actions for one instance.
        ``terminal_productions`` : ``Set[str]``
            String representations of terminal productions in the corresponding world.
        ``max_num_terminals`` : ``int``
            Length to which the checklist vectors will be padded till. This is the max number of
            terminal productions in all the worlds in the batch.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRule, a tuple where the first item is the production
            # rule string.
            if action[0] in terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        while len(target_checklist_list) < max_num_terminals:
            target_checklist_list.append([0])
            terminal_indices.append([-1])
        # (max_num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (max_num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
Exemplo n.º 9
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs : ``torch.Tensor``, required.
            A Tensor of shape ``(batch_size, sequence_length, hidden_size)``.
        mask : ``torch.LongTensor``, required.
            A binary mask of shape ``(batch_size, sequence_length)`` representing the
            non-padded elements in each sequence in the batch.

        Returns
        -------
        A ``torch.Tensor`` of shape (num_layers, batch_size, sequence_length, hidden_size),
        where the num_layers dimension represents the LSTM output from that layer.
        """
        batch_size, total_sequence_length = mask.size()
        stacked_sequence_output, final_states, restoration_indices = \
            self.sort_and_run_forward(self._lstm_forward, inputs, mask)

        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size()
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        if num_valid < batch_size:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size - num_valid,
                                                     returned_timesteps,
                                                     encoder_dim).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.data.new(num_layers, batch_size - num_valid, state_dim).fill_(0)
                zeros = Variable(zeros)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size,
                                                     sequence_length_difference,
                                                     stacked_sequence_output[0].size(-1)).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2)

        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)
    def _get_checklist_info(self,
                            agenda: torch.LongTensor,
                            all_actions: List[ProductionRuleArray]) -> Tuple[torch.Tensor,
                                                                             torch.Tensor,
                                                                             torch.Tensor]:
        """
        Takes an agenda and a list of all actions and returns a target checklist against which the
        checklist at each state will be compared to compute a loss, indices of ``terminal_actions``,
        and a ``checklist_mask`` that indicates which of the terminal actions are relevant for
        checklist loss computation. If ``self.penalize_non_agenda_actions`` is set to``True``,
        ``checklist_mask`` will be all 1s (i.e., all terminal actions are relevant). If it is set to
        ``False``, indices of all terminals that are not in the agenda will be masked.

        Parameters
        ----------
        ``agenda`` : ``torch.LongTensor``
            Agenda of one instance of size ``(agenda_size, 1)``.
        ``all_actions`` : ``List[ProductionRuleArray]``
            All actions for one instance.
        """
        terminal_indices = []
        target_checklist_list = []
        agenda_indices_set = set([int(x) for x in agenda.squeeze(0).detach().cpu().numpy()])
        for index, action in enumerate(all_actions):
            # Each action is a ProductionRuleArray, a tuple where the first item is the production
            # rule string.
            if action[0] in self._terminal_productions:
                terminal_indices.append([index])
                if index in agenda_indices_set:
                    target_checklist_list.append([1])
                else:
                    target_checklist_list.append([0])
        # We want to return checklist target and terminal actions that are column vectors to make
        # computing softmax over the difference between checklist and target easier.
        # (num_terminals, 1)
        terminal_actions = agenda.new_tensor(terminal_indices)
        # (num_terminals, 1)
        target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
        if self._penalize_non_agenda_actions:
            # All terminal actions are relevant
            checklist_mask = torch.ones_like(target_checklist)
        else:
            checklist_mask = (target_checklist != 0).float()
        return target_checklist, terminal_actions, checklist_mask
Exemplo n.º 11
0
def batched_index_select(target: torch.Tensor,
                         indices: torch.LongTensor,
                         flattened_indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
    """
    The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence
    dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length,
    embedding_size)``.

    This function returns selected values in the target with respect to the provided indices, which
    have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally
    precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given.

    An example use case of this function is looking up the start and end indices of spans in a
    sequence tensor. This is used in the
    :class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select
    contextual word representations corresponding to the start and end indices of mentions. The key
    reason this can't be done with basic torch functions is that we want to be able to use look-up
    tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
    a-priori how many spans we are looking up).

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : ``torch.LongTensor``
        A tensor of shape (batch_size, ...), where each element is an index into the
        ``sequence_length`` dimension of the ``target`` tensor.
    flattened_indices : Optional[torch.Tensor], optional (default = None)
        An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices`
        on ``indices``. This is helpful in the case that the indices can be flattened once and
        cached for many batch lookups.

    Returns
    -------
    selected_targets : ``torch.Tensor``
        A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
        extracted from the batch flattened target tensor.
    """
    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.view(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.size()) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.view(*selected_shape)
    return selected_targets
Exemplo n.º 12
0
    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]
Exemplo n.º 13
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]


        if offsets is None:
            return mix
        else:
            batch_size = input_ids.size(0)
            range_vector = util.get_range_vector(batch_size,
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            return mix[range_vector, offsets]
Exemplo n.º 14
0
 def _ngrams(self,
             tensor: torch.LongTensor,
             ngram_size: int) -> Dict[Tuple[int, ...], int]:
     ngram_counts: Dict[Tuple[int, ...], int] = Counter()
     if ngram_size > tensor.size(-1):
         return ngram_counts
     for start_position in range(ngram_size):
         for tensor_slice in tensor[start_position:].split(ngram_size, dim=-1):
             if tensor_slice.size(-1) < ngram_size:
                 break
             ngram = tuple(x.item() for x in tensor_slice)
             if any(x in self._exclude_indices for x in ngram):
                 continue
             ngram_counts[ngram] += 1
     return ngram_counts
Exemplo n.º 15
0
    def _prepare_decode_step_input(self,
                                   input_indices: torch.LongTensor,
                                   decoder_hidden_state: torch.LongTensor = None,
                                   encoder_outputs: torch.LongTensor = None,
                                   encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)
        if self._attention_function:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Exemplo n.º 16
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None,
            # pylint: disable=unused-argument
            **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)

        logits = torch.nn.functional.log_softmax(
            self.tag_projection_layer(encoded_text), dim=-1)
        predicted_tags = torch.argmax(logits, -1)
        # # Just get the tags and ignore the score.
        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            active_loss = mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_tags)[active_loss]
            active_tags = tags.view(-1)[active_loss]
            loss = self._loss(active_logits, active_tags)
            output["loss"] = loss

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for key, metric in self.metrics.items():
                if 'accuracy' in key:
                    metric(logits.view(-1, self.num_tags), tags.view(-1),
                           mask.view(-1) == 1)

            if self.calculate_span_f1:
                tags_ = [[
                    self.vocab.get_token_from_index(
                        tag.data.item(), namespace=self.label_namespace)
                    for tag in instance_tags
                ] for instance_tags in tags]
                predicted_tags_ = [[
                    self.vocab.get_token_from_index(
                        tag.data.item(), namespace=self.label_namespace)
                    for tag in instance_tags
                ] for instance_tags in predicted_tags]

                prec, recall, f1 = evaluate(
                    [x for y in tags_ for x in y],
                    [x for y in predicted_tags_ for x in y])
                self.metrics['precision-measure-overall'](prec)
                self.metrics['recall-measure-overall'](recall)
                self.metrics['f1-measure-overall'](f1)

        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output
Exemplo n.º 17
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_starts: torch.LongTensor,
                span_ends: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> None:
        sequence_tensor = sequence_tensor.unsqueeze(0)
        span_starts = span_starts.unsqueeze(0)
        span_ends = span_ends.unsqueeze(0)
        if sequence_mask is not None:
            sequence_mask = sequence_mask.unsqueeze(0)
        if span_indices_mask is not None:
            span_indices_mask = span_indices_mask.unsqueeze(0)

        if span_indices_mask is not None:
            # It's not strictly necessary to multiply the span indices by the mask here,
            # but it's possible that the span representation was padded with something other
            # than 0 (such as -1, which would be an invalid index), so we do so anyway to
            # be safe.
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask

        if not self._use_exclusive_start_indices:
            start_embeddings = util.batched_index_select(
                sequence_tensor, span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

        else:
            # We want `exclusive` span starts, so we remove 1 from the forward span starts
            # as the AllenNLP ``SpanField`` is inclusive.
            # shape (batch_size, num_spans)
            exclusive_span_starts = span_starts - 1
            # shape (batch_size, num_spans, 1)
            start_sentinel_mask = (
                exclusive_span_starts == -1).long().unsqueeze(-1)
            exclusive_span_starts = exclusive_span_starts * (
                1 - start_sentinel_mask.squeeze(-1))

            # We'll check the indices here at runtime, because it's difficult to debug
            # if this goes wrong and it's tricky to get right.
            if (exclusive_span_starts < 0).any():
                raise ValueError(
                    "Adjusted span indices must lie inside the the sequence tensor, "
                    "but found: exclusive_span_starts: {exclusive_span_starts}."
                )

            start_embeddings = util.batched_index_select(
                sequence_tensor, exclusive_span_starts)
            end_embeddings = util.batched_index_select(sequence_tensor,
                                                       span_ends)

            # We're using sentinels, so we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with the start sentinel.
            float_start_sentinel_mask = start_sentinel_mask.float()
            start_embeddings = start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel

        combined_tensors = util.combine_tensors(
            self._combination, [start_embeddings, end_embeddings])
        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts,
                    num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([combined_tensors, span_width_embeddings],
                             -1).squeeze(0)

        if span_indices_mask is not None:
            return (combined_tensors *
                    span_indices_mask.unsqueeze(-1).float()).squeeze(0)
        return combined_tensors.squeeeze(0)
def train_model(model, loss_fn, lr=0.001, batch_size=64, n_epochs=5):
    min_loss = float('inf')
    for i, (train_idx, val_idx) in enumerate(idx_splits):

        train_ds = TorchtextSubset(kfold_train_tabular_dataset, train_idx)
        val_ds = TorchtextSubset(kfold_train_tabular_dataset, val_idx)

        train_loader, val_loader = torchtext.data.BucketIterator.splits(
            [train_ds, val_ds],
            batch_sizes=[batch_size, batch_size],
            device=device,
            sort_key=lambda x: len(x.comment_text),
            sort_within_batch=True,
            repeat=False)

        print('Fold::::::::::', i)

        param_lrs = [{
            'params': param,
            'lr': lr
        } for param in model.parameters()]
        optimizer = torch.optim.Adam(param_lrs, lr=lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda epoch: 0.6**epoch)
        step = 0

        for epoch in range(n_epochs):
            train_loader.init_epoch()

            start_time = time.time()

            epoch_loss, epoch_acc = 0, 0

            for i, data in enumerate(train_loader):
                step += 1
                optimizer.zero_grad()

                data_len_comment_text = LongTensor(
                    list(map(len, data.comment_text)))

                question = data.comment_text
                question = question.to(device)
                my_target = data.our_target.to(device).float()

                x_batch = question
                y_batch = my_target

                y_pred = model(x_batch, data_len_comment_text).squeeze(1)
                loss = loss_fn(y_pred, y_batch)

                acc = binary_accuracy(y_pred, y_batch)

                loss.backward()
                optimizer.step()
                scheduler.step()

                epoch_loss += loss.item()
                epoch_acc += acc.item()

                if step % len(train_loader) == 0:

                    val_acc, val_loss = get_val_score(model, val_loader,
                                                      loss_fn)

                    print("current val_loss", val_loss, "last min_loss",
                          min_loss)

                    if val_loss < min_loss:
                        save(m=model,
                             info={
                                 'epoch': epoch,
                                 'val_loss': val_loss
                             })
                        min_loss = val_loss

                    print('val_acc', val_acc, 'val_loss', val_loss,
                          'train_acc', epoch_acc / len(train_loader),
                          'train_loss', epoch_loss / len(train_loader))

            elapsed_time = time.time() - start_time
            print(
                'Epoch {}/{} \t loss={:.4f} \t accouracy={} \t time={:.2f}s  '.
                format(epoch + 1, n_epochs, epoch_loss / len(train_loader),
                       epoch_acc / len(train_loader), elapsed_time))

    return model
Exemplo n.º 19
0
 def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor:
     valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.uint8)
     for index in self._exclude_indices:
         valid_tokens_mask = valid_tokens_mask & (tensor != index)
     return valid_tokens_mask
Exemplo n.º 20
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                mask_positions: torch.LongTensor,
                target_ids: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_tensor()`` for a batch of sentences.
        mask_positions : ``torch.LongTensor``
            The positions in ``tokens`` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : ``Dict[str, torch.LongTensor]``
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a ``TextField``, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            ``mask_positions`` - one target token per mask position.
        """
        # pylint: disable=arguments-differ
        targets = None
        if target_ids is not None:
            # A bit of a hack to get the right targets out of the TextField output...
            if len(target_ids) != 1:
                targets = target_ids['bert']
            else:
                targets = list(target_ids.values())[0]
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(f"Number of targets ({targets.size()}) and number of masks "
                             f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size, 5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        # Using the namespace here is a hack...
        output_dict["token_ids"] = tokens[self._target_namespace]

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks, vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict['loss'] = loss

        return output_dict
Exemplo n.º 21
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesWorld],
        actions: List[List[ProductionRuleArray]],
        example_lisp_string: List[str] = None,
        target_action_sequences: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(
            entity_type_embeddings + projected_neighbor_embeddings)

        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (
            embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (
            embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = Variable(
            encoder_outputs.data.new(batch_size,
                                     self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(
            embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state = WikiTablesDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            debug_info=None)
        if self.training:
            return self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if self._linking_params is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                        0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in best_action_indices
                    ]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(
                            action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form,
                                                  example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs
Exemplo n.º 22
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                verb_span: torch.LongTensor,
                entity_span: torch.LongTensor,
                state_change_type_labels: torch.LongTensor = None,
                state_change_tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_span: torch.LongTensor, required.
            An integer ``SequenceLabelField`` representation of the position of the focus verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that pre-processing stage could not extract a verbal predicate.
        entity_span: torch.LongTensor, required.
            An integer ``SequenceLabelField`` representation of the position of the focus entity
            in the sentence. This should have shape (batch_size, num_tokens) 
        state_change_type_labels: torch.LongTensor, optional (default = None)
            A torch tensor representing the state change type class labels of shape ``(batch_size, 1)???
        state_change_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
            In the first implementation we focus only on state_change_types.

        Returns
        -------
        An output dictionary consisting of:
        type_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_state_change_types)`` representing
            a distribution of state change types per datapoint.
        tags_class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_state_change_types, num_tokens)`` representing
            a distribution of location tags per token in a sentence.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # Layer 1 = Word + Character embedding layer
        embedded_sentence = self.text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        # Layer 2 = Add positional bit to encode position of focus verb and entity
        embedded_sentence_verb_entity = \
            torch.cat([embedded_sentence, verb_span.float().unsqueeze(-1), entity_span.float().unsqueeze(-1)], dim=-1)

        # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence
        contextual_embedding = self.seq2seq_encoder(embedded_sentence_verb_entity, mask)

        # Layer 4: Attention (Contextual embedding, BOW(verb span))
        verb_weight_matrix = verb_span.float() / (verb_span.float().sum(-1).unsqueeze(-1) + 1e-13)
        verb_vector = weighted_sum(contextual_embedding * verb_span.float().unsqueeze(-1), verb_weight_matrix)
        entity_weight_matrix = entity_span.float() / (entity_span.float().sum(-1).unsqueeze(-1) + 1e-13)
        entity_vector = weighted_sum(contextual_embedding * entity_span.float().unsqueeze(-1), entity_weight_matrix)
        verb_entity_vector = torch.cat([verb_vector, entity_vector], 1)
        batch_size, sequence_length, binary_feature_dim = verb_span.float().unsqueeze(-1).size()

        # attention weights for type prediction
        attention_weights_types = self.attention_layer(verb_entity_vector, contextual_embedding)
        attention_output_vector = weighted_sum(contextual_embedding, attention_weights_types)

        # contextual embedding + positional vectors for tag prediction
        context_positional_tags = torch.cat([contextual_embedding, verb_span.float().unsqueeze(-1), entity_span.float().unsqueeze(-1)], dim=-1)

        # Layer 5 = Dense softmax layer to pick one state change type per datapoint,
        # and one tag per word in the sentence
        type_logits = self.aggregate_feedforward(attention_output_vector)
        type_probs = torch.nn.functional.softmax(type_logits, dim=-1)

        tags_logits = self.tag_projection_layer(context_positional_tags)
        reshaped_log_probs = tags_logits.view(-1, self.num_tags)
        tags_class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.num_tags])

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {'type_probs': type_probs}
        if state_change_type_labels is not None:
            state_change_type_labels_loss = self._loss(type_logits, state_change_type_labels.long().view(-1))
            for type_label in self.type_labels_vocab.values():
                metric = self.type_f1_metrics["type_" + type_label]
                metric(type_probs, state_change_type_labels.squeeze(-1))

            self._type_accuracy(type_probs, state_change_type_labels.squeeze(-1))

        if state_change_tags is not None:
            state_change_tags_loss = sequence_cross_entropy_with_logits(tags_logits, state_change_tags, mask)
            self.span_metric(tags_class_probabilities, state_change_tags, mask)
            output_dict["tags_class_probabilities"] = tags_class_probabilities

        output_dict['loss'] = (state_change_type_labels_loss + state_change_tags_loss)

        return output_dict
Exemplo n.º 23
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            spans=None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        # Shape: (batch_size, 4, passage_length, embedding_dim)
        embedded_passage = self._text_field_embedder(passage)

        (batch_size, q_length, embedding_dim) = embedded_question.size()
        passage_length = embedded_passage.size(2)

        # reshape: (batch_size*4, -1, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._highway_layer(embedded_passage)

        embedded_question = embedded_question.unsqueeze(0).expand(
            4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim)
        question_mask = util.get_text_field_mask(question).float()
        question_mask = question_mask.unsqueeze(0).expand(
            4, -1, -1).contiguous().view(-1, q_length)

        passage_mask = util.get_text_field_mask(passage, 1).float()
        passage_mask = passage_mask.view(-1, passage_length)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        cuda_device = encoded_question.get_device()

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))

        # TODO: attach residual self-attention layer
        # Shape: (batch_size, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        mask = passage_mask.resize(batch_size, passage_length,
                                   1) * passage_mask.resize(
                                       batch_size, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size, passage_length, passage_length)
        x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2)
        y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1)
        dot_similarity = torch.bmm(residual_passage * self._w_xy,
                                   residual_passage.transpose(1, 2))
        passage_self_similarity = dot_similarity + x_similarity + y_similarity
        #for i in range(passage_length):
        #    passage_self_similarity[:, i, i] = float('-Inf')
        # Shape: (batch_size, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))

        # Shape: (batch_size, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Shape: (batch_size, encoding_dim)
        v_1 = util.weighted_sum(encoded_span_start, span_start_probs)
        v_2 = util.weighted_sum(encoded_span_end, span_end_probs)

        no_span_logits = self._no_answer_predictor(
            self_attended_passage).squeeze(-1)
        no_span_probs = util.masked_softmax(no_span_logits, passage_mask)
        v_3 = util.weighted_sum(self_attended_passage, no_span_probs)
        # Shape: (batch_size, 1)
        z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1))
        # compute no-answer score

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }
        # create target tensor including no-answer label
        span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device)
        for b in range(batch_size):
            span_target[b].data[0] = span_start[
                b, 0].data[0] * passage_length + span_end[b, 0].data[0]
        span_target[span_target < 0] = passage_length**2

        # Shape: (batch_size, passage_length, passage_length)
        span_start_logits_tiled = span_start_logits.unsqueeze(1).expand(
            batch_size, passage_length, passage_length)
        span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand(
            batch_size, passage_length, passage_length)
        span_logits = (span_start_logits_tiled + span_end_logits_tiled).view(
            batch_size, -1)
        answer_mask = torch.bmm(passage_mask.unsqueeze(-1),
                                passage_mask.unsqueeze(1)).view(
                                    batch_size, -1)
        no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        all_logits = torch.cat([span_logits, z_score], dim=-1)
        loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask),
                        span_target)
        output_dict["loss"] = loss

        # Shape(batch_size, max_answers, num_span)
        #    max_answers = spans.size(1)
        #    span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1)
        #    answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1)
        #    no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        #    combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        #    # Shape: (batch_size, passage_length**2 + 1)
        #    all_logits = torch.cat([span_logits, z_score], dim=-1)
        #    # Shape: (batch_size, max_answers)
        #    spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1]
        #    spans_combined[spans_combined < 0] = passage_length*passage_length
        #
        #    all_modified_logits = []
        #    for b in range(batch_size):
        #        idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device)
        #        for i in range(max_answers):
        #            idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data
        #        idxs[passage_length**2].data[0] = passage_length**2

        #        modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device)
        #        modified_logits.index_add_(0, idxs, all_logits[b])
        #        all_modified_logits.append(modified_logits)

        #    all_modified_logits = torch.stack(all_modified_logits, dim=0)
        #    loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0])
        #    output_dict["loss"] = loss

        if span_start is not None:
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Exemplo n.º 24
0
def compute_class_AP(model,
                     dl,
                     n_classes,
                     show_progress,
                     iou_thresh=0.1,
                     detect_thresh=0.5,
                     num_keep=100):

    tps, clas, p_scores = [], [], []
    classes, n_gts = LongTensor(
        range(n_classes)), torch.zeros(n_classes).long()
    model.learn.model.eval()

    with torch.no_grad():
        for input, target in progress_bar(dl, display=show_progress):
            # input - 4(batch-size),3,256,256
            # target - 2(regression,classification), 4(batch-size), 3/4/2(max no of detections in the batch), 4/1(bbox,class)
            output = model.learn.pred_batch(batch=(input, target))

            for i in range(target[0].size(0)):  # range batch-size
                #output[0] - classpreds, output[1] - bbox preds
                op = model._data.y.analyze_pred((output[0][i], output[1][i]),
                                                thresh=detect_thresh,
                                                nms_overlap=iou_thresh,
                                                ssd=model,
                                                ret_scores=True,
                                                device=model._device)
                #op - bbox preds, class preds, scores

                # Unpad the targets
                tgt_bbox, tgt_clas = _get_y(target[0][i], target[1][i])

                try:
                    bbox_pred, preds, scores = op
                    if len(bbox_pred) != 0 and len(tgt_bbox) != 0:

                        bbox_pred = bbox_pred.to(model._device)
                        preds = preds.to(model._device)
                        tgt_bbox = tgt_bbox.to(model._device)

                        # Convert the bbox coordinates to center-height-width(cthw) before calculating Intersection Over Union
                        ious = IoU_values(tlbr2cthw(bbox_pred),
                                          tlbr2cthw(tgt_bbox))
                        max_iou, matches = ious.max(1)
                        detected = []

                        for i in range(len(preds)):
                            if max_iou[i] >= iou_thresh and matches[
                                    i] not in detected and tgt_clas[
                                        matches[i]] == preds[i]:
                                detected.append(matches[i])
                                tps.append(1)
                            else:
                                tps.append(0)
                        clas.append(preds.cpu())
                        p_scores.append(scores.cpu())
                except:
                    pass
                n_gts += ((tgt_clas.cpu()[:, None] -
                           1) == classes[None, :]).sum(0)

    # If no true positives are found return an average precision score of 0.
    if len(tps) == 0: return [0. for cls in range(1, n_classes + 1)]

    tps, p_scores, clas = torch.tensor(tps), torch.cat(p_scores,
                                                       0), torch.cat(clas, 0)
    fps = 1 - tps
    idx = p_scores.argsort(descending=True)
    tps, fps, clas = tps[idx], fps[idx], clas[idx]
    aps = []

    for cls in range(1, n_classes + 1):
        tps_cls, fps_cls = tps[clas == cls].float().cumsum(0), fps[
            clas == cls].float().cumsum(0)
        if tps_cls.numel() != 0 and tps_cls[-1] != 0:
            precision = tps_cls / (tps_cls + fps_cls + 1e-8)
            recall = tps_cls / (n_gts[cls - 1] + 1e-8)
            aps.append(compute_ap(precision, recall))
        else:
            aps.append(0.)
    return aps
Exemplo n.º 25
0
 def get_long_tensor(np_tensor):
     if torch.cuda.is_available():
         return LongTensor(from_numpy(np_tensor)).cuda()
     else:
         return LongTensor(from_numpy(np_tensor))
Exemplo n.º 26
0
def encode_class(idxs, n_classes):
    target = idxs.new_zeros(len(idxs), n_classes).float()
    mask = idxs != 0
    i1s = LongTensor(list(range(len(idxs))))
    target[i1s[mask], idxs[mask] - 1] = 1
    return target
Exemplo n.º 27
0
def _dynamic_rnn_loop(cell: RNNCellBase[State],
                      inputs: torch.Tensor,
                      initial_state: State,
                      sequence_length: torch.LongTensor) \
        -> Tuple[torch.Tensor, State]:
    r"""Internal implementation of Dynamic RNN.

    Args:
        cell: An instance of RNNCell.
        inputs: A ``Tensor`` of shape ``[time, batch_size, input_size]``,
            or a nested tuple of such elements.
        initial_state: A ``Tensor`` of shape ``[batch_size, state_size]``,
            or if ``cell.state_size`` is a tuple, then this should be a tuple
            of tensors having shapes ``[batch_size, s]`` for ``s`` in
            ``cell.state_size``.
        sequence_length: (optional) An ``int32`` ``Tensor``
            of shape ``[batch_size]``.

    Returns:
        Tuple ``(final_outputs, final_state)``.
        final_outputs:
            A ``Tensor`` of shape ``[time, batch_size, cell.output_size]``. If
            ``cell.output_size`` is a (possibly nested) tuple of ints or
            ``TensorShape`` objects, then this returns a
            (possibly nested) tuple of Tensors matching the corresponding
            shapes.
        final_state:
            A ``Tensor``, or possibly nested tuple of Tensors, matching
            in length and shapes to ``initial_state``.
    """
    state = initial_state
    time_steps = inputs.shape[0]
    all_outputs = []

    all_state: MaybeTuple[List[torch.Tensor]]
    if isinstance(state, tuple):
        all_state = ([], [])
    else:
        all_state = []

    for i in range(time_steps):
        output, state = cell(inputs[i], state)
        all_outputs.append(output)
        if isinstance(state, tuple):
            all_state[0].append(state[0])
            all_state[1].append(state[1])
        else:
            all_state.append(state)  # type: ignore
    # TODO: Do not compute everything regardless of sequence_length

    final_outputs = torch.stack(all_outputs, dim=0)

    final_outputs = mask_sequences(final_outputs,
                                   sequence_length=sequence_length,
                                   time_major=True)

    final_state: MaybeTuple[List[torch.Tensor]]
    if isinstance(state, tuple):
        final_state = ([], [])
    else:
        final_state = []

    for batch_idx, time_idx in enumerate(sequence_length.tolist()):
        if time_idx > 0:
            if isinstance(state, tuple):
                final_state[0].append(all_state[0][time_idx - 1][batch_idx])
                final_state[1].append(all_state[1][time_idx - 1][batch_idx])
            else:
                final_state.append(  # type: ignore
                    all_state[time_idx - 1][batch_idx])
        else:
            if isinstance(initial_state, tuple):
                final_state[0].append(initial_state[0][batch_idx])
                final_state[1].append(initial_state[1][batch_idx])
            else:
                final_state.append(initial_state[batch_idx])  # type: ignore

    if isinstance(state, tuple):
        final_state = (torch.stack(final_state[0],
                                   dim=0), torch.stack(final_state[1], dim=0))
    else:
        final_state = torch.stack(final_state, dim=0)  # type: ignore

    return final_outputs, final_state
    def _get_linking_probabilities(
            self, worlds: List[WikiTablesVariableFreeWorld],
            linking_scores: torch.FloatTensor, question_mask: torch.LongTensor,
            entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "date_column:", followed by "number_column:", "string:", and "string_column:".
            # This is not a great assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities +
                                        entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices,
                                                    dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(
                    1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores,
                                                                 dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(
                    num_question_tokens,
                    num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRule]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default = None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=example_lisp_string,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             example_lisp_string,
                                             metadata,
                                             outputs)

            return outputs
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(int(self._input_dim / 2), dim=-1)
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP ``SpanField`` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP ``SpanField`` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = util.ones_like(sequence_tensor[:, 0, 0]).long() * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (exclusive_span_ends == sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (exclusive_span_ends > sequence_lengths.unsqueeze(-1)).any():
            raise ValueError(f"Adjusted span indices must lie inside the length of the sequence tensor, "
                             f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                             f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                             f"{sequence_lengths}.")

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(forward_sequence, exclusive_span_starts)
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(backward_sequence, exclusive_span_ends)
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = forward_start_embeddings * (1 - float_start_sentinel_mask) \
                                        + float_start_sentinel_mask * self._start_sentinel
            backward_start_embeddings = backward_start_embeddings * (1 - float_end_sentinel_mask) \
                                        + float_end_sentinel_mask * self._end_sentinel

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(self._forward_combination,
                                             [forward_start_embeddings, forward_end_embeddings])
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(self._backward_combination,
                                              [backward_start_embeddings, backward_end_embeddings])
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
Exemplo n.º 31
0
    def construct_trees(self,
                        predictions: torch.FloatTensor,
                        all_spans: torch.LongTensor,
                        num_spans: torch.LongTensor,
                        sentences: List[List[str]],
                        pos_tags: List[List[str]] = None) -> List[Tree]:
        """
        Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
        The trees use exclusive end indices, which contrasts with how spans are
        represented in the rest of the model.

        Parameters
        ----------
        predictions : ``torch.FloatTensor``, required.
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        all_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the span
            indices we scored.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        sentences : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, optional (default = None).
            A list of POS tags for each word in the sentence for each element
            in the batch.

        Returns
        -------
        A ``List[Tree]`` containing the decoded trees for each element in the batch.
        """
        # Switch to using exclusive end spans.
        exclusive_end_spans = all_spans.clone()
        exclusive_end_spans[:, :, -1] += 1
        no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

        trees: List[Tree] = []
        for batch_index, (scored_spans, spans, sentence) in enumerate(zip(predictions,
                                                                          exclusive_end_spans,
                                                                          sentences)):
            selected_spans = []
            for prediction, span in zip(scored_spans[:num_spans[batch_index]],
                                        spans[:num_spans[batch_index]]):
                start, end = span
                no_label_prob = prediction[no_label_id]
                label_prob, label_index = torch.max(prediction, -1)

                # Does the span have a label != NO-LABEL or is it the root node?
                # If so, include it in the spans that we consider.
                if int(label_index) != no_label_id or (start == 0 and end == len(sentence)):
                    # TODO(Mark): Remove this once pylint sorts out named tuples.
                    # https://github.com/PyCQA/pylint/issues/1418
                    selected_spans.append(SpanInformation(start=int(start), # pylint: disable=no-value-for-parameter
                                                          end=int(end),
                                                          label_prob=float(label_prob),
                                                          no_label_prob=float(no_label_prob),
                                                          label_index=int(label_index)))

            # The spans we've selected might overlap, which causes problems when we try
            # to construct the tree as they won't nest properly.
            consistent_spans = self.resolve_overlap_conflicts_greedily(selected_spans)

            spans_to_labels = {(span.start, span.end):
                               self.vocab.get_token_from_index(span.label_index, "labels")
                               for span in consistent_spans}
            sentence_pos = pos_tags[batch_index] if pos_tags is not None else None
            trees.append(self.construct_tree_from_spans(spans_to_labels, sentence, sentence_pos))

        return trees
Exemplo n.º 32
0
    def forward(self,  # type: ignore
                start_tokens: torch.LongTensor,
                memory: Optional[State] = None,
                cache_len: int = 512,
                max_decoding_length: Optional[int] = 500,
                recompute_memory: bool = True,
                print_steps: bool = False,
                helper_type: Optional[Union[str, Type[Helper]]] = None,
                **helper_kwargs) \
            -> Tuple[Output, Optional[State]]:
        r"""Perform autoregressive decoding using XLNet. The algorithm is
        largely inspired by: https://github.com/rusiaaman/XLNet-gen.

        Args:
            start_tokens: A LongTensor of shape `[batch_size, prompt_len]`,
                representing the tokenized initial prompt.
            memory (optional): The initial memory.
            cache_len: Length of memory (number of tokens) to cache.
            max_decoding_length (int): Maximum number of tokens to decode.
            recompute_memory (bool): If `True`, the entire memory is recomputed
                for each token to generate. This leads to better performance
                because it enables every generated token to attend to each
                other, compared to reusing previous memory which is equivalent
                to using a causal attention mask. However, it is computationally
                more expensive. Defaults to `True`.
            print_steps (bool): If `True`, will print decoding progress.
            helper: Type (or name of the type) of any sub-class of
                :class:`~texar.modules.decoders.Helper`.
            helper_kwargs: The keyword arguments to pass to constructor of
                the specific helper type.

        :returns: A tuple of `(output, new_memory)`:
            - **`output`**: The sampled tokens as a list of integers.
            - **`new_memory`**: The memory of the sampled tokens.
        """

        start_tokens = start_tokens.t()
        self._state_recompute_memory = recompute_memory
        self._state_cache_len = cache_len
        self._state_previous_inputs = list(
            self.word_embed(start_tokens).unbind(dim=0))[:-1]

        if helper_type is None:
            helper_type = SampleEmbeddingHelper

        if not recompute_memory and start_tokens.size(0) > 1:
            _, memory = self._forward(memory=memory,
                                      cache_len=cache_len,
                                      **self._create_input(
                                          self._state_previous_inputs,
                                          initial=True))
        start_tokens = start_tokens[-1]

        helper_kwargs.update(start_tokens=start_tokens)

        if helper_kwargs.get("end_token") is None:
            raise ValueError("'end_token' must be specified.")

        helper = get_instance(
            helper_type,
            helper_kwargs,
            module_paths=['texar.modules.decoders.decoder_helpers'])

        step_hook = None
        if print_steps:
            step_hook = lambda step: print(f"\033[2K\rDecoding step: {step}",
                                           end='')
        output, new_memory, _ = self.dynamic_decode(
            helper,
            inputs=None,
            sequence_length=None,
            initial_state=memory,
            max_decoding_length=max_decoding_length,
            step_hook=step_hook)
        if print_steps:
            print("\033[2K\r", end='')

        return output, new_memory
Exemplo n.º 33
0
    def forward(
        self,
        src_tokens: torch.LongTensor,
        src_lengths: torch.LongTensor,
        return_encoder_out: bool = False,
        return_encoder_padding_mask: bool = False,
    ) -> EncoderOuts:
        """Encode a batch of sequences

        Arguments:
            src_tokens {torch.LongTensor} -- [batch_size, seq_len]
            src_lengths {torch.LongTensor} -- [batch_size]

        Keyword Arguments:
            return_encoder_out {bool} --
                Return output tensors? (default: {False})
            return_encoder_padding_mask {bool} --
                Return encoder padding mask? (default: {False})

        Returns:
            [type] -- [description]
        """
        bsz, seqlen = src_tokens.size()

        x = self.embed_tokens(src_tokens)
        x = x.transpose(0, 1)  # BTC -> TBC

        # Pack then apply LSTM
        packed_x = nn.utils.rnn.pack_padded_sequence(x,
                                                     src_lengths,
                                                     batch_first=False,
                                                     enforce_sorted=True)
        packed_outs, (final_hiddens, final_cells) = \
            self.lstm.forward(packed_x)

        x, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outs, padding_value=self.padding_value)
        assert list(x.size()) == [seqlen, bsz, self.output_units]

        # Set padded outputs to -inf so they are not selected by max-pooling
        padding_mask = src_tokens.eq(self.padding_idx).t()
        if padding_mask.any():
            x = x.float().masked_fill_(
                mask=padding_mask.unsqueeze(-1),
                value=float('-inf'),
            ).type_as(x)

        # Build the sentence embedding by max-pooling over the encoder outputs
        sentemb = x.max(dim=0)[0]

        encoder_out = None
        if return_encoder_out:
            final_hiddens = self._combine_outs(final_hiddens)
            final_cells = self._combine_outs(final_cells)
            encoder_out = (x, final_hiddens, final_cells)

        encoder_padding_mask = None
        if return_encoder_padding_mask:
            encoder_padding_mask = src_tokens.eq(self.padding_idx).t()

        return EncoderOuts(sentemb=sentemb,
                           encoder_out=encoder_out,
                           encoder_padding_mask=encoder_padding_mask)
Exemplo n.º 34
0
def sequence_cross_entropy_with_logits(
    logits: torch.FloatTensor,
    targets: torch.LongTensor,
    weights: torch.FloatTensor,
    average: str = "batch",
    label_smoothing: float = None,
    gamma: float = None,
    alpha: Union[float, List[float], torch.FloatTensor] = None,
) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.
    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    average: str, optional (default = "batch")
        If "batch", average the loss across the batches. If "token", average
        the loss across each item in the input. If ``None``, return a vector
        of losses per batch element.
    label_smoothing : ``float``, optional (default = None)
        Whether or not to apply label smoothing to the cross-entropy loss.
        For example, with a label smoothing value of 0.2, a 4 class classification
        target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
        the correct label.
    gamma : ``float``, optional (default = None)
        Focal loss[*] focusing parameter ``gamma`` to reduces the relative loss for
        well-classified examples and put more focus on hard. The greater value
        ``gamma`` is, the more focus on hard examples.
    alpha : ``float`` or ``List[float]``, optional (default = None)
        Focal loss[*] weighting factor ``alpha`` to balance between classes. Can be
        used independently with ``gamma``. If a single ``float`` is provided, it
        is assumed binary case using ``alpha`` and ``1 - alpha`` for positive and
        negative respectively. If a list of ``float`` is provided, with the same
        length as the number of classes, the weights will match the classes.
        [*] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár, "Focal Loss for
        Dense Object Detection," 2017 IEEE International Conference on Computer
        Vision (ICCV), Venice, 2017, pp. 2999-3007.
    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``average=="batch"`` or ``average=="token"``, the returned loss is a scalar.
    If ``average is None``, the returned loss is a vector of shape (batch_size,).
    """
    if average not in {None, "token", "batch"}:
        raise ValueError("Got average f{average}, expected one of "
                         "None, 'token', or 'batch'")

    # make sure weights are float
    weights = weights.float()
    # sum all dim except batch
    non_batch_dims = tuple(range(1, len(weights.shape)))
    # shape : (batch_size,)
    weights_batch_sum = weights.sum(dim=non_batch_dims)
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.reshape(-1, 1).long()
    # focal loss coefficient
    if gamma:
        # shape : (batch * sequence_length, num_classes)
        probs_flat = log_probs_flat.exp()
        # shape : (batch * sequence_length,)
        probs_flat = torch.gather(probs_flat, dim=1, index=targets_flat)
        # shape : (batch * sequence_length,)
        focal_factor = (1.0 - probs_flat)**gamma
        # shape : (batch, sequence_length)
        focal_factor = focal_factor.view(*targets.size())
        weights = weights * focal_factor

    if alpha is not None:
        # shape : () / (num_classes,)
        if isinstance(alpha, (float, int)):

            # shape : (2,)
            alpha_factor = torch.tensor(
                [1.0 - float(alpha), float(alpha)],
                dtype=weights.dtype,
                device=weights.device)

        elif isinstance(alpha, (list, numpy.ndarray, torch.Tensor)):

            # shape : (c,)
            alpha_factor = torch.tensor(alpha,
                                        dtype=weights.dtype,
                                        device=weights.device)

            if not alpha_factor.size():
                # shape : (1,)
                alpha_factor = alpha_factor.view(1)
                # shape : (2,)
                alpha_factor = torch.cat([1 - alpha_factor, alpha_factor])
        else:
            raise TypeError(
                ("alpha must be float, list of float, or torch.FloatTensor, "
                 "{} provided.").format(type(alpha)))
        # shape : (batch, max_len)
        alpha_factor = torch.gather(
            alpha_factor, dim=0,
            index=targets_flat.view(-1)).view(*targets.size())
        weights = weights * alpha_factor

    if label_smoothing is not None and label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / num_classes
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
            -1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(
            -1, keepdim=True)
    else:
        # Contribution to the negative log likelihood only comes from the exact indices
        # of the targets, as the target distributions are one-hot. Here we use torch.gather
        # to extract the indices of the num_classes dimension which contribute to the loss.
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = -torch.gather(
            log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(
        *targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights

    if average == "batch":
        # shape : (batch_size,)
        per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
            weights_batch_sum + 1e-13)
        num_non_empty_sequences = (weights_batch_sum > 0).float().sum() + 1e-13
        return per_batch_loss.sum() / num_non_empty_sequences
    elif average == "token":
        return negative_log_likelihood.sum() / (weights_batch_sum.sum() +
                                                1e-13)
    else:
        # shape : (batch_size,)
        per_batch_loss = negative_log_likelihood.sum(non_batch_dims) / (
            weights_batch_sum + 1e-13)
        return per_batch_loss
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
                "heads": predicted_heads,
                "head_tags": predicted_head_tags,
                "arc_loss": arc_nll,
                "tag_loss": tag_nll,
                "loss": loss,
                "mask": mask,
                "words": [meta["words"] for meta in metadata],
                "pos": [meta["pos"] for meta in metadata]
                }

        return output_dict
Exemplo n.º 36
0
def triple_tensor_to_set(tensor: torch.LongTensor) -> Set[Tuple[int, ...]]:
    """Convert a tensor of triples to a set of int-tuples."""
    return set(map(tuple, tensor.tolist()))
Exemplo n.º 37
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True,
                                       label_smoothing: float = None) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.
    label_smoothing : ``float``, optional (default = None)
        Whether or not to apply label smoothing to the cross-entropy loss.
        For example, with a label smoothing value of 0.2, a 4 class classifcation
        target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
        the correct label.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    if label_smoothing is not None and label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / num_classes
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
    else:
        # Contribution to the negative log likelihood only comes from the exact indices
        # of the targets, as the target distributions are one-hot. Here we use torch.gather
        # to extract the indices of the num_classes dimension which contribute to the loss.
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Exemplo n.º 38
0
def transformer_sliding_window(
    transformer: PreTrainedModel,
    input_ids: torch.LongTensor,
    input_mask=None,
    offsets: torch.LongTensor = None,
    token_type_ids: torch.LongTensor = None,
    max_pieces=512,
    start_tokens: int = 1,
    end_tokens: int = 1,
    ret_cls=None,
) -> torch.Tensor:
    """

    Args:
      transformer:
      input_ids: torch.LongTensor: 
      input_mask:  (Default value = None)
      offsets: torch.LongTensor:  (Default value = None)
      token_type_ids: torch.LongTensor:  (Default value = None)
      max_pieces:  (Default value = 512)
      start_tokens: int:  (Default value = 1)
      end_tokens: int:  (Default value = 1)
      ret_cls:  (Default value = None)

    Returns:

    
    """
    # pylint: disable=arguments-differ
    batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
    initial_dims = list(input_ids.shape[:-1])

    # The embedder may receive an input tensor that has a sequence length longer than can
    # be fit. In that case, we should expect the wordpiece indexer to create padded windows
    # of length `max_pieces` for us, and have them concatenated into one long sequence.
    # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
    # We can then split the sequence into sub-sequences of that length, and concatenate them
    # along the batch dimension so we effectively have one huge batch of partial sentences.
    # This can then be fed into BERT without any sentence length issues. Keep in mind
    # that the memory consumption can dramatically increase for large batches with extremely
    # long sentences.
    needs_split = full_seq_len > max_pieces
    if needs_split:
        input_ids = split_to_sliding_window(input_ids, max_pieces)

    # if token_type_ids is None:
    #     token_type_ids = torch.zeros_like(input_ids)
    if input_mask is None:
        input_mask = (input_ids != 0).long()

    # input_ids may have extra dimensions, so we reshape down to 2-d
    # before calling the BERT model and then reshape back at the end.
    outputs = transformer(
        input_ids=util.combine_initial_dims_to_1d_or_2d(input_ids),
        # token_type_ids=util.combine_initial_dims_to_1d_or_2d(token_type_ids),
        attention_mask=util.combine_initial_dims_to_1d_or_2d(
            input_mask)).to_tuple()
    if len(outputs) == 3:
        all_encoder_layers = outputs.hidden_states
        all_encoder_layers = torch.stack(all_encoder_layers)
    elif len(outputs) == 2:
        all_encoder_layers, _ = outputs[:2]
    else:
        all_encoder_layers = outputs[0]

    if needs_split:
        if ret_cls is not None:
            cls_mask = input_ids[:, 0] == input_ids[0][0]
            cls_hidden = all_encoder_layers[:, 0, :]
            if ret_cls == 'max':
                cls_hidden[~cls_mask] = -1e20
            else:
                cls_hidden[~cls_mask] = 0
            cls_mask = cls_mask.view(-1, batch_size).transpose(0, 1)
            cls_hidden = cls_hidden.reshape(cls_mask.size(1), batch_size,
                                            -1).transpose(0, 1)
            if ret_cls == 'max':
                cls_hidden = cls_hidden.max(1)[0]
            elif ret_cls == 'raw':
                return cls_hidden, cls_mask
            else:
                cls_hidden = torch.sum(cls_hidden, dim=1)
                cls_hidden /= torch.sum(cls_mask, dim=1, keepdim=True)
            return cls_hidden
        else:
            recombined_embeddings, select_indices = restore_from_sliding_window(
                all_encoder_layers, batch_size, max_pieces, full_seq_len,
                start_tokens, end_tokens)

            initial_dims.append(len(select_indices))
    else:
        recombined_embeddings = all_encoder_layers

    # Recombine the outputs of all layers
    # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
    # recombined = torch.cat(combined, dim=2)
    # input_mask = (recombined_embeddings != 0).long()

    # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

    if offsets is None:
        # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
        dims = initial_dims if needs_split else input_ids.size()
        layers = util.uncombine_initial_dims(recombined_embeddings, dims)
    else:
        # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
        offsets2d = util.combine_initial_dims_to_1d_or_2d(offsets)
        # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
        range_vector = util.get_range_vector(
            offsets2d.size(0),
            device=util.get_device_of(recombined_embeddings)).unsqueeze(1)
        # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
        selected_embeddings = recombined_embeddings[:, range_vector, offsets2d]

        layers = util.uncombine_initial_dims(selected_embeddings,
                                             offsets.size())

    return layers
Exemplo n.º 39
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        sequence_mask: torch.LongTensor = None,
        span_indices_mask: torch.LongTensor = None,
    ) -> torch.FloatTensor:

        # Both of shape (batch_size, sequence_length, embedding_size / 2)
        forward_sequence, backward_sequence = sequence_tensor.split(
            int(self._input_dim / 2), dim=-1
        )
        forward_sequence = forward_sequence.contiguous()
        backward_sequence = backward_sequence.contiguous()

        # shape (batch_size, num_spans)
        span_starts, span_ends = [index.squeeze(-1) for index in span_indices.split(1, dim=-1)]

        if span_indices_mask is not None:
            span_starts = span_starts * span_indices_mask
            span_ends = span_ends * span_indices_mask
        # We want `exclusive` span starts, so we remove 1 from the forward span starts
        # as the AllenNLP `SpanField` is inclusive.
        # shape (batch_size, num_spans)
        exclusive_span_starts = span_starts - 1
        # shape (batch_size, num_spans, 1)
        start_sentinel_mask = (exclusive_span_starts == -1).long().unsqueeze(-1)

        # We want `exclusive` span ends for the backward direction
        # (so that the `start` of the span in that direction is exlusive), so
        # we add 1 to the span ends as the AllenNLP `SpanField` is inclusive.
        exclusive_span_ends = span_ends + 1

        if sequence_mask is not None:
            # shape (batch_size)
            sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask)
        else:
            # shape (batch_size), filled with the sequence length size of the sequence_tensor.
            sequence_lengths = torch.ones_like(
                sequence_tensor[:, 0, 0], dtype=torch.long
            ) * sequence_tensor.size(1)

        # shape (batch_size, num_spans, 1)
        end_sentinel_mask = (
            (exclusive_span_ends >= sequence_lengths.unsqueeze(-1)).long().unsqueeze(-1)
        )

        # As we added 1 to the span_ends to make them exclusive, which might have caused indices
        # equal to the sequence_length to become out of bounds, we multiply by the inverse of the
        # end_sentinel mask to erase these indices (as we will replace them anyway in the block below).
        # The same argument follows for the exclusive span start indices.
        exclusive_span_ends = exclusive_span_ends * (1 - end_sentinel_mask.squeeze(-1))
        exclusive_span_starts = exclusive_span_starts * (1 - start_sentinel_mask.squeeze(-1))

        # We'll check the indices here at runtime, because it's difficult to debug
        # if this goes wrong and it's tricky to get right.
        if (exclusive_span_starts < 0).any() or (
            exclusive_span_ends > sequence_lengths.unsqueeze(-1)
        ).any():
            raise ValueError(
                f"Adjusted span indices must lie inside the length of the sequence tensor, "
                f"but found: exclusive_span_starts: {exclusive_span_starts}, "
                f"exclusive_span_ends: {exclusive_span_ends} for a sequence tensor with lengths "
                f"{sequence_lengths}."
            )

        # Forward Direction: start indices are exclusive. Shape (batch_size, num_spans, input_size / 2)
        forward_start_embeddings = util.batched_index_select(
            forward_sequence, exclusive_span_starts
        )
        # Forward Direction: end indices are inclusive, so we can just use span_ends.
        # Shape (batch_size, num_spans, input_size / 2)
        forward_end_embeddings = util.batched_index_select(forward_sequence, span_ends)

        # Backward Direction: The backward start embeddings use the `forward` end
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_start_embeddings = util.batched_index_select(
            backward_sequence, exclusive_span_ends
        )
        # Backward Direction: The backward end embeddings use the `forward` start
        # indices, because we are going backwards.
        # Shape (batch_size, num_spans, input_size / 2)
        backward_end_embeddings = util.batched_index_select(backward_sequence, span_starts)

        if self._use_sentinels:
            # If we're using sentinels, we need to replace all the elements which were
            # outside the dimensions of the sequence_tensor with either the start sentinel,
            # or the end sentinel.
            float_end_sentinel_mask = end_sentinel_mask.float()
            float_start_sentinel_mask = start_sentinel_mask.float()
            forward_start_embeddings = (
                forward_start_embeddings * (1 - float_start_sentinel_mask)
                + float_start_sentinel_mask * self._start_sentinel
            )
            backward_start_embeddings = (
                backward_start_embeddings * (1 - float_end_sentinel_mask)
                + float_end_sentinel_mask * self._end_sentinel
            )

        # Now we combine the forward and backward spans in the manner specified by the
        # respective combinations and concatenate these representations.
        # Shape (batch_size, num_spans, forward_combination_dim)
        forward_spans = util.combine_tensors(
            self._forward_combination, [forward_start_embeddings, forward_end_embeddings]
        )
        # Shape (batch_size, num_spans, backward_combination_dim)
        backward_spans = util.combine_tensors(
            self._backward_combination, [backward_start_embeddings, backward_end_embeddings]
        )
        # Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
        span_embeddings = torch.cat([forward_spans, backward_spans], -1)

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(
                    span_ends - span_starts, num_total_buckets=self._num_width_embeddings
                )
            else:
                span_widths = span_ends - span_starts

            span_width_embeddings = self._span_width_embedding(span_widths)
            return torch.cat([span_embeddings, span_width_embeddings], -1)

        if span_indices_mask is not None:
            return span_embeddings * span_indices_mask.float().unsqueeze(-1)
        return span_embeddings
Exemplo n.º 41
0
    def _joint_likelihood(self,
                          logits: torch.Tensor,
                          tags: torch.Tensor,
                          mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        """
        batch_size, sequence_length, num_tags = logits.data.shape

        # Transpose batch size and sequence dimensions:
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Broadcast the transition scores to one per batch element
        broadcast_transitions = self.transitions.view(1, num_tags, num_tags).expand(batch_size, num_tags, num_tags)

        # Add up the scores for the observed transitions and all the inputs but the last
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i+1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = (
                    broadcast_transitions
                    # Choose the current_tag-th row for each input
                    .gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags))
                    # Squeeze down to (batch_size, num_tags)
                    .squeeze(1)
                    # Then choose the next_tag-th column for each of those
                    .gather(1, next_tag.view(batch_size, 1))
                    # And squeeze down to (batch_size,)
                    .squeeze(1)
            )

            # The score for using current_tag
            emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            score = score + transition_score * mask[i + 1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1, batch_size).expand(sequence_length, batch_size))

        # Is (sequence_length, batch_size), but all the columns are the same, so take the first.
        last_tags = last_tags[0]

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]                                         # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(-1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()                    # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score
Exemplo n.º 42
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            spans: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            pos_tags: Dict[str, torch.LongTensor] = None,
            span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)

        span_representations = self.span_extractor(encoded_text, spans, mask,
                                                   span_mask)

        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)

        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": class_probabilities,
            "spans": spans,
            "tokens": [meta["tokens"] for meta in metadata],
            "pos_tags": [meta.get("pos_tags") for meta in metadata],
            "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels,
                                                      span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees
               ) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [
                list(zip(*tree.pos()))[1] for tree in batch_gold_trees
            ]
            predicted_trees = self.construct_trees(
                class_probabilities.cpu().data,
                spans.cpu().data, num_spans.data, output_dict["tokens"],
                gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Exemplo n.º 43
0
    def forward(
        self,
        token_ids: torch.LongTensor,
        mask: torch.BoolTensor,
        type_ids: Optional[torch.LongTensor] = None,
        segment_concat_mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:  # type: ignore
        """
        # Parameters

        token_ids: `torch.LongTensor`
            Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`.
            num_segment_concat_wordpieces is num_wordpieces plus special tokens inserted in the
            middle, e.g. the length of: "[CLS] A B C [SEP] [CLS] D E F [SEP]" (see indexer logic).
        mask: `torch.BoolTensor`
            Shape: [batch_size, num_wordpieces].
        type_ids: `Optional[torch.LongTensor]`
            Shape: `[batch_size, num_wordpieces if max_length is None else num_segment_concat_wordpieces]`.
        segment_concat_mask: `Optional[torch.BoolTensor]`
            Shape: `[batch_size, num_segment_concat_wordpieces]`.

        # Returns

        `torch.Tensor`
            Shape: `[batch_size, num_wordpieces, embedding_size]`.

        """
        # Some of the huggingface transformers don't support type ids at all and crash when you supply
        # them. For others, you can supply a tensor of zeros, and if you don't, they act as if you did.
        # There is no practical difference to the caller, so here we pretend that one case is the same
        # as another case.
        if type_ids is not None:
            max_type_id = type_ids.max()
            if max_type_id == 0:
                type_ids = None
            else:
                if max_type_id >= self._number_of_token_type_embeddings():
                    raise ValueError("Found type ids too large for the chosen transformer model.")
                assert token_ids.shape == type_ids.shape

        fold_long_sequences = self._max_length is not None and token_ids.size(1) > self._max_length
        if fold_long_sequences:
            batch_size, num_segment_concat_wordpieces = token_ids.size()
            token_ids, segment_concat_mask, type_ids = self._fold_long_sequences(
                token_ids, segment_concat_mask, type_ids
            )

        transformer_mask = segment_concat_mask if self._max_length is not None else mask
        # Shape: [batch_size, num_wordpieces, embedding_size],
        # or if self._max_length is not None:
        # [batch_size * num_segments, self._max_length, embedding_size]

        # We call this with kwargs because some of the huggingface models don't have the
        # token_type_ids parameter and fail even when it's given as None.
        # Also, as of transformers v2.5.1, they are taking FloatTensor masks.
        parameters = {"input_ids": token_ids, "attention_mask": transformer_mask.float()}
        if type_ids is not None:
            parameters["token_type_ids"] = type_ids

        transformer_output = self.transformer_model(**parameters)
        if self._scalar_mix is not None:
            # As far as I can tell, the hidden states will always be the last element
            # in the output tuple as long as the model is not also configured to return
            # attention scores.
            # See, for example, the return value description for BERT:
            # https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel.forward
            # These hidden states will also include the embedding layer, which we don't
            # include in the scalar mix. Hence the `[1:]` slicing.
            hidden_states = transformer_output[-1][1:]
            embeddings = self._scalar_mix(hidden_states)
        else:
            embeddings = transformer_output[0]

        if fold_long_sequences:
            embeddings = self._unfold_long_sequences(
                embeddings, segment_concat_mask, batch_size, num_segment_concat_wordpieces
            )

        return embeddings
Exemplo n.º 44
0
    def construct_trees(self,
                        predictions: torch.FloatTensor,
                        all_spans: torch.LongTensor,
                        num_spans: torch.LongTensor,
                        sentences: List[List[str]],
                        pos_tags: List[List[str]] = None) -> List[Tree]:
        """
        Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
        The trees use exclusive end indices, which contrasts with how spans are
        represented in the rest of the model.

        Parameters
        ----------
        predictions : ``torch.FloatTensor``, required.
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        all_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the span
            indices we scored.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        sentences : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, optional (default = None).
            A list of POS tags for each word in the sentence for each element
            in the batch.

        Returns
        -------
        A ``List[Tree]`` containing the decoded trees for each element in the batch.
        """
        # Switch to using exclusive end spans.
        exclusive_end_spans = all_spans.clone()
        exclusive_end_spans[:, :, -1] += 1
        no_label_id = self.vocab.get_token_index("NO-LABEL", "labels")

        trees: List[Tree] = []
        for batch_index, (scored_spans, spans, sentence) in enumerate(
                zip(predictions, exclusive_end_spans, sentences)):
            selected_spans = []
            for prediction, span in zip(scored_spans[:num_spans[batch_index]],
                                        spans[:num_spans[batch_index]]):
                start, end = span
                no_label_prob = prediction[no_label_id]
                label_prob, label_index = torch.max(prediction, -1)

                # Does the span have a label != NO-LABEL or is it the root node?
                # If so, include it in the spans that we consider.
                if int(label_index) != no_label_id or (start == 0 and end
                                                       == len(sentence)):
                    # TODO(Mark): Remove this once pylint sorts out named tuples.
                    # https://github.com/PyCQA/pylint/issues/1418
                    selected_spans.append(
                        SpanInformation(
                            start=int(start),  # pylint: disable=no-value-for-parameter
                            end=int(end),
                            label_prob=float(label_prob),
                            no_label_prob=float(no_label_prob),
                            label_index=int(label_index)))

            # The spans we've selected might overlap, which causes problems when we try
            # to construct the tree as they won't nest properly.
            consistent_spans = self.resolve_overlap_conflicts_greedily(
                selected_spans)

            spans_to_labels = {
                (span.start, span.end):
                self.vocab.get_token_from_index(span.label_index, "labels")
                for span in consistent_spans
            }
            sentence_pos = pos_tags[
                batch_index] if pos_tags is not None else None
            trees.append(
                self.construct_tree_from_spans(spans_to_labels, sentence,
                                               sentence_pos))

        return trees
                    optimizer=optimizer,
                )
# # validate model with validation set
            if True:
                loss_avg_val, accu_avg_val = trainer_epoch(
                    model,
                    dataloader_val,
                    criterion,
                )
# # Test
            if params_LocalTest:
                loss_avg_tst, accu_avg_tst = trainer_epoch(
                    model,
                    dataloader_tst,
                    criterion,
                    idxCat=LongTensor(idxCatUnseen).cuda(),
                )
# Predict
            if False:
                arr_outputs_prd = trainer_epoch(
                    model,
                    dataloader_prd,
                    criterion,
                    predict=True,
                    predict_DataAug=False,
                    idxCat=LongTensor(idxCatUnannotd).cuda(),
                )

        time_elapsed = time.time() - time_start

        EpochResult = r'TRN_lss_{:.3g}_accu_{:.3g}_VAL_lss_{:.3g}_accu_{:.3g}_TST_lss_{:.3g}_accu_{:.3g}'.format(
Exemplo n.º 46
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            spans=None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, num_passage=4, passage_length, embedding_dim)
        embedded_passage = self._text_field_embedder(passage)
        passage_mask = util.get_text_field_mask(passage, 1).float()
        # get some parameters
        cuda_device = embedded_passage.get_device()
        batch_size, num_passage, passage_length, embedding_dim = embedded_passage.size(
        )
        # when training, select randomly 2 passages from 4 passages each epoch
        if self.training:
            num_passage = 2
            probs = torch.Tensor([1, 1, 1]).unsqueeze(0).expand(batch_size, 3)
            indices = torch.multinomial(probs, 1) + 1
            zeros_tensor = torch.zeros(batch_size).long()
            # Shape: (batch_size, 2)
            indices = Variable(
                torch.cat([zeros_tensor.unsqueeze(-1), indices],
                          1).cuda(cuda_device))
            # Shape: (batch_size, num_passage, passage_length, embedding_dim)
            embedded_passage = torch.gather(
                embedded_passage, 1,
                indices.unsqueeze(-1).unsqueeze(-1).expand(
                    batch_size, num_passage, passage_length, embedding_dim))
            # Shape: (batch_size, num_passage, passage_length)
            passage_mask = torch.gather(
                passage_mask, 1,
                indices.unsqueeze(-1).expand(batch_size, num_passage,
                                             passage_length))

        # Shape: (batch_size*num_passage, passage_length, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._highway_layer(embedded_passage)
        # Shape: (batch_size*num_passage, passage_length)
        passage_mask = passage_mask.view(-1, passage_length)
        # Shape: (batch_size, question_length, embedding_dim)
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        question_length = embedded_question.size(1)
        # Shape: (batch_size*numpassage, question_length, embedding_dim)
        embedded_question = embedded_question.unsqueeze(1).expand(
            -1, num_passage, -1, -1).contiguous().view(-1, question_length,
                                                       embedding_dim)
        # Shape: (batch_size, question_length)
        question_mask = util.get_text_field_mask(question).float()
        # Shape: (batch_size*num_passage, question_length)
        question_mask = question_mask.unsqueeze(1).expand(
            -1, num_passage, -1).contiguous().view(-1, question_length)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size*num_passage, passage_length, question_length)
        #passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (B, P, 1)
        p_similarity = torch.matmul(encoded_passage, self._w_p).unsqueeze(2)
        # Shape: (B, 1, Q)
        q_similarity = torch.matmul(encoded_question, self._w_q).unsqueeze(1)
        # Shape: (B, P, Q)
        pq_similarity = torch.bmm(encoded_passage * self._w_pq,
                                  encoded_question.transpose(1, 2))
        passage_question_similarity = pq_similarity + p_similarity + q_similarity
        # Shape: (batch_size*num_passage, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size*num_passage, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size * num_passage, passage_length, encoding_dim)

        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))

        # attach residual self-attention layer
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        # create mask for self-attention
        mask = passage_mask.resize(
            batch_size * num_passage, passage_length, 1) * passage_mask.resize(
                batch_size * num_passage, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2)
        y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1)
        dot_similarity = torch.bmm(residual_passage * self._w_xy,
                                   residual_passage.transpose(1, 2))
        passage_self_similarity = dot_similarity + x_similarity + y_similarity
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))

        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage

        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Shape: (batch_size*num_passage, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            #"best_span": best_span,
        }

        if span_start is not None:
            if self.training:
                # merge logits of multiple passages in the same context
                # Shape: (batch_size, num_passage*passage_length)
                span_start_logits = span_start_logits.view(
                    batch_size, num_passage,
                    passage_length).view(batch_size, -1)
                span_end_logits = span_end_logits.view(batch_size, num_passage,
                                                       passage_length).view(
                                                           batch_size, -1)
                # Shape: (batch_size, num_passage*passage_length)
                passage_mask = passage_mask.view(batch_size, num_passage,
                                                 passage_length).view(
                                                     batch_size, -1)
                loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, passage_mask),
                    span_start.squeeze(-1))
                #self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
                loss += nll_loss(
                    util.masked_log_softmax(span_end_logits, passage_mask),
                    span_end.squeeze(-1))
                #self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                #self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            else:  # do not care about loss when validating
                loss = Variable(torch.Tensor([0]).cuda(cuda_device))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training and metadata is not None:
            # Shape: (batch_size*num_passage, 3)
            best_span = self.get_best_span(span_start_logits, span_end_logits)
            # Shape: (batch_size, num_passage, 3)
            best_span = best_span.view(batch_size, num_passage, 3)

            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                all_passages = metadata[i]['all_passages']
                passage_offsets = metadata[i]['passage_offsets']
                _, max_id = torch.max(best_span[i, :, 2], dim=0)
                max_id = int(max_id)
                predicted_span = tuple(best_span[i, max_id].data.cpu().numpy())
                start_offset = passage_offsets[max_id][int(
                    predicted_span[0])][0]
                end_offset = passage_offsets[max_id][int(predicted_span[1])][1]
                best_span_string = all_passages[max_id][
                    start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings
Exemplo n.º 48
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                verb_indicator: torch.LongTensor,
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size()

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
                                                                          sequence_length,
                                                                          self.num_classes])
        output_dict = {"logits": logits, "class_probabilities": class_probabilities}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(logits,
                                                      tags,
                                                      mask,
                                                      label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [example_metadata["verb_index"] for example_metadata in metadata]
                batch_sentences = [example_metadata["words"] for example_metadata in metadata]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                batch_conll_predicted_tags = [convert_bio_tags_to_conll_format(tags) for
                                              tags in batch_bio_predicted_tags]
                batch_bio_gold_tags = [example_metadata["gold_tags"] for example_metadata in metadata]
                batch_conll_gold_tags = [convert_bio_tags_to_conll_format(tags) for
                                         tags in batch_bio_gold_tags]
                self.span_metric(batch_verb_indices,
                                 batch_sentences,
                                 batch_conll_predicted_tags,
                                 batch_conll_gold_tags)
            output_dict["loss"] = loss

        words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
        if metadata is not None:
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict
Exemplo n.º 49
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Exemplo n.º 50
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.
            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            # Split the flattened list by the window size, `max_pieces`
            split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))

            # We want all sequences to be the same length, so pad the last sequence
            last_window_size = split_input_ids[-1].size(-1)
            padding_amount = self.max_pieces - last_window_size
            split_input_ids[-1] = F.pad(split_input_ids[-1],
                                        pad=[0, padding_amount],
                                        value=0)

            # Now combine the sequences along the batch dimension
            input_ids = torch.cat(split_input_ids, dim=0)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),
            token_type_ids=util.combine_initial_dims(token_type_ids),
            attention_mask=util.combine_initial_dims(input_mask))
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers,
                                              batch_size,
                                              dim=1)
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)

            # Next, select indices of the sequence such that it will result in embeddings representing the original
            # sentence. To capture maximal context, the indices will be the middle part of each embedded window
            # sub-sequence (plus any leftover start and final edge windows), e.g.,
            #  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
            # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
            # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
            # and final windows with indices [0, 1] and [14, 15] respectively.

            # Find the stride as half the max pieces, ignoring the special start and end tokens
            # Calculate an offset to extract the centermost embeddings of each window
            stride = (self.max_pieces - self.start_tokens -
                      self.end_tokens) // 2
            stride_offset = stride // 2 + self.start_tokens

            first_window = list(range(stride_offset))

            max_context_windows = [
                i for i in range(full_seq_len) if stride_offset - 1 < i %
                self.max_pieces < stride_offset + stride
            ]

            final_window_start = full_seq_len - (
                full_seq_len % self.max_pieces) + stride_offset + stride
            final_window = list(range(final_window_start, full_seq_len))

            select_indices = first_window + max_context_windows + final_window

            initial_dims.append(len(select_indices))

            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            layers = util.uncombine_initial_dims(recombined_embeddings, dims)
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(
                offsets2d.size(0),
                device=util.get_device_of(recombined_embeddings)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = recombined_embeddings[:, range_vector,
                                                        offsets2d]

            layers = util.uncombine_initial_dims(selected_embeddings,
                                                 offsets.size())

        if self._scalar_mix is not None:
            return self._scalar_mix(layers, input_mask)
        elif self.combine_layers == "last":
            return layers[-1]
        else:
            return layers
Exemplo n.º 51
0
    def _get_linking_probabilities(self,
                                   worlds: List[WikiTablesWorld],
                                   linking_scores: torch.FloatTensor,
                                   question_mask: torch.LongTensor,
                                   entity_type_dict: Dict[int, int]) -> torch.FloatTensor:
        """
        Produces the probability of an entity given a question word and type. The logic below
        separates the entities by type since the softmax normalization term sums over entities
        of a single type.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        linking_scores : ``torch.FloatTensor``
            Has shape (batch_size, num_question_tokens, num_entities).
        question_mask: ``torch.LongTensor``
            Has shape (batch_size, num_question_tokens).
        entity_type_dict : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.

        Returns
        -------
        batch_probabilities : ``torch.FloatTensor``
            Has shape ``(batch_size, num_question_tokens, num_entities)``.
            Contains all the probabilities for an entity given a question word.
        """
        _, num_question_tokens, num_entities = linking_scores.size()
        batch_probabilities = []

        for batch_index, world in enumerate(worlds):
            all_probabilities = []
            num_entities_in_instance = 0

            # NOTE: The way that we're doing this here relies on the fact that entities are
            # implicitly sorted by their types when we sort them by name, and that numbers come
            # before "fb:cell", and "fb:cell" comes before "fb:row".  This is not a great
            # assumption, and could easily break later, but it should work for now.
            for type_index in range(self._num_entity_types):
                # This index of 0 is for the null entity for each type, representing the case where a
                # word doesn't link to any entity.
                entity_indices = [0]
                entities = world.table_graph.entities
                for entity_index, _ in enumerate(entities):
                    if entity_type_dict[batch_index * num_entities + entity_index] == type_index:
                        entity_indices.append(entity_index)

                if len(entity_indices) == 1:
                    # No entities of this type; move along...
                    continue

                # We're subtracting one here because of the null entity we added above.
                num_entities_in_instance += len(entity_indices) - 1

                # We separate the scores by type, since normalization is done per type.  There's an
                # extra "null" entity per type, also, so we have `num_entities_per_type + 1`.  We're
                # selecting from a (num_question_tokens, num_entities) linking tensor on _dimension 1_,
                # so we get back something of shape (num_question_tokens,) for each index we're
                # selecting.  All of the selected indices together then make a tensor of shape
                # (num_question_tokens, num_entities_per_type + 1).
                indices = linking_scores.new_tensor(entity_indices, dtype=torch.long)
                entity_scores = linking_scores[batch_index].index_select(1, indices)

                # We used index 0 for the null entity, so this will actually have some values in it.
                # But we want the null entity's score to be 0, so we set that here.
                entity_scores[:, 0] = 0

                # No need for a mask here, as this is done per batch instance, with no padding.
                type_probabilities = torch.nn.functional.softmax(entity_scores, dim=1)
                all_probabilities.append(type_probabilities[:, 1:])

            # We need to add padding here if we don't have the right number of entities.
            if num_entities_in_instance != num_entities:
                zeros = linking_scores.new_zeros(num_question_tokens,
                                                 num_entities - num_entities_in_instance)
                all_probabilities.append(zeros)

            # (num_question_tokens, num_entities)
            probabilities = torch.cat(all_probabilities, dim=1)
            batch_probabilities.append(probabilities)
        batch_probabilities = torch.stack(batch_probabilities, dim=0)
        return batch_probabilities * question_mask.unsqueeze(-1).float()
Exemplo n.º 52
0
    def forward(self,  # type: ignore
                utterance: Dict[str, torch.LongTensor],
                world: List[AtisWorld],
                actions: List[List[ProductionRule]],
                linking_scores: torch.Tensor,
                target_action_sequence: torch.LongTensor = None,
                sql_queries: List[List[str]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        utterance : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the utterance ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        world : ``List[AtisWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[AtisWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        linking_scores: ``torch.Tensor``
            A matrix of the linking the utterance tokens and the entities. This is a binary matrix that
            is deterministically generated where each entry indicates whether a token generated an entity.
            This tensor has shape ``(batch_size, num_entities, num_utterance_tokens)``.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        initial_state = self._get_initial_state(utterance, world, actions, linking_scores)
        batch_size = linking_scores.shape[0]
        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we unsqueeze it for
            # the MML trainer.
            return self._decoder_trainer.decode(initial_state,
                                                self._transition_function,
                                                (target_action_sequence.unsqueeze(1), target_mask.unsqueeze(1)))
        else:
            # TODO(kevin) Move some of this functionality to a separate method for computing validation outputs.
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            outputs['linking_scores'] = linking_scores
            if target_action_sequence is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._transition_function,
                                                               (target_action_sequence.unsqueeze(1),
                                                                target_mask.unsqueeze(1)))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            outputs['utterance'] = []
            outputs['tokenized_utterance'] = []

            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[(i, action_index)]
                                  for action_index in best_action_indices]
                predicted_sql_query = action_sequence_to_sql(action_strings)

                if target_action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = target_action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                if sql_queries and sql_queries[i]:
                    denotation_correct = self._executor.evaluate_sql_query(predicted_sql_query, sql_queries[i])
                    self._denotation_accuracy(denotation_correct)
                    outputs['sql_queries'].append(sql_queries[i])

                outputs['utterance'].append(world[i].utterances[-1])
                outputs['tokenized_utterance'].append([token.text
                                                       for token in world[i].tokenized_utterances[-1]])
                outputs['entities'].append(world[i].entities)
                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
            return outputs
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRule]],
                identifier: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        label_strings = self._get_label_strings(labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
                                                                 initial_state,
                                                                 self._decoder_step,
                                                                 keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [best_final_states[i][0].action_history[0]]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings, worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrWorld]],
            actions: List[List[ProductionRuleArray]],
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)
        action_embeddings, action_indices = self._embed_actions(actions)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            util.new_variable_with_data(
                list(sentence.values())[0], torch.Tensor([0.0]))
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]
        worlds_list = [worlds[i] for i in range(batch_size)]

        initial_state = NlvrDecoderState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            action_indices=action_indices,
            possible_actions=actions,
            worlds=worlds_list,
            label_strings=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        best_final_states = self._decoder_beam_search.search(
            self._max_decoding_steps,
            initial_state,
            self._decoder_step,
            keep_final_unfinished_states=False)
        best_action_sequences: Dict[int, List[List[int]]] = {}
        for i in range(batch_size):
            # Decoding may not have terminated with any completed logical forms, if `num_steps`
            # isn't long enough (or if the model is not trained enough and gets into an
            # infinite action loop).
            if i in best_final_states:
                best_action_indices = [
                    best_final_states[i][0].action_history[0]
                ]
                best_action_sequences[i] = best_action_indices
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if target_action_sequences is not None:
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings)
        else:
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
        return outputs
Exemplo n.º 55
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                verb_indicator: torch.LongTensor,
                tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, embedding_dim_with_binary_feature = embedded_text_with_verb_indicator.size()

        if self.stacked_encoder.get_input_dim() != embedding_dim_with_binary_feature:
            raise ConfigurationError("The SRL model uses an indicator feature, which makes "
                                     "the embedding dimension one larger than the value "
                                     "specified. Therefore, the 'input_dim' of the stacked_encoder "
                                     "must be equal to total_embedding_dim + 1.")

        encoded_text = self.stacked_encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs).view([batch_size, sequence_length, self.num_classes])
        output_dict = {"logits": logits, "class_probabilities": class_probabilities}
        if tags is not None:
            loss = sequence_cross_entropy_with_logits(logits, tags, mask)
            self.span_metric(class_probabilities, tags, mask)
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 56
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Exemplo n.º 57
0
    def forward(self, # pylint: disable=arguments-differ
                embeddings: torch.FloatTensor,
                mask: torch.LongTensor,
                num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
                                                 torch.LongTensor, torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model).

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``int``, required.
            The number of items to keep when pruning.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, 1).
        """
        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape"
                             f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, num_items_to_keep, 1)
        _, top_indices = scores.topk(num_items_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size, num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Shape: (batch_size * num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

        # Shape: (batch_size, num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
        # Shape: (batch_size, num_items_to_keep)
        top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)

        # Shape: (batch_size, num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)

        return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
Exemplo n.º 58
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True,
                                       label_smoothing: float = None) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.
    label_smoothing : ``float``, optional (default = None)
        Whether or not to apply label smoothing to the cross-entropy loss.
        For example, with a label smoothing value of 0.2, a 4 class classifcation
        target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
        the correct label.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    if label_smoothing is not None and label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / num_classes
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
    else:
        # Contribution to the negative log likelihood only comes from the exact indices
        # of the targets, as the target distributions are one-hot. Here we use torch.gather
        # to extract the indices of the num_classes dimension which contribute to the loss.
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Exemplo n.º 59
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                verb_indicator: torch.LongTensor,
                tags: torch.LongTensor = None,
                training: bool = False,  # added by ph to make function consistent with other model
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape ``(batch_size, num_tokens)``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containing the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.
        training : added by ph to make function consistent with other model - does nothing

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """

        # added by ph
        tokens['tokens'] = tokens['tokens'].cuda()
        verb_indicator = verb_indicator.cuda()
        if tags is not None:
            tags = tags.cuda()

        embedded_text_input = self.embedding_dropout(self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat([embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size()

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
                                                                          sequence_length,
                                                                          self.num_classes])
        output_dict = {"logits": logits, "class_probabilities": class_probabilities, "mask": mask}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(logits,
                                                      tags,
                                                      mask,
                                                      label_smoothing=self._label_smoothing)
            output_dict["loss"] = loss

        # added by ph
        output_dict['softmax_3d'] = class_probabilities.detach().cpu().numpy()
        return output_dict
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:
        """KAIMARY"""
        # utterance
        # utterance:{'tokens': tensor([[   6,    8,    2,  149,   46,    3,   14,    2,  290,  149,   98,    4,
        #             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0], ...])}

        # valid_actions
        # formatted as: [ProductionRule(rule='arg_list -> [expr, ",", arg_list]', is_global_rule=True, rule_id=tensor([0]), nonterminal='arg_list'), ...]

        # schema
        # schema:{'text': {'tokens': tensor([[[ 519,   35,    0,    0],
        #                                                                      [ 149,   35,    0,    0],
        #                                                                      [ 519,   35,    0,    0],
        #                                                                      ...,
        #                                                                      [   0,    0,    0,    0],
        #                                                                      [   0,    0,    0,    0],
        #                                                                      [   0,    0,    0,    0]], ...]
        #                   'linking': tensor([[[[ 0.0000,  0.0000,  0.0000,  ..., -4.2500,  0.0000,  0.0000],
        #                                                      [ 0.0000,  0.0000,  0.0000,  ..., -5.6667,  0.0000,  0.0000],
        #                                                     [ 0.0000,  0.0000,  0.0000,  ..., -6.0000,  0.0000,  0.0000], ...]]]}

        # action_sequences
        """action_sequence:tensor([[[137],
         [118],
         [119],
         [163],
         [ 69],
         [157],
         [ 51],
         [ 12],
         [ 70],
         [161],
         [138],
         [ -1],
         [ -1],
         [ -1]]"""
        batch_size = len(world)
        device = utterance['tokens'].device

        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(
                initial_state, self._transition_function,
                (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
            return outputs