Exemple #1
0
def flattened_index_select(target: torch.Tensor,
                           indices: torch.LongTensor) -> torch.Tensor:
    """
    The given ``indices`` of size ``(set_size, subset_size)`` specifies subsets of the ``target``
    that each of the set_size rows should select. The `target` has size
    ``(batch_size, sequence_length, embedding_size)``, and the resulting selected tensor has size
    ``(batch_size, set_size, subset_size, embedding_size)``.

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, sequence_length, embedding_size).
    indices : ``torch.LongTensor``, required.
        A LongTensor of shape (set_size, subset_size). All indices must be < sequence_length
        as this tensor is an index into the sequence_length dimension of the target.

    Returns
    -------
    selected : ``torch.Tensor``, required.
        A Tensor of shape (batch_size, set_size, subset_size, embedding_size).
    """
    if indices.dim() != 2:
        raise ConfigurationError("Indices passed to flattened_index_select had shape {} but "
                                 "only 2 dimensional inputs are supported.".format(indices.size()))
    # Shape: (batch_size, set_size * subset_size, embedding_size)
    flattened_selected = target.index_select(1, indices.view(-1))

    # Shape: (batch_size, set_size, subset_size, embedding_size)
    selected = flattened_selected.view(target.size(0), indices.size(0), indices.size(1), -1)
    return selected
Exemple #2
0
    def __call__(self,  # type: ignore
                 predictions: torch.LongTensor,
                 gold_targets: torch.LongTensor) -> None:
        """
        Update precision counts.

        Parameters
        ----------
        predictions : ``torch.LongTensor``, required
            Batched predicted tokens of shape `(batch_size, max_sequence_length)`.
        references : ``torch.LongTensor``, required
            Batched reference (gold) translations with shape `(batch_size, max_gold_sequence_length)`.

        Returns
        -------
        None
        """
        predictions, gold_targets = self.unwrap_to_tensors(predictions, gold_targets)
        for ngram_size, _ in enumerate(self._ngram_weights, start=1):
            precision_matches, precision_totals = self._get_modified_precision_counts(
                    predictions, gold_targets, ngram_size)
            self._precision_matches[ngram_size] += precision_matches
            self._precision_totals[ngram_size] += precision_totals
        if not self._exclude_indices:
            self._prediction_lengths += predictions.size(0) * predictions.size(1)
            self._reference_lengths += gold_targets.size(0) * gold_targets.size(1)
        else:
            valid_predictions_mask = self._get_valid_tokens_mask(predictions)
            self._prediction_lengths += valid_predictions_mask.sum().item()
            valid_gold_targets_mask = self._get_valid_tokens_mask(gold_targets)
            self._reference_lengths += valid_gold_targets_mask.sum().item()
Exemple #3
0
    def _get_modified_precision_counts(self,
                                       predicted_tokens: torch.LongTensor,
                                       reference_tokens: torch.LongTensor,
                                       ngram_size: int) -> Tuple[int, int]:
        """
        Compare the predicted tokens to the reference (gold) tokens at the desired
        ngram size and calculate the numerator and denominator for a modified
        form of precision.

        The numerator is the number of ngrams in the predicted sentences that match
        with an ngram in the corresponding reference sentence, clipped by the total
        count of that ngram in the reference sentence. The denominator is just
        the total count of predicted ngrams.
        """
        clipped_matches = 0
        total_predicted = 0
        for batch_num in range(predicted_tokens.size(0)):
            predicted_row = predicted_tokens[batch_num, :]
            reference_row = reference_tokens[batch_num, :]
            predicted_ngram_counts = self._ngrams(predicted_row, ngram_size)
            reference_ngram_counts = self._ngrams(reference_row, ngram_size)
            for ngram, count in predicted_ngram_counts.items():
                clipped_matches += min(count, reference_ngram_counts[ngram])
                total_predicted += count
        return clipped_matches, total_predicted
Exemple #4
0
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs : ``torch.Tensor``, required.
            A Tensor of shape ``(batch_size, sequence_length, hidden_size)``.
        mask : ``torch.LongTensor``, required.
            A binary mask of shape ``(batch_size, sequence_length)`` representing the
            non-padded elements in each sequence in the batch.

        Returns
        -------
        A ``torch.Tensor`` of shape (num_layers, batch_size, sequence_length, hidden_size),
        where the num_layers dimension represents the LSTM output from that layer.
        """
        batch_size, total_sequence_length = mask.size()
        stacked_sequence_output, final_states, restoration_indices = \
            self.sort_and_run_forward(self._lstm_forward, inputs, mask)

        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size()
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        if num_valid < batch_size:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size - num_valid,
                                                     returned_timesteps,
                                                     encoder_dim).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.data.new(num_layers, batch_size - num_valid, state_dim).fill_(0)
                zeros = Variable(zeros)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size,
                                                     sequence_length_difference,
                                                     stacked_sequence_output[0].size(-1)).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2)

        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)
 def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(1):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     targets_trimmed = targets[:, :len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return torch.max(torch.min(targets_trimmed.eq(predicted_tensor), dim=1)[0]).item()
Exemple #6
0
 def remove_input_features(self, remaining_features: LongTensor,
                           input_index: Any,
                           log: GarbageCollectionLog) -> None:
     assert input_index == 0, "We are only aware of one parent"
     diff = int(np.array(self.input_features[0].additional_dims).prod())
     indicies = arange(0, diff).long().unsqueeze(1)
     indicies = indicies.repeat(1, remaining_features.size(0))
     indicies += remaining_features * diff
     indicies = indicies.transpose(0, 1).contiguous().view(-1)
     self.output_features.remove_features(self, indicies, log)
Exemple #7
0
def get_subsequent_mask(seq: torch.LongTensor):
    ''' For masking out the subsequent info. '''
    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(torch.ones((len_s, len_s),
                                            device=seq.device,
                                            dtype=torch.bool),
                                 diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1,
                                                          -1)  # b x ls x ls
    return subsequent_mask
Exemple #8
0
def _expand_with_new_candidates(
        model: SequentialRNN, context: LongTensor, n_predictions: int,
        current_candidates: FlattenedCandidateList) -> FlattenedCandidateList:
    assert len(context.size()) == 2

    loss, num_tokens = _get_topk_predictions(model, context, n_predictions)
    batch_size = context.size(0)

    return FlattenedCandidateList(
        scores=(-loss + current_candidates.scores[:, None]).view(-1),
        hidden_indices=torch.arange(0, batch_size,
                                    device=DEVICE)[:, None].expand(
                                        batch_size,
                                        n_predictions).contiguous().view(-1),
        subtokens=torch.cat([
            current_candidates.subtokens.repeat(n_predictions, 1),
            num_tokens.view(-1)[:, None]
        ],
                            dim=-1))
    def _add_full_bio(self, correct_most_likely_predictions: torch.LongTensor,
                      full_bio: torch.LongTensor):
        predictions_count = correct_most_likely_predictions.size()[1]

        not_added = ((full_bio.unsqueeze(1) == correct_most_likely_predictions
                      ).prod(-1).sum(-1) == 0).long()

        return torch.cat((correct_most_likely_predictions,
                          (full_bio * not_added.unsqueeze(-1)).unsqueeze(1)),
                         dim=1)
Exemple #10
0
 def wrap(b: torch.LongTensor):
     if b is None:
         return b
     if len(b.size()) > 1 and isinstance(b, list):
         b = torch.stack(b, 0)
     b = b.contiguous()
     if self.cuda:
         b = b.cuda()
     b = Variable(b, volatile=self.volatile, requires_grad=False)
     return b
 def forward(self, input: LongTensor) -> Tuple[Tensor, Tensor]:
     sl, bs = input.size()
     self.reset()
     raw_outputs, outputs = [], []
     for i in range(0, sl, self.bptt):
         r, o = super().forward(input[i:min(i + self.bptt, sl)])
         if i > (sl - self.max_seq):
             raw_outputs.append(r)
             outputs.append(o)
     return self.concat(raw_outputs), self.concat(outputs)
Exemple #12
0
    def decode_forced(
            self, encoder_states: Tuple[torch.Tensor, ...],
            ys: torch.LongTensor) -> Tuple[torch.Tensor, torch.LongTensor]:
        """
        Decode with a fixed, true sequence, computing loss.

        Override TGM.decode_forced to both:
        1) handle BART eos/bos issues, and
        2) appropriately get forced decoder input.

        :param encoder_states:
            encoder output states
        :param ys:
            teacher forced label

        :return logits, preds:
            logits: output token distribution (as logits, not probs)
            preds: tokens corresponding with max probs according to output distribution.
        """
        bsz = ys.size(0)
        seqlen = ys.size(1)
        inputs = ys.narrow(1, 0, seqlen - 1)
        if (ys[:, 0]
                == self.START_IDX).any() and self.generation_model != 'bart':
            raise AssertionError(
                "The Beginning of Sentence token is automatically added to the "
                "label in decode_forced, but you included it in the label. This means "
                "your model will have a double BOS token, which is probably not what "
                "you intended.")
        doc_scores = encoder_states[-1]

        inputs = self._rag_model_interface.get_initial_forced_decoder_input(
            bsz,
            inputs,
            n_docs=doc_scores.size(1) if doc_scores is not None else None,
            start_idx=self.START_IDX,
            end_idx=self.END_IDX,
            input_turns_cnt=encoder_states[2],
        )
        latent, _ = self.decoder(inputs, encoder_states)
        logits = self.output(latent)
        _, preds = logits.max(dim=-1)
        return logits, preds  # type: ignore
Exemple #13
0
 def _action_history_match(predicted: List[int],
                           targets: torch.LongTensor) -> int:
     # TODO(mattg): this could probably be moved into a FullSequenceMatch metric, or something.
     # Check if target is big enough to cover prediction (including start/end symbols)
     if len(predicted) > targets.size(0):
         return 0
     predicted_tensor = targets.new_tensor(predicted)
     targets_trimmed = targets[:len(predicted)]
     # Return 1 if the predicted sequence is anywhere in the list of targets.
     return predicted_tensor.equal(targets_trimmed)
Exemple #14
0
    def get_rseq(self, rel: torch.LongTensor, tem: torch.LongTensor):

        r_e = self.embedding['rel'](rel)
        r_e = r_e.unsqueeze(0).transpose(0, 1)

        bs = tem.size(0)
        tem_len = tem.size(1)
        tem = tem.contiguous()
        tem = tem.view(bs * tem_len)

        token_e = self.embedding['tem'](tem)
        token_e = token_e.view(bs, tem_len, self.emb_dim)
        seq_e = torch.cat((r_e, token_e), 1)

        hidden_tem = self.lstm(seq_e)
        hidden_tem = hidden_tem[0, :, :]
        rseq_e = hidden_tem

        return rseq_e
Exemple #15
0
def span_to_position_ids(span: torch.LongTensor,
                         max_length: int = None) -> torch.LongTensor:
    batch_size = span.size(0)
    max_length = max_length or get_span_max_length(span)
    position_ids = span.new_full((batch_size, max_length), fill_value=-1)

    for i, (start, end) in enumerate(span):
        positions = torch.arange(start, end + 1)
        position_ids[i, :len(positions)] = positions
    return position_ids
Exemple #16
0
def pad_mask(lengths: torch.LongTensor) -> torch.ByteTensor:
    """
    Create a mask of seq x batch where seq = max(lengths), with 0 in padding locations and 1 otherwise. 
    """
    # lengths: bs. Ex: [2, 3, 1]
    max_seqlen = torch.max(lengths)
    expanded_lengths = lengths.unsqueeze(0).repeat((max_seqlen, 1))  # [[2, 3, 1], [2, 3, 1], [2, 3, 1]]
    indices = torch.arange(max_seqlen).unsqueeze(1).repeat((1, lengths.size(0))).to(lengths.device)  # [[0, 0, 0], [1, 1, 1], [2, 2, 2]]

    return expanded_lengths > indices  # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs
Exemple #17
0
def _save_ply(
    f,
    verts: torch.Tensor,
    faces: torch.LongTensor,
    verts_normals: torch.Tensor,
    decimal_places: Optional[int] = None,
) -> None:
    """
    Internal implementation for saving 3D data to a .ply file.

    Args:
        f: File object to which the 3D data should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shsape (F, 3) giving faces.
        verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
        decimal_places: Number of decimal places for saving.
    """
    assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
    assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
    assert not len(verts_normals) or (
        verts_normals.dim() == 2 and verts_normals.size(1) == 3
    )

    print('ply\nformat ascii 1.0', file=f)
    print(f'element vertex {verts.shape[0]}', file=f)
    print('property float x', file=f)
    print('property float y', file=f)
    print('property float z', file=f)
    if verts_normals.numel() > 0:
        print('property float nx', file=f)
        print('property float ny', file=f)
        print('property float nz', file=f)
    print(f'element face {faces.shape[0]}', file=f)
    print('property list uchar int vertex_index', file=f)
    print('end_header', file=f)

    if not (len(verts) or len(faces)):
        warnings.warn("Empty 'verts' and 'faces' arguments provided")
        return

    if decimal_places is None:
        float_str = '%f'
    else:
        float_str = '%' + '.%df' % decimal_places

    vert_data = torch.cat((verts, verts_normals), dim=1)
    np.savetxt(f, vert_data.detach().numpy(), float_str)

    faces_array = faces.detach().numpy()

    if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
        warnings.warn('Faces have invalid indices')

    if len(faces_array):
        np.savetxt(f, faces_array, '3 %d %d %d')
Exemple #18
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    # Contribution to the negative log likelihood only comes from the exact indices
    # of the targets, as the target distributions are one-hot. Here we use torch.gather
    # to extract the indices of the num_classes dimension which contribute to the loss.
    # shape : (batch * sequence_length, 1)
    negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs : ``torch.Tensor``, required.
            A Tensor of shape ``(batch_size, sequence_length, hidden_size)``.
        mask : ``torch.LongTensor``, required.
            A binary mask of shape ``(batch_size, sequence_length)`` representing the
            non-padded elements in each sequence in the batch.

        Returns
        -------
        A ``torch.Tensor`` of shape (num_layers, batch_size, sequence_length, hidden_size),
        where the num_layers dimension represents the LSTM output from that layer.
        """
        batch_size, total_sequence_length = mask.size()
        stacked_sequence_output, final_states, restoration_indices = \
            self.sort_and_run_forward(self._lstm_forward, inputs, mask)

        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size()
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        if num_valid < batch_size:
            zeros = stacked_sequence_output.new_zeros(num_layers,
                                                      batch_size - num_valid,
                                                      returned_timesteps,
                                                      encoder_dim)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.new_zeros(num_layers, batch_size - num_valid, state_dim)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.new_zeros(num_layers,
                                                      batch_size,
                                                      sequence_length_difference,
                                                      stacked_sequence_output[0].size(-1))
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2)

        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)
Exemple #20
0
def batched_index_select(
    target: torch.Tensor,
    indices: torch.LongTensor,
    flattened_indices: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
    """
    The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence
    dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length,
    embedding_size)``.

    This function returns selected values in the target with respect to the provided indices, which
    have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally
    precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given.

    An example use case of this function is looking up the start and end indices of spans in a
    sequence tensor. This is used in the
    :class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select
    contextual word representations corresponding to the start and end indices of mentions. The key
    reason this can't be done with basic torch functions is that we want to be able to use look-up
    tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
    a-priori how many spans we are looking up).

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : ``torch.LongTensor``
        A tensor of shape (batch_size, ...), where each element is an index into the
        ``sequence_length`` dimension of the ``target`` tensor.
    flattened_indices : Optional[torch.Tensor], optional (default = None)
        An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices`
        on ``indices``. This is helpful in the case that the indices can be flattened once and
        cached for many batch lookups.

    Returns
    -------
    selected_targets : ``torch.Tensor``
        A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
        extracted from the batch flattened target tensor.
    """
    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(
            indices, target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.view(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.size()) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.view(*selected_shape)
    return selected_targets
Exemple #21
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    # Contribution to the negative log likelihood only comes from the exact indices
    # of the targets, as the target distributions are one-hot. Here we use torch.gather
    # to extract the indices of the num_classes dimension which contribute to the loss.
    # shape : (batch * sequence_length, 1)
    negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Exemple #22
0
    def _embed_source(self, source_tokens: Dict[str, torch.Tensor],
                      source_entity_length: torch.LongTensor):
        """
        :param source_tokens
        :param source_entity_length: (batch_size, max_token_num)
        :return
            (batch_size, max_token_num, embedding_dim)
        """
        token_ids = source_tokens['tokens']
        embedded = self._source_embedding(token_ids)

        batched_embedded = list()
        embedding_dim = embedded.size(-1)
        batch_size, max_token_num = source_entity_length.size()

        for _embedded, _length in zip(embedded, source_entity_length.long()):
            merged_embedded_input = list()
            idx = 0
            for length in _length:
                if length > 0:
                    embedding = torch.mean(_embedded[idx:idx + length, :],
                                           dim=0)
                    merged_embedded_input.append(embedding)
                    idx += length
                else:
                    break
            merged_embedded_input = torch.stack(merged_embedded_input, dim=0)
            pad_num = max_token_num - merged_embedded_input.size(0)
            if pad_num > 0:
                merged_embedded_input = torch.cat(
                    (merged_embedded_input,
                     merged_embedded_input.new_zeros([pad_num, embedding_dim
                                                      ])),
                    dim=0)
            batched_embedded.append(merged_embedded_input)

        # shape: (batch_size, max_token_num, embedding_dim)
        batched_embedded = torch.stack(batched_embedded, dim=0)
        assert batched_embedded.size(0) == embedded.size(
            0) and batched_embedded.size(1) == source_entity_length.size(1)
        # TODO: Dropout
        return batched_embedded
Exemple #23
0
    def token_dropout(tokens: torch.LongTensor,
                      oov_token: int,
                      exclude_tokens: List[int],
                      p: float = 0.2,
                      training: float = True) -> torch.LongTensor:
        """During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p``
        
        Adopted from https://github.com/Hyperparticle/udify

        Args:
          tokens: The current batch of padded sentences with word ids
          oov_token: The mask token
          exclude_tokens: The tokens for padding the input batch
          p: The probability a word gets mapped to the unknown token
          training: Applies the dropout if set to ``True``
          tokens: torch.LongTensor: 
          oov_token: int: 
          exclude_tokens: List[int]: 
          p: float:  (Default value = 0.2)
          training: float:  (Default value = True)

        Returns:
          A copy of the input batch with token dropout applied

        """
        if training and p > 0:
            # This creates a mask that only considers unpadded tokens for mapping to oov
            padding_mask = tokens.new_ones(tokens.size(), dtype=torch.bool)
            for pad in exclude_tokens:
                padding_mask &= (tokens != pad)

            # Create a uniformly random mask selecting either the original words or OOV tokens
            dropout_mask = (tokens.new_empty(tokens.size(), dtype=torch.float).uniform_() < p)
            oov_mask = dropout_mask & padding_mask

            oov_fill = tokens.new_empty(tokens.size(), dtype=torch.long).fill_(oov_token)

            result = torch.where(oov_mask, oov_fill, tokens)

            return result
        else:
            return tokens
    def forward(
        self,
        items: torch.LongTensor,   # 一个 batch 的items
        labels: torch.LongTensor,   # 一个 batch 的labels
        memories_h: list,
        memories_r: list,
        memories_t: list,
    ):
        batch_size = items.size(0)
        if batch_size != self.batch_size:
            self.batch_size = batch_size
        # [batch size, dim]

        item_embeddings_ripple = self.entity_emb(items)
        h_emb_list = []
        r_emb_list = []
        t_emb_list = []
        for i in range(self.n_hop):
            # [batch size, n_memory, dim]
            h_emb_list.append(self.entity_emb(memories_h[i]))
            # [batch size, n_memory, dim, dim]
            r_emb_list.append(
                self.relation_emb(memories_r[i]).view(
                    -1, self.n_memory, self.dim, self.dim
                )
            )

            # [batch size, n_memory, dim]
            t_emb_list.append(self.entity_emb(memories_t[i]))


        o_list, item_embeddings_ripple = self._key_addressing(
            h_emb_list, r_emb_list, t_emb_list, item_embeddings_ripple
        )

        h_rep = self._history_extracting(h_emb_list)

        user_embedding_ripple = self.get_user_embedding(o_list)

        # entities: list   ele1 [batch size 1]    ele2 [batch size  8]   ele3 [batch size  64 ]  ele4 [batch size  512(64 * 8)]
        entities, relations = self._get_neighbors(items)

        #[batch dim ]
        item_embeddings = self._aggregate(h_rep, entities, relations)

        #o_list.append(u_history_embedding)
        scores = self.predict(item_embeddings_ripple,user_embedding_ripple,h_rep,item_embeddings)

        return_dict = self._compute_loss(
            scores, labels, h_emb_list, t_emb_list, r_emb_list
        )
        return_dict["scores"] = scores

        return return_dict
Exemple #25
0
    def forward(
            self,
            utterance: Dict[str, torch.LongTensor],
            valid_actions: List[List[ProductionRule]],
            world: List[SpiderWorld],
            schema: Dict[str, torch.LongTensor],
            action_sequence: torch.LongTensor = None
    ) -> Dict[str, torch.Tensor]:

        device = utterance['tokens'].device
        initial_state = self._get_initial_state(utterance, world, schema,
                                                valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            action_mask = action_sequence != self._action_padding_index
        else:
            action_mask = None

        if self.training:
            decode_output = self._decoder_trainer.decode(
                initial_state, self._transition_function,
                (action_sequence.unsqueeze(1), action_mask.unsqueeze(1)))

            return {'loss': decode_output['loss']}
        else:
            loss = torch.tensor([0]).float().to(device)
            if action_sequence is not None and action_sequence.size(1) > 1:
                try:
                    loss = self._decoder_trainer.decode(
                        initial_state, self._transition_function,
                        (action_sequence.unsqueeze(1),
                         action_mask.unsqueeze(1)))['loss']
                except ZeroDivisionError:
                    # reached a dead-end during beam search
                    pass

            outputs: Dict[str, Any] = {'loss': loss}

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            # initial_state.debug_info = [[] for _ in range(batch_size)]

            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False)

            self._compute_validation_outputs(valid_actions, best_final_states,
                                             world, action_sequence, outputs)
        return outputs
Exemple #26
0
def _mix_batches(batch_1: torch.LongTensor, batch_2: torch.LongTensor, pad_token_id: int, device: str):
	if batch_1.size(1) > batch_2.size(1):
		# pad 2nd batch
		pad_1 = batch_1.to(device)
		pad_2 = torch.cat((batch_2.to(device), torch.full((batch_2.size(0), batch_1.size(1) - batch_2.size(1)), fill_value=pad_token_id, device=device)), dim=1)
	elif batch_1.size(1) < batch_2.size(1):
		# pad 1st batch
		pad_1 = torch.cat((batch_1.to(device), torch.full((batch_1.size(0), batch_2.size(1) - batch_1.size(1)), fill_value=pad_token_id, device=device)), dim=1)
		pad_2 = batch_2.to(device)
	else:
		pad_1 = batch_1.to(device)
		pad_2 = batch_2.to(device)
	return torch.cat((pad_1, pad_2), dim=0)
Exemple #27
0
 def smooth_one_hot(target: torch.LongTensor,
                    num_classes: int,
                    smoothing: float = 0.):
     assert 0 <= smoothing < 1
     confidence = 1. - smoothing
     label_shape = torch.Size((target.size(0), num_classes))
     with torch.no_grad():
         true_dist = torch.zeros(label_shape, device=target.device)
         true_dist.fill_(smoothing / (num_classes - 1))
         true_dist.scatter_(1, target.data.unsqueeze(1), confidence)
     return true_dist  # (B, C)
 def build_position_id(self, lengths: t.LongTensor):
     batch_size = lengths.size(0)
     max_length = lengths.max()
     device = lengths.device
     position_id = t.zeros(batch_size,
                           max_length,
                           device=device,
                           dtype=t.long)
     for index, value in enumerate(lengths):
         position_id[index][:value] = self.position[:value]
     return position_id
Exemple #29
0
    def forward(
        self,
        x: torch.LongTensor,
        instances: List[List[str]],
        c0: torch.FloatTensor = None,
        h0: torch.FloatTensor = None,
    ):
        assert x.size(0) == len(instances), self.msg_printer.fail(
            f"The batch size for tokens "
            f"and string instances should "
            f"be the same. You passed tokens of size {x.size()} and instances "
            f"have length {len(instances)}"
        )
        batch_size = x.size(0)
        elmo_embeddings = self.elmo_embedder(instances)
        token_embeddings = self.embedding(x)

        # concat both of them together
        embeddings = torch.cat([elmo_embeddings, token_embeddings], dim=2)

        if h0 is None or c0 is None:
            h0, c0 = self.get_initial_hidden(batch_size=batch_size)

        # output = batch_size, sequence_length, num_layers * num_directions
        # h_n = num_layers * num_directions, batch_size, hidden_dimension
        # c_n = num_layers * num_directions, batch_size, hidden_dimension
        output, (h_n, c_n) = self.rnn(embeddings, (h0, c0))

        encoding = None
        if self.bidirectional:
            forward_hidden = h_n[0, :, :]
            backward_hidden = h_n[1, :, :]
            if self.combine_strategy == "concat":
                encoding = torch.cat([forward_hidden, backward_hidden], dim=1)
            elif self.combine_strategy == "sum":
                encoding = torch.add(forward_hidden, backward_hidden)
        else:
            encoding = h_n[0, :, :]

        # N * hidden_dim
        return encoding
Exemple #30
0
    def forward(self,  # type: ignore
                inputs: torch.LongTensor,
                input_mask: torch.LongTensor,
                extra_input_embedding: torch.LongTensor = None,
                answer_spans: torch.LongTensor = None,
                span_counts: torch.LongTensor = None,
                num_answers: torch.LongTensor = None,
                metadata = None,
                **kwargs):

        if self._extra_input_dim > 0 and extra_input_embedding is None:
            raise ConfigurationError("SpanSelector with extra input configured must receive extra input embeddings.")

        batch_size, num_tokens, _ = inputs.size()
        span_hidden, span_mask = self._span_hidden(inputs, inputs, input_mask, input_mask)

        if self._extra_input_dim > 0:
            full_hidden = self._extra_input_lin(extra_input_embedding).unsqueeze(1) + span_hidden
        else:
            full_hidden = span_hidden

        span_logits = self._span_scorer(full_hidden).squeeze(-1)

        # output_dict = {
        #     "span_mask": span_mask,
        #     "span_logits": span_logits
        # }

        if answer_spans is not None:
            num_gold_spans = answer_spans.size(1)
            span_counts_dist = torch.zeros_like(span_logits)
            for b in range(batch_size):
                for s in range(num_gold_spans):
                    span = answer_spans[b, s]
                    if span[0] > -1:
                        span_index = (2 * span[0] * num_tokens - span[0].float().pow(2).long() + span[0]) / 2 + (span[1] - span[0])
                        span_counts_dist[b, span_index] = span_counts[b, s]
        else:
            span_counts_dist = None

        return self._classifier(logits = span_logits, mask = span_mask, label_counts = span_counts_dist, num_labelers = num_answers)
def pad_mask(lengths: torch.LongTensor,
             device='cuda',
             max_seqlen=None) -> torch.ByteTensor:
    # lengths: bs. Ex: [2, 3, 1]
    if max_seqlen is None:
        max_seqlen = torch.max(lengths)
    expanded_lengths = lengths.unsqueeze(0).repeat(
        (max_seqlen, 1))  # [[2, 3, 1], [2, 3, 1], [2, 3, 1]]
    indices = torch.arange(max_seqlen).unsqueeze(1).repeat(
        (1, lengths.size(0))).to(device)  # [[0, 0, 0], [1, 1, 1], [2, 2, 2]]

    return expanded_lengths > indices  # pad locations are 0. #[[1, 1, 1], [1, 1, 0], [0, 1, 0]]. seqlen x bs
Exemple #32
0
 def __compose_edge_index_and_weight(
     _edge_index: torch.LongTensor,
     _edge_weight: _typing.Optional[torch.Tensor] = None,
 ) -> _typing.Tuple[torch.LongTensor, _typing.Optional[torch.Tensor]]:
     if type(_edge_index) != torch.Tensor or _edge_index.dtype != torch.int64:
         raise TypeError
     if _edge_weight is not None and (
         type(_edge_weight) != torch.Tensor
         or _edge_index.size() != (2, _edge_weight.size(0))
     ):
         _edge_weight: _typing.Optional[torch.Tensor] = None
     return _edge_index, _edge_weight
Exemple #33
0
 def get_seq_pred_tags(
         self, batchseq_feats: torch.Tensor,
         batch_seq_tags: torch.LongTensor) -> List[torch.Tensor]:
     batch_size, seq_size = batch_seq_tags.size()
     batch_seq_pred_tags = batchseq_feats.argmax(dim=1).view(
         batch_size, seq_size)
     mask = (batch_seq_tags >= 0).int()
     batch_seq_pred_tags = [
         seq_pred_tags[:cur_seq_len] for seq_pred_tags, cur_seq_len in zip(
             batch_seq_pred_tags, mask.sum(dim=1))
     ]
     return batch_seq_pred_tags
    def forward(self, goal_tokens : torch.LongTensor, prev_tactic : torch.LongTensor) \
        -> torch.FloatTensor:
        batch_size = goal_tokens.size()[0]
        prev_data = self.tactic_embedding(prev_tactic).view(
            batch_size, self.hidden_size)
        hidden = maybe_cuda(
            Variable(torch.zeros(1, batch_size, self.hidden_size)))
        for i in range(goal_tokens.size()[1]):
            goal_data = self.goal_embedding(goal_tokens[:,i])\
                            .view(1, batch_size, self.hidden_size)
            for _ in range(self.num_encoder_layers):
                goal_data = F.relu(goal_data)
                goal_data, hidden = self.gru(goal_data, hidden)

        goal_output = goal_data[0]

        full_data = self.squish(torch.cat((goal_output, prev_data), dim=1))
        for i in range(self.num_decoder_layers - 1):
            full_data = F.relu(full_data)
            full_data = self.decoder_layers[i](full_data)
        return self.softmax(self.decoder_out(F.relu(full_data)))
Exemple #35
0
 def build_span_to_idx_dict(self, spans: torch.LongTensor):
     """
 TODO: Push these operations to C++ if possible
 """
     span_to_idx_dict = {}
     for batch_idx in range(spans.size(0)):
         span_to_idx_dict[batch_idx] = {}
         for idx, span in enumerate(spans[batch_idx]):
             span_data = tuple(span.data.cpu().numpy())
             if span_data not in span_to_idx_dict[batch_idx]:
                 span_to_idx_dict[batch_idx][span_data] = idx
     return span_to_idx_dict
Exemple #36
0
 def extract_features(self,
                      tokens: torch.LongTensor,
                      return_all_hiddens: bool = False) -> torch.Tensor:
     if tokens.dim() == 1:
         tokens = tokens.unsqueeze(0)
     if tokens.size(-1) > self.model.max_positions():
         raise ValueError("tokens exceeds maximum length: {} > {}".format(
             tokens.size(-1), self.model.max_positions()))
     features, extra = self.model(
         tokens.to(device=self.device),
         features_only=True,
         return_all_hiddens=return_all_hiddens,
     )
     if return_all_hiddens:
         # convert from T x B x C -> B x T x C
         inner_states = extra["inner_states"]
         return [
             inner_state.transpose(0, 1) for inner_state in inner_states
         ]
     else:
         return features  # just the last layer's features
Exemple #37
0
def batched_index_select(target: torch.Tensor,
                         indices: torch.LongTensor,
                         flattened_indices: Optional[torch.LongTensor] = None) -> torch.Tensor:
    """
    The given ``indices`` of size ``(batch_size, d_1, ..., d_n)`` indexes into the sequence
    dimension (dimension 2) of the target, which has size ``(batch_size, sequence_length,
    embedding_size)``.

    This function returns selected values in the target with respect to the provided indices, which
    have size ``(batch_size, d_1, ..., d_n, embedding_size)``. This can use the optionally
    precomputed :func:`~flattened_indices` with size ``(batch_size * d_1 * ... * d_n)`` if given.

    An example use case of this function is looking up the start and end indices of spans in a
    sequence tensor. This is used in the
    :class:`~allennlp.models.coreference_resolution.CoreferenceResolver`. Model to select
    contextual word representations corresponding to the start and end indices of mentions. The key
    reason this can't be done with basic torch functions is that we want to be able to use look-up
    tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know
    a-priori how many spans we are looking up).

    Parameters
    ----------
    target : ``torch.Tensor``, required.
        A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size).
        This is the tensor to be indexed.
    indices : ``torch.LongTensor``
        A tensor of shape (batch_size, ...), where each element is an index into the
        ``sequence_length`` dimension of the ``target`` tensor.
    flattened_indices : Optional[torch.Tensor], optional (default = None)
        An optional tensor representing the result of calling :func:~`flatten_and_batch_shift_indices`
        on ``indices``. This is helpful in the case that the indices can be flattened once and
        cached for many batch lookups.

    Returns
    -------
    selected_targets : ``torch.Tensor``
        A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices
        extracted from the batch flattened target tensor.
    """
    if flattened_indices is None:
        # Shape: (batch_size * d_1 * ... * d_n)
        flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1))

    # Shape: (batch_size * sequence_length, embedding_size)
    flattened_target = target.view(-1, target.size(-1))

    # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
    flattened_selected = flattened_target.index_select(0, flattened_indices)
    selected_shape = list(indices.size()) + [target.size(-1)]
    # Shape: (batch_size, d_1, ..., d_n, embedding_size)
    selected_targets = flattened_selected.view(*selected_shape)
    return selected_targets
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        all_encoder_layers, _ = self.bert_model(input_ids, input_mask, token_type_ids)
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]


        if offsets is None:
            return mix
        else:
            batch_size = input_ids.size(0)
            range_vector = util.get_range_vector(batch_size,
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            return mix[range_vector, offsets]
Exemple #39
0
 def _ngrams(self,
             tensor: torch.LongTensor,
             ngram_size: int) -> Dict[Tuple[int, ...], int]:
     ngram_counts: Dict[Tuple[int, ...], int] = Counter()
     if ngram_size > tensor.size(-1):
         return ngram_counts
     for start_position in range(ngram_size):
         for tensor_slice in tensor[start_position:].split(ngram_size, dim=-1):
             if tensor_slice.size(-1) < ngram_size:
                 break
             ngram = tuple(x.item() for x in tensor_slice)
             if any(x in self._exclude_indices for x in ngram):
                 continue
             ngram_counts[ngram] += 1
     return ngram_counts
Exemple #40
0
    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]
Exemple #41
0
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
                                       targets: torch.LongTensor,
                                       weights: torch.FloatTensor,
                                       batch_average: bool = True,
                                       label_smoothing: float = None) -> torch.FloatTensor:
    """
    Computes the cross entropy loss of a sequence, weighted with respect to
    some user provided weights. Note that the weighting here is not the same as
    in the :func:`torch.nn.CrossEntropyLoss()` criterion, which is weighting
    classes; here we are weighting the loss contribution from particular elements
    in the sequence. This allows loss computations for models which use padding.

    Parameters
    ----------
    logits : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes)
        which contains the unnormalized probability for each class.
    targets : ``torch.LongTensor``, required.
        A ``torch.LongTensor`` of size (batch, sequence_length) which contains the
        index of the true class for each corresponding step.
    weights : ``torch.FloatTensor``, required.
        A ``torch.FloatTensor`` of size (batch, sequence_length)
    batch_average : bool, optional, (default = True).
        A bool indicating whether the loss should be averaged across the batch,
        or returned as a vector of losses per batch element.
    label_smoothing : ``float``, optional (default = None)
        Whether or not to apply label smoothing to the cross-entropy loss.
        For example, with a label smoothing value of 0.2, a 4 class classifcation
        target would look like ``[0.05, 0.05, 0.85, 0.05]`` if the 3rd class was
        the correct label.

    Returns
    -------
    A torch.FloatTensor representing the cross entropy loss.
    If ``batch_average == True``, the returned loss is a scalar.
    If ``batch_average == False``, the returned loss is a vector of shape (batch_size,).

    """
    # shape : (batch * sequence_length, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # shape : (batch * sequence_length, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # shape : (batch * max_len, 1)
    targets_flat = targets.view(-1, 1).long()

    if label_smoothing is not None and label_smoothing > 0.0:
        num_classes = logits.size(-1)
        smoothing_value = label_smoothing / num_classes
        # Fill all the correct indices with 1 - smoothing value.
        one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing)
        smoothed_targets = one_hot_targets + smoothing_value
        negative_log_likelihood_flat = - log_probs_flat * smoothed_targets
        negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
    else:
        # Contribution to the negative log likelihood only comes from the exact indices
        # of the targets, as the target distributions are one-hot. Here we use torch.gather
        # to extract the indices of the num_classes dimension which contribute to the loss.
        # shape : (batch * sequence_length, 1)
        negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat)
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size())
    # shape : (batch, sequence_length)
    negative_log_likelihood = negative_log_likelihood * weights.float()
    # shape : (batch_size,)
    per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13)

    if batch_average:
        num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences
    return per_batch_loss
Exemple #42
0
 def _get_valid_tokens_mask(self, tensor: torch.LongTensor) -> torch.ByteTensor:
     valid_tokens_mask = torch.ones(tensor.size(), dtype=torch.uint8)
     for index in self._exclude_indices:
         valid_tokens_mask = valid_tokens_mask & (tensor != index)
     return valid_tokens_mask
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                token_type_ids=util.combine_initial_dims(token_type_ids),
                                                attention_mask=util.combine_initial_dims(input_mask))
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(offsets2d.size(0),
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings, offsets.size())